Skip to main content

fraiseql_server/middleware/
oidc_auth.rs

1//! OIDC Authentication Middleware
2//!
3//! Provides JWT authentication for GraphQL endpoints using OIDC discovery.
4//! Supports Auth0, Keycloak, Okta, Cognito, Azure AD, and any OIDC-compliant provider.
5
6use std::sync::Arc;
7
8use axum::{
9    body::Body,
10    extract::State,
11    http::{Request, StatusCode, header},
12    middleware::Next,
13    response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17/// State for OIDC authentication middleware.
18#[derive(Clone)]
19pub struct OidcAuthState {
20    /// The OIDC validator.
21    pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25    /// Create new OIDC auth state.
26    #[must_use]
27    pub fn new(validator: Arc<OidcValidator>) -> Self {
28        Self { validator }
29    }
30}
31
32/// Request extension containing the authenticated user.
33///
34/// After authentication middleware runs, handlers can extract this
35/// to access the authenticated user information.
36#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39/// OIDC authentication middleware.
40///
41/// Validates JWT tokens from the Authorization header using OIDC/JWKS.
42///
43/// # Behavior
44///
45/// - If auth is required and no token: returns 401 Unauthorized
46/// - If token is invalid/expired: returns 401 Unauthorized
47/// - If token is valid: adds `AuthUser` to request extensions
48/// - If auth is optional and no token: allows request through (no AuthUser)
49///
50/// # Example
51///
52/// ```ignore
53/// use axum::{middleware, Router};
54///
55/// let oidc_state = OidcAuthState::new(validator);
56/// let app = Router::new()
57///     .route("/graphql", post(graphql_handler))
58///     .layer(middleware::from_fn_with_state(oidc_state, oidc_auth_middleware));
59/// ```
60pub async fn oidc_auth_middleware(
61    State(auth_state): State<OidcAuthState>,
62    mut request: Request<Body>,
63    next: Next,
64) -> Response {
65    // Extract Authorization header
66    let auth_header = request
67        .headers()
68        .get(header::AUTHORIZATION)
69        .and_then(|value| value.to_str().ok());
70
71    match auth_header {
72        None => {
73            // No authorization header
74            if auth_state.validator.is_required() {
75                tracing::debug!("Authentication required but no Authorization header");
76                return (
77                    StatusCode::UNAUTHORIZED,
78                    [(
79                        header::WWW_AUTHENTICATE,
80                        format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
81                    )],
82                    "Authentication required",
83                )
84                    .into_response();
85            }
86            // Auth is optional, continue without user context
87            next.run(request).await
88        },
89        Some(header_value) => {
90            // Extract bearer token
91            if !header_value.starts_with("Bearer ") {
92                tracing::debug!("Invalid Authorization header format");
93                return (
94                    StatusCode::UNAUTHORIZED,
95                    [(header::WWW_AUTHENTICATE, "Bearer error=\"invalid_request\"".to_string())],
96                    "Invalid Authorization header format",
97                )
98                    .into_response();
99            }
100
101            let token = &header_value[7..];
102
103            // Validate token
104            match auth_state.validator.validate_token(token).await {
105                Ok(user) => {
106                    tracing::debug!(
107                        user_id = %user.user_id,
108                        scopes = ?user.scopes,
109                        "User authenticated successfully"
110                    );
111                    // Add authenticated user to request extensions
112                    request.extensions_mut().insert(AuthUser(user));
113                    next.run(request).await
114                },
115                Err(e) => {
116                    tracing::debug!(error = %e, "Token validation failed");
117                    let error_description = match &e {
118                        fraiseql_core::security::SecurityError::TokenExpired { .. } => {
119                            "Bearer error=\"invalid_token\", error_description=\"Token has expired\""
120                        },
121                        fraiseql_core::security::SecurityError::InvalidToken => {
122                            "Bearer error=\"invalid_token\", error_description=\"Token is invalid\""
123                        },
124                        _ => "Bearer error=\"invalid_token\"",
125                    };
126                    (
127                        StatusCode::UNAUTHORIZED,
128                        [(header::WWW_AUTHENTICATE, error_description.to_string())],
129                        "Invalid or expired token",
130                    )
131                        .into_response()
132                },
133            }
134        },
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_auth_user_clone() {
144        use chrono::Utc;
145
146        let user = AuthenticatedUser {
147            user_id:    "user123".to_string(),
148            scopes:     vec!["read".to_string()],
149            expires_at: Utc::now(),
150        };
151
152        let auth_user = AuthUser(user.clone());
153        let cloned = auth_user.clone();
154
155        assert_eq!(auth_user.0.user_id, cloned.0.user_id);
156    }
157
158    #[test]
159    fn test_oidc_auth_state_clone() {
160        // Can't easily test without a real validator, but we can verify Clone is implemented
161        // by verifying the type compiles with Clone trait bound
162        fn assert_clone<T: Clone>() {}
163        assert_clone::<OidcAuthState>();
164    }
165}