use crate::accounts::Account;
use crate::authz::{AccessHierarchy, AccessPolicy, AuthorizationService};
use crate::codecs::Codec;
use crate::codecs::jwt::{JwtClaims, JwtValidationResult, JwtValidationService, RegisteredClaims};
use std::convert::Infallible;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use std::time::Instant;
#[cfg(feature = "audit-logging")]
use crate::audit;
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use crate::audit::prometheus_metrics::{JwtValidationOutcome, observe_jwt_validation_latency};
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
use crate::audit::{AuthzOutcome, observe_authz_latency};
use crate::cookie_template::CookieTemplate;
use axum::{body::Body, extract::Request, http::Response};
use axum_extra::extract::cookie::{Cookie, CookieJar};
use http::StatusCode;
use tower::Service;
use tracing::{debug, trace, warn};
#[derive(Debug, Clone)]
pub struct CookieGateService<C, R, G, S>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
inner: S,
authorization_service: AuthorizationService<R, G>,
jwt_validation_service: JwtValidationService<C>,
cookie_template: CookieTemplate,
install_optional_extensions: bool,
}
impl<C, R, G, S> CookieGateService<C, R, G, S>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
pub fn new(
inner: S,
issuer: &str,
policy: AccessPolicy<R, G>,
codec: Arc<C>,
cookie_template: CookieTemplate,
) -> Self {
Self {
inner,
authorization_service: AuthorizationService::new(policy),
jwt_validation_service: JwtValidationService::new(codec, issuer),
cookie_template,
install_optional_extensions: false,
}
}
pub fn new_with_optional_extensions(
inner: S,
issuer: &str,
codec: Arc<C>,
cookie_template: CookieTemplate,
) -> Self {
Self {
inner,
authorization_service: AuthorizationService::new(AccessPolicy::deny_all()),
jwt_validation_service: JwtValidationService::new(codec, issuer),
cookie_template,
install_optional_extensions: true,
}
}
}
impl<C, R, G, S> CookieGateService<C, R, G, S>
where
C: Codec,
R: AccessHierarchy + Eq + std::fmt::Display,
G: Eq + Clone,
{
pub fn auth_cookie(&self, req: &Request<Body>) -> Option<Cookie<'_>> {
let cookie_jar = CookieJar::from_headers(req.headers());
cookie_jar
.get(self.cookie_template.cookie_name_ref())
.cloned()
}
#[allow(clippy::unwrap_used)]
fn unauthorized() -> Response<Body> {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Body::from("Unauthorized"))
.unwrap()
}
}
impl<C, R, G, S> Service<Request<Body>> for CookieGateService<C, R, G, S>
where
S: Service<Request<Body>, Response = Response<Body>, Error = Infallible> + Send + 'static,
S::Future: Send + 'static,
Account<R, G>: Clone,
C: Codec<Payload = JwtClaims<Account<R, G>>>,
R: AccessHierarchy + Eq + std::fmt::Display + Sync + Send + 'static,
G: Eq + Clone + Sync + Send + 'static,
{
type Response = Response<Body>;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let unauthorized_future = Box::pin(async move { Ok(Self::unauthorized()) });
#[cfg(feature = "audit-logging")]
let _audit_request_span =
audit::request_span(req.method().as_str(), req.uri().path(), None);
#[cfg(feature = "audit-logging")]
let _audit_request_enter = _audit_request_span.enter();
if self.install_optional_extensions {
let mut opt_account: Option<Account<R, G>> = None;
let mut opt_reg_claims: Option<RegisteredClaims> = None;
if let Some(auth_cookie) = self.auth_cookie(&req) {
trace!("axum-gate (optional) cookie: {auth_cookie:#?}");
let cookie_value = auth_cookie.value_trimmed();
if let JwtValidationResult::Valid(jwt) =
self.jwt_validation_service.validate_token(cookie_value)
{
opt_account = Some(jwt.custom_claims.clone());
opt_reg_claims = Some(jwt.registered_claims.clone());
} else {
debug!("Optional mode: invalid or mismatched JWT, installing None extensions");
}
} else {
trace!("Optional mode: no auth cookie present; installing None extensions");
}
req.extensions_mut().insert(opt_account);
req.extensions_mut().insert(opt_reg_claims);
let inner = self.inner.call(req);
return Box::pin(inner);
}
if self.authorization_service.policy_denies_all_access() {
debug!("Denying access because roles, groups or permissions are empty.");
#[cfg(feature = "audit-logging")]
{
audit::denied(None, "policy_denies_all");
}
return unauthorized_future;
}
let Some(auth_cookie) = self.auth_cookie(&req) else {
#[cfg(feature = "audit-logging")]
{
audit::denied(None, "missing_cookie");
}
return unauthorized_future;
};
trace!("axum-gate cookie: {auth_cookie:#?}");
let cookie_value = auth_cookie.value_trimmed();
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
let jwt_validation_start = Instant::now();
let jwt = match self.jwt_validation_service.validate_token(cookie_value) {
JwtValidationResult::Valid(jwt) => {
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
observe_jwt_validation_latency(jwt_validation_start, JwtValidationOutcome::Valid);
jwt
}
JwtValidationResult::InvalidToken => {
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
observe_jwt_validation_latency(
jwt_validation_start,
JwtValidationOutcome::InvalidToken,
);
debug!("JWT token validation failed");
#[cfg(feature = "audit-logging")]
{
audit::jwt_invalid_token("validation_failed");
}
return unauthorized_future;
}
JwtValidationResult::InvalidIssuer { expected, actual } => {
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
observe_jwt_validation_latency(
jwt_validation_start,
JwtValidationOutcome::InvalidIssuer,
);
warn!(
"JWT issuer validation failed. Expected: '{}', Actual: '{}', Account: {}",
expected, actual, "unknown"
);
#[cfg(feature = "audit-logging")]
{
audit::jwt_invalid_issuer(&expected, &actual);
}
return unauthorized_future;
}
};
debug!("Logged in with id: {}", jwt.custom_claims.account_id);
#[cfg(feature = "audit-logging")]
let _authz_span = audit::authorization_span(Some(&jwt.custom_claims.account_id), None);
#[cfg(feature = "audit-logging")]
let _authz_enter = _authz_span.enter();
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
let authz_start = std::time::Instant::now();
let account = &jwt.custom_claims;
let is_authorized = self.authorization_service.is_authorized(account);
if !is_authorized {
#[cfg(feature = "audit-logging")]
{
audit::denied(Some(&jwt.custom_claims.account_id), "policy_denied");
}
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
observe_authz_latency(authz_start, AuthzOutcome::Denied);
return unauthorized_future;
}
#[cfg(feature = "audit-logging")]
{
audit::authorized(&jwt.custom_claims.account_id, None);
}
#[cfg(all(feature = "audit-logging", feature = "prometheus"))]
observe_authz_latency(authz_start, AuthzOutcome::Authorized);
req.extensions_mut().insert(jwt.custom_claims.clone());
req.extensions_mut().insert(jwt.registered_claims.clone());
let inner = self.inner.call(req);
Box::pin(inner)
}
}