Skip to main content

fraiseql_server/middleware/
oidc_auth.rs

1//! OIDC Authentication Middleware
2//!
3//! Provides JWT authentication for GraphQL endpoints using OIDC discovery.
4//! Supports Auth0, Keycloak, Okta, Cognito, Azure AD, and any OIDC-compliant provider.
5
6use std::sync::Arc;
7
8use axum::{
9    body::Body,
10    extract::State,
11    http::{Request, StatusCode, header},
12    middleware::Next,
13    response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17/// State for OIDC authentication middleware.
18#[derive(Clone)]
19pub struct OidcAuthState {
20    /// The OIDC validator.
21    pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25    /// Create new OIDC auth state.
26    #[must_use]
27    pub const fn new(validator: Arc<OidcValidator>) -> Self {
28        Self { validator }
29    }
30}
31
32/// Request extension containing the authenticated user.
33///
34/// After authentication middleware runs, handlers can extract this
35/// to access the authenticated user information.
36#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39/// OIDC authentication middleware.
40///
41/// Validates JWT tokens from the Authorization header using OIDC/JWKS.
42///
43/// # Behavior
44///
45/// - If auth is required and no token: returns 401 Unauthorized
46/// - If token is invalid/expired: returns 401 Unauthorized
47/// - If token is valid: adds `AuthUser` to request extensions
48/// - If auth is optional and no token: allows request through (no `AuthUser`)
49///
50/// # Example
51///
52/// ```text
53/// // Requires: OIDC provider reachable for JWKS discovery, running Axum application.
54/// use axum::{middleware, Router};
55///
56/// let oidc_state = OidcAuthState::new(validator);
57/// let app = Router::new()
58///     .route("/graphql", post(graphql_handler))
59///     .layer(middleware::from_fn_with_state(oidc_state, oidc_auth_middleware));
60/// ```
61#[allow(clippy::cognitive_complexity)] // Reason: OIDC authentication middleware with token parsing, validation, and claims extraction
62pub async fn oidc_auth_middleware(
63    State(auth_state): State<OidcAuthState>,
64    mut request: Request<Body>,
65    next: Next,
66) -> Response {
67    // Extract Authorization header
68    let auth_header = request
69        .headers()
70        .get(header::AUTHORIZATION)
71        .and_then(|value| value.to_str().ok());
72
73    match auth_header {
74        None => {
75            // No authorization header
76            if auth_state.validator.is_required() {
77                tracing::debug!("Authentication required but no Authorization header");
78                return (
79                    StatusCode::UNAUTHORIZED,
80                    [(
81                        header::WWW_AUTHENTICATE,
82                        format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
83                    )],
84                    "Authentication required",
85                )
86                    .into_response();
87            }
88            // Auth is optional, continue without user context
89            next.run(request).await
90        },
91        Some(header_value) => {
92            // Extract bearer token
93            if !header_value.starts_with("Bearer ") {
94                tracing::debug!("Invalid Authorization header format");
95                return (
96                    StatusCode::UNAUTHORIZED,
97                    [(header::WWW_AUTHENTICATE, "Bearer error=\"invalid_request\"".to_string())],
98                    "Invalid Authorization header format",
99                )
100                    .into_response();
101            }
102
103            let token = &header_value[7..];
104
105            // Validate token
106            match auth_state.validator.validate_token(token).await {
107                Ok(user) => {
108                    tracing::debug!(
109                        user_id = %user.user_id,
110                        scopes = ?user.scopes,
111                        "User authenticated successfully"
112                    );
113                    // Add authenticated user to request extensions
114                    request.extensions_mut().insert(AuthUser(user));
115                    next.run(request).await
116                },
117                Err(e) => {
118                    tracing::debug!(error = %e, "Token validation failed");
119                    let (www_authenticate, body) = match &e {
120                        fraiseql_core::security::SecurityError::TokenExpired { .. } => (
121                            "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
122                            "Token has expired",
123                        ),
124                        fraiseql_core::security::SecurityError::InvalidToken => (
125                            "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
126                            "Token is invalid",
127                        ),
128                        _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
129                    };
130                    (
131                        StatusCode::UNAUTHORIZED,
132                        [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
133                        body,
134                    )
135                        .into_response()
136                },
137            }
138        },
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_auth_user_clone() {
148        use chrono::Utc;
149
150        let user = AuthenticatedUser {
151            user_id:    "user123".to_string(),
152            scopes:     vec!["read".to_string()],
153            expires_at: Utc::now(),
154        };
155
156        let auth_user = AuthUser(user);
157        let cloned = auth_user.clone();
158
159        assert_eq!(auth_user.0.user_id, cloned.0.user_id);
160    }
161
162    #[test]
163    fn test_oidc_auth_state_clone() {
164        // Can't easily test without a real validator, but we can verify Clone is implemented
165        // by verifying the type compiles with Clone trait bound
166        fn assert_clone<T: Clone>() {}
167        assert_clone::<OidcAuthState>();
168    }
169}