Skip to main content

fraiseql_server/middleware/
oidc_auth.rs

1//! OIDC Authentication Middleware
2//!
3//! Provides JWT authentication for GraphQL endpoints using OIDC discovery.
4//! Supports Auth0, Keycloak, Okta, Cognito, Azure AD, and any OIDC-compliant provider.
5
6use std::sync::Arc;
7
8use axum::{
9    body::Body,
10    extract::State,
11    http::{Request, StatusCode, header},
12    middleware::Next,
13    response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17/// State for OIDC authentication middleware.
18#[derive(Clone)]
19pub struct OidcAuthState {
20    /// The OIDC validator.
21    pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25    /// Create new OIDC auth state.
26    #[must_use]
27    pub const fn new(validator: Arc<OidcValidator>) -> Self {
28        Self { validator }
29    }
30}
31
32/// Request extension containing the authenticated user.
33///
34/// After authentication middleware runs, handlers can extract this
35/// to access the authenticated user information.
36#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39/// Extract the bearer token from a raw `Cookie` header value.
40///
41/// Looks for `__Host-access_token=<value>` in the semicolon-separated cookie
42/// string and returns the token value, stripping RFC 6265 double-quotes if
43/// present.  Returns `None` if the cookie is absent.
44///
45/// This is used as a fallback by [`oidc_auth_middleware`] when no
46/// `Authorization: Bearer` header is present, to support browser flows where
47/// the JWT is stored in an `HttpOnly` cookie inaccessible to client-side script.
48fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
49    headers
50        .get(header::COOKIE)
51        .and_then(|v| v.to_str().ok())
52        .and_then(|cookies| {
53            cookies.split(';').find_map(|part| {
54                let part = part.trim();
55                part.strip_prefix("__Host-access_token=")
56                    .map(|v| v.trim_matches('"').to_owned())
57            })
58        })
59}
60
61/// OIDC authentication middleware.
62///
63/// Validates JWT tokens from the `Authorization: Bearer` header using
64/// OIDC/JWKS.  When no `Authorization` header is present, falls back to the
65/// `__Host-access_token` `HttpOnly` cookie set by the PKCE callback.
66///
67/// # Behavior
68///
69/// - If auth is required and no token (header or cookie): returns 401 Unauthorized
70/// - If token is invalid/expired: returns 401 Unauthorized
71/// - If token is valid: adds `AuthUser` to request extensions
72/// - If auth is optional and no token: allows request through (no `AuthUser`)
73///
74/// # Example
75///
76/// ```text
77/// // Requires: OIDC provider reachable for JWKS discovery, running Axum application.
78/// use axum::{middleware, Router};
79///
80/// let oidc_state = OidcAuthState::new(validator);
81/// let app = Router::new()
82///     .route("/graphql", post(graphql_handler))
83///     .layer(middleware::from_fn_with_state(oidc_state, oidc_auth_middleware));
84/// ```
85#[allow(clippy::cognitive_complexity)] // Reason: OIDC authentication middleware with token parsing, validation, and claims extraction
86pub async fn oidc_auth_middleware(
87    State(auth_state): State<OidcAuthState>,
88    mut request: Request<Body>,
89    next: Next,
90) -> Response {
91    // Prefer Authorization: Bearer header; fall back to __Host-access_token cookie.
92    // The token is extracted as an owned String to avoid borrow conflicts with
93    // request.extensions_mut() later in this function.
94    let token_string: Option<String> = {
95        let auth_header = request
96            .headers()
97            .get(header::AUTHORIZATION)
98            .and_then(|value| value.to_str().ok());
99
100        match auth_header {
101            Some(header_value) => {
102                if !header_value.starts_with("Bearer ") {
103                    tracing::debug!("Invalid Authorization header format");
104                    return (
105                        StatusCode::UNAUTHORIZED,
106                        [(
107                            header::WWW_AUTHENTICATE,
108                            "Bearer error=\"invalid_request\"".to_string(),
109                        )],
110                        "Invalid Authorization header format",
111                    )
112                        .into_response();
113                }
114                Some(header_value[7..].to_owned())
115            },
116            None => extract_access_token_cookie(request.headers()),
117        }
118    };
119
120    match token_string {
121        None => {
122            if auth_state.validator.is_required() {
123                tracing::debug!("Authentication required but no token found (header or cookie)");
124                return (
125                    StatusCode::UNAUTHORIZED,
126                    [(
127                        header::WWW_AUTHENTICATE,
128                        format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
129                    )],
130                    "Authentication required",
131                )
132                    .into_response();
133            }
134            // Auth is optional, continue without user context
135            next.run(request).await
136        },
137        Some(token) => {
138            // Validate token
139            match auth_state.validator.validate_token(&token).await {
140                Ok(user) => {
141                    tracing::debug!(
142                        user_id = %user.user_id,
143                        scopes = ?user.scopes,
144                        "User authenticated successfully"
145                    );
146                    // Add authenticated user to request extensions
147                    request.extensions_mut().insert(AuthUser(user));
148                    next.run(request).await
149                },
150                Err(e) => {
151                    tracing::debug!(error = %e, "Token validation failed");
152                    let (www_authenticate, body) = match &e {
153                        fraiseql_core::security::SecurityError::TokenExpired { .. } => (
154                            "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
155                            "Token has expired",
156                        ),
157                        fraiseql_core::security::SecurityError::InvalidToken => (
158                            "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
159                            "Token is invalid",
160                        ),
161                        _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
162                    };
163                    (
164                        StatusCode::UNAUTHORIZED,
165                        [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
166                        body,
167                    )
168                        .into_response()
169                },
170            }
171        },
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
178
179    use super::*;
180
181    #[test]
182    fn test_auth_user_clone() {
183        use chrono::Utc;
184
185        let user = AuthenticatedUser {
186            user_id:      "user123".to_string(),
187            scopes:       vec!["read".to_string()],
188            expires_at:   Utc::now(),
189            extra_claims: std::collections::HashMap::new(),
190        };
191
192        let auth_user = AuthUser(user);
193        let cloned = auth_user.clone();
194
195        assert_eq!(auth_user.0.user_id, cloned.0.user_id);
196    }
197
198    #[test]
199    fn test_oidc_auth_state_clone() {
200        // Can't easily test without a real validator, but we can verify Clone is implemented
201        // by verifying the type compiles with Clone trait bound
202        fn assert_clone<T: Clone>() {}
203        assert_clone::<OidcAuthState>();
204    }
205
206    #[test]
207    fn test_cookie_fallback_extracts_token() {
208        let mut headers = axum::http::HeaderMap::new();
209        headers.insert(
210            header::COOKIE,
211            "__Host-access_token=my.jwt.token; Path=/; SameSite=Strict".parse().unwrap(),
212        );
213
214        let token = extract_access_token_cookie(&headers);
215        assert_eq!(token.as_deref(), Some("my.jwt.token"));
216    }
217
218    #[test]
219    fn test_cookie_fallback_strips_rfc6265_quotes() {
220        let mut headers = axum::http::HeaderMap::new();
221        headers.insert(
222            header::COOKIE,
223            "__Host-access_token=\"my.jwt.token\"".parse().unwrap(),
224        );
225
226        let token = extract_access_token_cookie(&headers);
227        assert_eq!(token.as_deref(), Some("my.jwt.token"));
228    }
229
230    #[test]
231    fn test_cookie_fallback_absent_returns_none() {
232        let mut headers = axum::http::HeaderMap::new();
233        headers.insert(header::COOKIE, "session=abc; other=xyz".parse().unwrap());
234
235        let token = extract_access_token_cookie(&headers);
236        assert!(token.is_none());
237    }
238
239    #[test]
240    fn test_cookie_fallback_no_cookie_header_returns_none() {
241        let headers = axum::http::HeaderMap::new();
242        let token = extract_access_token_cookie(&headers);
243        assert!(token.is_none());
244    }
245
246    #[test]
247    fn test_cookie_fallback_multiple_cookies_finds_correct_one() {
248        let mut headers = axum::http::HeaderMap::new();
249        headers.insert(
250            header::COOKIE,
251            "session=abc; __Host-access_token=correct.token; csrf=xyz".parse().unwrap(),
252        );
253
254        let token = extract_access_token_cookie(&headers);
255        assert_eq!(token.as_deref(), Some("correct.token"));
256    }
257}