fraiseql_server/middleware/
oidc_auth.rs1use std::sync::Arc;
7
8use axum::{
9 body::Body,
10 extract::State,
11 http::{Request, StatusCode, header},
12 middleware::Next,
13 response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17#[derive(Clone)]
19pub struct OidcAuthState {
20 pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25 #[must_use]
27 pub const fn new(validator: Arc<OidcValidator>) -> Self {
28 Self { validator }
29 }
30}
31
32#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
49 headers.get(header::COOKIE).and_then(|v| v.to_str().ok()).and_then(|cookies| {
50 cookies.split(';').find_map(|part| {
51 let part = part.trim();
52 part.strip_prefix("__Host-access_token=")
53 .map(|v| v.trim_matches('"').to_owned())
54 })
55 })
56}
57
58#[allow(clippy::cognitive_complexity)] pub async fn oidc_auth_middleware(
84 State(auth_state): State<OidcAuthState>,
85 mut request: Request<Body>,
86 next: Next,
87) -> Response {
88 let token_string: Option<String> = {
92 let auth_header = request
93 .headers()
94 .get(header::AUTHORIZATION)
95 .and_then(|value| value.to_str().ok());
96
97 match auth_header {
98 Some(header_value) => {
99 if !header_value.starts_with("Bearer ") {
100 tracing::debug!("Invalid Authorization header format");
101 return (
102 StatusCode::UNAUTHORIZED,
103 [(
104 header::WWW_AUTHENTICATE,
105 "Bearer error=\"invalid_request\"".to_string(),
106 )],
107 "Invalid Authorization header format",
108 )
109 .into_response();
110 }
111 Some(header_value[7..].to_owned())
112 },
113 None => extract_access_token_cookie(request.headers()),
114 }
115 };
116
117 match token_string {
118 None => {
119 if auth_state.validator.is_required() {
120 tracing::debug!("Authentication required but no token found (header or cookie)");
121 return (
122 StatusCode::UNAUTHORIZED,
123 [(
124 header::WWW_AUTHENTICATE,
125 format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
126 )],
127 "Authentication required",
128 )
129 .into_response();
130 }
131 next.run(request).await
133 },
134 Some(token) => {
135 match auth_state.validator.validate_token(&token).await {
137 Ok(user) => {
138 tracing::debug!(
139 user_id = %user.user_id,
140 scopes = ?user.scopes,
141 "User authenticated successfully"
142 );
143 request.extensions_mut().insert(AuthUser(user));
145 next.run(request).await
146 },
147 Err(e) => {
148 tracing::debug!(error = %e, "Token validation failed");
149 let (www_authenticate, body) = match &e {
150 fraiseql_core::security::SecurityError::TokenExpired { .. } => (
151 "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
152 "Token has expired",
153 ),
154 fraiseql_core::security::SecurityError::InvalidToken => (
155 "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
156 "Token is invalid",
157 ),
158 _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
159 };
160 (
161 StatusCode::UNAUTHORIZED,
162 [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
163 body,
164 )
165 .into_response()
166 },
167 }
168 },
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 #![allow(clippy::unwrap_used)] use super::*;
177
178 #[test]
179 fn test_auth_user_clone() {
180 use chrono::Utc;
181
182 let user = AuthenticatedUser {
183 user_id: "user123".to_string(),
184 scopes: vec!["read".to_string()],
185 expires_at: Utc::now(),
186 extra_claims: std::collections::HashMap::new(),
187 };
188
189 let auth_user = AuthUser(user);
190 let cloned = auth_user.clone();
191
192 assert_eq!(auth_user.0.user_id, cloned.0.user_id);
193 }
194
195 #[test]
196 fn test_oidc_auth_state_clone() {
197 fn assert_clone<T: Clone>() {}
200 assert_clone::<OidcAuthState>();
201 }
202
203 #[test]
204 fn test_cookie_fallback_extracts_token() {
205 let mut headers = axum::http::HeaderMap::new();
206 headers.insert(
207 header::COOKIE,
208 "__Host-access_token=my.jwt.token; Path=/; SameSite=Strict".parse().unwrap(),
209 );
210
211 let token = extract_access_token_cookie(&headers);
212 assert_eq!(token.as_deref(), Some("my.jwt.token"));
213 }
214
215 #[test]
216 fn test_cookie_fallback_strips_rfc6265_quotes() {
217 let mut headers = axum::http::HeaderMap::new();
218 headers.insert(header::COOKIE, "__Host-access_token=\"my.jwt.token\"".parse().unwrap());
219
220 let token = extract_access_token_cookie(&headers);
221 assert_eq!(token.as_deref(), Some("my.jwt.token"));
222 }
223
224 #[test]
225 fn test_cookie_fallback_absent_returns_none() {
226 let mut headers = axum::http::HeaderMap::new();
227 headers.insert(header::COOKIE, "session=abc; other=xyz".parse().unwrap());
228
229 let token = extract_access_token_cookie(&headers);
230 assert!(token.is_none());
231 }
232
233 #[test]
234 fn test_cookie_fallback_no_cookie_header_returns_none() {
235 let headers = axum::http::HeaderMap::new();
236 let token = extract_access_token_cookie(&headers);
237 assert!(token.is_none());
238 }
239
240 #[test]
241 fn test_cookie_fallback_multiple_cookies_finds_correct_one() {
242 let mut headers = axum::http::HeaderMap::new();
243 headers.insert(
244 header::COOKIE,
245 "session=abc; __Host-access_token=correct.token; csrf=xyz".parse().unwrap(),
246 );
247
248 let token = extract_access_token_cookie(&headers);
249 assert_eq!(token.as_deref(), Some("correct.token"));
250 }
251}