use std::sync::Arc;
use axum::{
body::Body,
extract::State,
http::{Request, StatusCode, header},
middleware::Next,
response::{IntoResponse, Response},
};
use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
#[derive(Clone)]
pub struct OidcAuthState {
pub validator: Arc<OidcValidator>,
}
impl OidcAuthState {
#[must_use]
pub const fn new(validator: Arc<OidcValidator>) -> Self {
Self { validator }
}
}
#[derive(Clone, Debug)]
pub struct AuthUser(pub AuthenticatedUser);
pub(crate) fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
headers.get(header::COOKIE).and_then(|v| v.to_str().ok()).and_then(|cookies| {
cookies.split(';').find_map(|part| {
let part = part.trim();
part.strip_prefix("__Host-access_token=")
.map(|v| v.trim_matches('"').to_owned())
})
})
}
#[allow(clippy::cognitive_complexity)] pub async fn oidc_auth_middleware(
State(auth_state): State<OidcAuthState>,
mut request: Request<Body>,
next: Next,
) -> Response {
let token_string: Option<String> = {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|value| value.to_str().ok());
match auth_header {
Some(header_value) => {
if !header_value.starts_with("Bearer ") {
tracing::debug!("Invalid Authorization header format");
return (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
"Bearer error=\"invalid_request\"".to_string(),
)],
"Invalid Authorization header format",
)
.into_response();
}
Some(header_value[7..].to_owned())
},
None => extract_access_token_cookie(request.headers()),
}
};
match token_string {
None => {
if auth_state.validator.is_required() {
tracing::debug!("Authentication required but no token found (header or cookie)");
return (
StatusCode::UNAUTHORIZED,
[(
header::WWW_AUTHENTICATE,
format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
)],
"Authentication required",
)
.into_response();
}
next.run(request).await
},
Some(token) => {
match auth_state.validator.validate_token(&token).await {
Ok(user) => {
tracing::debug!(
user_id = %user.user_id,
scopes = ?user.scopes,
"User authenticated successfully"
);
request.extensions_mut().insert(AuthUser(user));
next.run(request).await
},
Err(e) => {
tracing::debug!(error = %e, "Token validation failed");
let (www_authenticate, body) = match &e {
fraiseql_core::security::SecurityError::TokenExpired { .. } => (
"Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
"Token has expired",
),
fraiseql_core::security::SecurityError::InvalidToken => (
"Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
"Token is invalid",
),
_ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
};
(
StatusCode::UNAUTHORIZED,
[(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
body,
)
.into_response()
},
}
},
}
}