use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
#[derive(Clone)]
pub struct OidcAuthState {
pub validator: Arc<OidcValidator>,
}
impl OidcAuthState {
#[must_use]
pub fn new(validator: Arc<OidcValidator>) -> Self {
Self { validator }
}
}
#[derive(Clone, Debug)]
pub struct AuthUser(pub AuthenticatedUser);
pub async fn oidc_auth_middleware(
State(auth_state): State<OidcAuthState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
None => {
if auth_state.validator.is_required() {
tracing::debug!("Authentication required but no Authorization header");
return (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
)],
"Authentication required",
)
.into_response();
}
next.run(request).await
},
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
tracing::debug!("Invalid Authorization header format");
return (
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, "Bearer error=\"invalid_request\"".to_string())],
"Invalid Authorization header format",
)
.into_response();
}
let token = &header_value[7..];
match auth_state.validator.validate_token(token).await {
Ok(user) => {
tracing::debug!(
user_id = %user.user_id,
scopes = ?user.scopes,
"User authenticated successfully"
);
request.extensions_mut().insert(AuthUser(user));
next.run(request).await
},
Err(e) => {
tracing::debug!(error = %e, "Token validation failed");
let error_description = match &e {
fraiseql_core::security::SecurityError::TokenExpired { .. } => {
"Bearer error=\"invalid_token\", error_description=\"Token has expired\""
},
fraiseql_core::security::SecurityError::InvalidToken => {
"Bearer error=\"invalid_token\", error_description=\"Token is invalid\""
},
_ => "Bearer error=\"invalid_token\"",
};
(
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, error_description.to_string())],
"Invalid or expired token",
)
.into_response()
},
}
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_auth_user_clone() {
use chrono::Utc;
let user = AuthenticatedUser {
user_id: "user123".to_string(),
scopes: vec!["read".to_string()],
expires_at: Utc::now(),
};
let auth_user = AuthUser(user.clone());
let cloned = auth_user.clone();
assert_eq!(auth_user.0.user_id, cloned.0.user_id);
}
#[test]
fn test_oidc_auth_state_clone() {
fn assert_clone<T: Clone>() {}
assert_clone::<OidcAuthState>();
}
}