use std::time::Duration;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use serde::Serialize;
use crate::config::RateLimitConfig;
use crate::state::AppState;
use gradatum_warden::{WardenConfig, WardenLayer};
pub fn build_warden_layer(cfg: &RateLimitConfig) -> Option<WardenLayer> {
if !cfg.enabled {
return None;
}
let warden_cfg = WardenConfig {
enabled: true,
rate_limit_per_minute: cfg.per_minute,
rate_limit_burst: cfg.burst,
bypass_loopback: cfg.exempt_localhost,
ip_allow: vec![],
ip_deny: vec![],
};
Some(WardenLayer::new(warden_cfg).expect(
"config warden invalide — per_minute et burst doivent être > 0, \
garantis par RateLimitConfig::default() (60, 10)",
))
}
const REVOCATION_CHECK_TIMEOUT: Duration = Duration::from_millis(200);
#[derive(Serialize)]
struct MiddlewareError {
error: &'static str,
}
fn unauthorized(msg: &'static str) -> Response {
(StatusCode::UNAUTHORIZED, axum::Json(MiddlewareError { error: msg })).into_response()
}
pub async fn auth_middleware(
axum::extract::State(state): axum::extract::State<AppState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let (trust, maybe_jti) = extract_trust(&state, &request);
if let Some(jti) = maybe_jti {
match tokio::time::timeout(
REVOCATION_CHECK_TIMEOUT,
state.revocation.is_revoked(jti.as_str()),
)
.await
{
Ok(Ok(true)) => {
tracing::debug!(jti = %jti, "token JWT révoqué — 401");
return unauthorized("token révoqué");
}
Ok(Ok(false)) => {
}
Ok(Err(e)) => {
tracing::error!(
err = %e,
"revocation store error — fail-closed (401)"
);
return unauthorized("erreur de vérification du token — réessayer plus tard");
}
Err(_timeout) => {
tracing::error!(
timeout_ms = REVOCATION_CHECK_TIMEOUT.as_millis(),
"revocation check timeout — fail-closed (401)"
);
return unauthorized("erreur de vérification du token — réessayer plus tard");
}
}
}
request.extensions_mut().insert(trust);
next.run(request).await
}
fn extract_trust(
state: &AppState,
request: &Request<Body>,
) -> (gradatum_core::trust::TrustContext, Option<String>) {
let header_value = match request.headers().get(axum::http::header::AUTHORIZATION) {
Some(v) => v,
None => return (gradatum_core::trust::TrustContext::Unauthenticated, None),
};
let raw = match header_value.to_str() {
Ok(s) => s,
Err(_) => {
tracing::debug!("Authorization header contient des octets non-UTF-8 — ignoré");
return (gradatum_core::trust::TrustContext::Unauthenticated, None);
}
};
let token = match raw.strip_prefix("Bearer ") {
Some(t) if !t.is_empty() => t,
_ => return (gradatum_core::trust::TrustContext::Unauthenticated, None),
};
match state.jwt.verify(token) {
Ok(claims) => {
tracing::debug!(
sub = %claims.sub,
tenant = %claims.tenant_id,
"JWT vérifié avec succès"
);
let jti = claims.jti.clone();
let trust = gradatum_core::trust::TrustContext::BearerToken {
kid: state.jwt.kid().to_string(),
aud: claims.aud,
sub: claims.sub,
scopes: claims.scopes,
tenant_id: claims.tenant_id,
};
(trust, Some(jti))
}
Err(e) => {
tracing::debug!(err = %e, "JWT invalide — TrustContext::Unauthenticated");
(gradatum_core::trust::TrustContext::Unauthenticated, None)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use axum::Router;
use tower::ServiceExt;
use gradatum_auth::jwt::TokenScope;
use gradatum_auth::revocation::{RevocationError, RevocationStore};
use crate::state::AppState;
fn make_state() -> AppState {
AppState::new()
}
fn make_state_with_revocation(store: Arc<dyn RevocationStore>) -> AppState {
let mut state = AppState::new();
state.revocation = store;
state
}
fn sign_token(state: &AppState, sub: &str) -> (String, String) {
let token = state
.jwt
.sign(sub, &["read".to_string()], TokenScope::Service, "main")
.expect("sign doit réussir avec une clé éphémère valide");
let claims = state.jwt.verify(&token).expect("verify immédiat ne peut pas échouer");
(token, claims.jti)
}
async fn handler_ok() -> StatusCode {
StatusCode::OK
}
fn test_router(state: AppState) -> Router {
Router::new()
.route("/test", get(handler_ok))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
crate::middleware::auth_middleware,
))
.with_state(state)
}
async fn send_request(router: Router, bearer: Option<&str>) -> StatusCode {
let mut builder = Request::builder().method("GET").uri("/test");
if let Some(token) = bearer {
builder = builder.header("Authorization", format!("Bearer {token}"));
}
let req = builder.body(Body::empty()).expect("request builder invariant");
let resp = router.oneshot(req).await.expect("handler ne doit pas paniquer");
resp.status()
}
#[tokio::test]
async fn test_valid_token_not_revoked_passes() {
let state = make_state();
let (token, _jti) = sign_token(&state, "user-test");
let router = test_router(state);
let status = send_request(router, Some(&token)).await;
assert_eq!(status, StatusCode::OK);
}
#[tokio::test]
async fn test_revoked_token_returns_401() {
let state = make_state();
let (token, jti) = sign_token(&state, "user-revoked");
let exp = SystemTime::now() + Duration::from_secs(86400);
state
.revocation
.revoke(&jti, exp)
.await
.expect("revoke doit réussir sur InMemoryRevocationStore");
let router = test_router(state);
let status = send_request(router, Some(&token)).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
struct AlwaysErrorStore;
#[async_trait::async_trait]
impl RevocationStore for AlwaysErrorStore {
async fn is_revoked(&self, _jti: &str) -> Result<bool, RevocationError> {
Err(RevocationError::Sqlite(sqlx::Error::RowNotFound))
}
async fn revoke(&self, _jti: &str, _exp: SystemTime) -> Result<(), RevocationError> {
Ok(())
}
async fn gc(&self) -> Result<usize, RevocationError> {
Ok(0)
}
}
#[tokio::test]
async fn test_store_error_returns_401_fail_closed() {
let error_store = Arc::new(AlwaysErrorStore);
let state = make_state_with_revocation(error_store);
let (token, _jti) = sign_token(&state, "user-store-err");
let router = test_router(state);
let status = send_request(router, Some(&token)).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
struct SlowStore;
#[async_trait::async_trait]
impl RevocationStore for SlowStore {
async fn is_revoked(&self, _jti: &str) -> Result<bool, RevocationError> {
tokio::time::sleep(Duration::from_millis(300)).await;
Ok(false)
}
async fn revoke(&self, _jti: &str, _exp: SystemTime) -> Result<(), RevocationError> {
Ok(())
}
async fn gc(&self) -> Result<usize, RevocationError> {
Ok(0)
}
}
#[tokio::test]
async fn test_timeout_returns_401_fail_closed() {
let slow_store = Arc::new(SlowStore);
let state = make_state_with_revocation(slow_store);
let (token, _jti) = sign_token(&state, "user-slow");
let router = test_router(state);
let status = send_request(router, Some(&token)).await;
assert_eq!(status, StatusCode::UNAUTHORIZED);
}
}