fraiseql_server/middleware/
oidc_auth.rs1use std::sync::Arc;
7
8use axum::{
9 body::Body,
10 extract::State,
11 http::{Request, StatusCode, header},
12 middleware::Next,
13 response::{IntoResponse, Response},
14};
15use fraiseql_core::security::{AuthenticatedUser, OidcValidator};
16
17#[derive(Clone)]
19pub struct OidcAuthState {
20 pub validator: Arc<OidcValidator>,
22}
23
24impl OidcAuthState {
25 #[must_use]
27 pub const fn new(validator: Arc<OidcValidator>) -> Self {
28 Self { validator }
29 }
30}
31
32#[derive(Clone, Debug)]
37pub struct AuthUser(pub AuthenticatedUser);
38
39fn extract_access_token_cookie(headers: &axum::http::HeaderMap) -> Option<String> {
49 headers
50 .get(header::COOKIE)
51 .and_then(|v| v.to_str().ok())
52 .and_then(|cookies| {
53 cookies.split(';').find_map(|part| {
54 let part = part.trim();
55 part.strip_prefix("__Host-access_token=")
56 .map(|v| v.trim_matches('"').to_owned())
57 })
58 })
59}
60
61#[allow(clippy::cognitive_complexity)] pub async fn oidc_auth_middleware(
87 State(auth_state): State<OidcAuthState>,
88 mut request: Request<Body>,
89 next: Next,
90) -> Response {
91 let token_string: Option<String> = {
95 let auth_header = request
96 .headers()
97 .get(header::AUTHORIZATION)
98 .and_then(|value| value.to_str().ok());
99
100 match auth_header {
101 Some(header_value) => {
102 if !header_value.starts_with("Bearer ") {
103 tracing::debug!("Invalid Authorization header format");
104 return (
105 StatusCode::UNAUTHORIZED,
106 [(
107 header::WWW_AUTHENTICATE,
108 "Bearer error=\"invalid_request\"".to_string(),
109 )],
110 "Invalid Authorization header format",
111 )
112 .into_response();
113 }
114 Some(header_value[7..].to_owned())
115 },
116 None => extract_access_token_cookie(request.headers()),
117 }
118 };
119
120 match token_string {
121 None => {
122 if auth_state.validator.is_required() {
123 tracing::debug!("Authentication required but no token found (header or cookie)");
124 return (
125 StatusCode::UNAUTHORIZED,
126 [(
127 header::WWW_AUTHENTICATE,
128 format!("Bearer realm=\"{}\"", auth_state.validator.issuer()),
129 )],
130 "Authentication required",
131 )
132 .into_response();
133 }
134 next.run(request).await
136 },
137 Some(token) => {
138 match auth_state.validator.validate_token(&token).await {
140 Ok(user) => {
141 tracing::debug!(
142 user_id = %user.user_id,
143 scopes = ?user.scopes,
144 "User authenticated successfully"
145 );
146 request.extensions_mut().insert(AuthUser(user));
148 next.run(request).await
149 },
150 Err(e) => {
151 tracing::debug!(error = %e, "Token validation failed");
152 let (www_authenticate, body) = match &e {
153 fraiseql_core::security::SecurityError::TokenExpired { .. } => (
154 "Bearer error=\"invalid_token\", error_description=\"Token has expired\"",
155 "Token has expired",
156 ),
157 fraiseql_core::security::SecurityError::InvalidToken => (
158 "Bearer error=\"invalid_token\", error_description=\"Token is invalid\"",
159 "Token is invalid",
160 ),
161 _ => ("Bearer error=\"invalid_token\"", "Invalid or expired token"),
162 };
163 (
164 StatusCode::UNAUTHORIZED,
165 [(header::WWW_AUTHENTICATE, www_authenticate.to_string())],
166 body,
167 )
168 .into_response()
169 },
170 }
171 },
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 #![allow(clippy::unwrap_used)] use super::*;
180
181 #[test]
182 fn test_auth_user_clone() {
183 use chrono::Utc;
184
185 let user = AuthenticatedUser {
186 user_id: "user123".to_string(),
187 scopes: vec!["read".to_string()],
188 expires_at: Utc::now(),
189 extra_claims: std::collections::HashMap::new(),
190 };
191
192 let auth_user = AuthUser(user);
193 let cloned = auth_user.clone();
194
195 assert_eq!(auth_user.0.user_id, cloned.0.user_id);
196 }
197
198 #[test]
199 fn test_oidc_auth_state_clone() {
200 fn assert_clone<T: Clone>() {}
203 assert_clone::<OidcAuthState>();
204 }
205
206 #[test]
207 fn test_cookie_fallback_extracts_token() {
208 let mut headers = axum::http::HeaderMap::new();
209 headers.insert(
210 header::COOKIE,
211 "__Host-access_token=my.jwt.token; Path=/; SameSite=Strict".parse().unwrap(),
212 );
213
214 let token = extract_access_token_cookie(&headers);
215 assert_eq!(token.as_deref(), Some("my.jwt.token"));
216 }
217
218 #[test]
219 fn test_cookie_fallback_strips_rfc6265_quotes() {
220 let mut headers = axum::http::HeaderMap::new();
221 headers.insert(
222 header::COOKIE,
223 "__Host-access_token=\"my.jwt.token\"".parse().unwrap(),
224 );
225
226 let token = extract_access_token_cookie(&headers);
227 assert_eq!(token.as_deref(), Some("my.jwt.token"));
228 }
229
230 #[test]
231 fn test_cookie_fallback_absent_returns_none() {
232 let mut headers = axum::http::HeaderMap::new();
233 headers.insert(header::COOKIE, "session=abc; other=xyz".parse().unwrap());
234
235 let token = extract_access_token_cookie(&headers);
236 assert!(token.is_none());
237 }
238
239 #[test]
240 fn test_cookie_fallback_no_cookie_header_returns_none() {
241 let headers = axum::http::HeaderMap::new();
242 let token = extract_access_token_cookie(&headers);
243 assert!(token.is_none());
244 }
245
246 #[test]
247 fn test_cookie_fallback_multiple_cookies_finds_correct_one() {
248 let mut headers = axum::http::HeaderMap::new();
249 headers.insert(
250 header::COOKIE,
251 "session=abc; __Host-access_token=correct.token; csrf=xyz".parse().unwrap(),
252 );
253
254 let token = extract_access_token_cookie(&headers);
255 assert_eq!(token.as_deref(), Some("correct.token"));
256 }
257}