use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
#[derive(Clone)]
pub struct OidcAuthState {
pub validator: Arc<OidcValidator>,
}
impl OidcAuthState {
#[must_use]
pub const fn new(validator: Arc<OidcValidator>) -> Self {
Self { validator }
}
}
#[derive(Clone, Debug)]
pub struct AuthUser(pub AuthenticatedUser);
fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
headers.get(header::COOKIE).and_then(|v| v.to_str().ok()).and_then(|cookies| {
cookies.split(';').find_map(|part| {
let part = part.trim();
part.strip_prefix("__Host-access_token=")
.map(|v| v.trim_matches('"').to_owned())
})
})
}
#[allow(clippy::cognitive_complexity)] pub async fn oidc_auth_middleware(
State(auth_state): State<OidcAuthState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let token_string: Option<String> = {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
tracing::debug!("Invalid Authorization header format");
return (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
"Bearer error=\"invalid_request\"".to_string(),
)],
"Invalid Authorization header format",
)
.into_response();
}
Some(header_value[7..].to_owned())
},
None => extract_access_token_cookie(request.headers()),
}
};
match token_string {
None => {
if auth_state.validator.is_required() {
tracing::debug!("Authentication required but no token found (header or cookie)");
return (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
)],
"Authentication required",
)
.into_response();
}
next.run(request).await
},
Some(token) => {
match auth_state.validator.validate_token(&token).await {
Ok(user) => {
tracing::debug!(
user_id = %user.user_id,
scopes = ?user.scopes,
"User authenticated successfully"
);
request.extensions_mut().insert(AuthUser(user));
next.run(request).await
},
Err(e) => {
tracing::debug!(error = %e, "Token validation failed");
let (www_authenticate, body) = match &e {
fraiseql_core::security::SecurityError::TokenExpired { .. } => (
"Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
"Token has expired",
),
fraiseql_core::security::SecurityError::InvalidToken => (
"Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
"Token is invalid",
),
_ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
};
(
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
body,
)
.into_response()
},
}
},
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used)]
use super::*;
#[test]
fn test_auth_user_clone() {
use chrono::Utc;
let user = AuthenticatedUser {
user_id: "user123".to_string(),
scopes: vec!["read".to_string()],
expires_at: Utc::now(),
extra_claims: std::collections::HashMap::new(),
};
let auth_user = AuthUser(user);
let cloned = auth_user.clone();
assert_eq!(auth_user.0.user_id, cloned.0.user_id);
}
#[test]
fn test_oidc_auth_state_clone() {
fn assert_clone<T: Clone>() {}
assert_clone::<OidcAuthState>();
}
#[test]
fn test_cookie_fallback_extracts_token() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
header::COOKIE,
"__Host-access_token=my.jwt.token; Path=/; SameSite=Strict".parse().unwrap(),
);
let token = extract_access_token_cookie(&headers);
assert_eq!(token.as_deref(), Some("my.jwt.token"));
}
#[test]
fn test_cookie_fallback_strips_rfc6265_quotes() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(header::COOKIE, "__Host-access_token=\"my.jwt.token\"".parse().unwrap());
let token = extract_access_token_cookie(&headers);
assert_eq!(token.as_deref(), Some("my.jwt.token"));
}
#[test]
fn test_cookie_fallback_absent_returns_none() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(header::COOKIE, "session=abc; other=xyz".parse().unwrap());
let token = extract_access_token_cookie(&headers);
assert!(token.is_none());
}
#[test]
fn test_cookie_fallback_no_cookie_header_returns_none() {
let headers = axum::http::HeaderMap::new();
let token = extract_access_token_cookie(&headers);
assert!(token.is_none());
}
#[test]
fn test_cookie_fallback_multiple_cookies_finds_correct_one() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
header::COOKIE,
"session=abc; __Host-access_token=correct.token; csrf=xyz".parse().unwrap(),
);
let token = extract_access_token_cookie(&headers);
assert_eq!(token.as_deref(), Some("correct.token"));
}
}