use crate::crypto::{PublicKey, SignedMessage};
use crate::error::Error::{self, *};
use super::AccessToken;
pub struct TokenValidator {
public_key: PublicKey,
}
#[derive(Clone, Copy)]
pub struct ValidationConfig {
pub check_expiration: bool,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
check_expiration: true,
}
}
}
impl TokenValidator {
pub fn new(public_key: PublicKey) -> Self {
Self { public_key }
}
pub fn validate<A: AccessToken, T: AsRef<[u8]>>(&self, token: T) -> Result<A, Error> {
self.validate_config(
token,
ValidationConfig {
check_expiration: true,
},
)
}
pub fn validate_config<A: AccessToken, T: AsRef<[u8]>>(
&self,
token: T,
config: ValidationConfig,
) -> Result<A, Error> {
let signed_message = SignedMessage::decode(token).ok_or(InvalidSignedMessage)?;
if !signed_message.verify(&self.public_key) {
return Err(SignatureVerificationFail);
}
let access_token =
A::from_bytes(signed_message.message()).map_err(|_| InvalidAccessToken)?;
if config.check_expiration {
if access_token.is_expired() {
return Err(ExpiredAccessToken);
}
}
Ok(access_token)
}
}
#[cfg(test)]
macro_rules! assert_auth_error {
($exp:expr, $err:path) => {
assert!(
matches!($exp, Err($err)),
concat!("Expect Err(", stringify!($err), ") but found {:?}"),
$exp
);
};
}
#[cfg(test)]
mod tests {
use crate::crypto::tests::{get_test_private_key, get_test_public_key};
use crate::crypto::PrivateKey;
use crate::rbac::test_helpers::TestPermission::{Permission1, Permission2};
use crate::token::test_utils::TestAccessToken;
use super::*;
type ValidateResult = Result<TestAccessToken, Error>;
fn create_access_token_with_key(token: TestAccessToken, private_key: &PrivateKey) -> String {
SignedMessage::create(token.to_bytes(), &private_key).encode()
}
fn create_access_token(token: TestAccessToken) -> String {
let private_key = PrivateKey::from_base64(&get_test_private_key()).unwrap();
create_access_token_with_key(token, &private_key)
}
fn make_validator() -> TokenValidator {
TokenValidator::new(PublicKey::from_base64(&get_test_public_key()).unwrap())
}
#[test]
fn test_no_token() {
let validator = make_validator();
let x: ValidateResult = validator.validate("");
assert_auth_error!(x, InvalidSignedMessage);
}
#[test]
fn test_bad_token() {
let validator = make_validator();
let x: ValidateResult = validator.validate("123");
assert_auth_error!(x, InvalidSignedMessage);
}
#[test]
fn test_sign_by_other_keys() {
let private_key_other =
PrivateKey::from_base64("B1H3hDtRa0K0XxPC2tjD8uj2Tx3i9RlsQ7jSpl4OOIY").unwrap();
let _public_key_other =
PublicKey::from_base64("uneKfdOZUuupqMK7q1KwPFluM9zxpdIlyNntF4V1Dgs").unwrap();
let validator = make_validator();
let token = TestAccessToken::new(vec![Permission1, Permission2].into(), false);
let access_token = create_access_token_with_key(token, &private_key_other);
let x: ValidateResult = validator.validate(access_token);
assert_auth_error!(x, SignatureVerificationFail);
}
#[test]
fn test_access_token() {
let validator = make_validator();
let token = create_access_token(TestAccessToken::new(vec![Permission1].into(), true));
let x: ValidateResult = validator.validate(token);
assert_auth_error!(x, ExpiredAccessToken);
}
#[test]
fn test_default_validation_config() {
let config = <ValidationConfig as Default>::default();
assert!(config.check_expiration, "default must check expiration");
}
}