Skip to main content

fraiseql_server/middleware/
oidc_auth.rs

1//! OIDC Authentication Middleware
2//!
3//! Provides JWT authentication for GraphQL endpoints using OIDC discovery.
4//! Supports Auth0, Keycloak, Okta, Cognito, Azure AD, and any OIDC-compliant provider.
5
6use std::sync::Arc;
7
8use axum::{
9    body::Body,
10    extract::State,
11    http::{Request, StatusCode, header},
12    middleware::Next,
13    response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17/// State for OIDC authentication middleware.
18#[derive(Clone)]
19pub struct OidcAuthState {
20    /// The OIDC validator.
21    pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25    /// Create new OIDC auth state.
26    #[must_use]
27    pub const fn new(validator: Arc<OidcValidator>) -> Self {
28        Self { validator }
29    }
30}
31
32/// Request extension containing the authenticated user.
33///
34/// After authentication middleware runs, handlers can extract this
35/// to access the authenticated user information.
36#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39/// Extract the bearer token from a raw `Cookie` header value.
40///
41/// Looks for `__Host-access_token=<value>` in the semicolon-separated cookie
42/// string and returns the token value, stripping RFC 6265 double-quotes if
43/// present.  Returns `None` if the cookie is absent.
44///
45/// This is used as a fallback by [`oidc_auth_middleware`] when no
46/// `Authorization: Bearer` header is present, to support browser flows where
47/// the JWT is stored in an `HttpOnly` cookie inaccessible to client-side script.
48fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
49    headers.get(header::COOKIE).and_then(|v| v.to_str().ok()).and_then(|cookies| {
50        cookies.split(';').find_map(|part| {
51            let part = part.trim();
52            part.strip_prefix("__Host-access_token=")
53                .map(|v| v.trim_matches('"').to_owned())
54        })
55    })
56}
57
58/// OIDC authentication middleware.
59///
60/// Validates JWT tokens from the `Authorization: Bearer` header using
61/// OIDC/JWKS.  When no `Authorization` header is present, falls back to the
62/// `__Host-access_token` `HttpOnly` cookie set by the PKCE callback.
63///
64/// # Behavior
65///
66/// - If auth is required and no token (header or cookie): returns 401 Unauthorized
67/// - If token is invalid/expired: returns 401 Unauthorized
68/// - If token is valid: adds `AuthUser` to request extensions
69/// - If auth is optional and no token: allows request through (no `AuthUser`)
70///
71/// # Example
72///
73/// ```text
74/// // Requires: OIDC provider reachable for JWKS discovery, running Axum application.
75/// use axum::{middleware, Router};
76///
77/// let oidc_state = OidcAuthState::new(validator);
78/// let app = Router::new()
79///     .route("/graphql", post(graphql_handler))
80///     .layer(middleware::from_fn_with_state(oidc_state, oidc_auth_middleware));
81/// ```
82#[allow(clippy::cognitive_complexity)] // Reason: OIDC authentication middleware with token parsing, validation, and claims extraction
83pub async fn oidc_auth_middleware(
84    State(auth_state): State<OidcAuthState>,
85    mut request: Request<Body>,
86    next: Next,
87) -> Response {
88    // Prefer Authorization: Bearer header; fall back to __Host-access_token cookie.
89    // The token is extracted as an owned String to avoid borrow conflicts with
90    // request.extensions_mut() later in this function.
91    let token_string: Option<String> = {
92        let auth_header = request
93            .headers()
94            .get(header::AUTHORIZATION)
95            .and_then(|value| value.to_str().ok());
96
97        match auth_header {
98            Some(header_value) => {
99                if !header_value.starts_with("Bearer ") {
100                    tracing::debug!("Invalid Authorization header format");
101                    return (
102                        StatusCode::UNAUTHORIZED,
103                        [(
104                            header::WWW_AUTHENTICATE,
105                            "Bearer error=\"invalid_request\"".to_string(),
106                        )],
107                        "Invalid Authorization header format",
108                    )
109                        .into_response();
110                }
111                Some(header_value[7..].to_owned())
112            },
113            None => extract_access_token_cookie(request.headers()),
114        }
115    };
116
117    match token_string {
118        None => {
119            if auth_state.validator.is_required() {
120                tracing::debug!("Authentication required but no token found (header or cookie)");
121                return (
122                    StatusCode::UNAUTHORIZED,
123                    [(
124                        header::WWW_AUTHENTICATE,
125                        format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
126                    )],
127                    "Authentication required",
128                )
129                    .into_response();
130            }
131            // Auth is optional, continue without user context
132            next.run(request).await
133        },
134        Some(token) => {
135            // Validate token
136            match auth_state.validator.validate_token(&token).await {
137                Ok(user) => {
138                    tracing::debug!(
139                        user_id = %user.user_id,
140                        scopes = ?user.scopes,
141                        "User authenticated successfully"
142                    );
143                    // Add authenticated user to request extensions
144                    request.extensions_mut().insert(AuthUser(user));
145                    next.run(request).await
146                },
147                Err(e) => {
148                    tracing::debug!(error = %e, "Token validation failed");
149                    let (www_authenticate, body) = match &e {
150                        fraiseql_core::security::SecurityError::TokenExpired { .. } => (
151                            "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
152                            "Token has expired",
153                        ),
154                        fraiseql_core::security::SecurityError::InvalidToken => (
155                            "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
156                            "Token is invalid",
157                        ),
158                        _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
159                    };
160                    (
161                        StatusCode::UNAUTHORIZED,
162                        [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
163                        body,
164                    )
165                        .into_response()
166                },
167            }
168        },
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
175
176    use super::*;
177
178    #[test]
179    fn test_auth_user_clone() {
180        use chrono::Utc;
181
182        let user = AuthenticatedUser {
183            user_id:      "user123".to_string(),
184            scopes:       vec!["read".to_string()],
185            expires_at:   Utc::now(),
186            extra_claims: std::collections::HashMap::new(),
187        };
188
189        let auth_user = AuthUser(user);
190        let cloned = auth_user.clone();
191
192        assert_eq!(auth_user.0.user_id, cloned.0.user_id);
193    }
194
195    #[test]
196    fn test_oidc_auth_state_clone() {
197        // Can't easily test without a real validator, but we can verify Clone is implemented
198        // by verifying the type compiles with Clone trait bound
199        fn assert_clone<T: Clone>() {}
200        assert_clone::<OidcAuthState>();
201    }
202
203    #[test]
204    fn test_cookie_fallback_extracts_token() {
205        let mut headers = axum::http::HeaderMap::new();
206        headers.insert(
207            header::COOKIE,
208            "__Host-access_token=my.jwt.token; Path=/; SameSite=Strict".parse().unwrap(),
209        );
210
211        let token = extract_access_token_cookie(&headers);
212        assert_eq!(token.as_deref(), Some("my.jwt.token"));
213    }
214
215    #[test]
216    fn test_cookie_fallback_strips_rfc6265_quotes() {
217        let mut headers = axum::http::HeaderMap::new();
218        headers.insert(header::COOKIE, "__Host-access_token=\"my.jwt.token\"".parse().unwrap());
219
220        let token = extract_access_token_cookie(&headers);
221        assert_eq!(token.as_deref(), Some("my.jwt.token"));
222    }
223
224    #[test]
225    fn test_cookie_fallback_absent_returns_none() {
226        let mut headers = axum::http::HeaderMap::new();
227        headers.insert(header::COOKIE, "session=abc; other=xyz".parse().unwrap());
228
229        let token = extract_access_token_cookie(&headers);
230        assert!(token.is_none());
231    }
232
233    #[test]
234    fn test_cookie_fallback_no_cookie_header_returns_none() {
235        let headers = axum::http::HeaderMap::new();
236        let token = extract_access_token_cookie(&headers);
237        assert!(token.is_none());
238    }
239
240    #[test]
241    fn test_cookie_fallback_multiple_cookies_finds_correct_one() {
242        let mut headers = axum::http::HeaderMap::new();
243        headers.insert(
244            header::COOKIE,
245            "session=abc; __Host-access_token=correct.token; csrf=xyz".parse().unwrap(),
246        );
247
248        let token = extract_access_token_cookie(&headers);
249        assert_eq!(token.as_deref(), Some("correct.token"));
250    }
251}