use std::marker::PhantomData;
use crate::crypto::PublicKey;
use crate::error::Error::{self, *};
use crate::message::SignedMessage;
use crate::rbac::PolicyCondition;
use super::{PolicyAccessToken, ToTokenStr};
pub struct ValidationAuthority<A> {
public_key: PublicKey,
_p: PhantomData<A>,
}
impl<A: PolicyAccessToken> ValidationAuthority<A> {
pub fn new(public_key: PublicKey) -> Self {
Self {
public_key,
_p: PhantomData,
}
}
fn decode_verify_check_expiration(&self, token: &str) -> Result<A, Error> {
let signed_message = SignedMessage::decode(token).ok_or(BadSignedMessageEncoding)?;
if !signed_message.verify(&self.public_key) {
return Err(SignatureVerificationFail);
}
let access_token =
A::from_bytes(signed_message.message()).map_err(|_| BadAccessTokenEncoding)?;
if access_token.is_expired() {
Err(ExpiredAccessToken)
} else {
Ok(access_token)
}
}
pub fn enforce(
&self,
condition: impl AsRef<PolicyCondition<A::Policy>>,
token: impl ToTokenStr,
) -> Result<A, Error> {
let token = token.to_token_str().ok_or(Unauthorized)?;
let access_token = self.decode_verify_check_expiration(token)?;
if condition.as_ref().satisfy(access_token.policies()) {
Ok(access_token)
} else {
Err(Forbidden)
}
}
pub fn to_access_enforcer(&self, token: impl ToTokenStr) -> Result<AccessEnforcer<A>, Error> {
let token = token.to_token_str().ok_or(Unauthorized)?;
self.decode_verify_check_expiration(token)
.map(AccessEnforcer::new)
}
}
#[derive(Clone)]
pub struct AccessEnforcer<A> {
access_token: A,
}
impl<A: PolicyAccessToken> AccessEnforcer<A> {
pub fn new(access_token: A) -> Self {
Self { access_token }
}
pub fn into_access_token(self) -> A {
self.access_token
}
pub fn enforce(&self, condition: impl AsRef<PolicyCondition<A::Policy>>) -> Result<&A, Error> {
if condition.as_ref().satisfy(self.access_token.policies()) {
Ok(&self.access_token)
} else {
Err(Forbidden)
}
}
}
#[cfg(test)]
macro_rules! assert_auth_error {
($exp:expr, $err:path) => {
assert!(
matches!($exp, Err($err)),
concat!("Expect Err(", stringify!($err), ") but found {:?}"),
$exp
);
};
}
#[cfg(test)]
mod tests {
use crate::crypto::tests::{get_test_private_key, get_test_public_key};
use crate::crypto::PrivateKey;
use crate::rbac::test_helpers::TestPolicy::{Policy1, Policy2};
use crate::token::test_utils::TestAccessToken;
use super::*;
fn create_access_token_with_key(token: TestAccessToken, private_key: &PrivateKey) -> String {
SignedMessage::create(token.to_bytes(), &private_key).encode()
}
fn create_access_token(token: TestAccessToken) -> String {
let private_key = PrivateKey::from_base64(&get_test_private_key()).unwrap();
create_access_token_with_key(token, &private_key)
}
fn make_va() -> ValidationAuthority<TestAccessToken> {
ValidationAuthority::new(PublicKey::from_base64(&get_test_public_key()).unwrap())
}
#[test]
fn test_no_token() {
let va = make_va();
let x = va.enforce(PolicyCondition::Nil, None::<&str>);
assert_auth_error!(x, Unauthorized);
}
#[test]
fn test_bad_token() {
let va = make_va();
let x = va.enforce(PolicyCondition::Nil, Some("123"));
assert_auth_error!(x, BadSignedMessageEncoding);
}
#[test]
fn test_sign_by_other_keys() {
let private_key_other =
PrivateKey::from_base64("B1H3hDtRa0K0XxPC2tjD8uj2Tx3i9RlsQ7jSpl4OOIY").unwrap();
let _public_key_other =
PublicKey::from_base64("uneKfdOZUuupqMK7q1KwPFluM9zxpdIlyNntF4V1Dgs").unwrap();
let va = make_va();
let token = TestAccessToken::new(vec![Policy1, Policy2].into(), false);
let access_token = create_access_token_with_key(token, &private_key_other);
let x = va.enforce(PolicyCondition::Nil, Some(access_token));
assert_auth_error!(x, SignatureVerificationFail);
}
#[test]
fn test_access_token() {
let va = make_va();
let token = create_access_token(TestAccessToken::new(vec![Policy1].into(), true));
let x = va.enforce(PolicyCondition::Nil, Some(token));
assert_auth_error!(x, ExpiredAccessToken);
let token = create_access_token(TestAccessToken::new(vec![].into(), false));
let x = va.enforce(PolicyCondition::Contains(Policy1), Some(token));
assert_auth_error!(x, Forbidden);
}
}