use std::fmt::Display;
use std::sync::Arc;
use super::GateExt;
use crate::accounts::Account;
use crate::authz::access_hierarchy::AccessHierarchy;
use crate::authz::access_policy::AccessPolicy;
use crate::authz::authorization_service::AuthorizationService;
use crate::codecs::Codec;
use crate::codecs::jwt::validation_result::JwtValidationResult;
use crate::codecs::jwt::validation_service::JwtValidationService;
use crate::codecs::jwt::{JwtClaims, RegisteredClaims};
use uuid::Uuid;
#[derive(Clone)]
pub struct JwtConfig<R, G>
where
R: AccessHierarchy + Eq + Display,
G: Eq,
{
policy: AccessPolicy<R, G>,
optional: bool,
}
impl<R, G> std::fmt::Debug for JwtConfig<R, G>
where
R: AccessHierarchy + Eq + Display,
G: Eq,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtConfig")
.field("optional", &self.optional)
.finish_non_exhaustive()
}
}
#[derive(Clone, Debug)]
pub struct StaticTokenConfig {
token: String,
optional: bool,
}
#[derive(Clone, Debug)]
pub struct BearerGate<C, R, G, M>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
G: Eq,
{
issuer: String,
codec: Arc<C>,
mode: M,
_phantom: std::marker::PhantomData<(R, G)>,
}
impl<C, R, G> BearerGate<C, R, G, JwtConfig<R, G>>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
G: Eq + Clone,
{
pub fn new_with_codec(issuer: &str, codec: Arc<C>) -> Self
where
R: Default,
{
Self {
issuer: issuer.to_string(),
codec,
mode: JwtConfig {
policy: AccessPolicy::deny_all(),
optional: false,
},
_phantom: std::marker::PhantomData,
}
}
pub fn with_policy(mut self, policy: AccessPolicy<R, G>) -> Self {
self.mode.policy = policy;
self
}
pub fn allow_anonymous_with_optional_user(mut self) -> Self {
self.mode.optional = true;
self
}
pub fn require_login(mut self) -> Self
where
R: Default,
{
let baseline = R::default();
self.mode.policy = AccessPolicy::require_role_or_supervisor(baseline);
self
}
pub fn with_static_token(
self,
token: impl Into<String>,
) -> BearerGate<C, R, G, StaticTokenConfig> {
BearerGate {
issuer: self.issuer,
codec: self.codec,
mode: StaticTokenConfig {
token: token.into(),
optional: false,
},
_phantom: std::marker::PhantomData,
}
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn codec(&self) -> &Arc<C> {
&self.codec
}
pub fn policy(&self) -> &AccessPolicy<R, G> {
&self.mode.policy
}
pub fn is_optional(&self) -> bool {
self.mode.optional
}
}
impl<C, R, G> BearerGate<C, R, G, StaticTokenConfig>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
G: Eq + Clone,
{
pub fn allow_anonymous_with_optional_user(mut self) -> Self {
self.mode.optional = true;
self
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn codec(&self) -> &Arc<C> {
&self.codec
}
pub fn token(&self) -> &str {
&self.mode.token
}
pub fn is_optional(&self) -> bool {
self.mode.optional
}
}
pub trait BearerGateAdapter<C, R, G, M>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
G: Eq,
{
type Output;
fn adapt(&self, gate: BearerGate<C, R, G, M>) -> Self::Output;
}
#[derive(Debug, Clone)]
pub enum BearerEvaluation<R, G>
where
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
JwtOptionalAnonymous,
JwtOptionalAuthorized {
account: Account<R, G>,
registered_claims: RegisteredClaims,
},
JwtMissingToken,
JwtInvalidToken,
JwtInvalidIssuer {
expected: String,
actual: String,
},
JwtDenyAllPolicy,
JwtPolicyDenied {
account_id: Uuid,
},
JwtAuthorized {
account: Account<R, G>,
registered_claims: RegisteredClaims,
},
StaticAuthorized,
StaticDenied,
StaticOptionalAuthorized {
matched: bool,
},
}
#[derive(Clone, Debug)]
pub struct JwtBearerRuntime<C, R, G>
where
C: Codec,
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
authorization_service: AuthorizationService<R, G>,
jwt_validation_service: JwtValidationService<C>,
optional: bool,
}
impl<C, R, G> JwtBearerRuntime<C, R, G>
where
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
pub fn new(issuer: &str, policy: AccessPolicy<R, G>, codec: Arc<C>, optional: bool) -> Self {
Self {
authorization_service: AuthorizationService::new(policy),
jwt_validation_service: JwtValidationService::new(codec, issuer),
optional,
}
}
pub fn evaluate(&self, token: Option<&str>) -> BearerEvaluation<R, G> {
if self.optional {
if let Some(token) = token
&& let JwtValidationResult::Valid(jwt) =
self.jwt_validation_service.validate_token(token)
{
return BearerEvaluation::JwtOptionalAuthorized {
account: jwt.custom_claims,
registered_claims: jwt.registered_claims,
};
}
return BearerEvaluation::JwtOptionalAnonymous;
}
if self.authorization_service.policy_denies_all_access() {
return BearerEvaluation::JwtDenyAllPolicy;
}
let Some(token) = token else {
return BearerEvaluation::JwtMissingToken;
};
match self.jwt_validation_service.validate_token(token) {
JwtValidationResult::Valid(jwt) => {
let account = jwt.custom_claims;
let registered_claims = jwt.registered_claims;
let account_id = account.account_id;
if self.authorization_service.is_authorized(&account) {
BearerEvaluation::JwtAuthorized {
account,
registered_claims,
}
} else {
BearerEvaluation::JwtPolicyDenied { account_id }
}
}
JwtValidationResult::InvalidToken => BearerEvaluation::JwtInvalidToken,
JwtValidationResult::InvalidIssuer { expected, actual } => {
BearerEvaluation::JwtInvalidIssuer { expected, actual }
}
}
}
}
#[derive(Clone, Debug)]
pub struct StaticTokenRuntime<R, G>
where
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
token: String,
optional: bool,
_phantom: std::marker::PhantomData<(R, G)>,
}
impl<R, G> StaticTokenRuntime<R, G>
where
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
pub fn new(token: impl Into<String>, optional: bool) -> Self {
Self {
token: token.into(),
optional,
_phantom: std::marker::PhantomData,
}
}
pub fn evaluate(&self, token: Option<&str>) -> BearerEvaluation<R, G> {
use subtle::ConstantTimeEq as _;
if self.optional {
let matched =
token.is_some_and(|t| bool::from(t.as_bytes().ct_eq(self.token.as_bytes())));
return BearerEvaluation::StaticOptionalAuthorized { matched };
}
if let Some(t) = token
&& bool::from(t.as_bytes().ct_eq(self.token.as_bytes()))
{
return BearerEvaluation::StaticAuthorized;
}
BearerEvaluation::StaticDenied
}
}
impl<C, R, G> BearerGate<C, R, G, JwtConfig<R, G>>
where
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq + Display + Clone,
G: Eq + Clone,
{
pub fn runtime(&self) -> JwtBearerRuntime<C, R, G>
where
R: Default,
{
JwtBearerRuntime::new(
&self.issuer,
self.mode.policy.clone(),
Arc::clone(&self.codec),
self.mode.optional,
)
}
}
impl<C, R, G> BearerGate<C, R, G, StaticTokenConfig>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
G: Eq + Clone,
{
pub fn runtime(&self) -> StaticTokenRuntime<R, G> {
StaticTokenRuntime::new(self.mode.token.clone(), self.mode.optional)
}
}
impl<C, R, Gt, M> GateExt for super::bearer::BearerGate<C, R, Gt, M>
where
C: Codec,
R: AccessHierarchy + Eq + Display,
Gt: Eq,
{
}
impl<C, R, Gt, M2, A> crate::gate::adapter::GateAdapter<BearerGate<C, R, Gt, M2>> for A
where
A: BearerGateAdapter<C, R, Gt, M2>,
C: Codec,
R: AccessHierarchy + Eq + Display,
Gt: Eq,
{
type Output = A::Output;
fn adapt(&self, gate: BearerGate<C, R, Gt, M2>) -> Self::Output {
A::adapt(self, gate)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::accounts::Account;
use crate::codecs::jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER as JWT_CRYPTO_PROVIDER;
use crate::codecs::jwt::{JsonWebToken, JwtClaims, RegisteredClaims};
use crate::groups::Group;
use crate::roles::Role;
use chrono::Utc;
fn install_jwt_crypto_provider() {
let _ = JWT_CRYPTO_PROVIDER.install_default();
}
#[test]
fn jwt_runtime_authorizes_when_policy_allows() -> Result<(), Box<dyn std::error::Error>> {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate = BearerGate::<_, Role, Group, JwtConfig<Role, Group>>::new_with_codec(
"issuer",
Arc::clone(&codec),
)
.require_login();
let account = Account::<Role, Group>::new("user");
let exp = Utc::now().timestamp() as u64 + 60;
let claims = JwtClaims::new(account.clone(), RegisteredClaims::new("issuer", exp));
let encoded = codec
.encode(&claims)
.map_err(|e| format!("encode jwt: {e}"))?;
let token = String::from_utf8(encoded).map_err(|e| format!("utf-8 decode: {e}"))?;
let runtime = gate.runtime();
let result = runtime.evaluate(Some(&token));
match result {
BearerEvaluation::JwtAuthorized {
account: acc,
registered_claims,
} => {
assert_eq!(acc.user_id, account.user_id);
assert_eq!(registered_claims.issuer, "issuer");
}
other => return Err(format!("expected JwtAuthorized, got {other:?}").into()),
}
Ok(())
}
#[test]
fn static_runtime_matches_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate = BearerGate::<_, Role, Group, JwtConfig<Role, Group>>::new_with_codec(
"issuer",
Arc::clone(&codec),
)
.with_static_token("secret-token");
let runtime = gate.runtime();
assert!(matches!(
runtime.evaluate(Some("secret-token")),
BearerEvaluation::StaticAuthorized
));
assert!(matches!(
runtime.evaluate(Some("wrong-token")),
BearerEvaluation::StaticDenied
));
assert!(matches!(
runtime.evaluate(None),
BearerEvaluation::StaticDenied
));
}
#[test]
fn static_runtime_rejects_last_byte_different_token() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate = BearerGate::<_, Role, Group, JwtConfig<Role, Group>>::new_with_codec(
"issuer",
Arc::clone(&codec),
)
.with_static_token("secret-tokenX");
let runtime = gate.runtime();
assert!(matches!(
runtime.evaluate(Some("secret-tokenY")),
BearerEvaluation::StaticDenied
));
assert!(matches!(
runtime.evaluate(Some("secret-tokenX")),
BearerEvaluation::StaticAuthorized
));
}
#[test]
fn static_optional_runtime_constant_time_comparison() {
install_jwt_crypto_provider();
let codec = Arc::new(JsonWebToken::<JwtClaims<Account<Role, Group>>>::default());
let gate = BearerGate::<_, Role, Group, JwtConfig<Role, Group>>::new_with_codec(
"issuer",
Arc::clone(&codec),
)
.with_static_token("my-secret")
.allow_anonymous_with_optional_user();
let runtime = gate.runtime();
assert!(matches!(
runtime.evaluate(Some("my-secret")),
BearerEvaluation::StaticOptionalAuthorized { matched: true }
));
assert!(matches!(
runtime.evaluate(Some("my-secreX")),
BearerEvaluation::StaticOptionalAuthorized { matched: false }
));
assert!(matches!(
runtime.evaluate(None),
BearerEvaluation::StaticOptionalAuthorized { matched: false }
));
}
}