1use std::sync::Arc;
3
4use axum::{
5 http::StatusCode,
6 response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11 error::{AuthError, Result},
12 jwt::{Claims, JwtValidator},
13 session::SessionStore,
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthenticatedUser {
19 pub user_id: String,
21 pub claims: Claims,
23}
24
25impl AuthenticatedUser {
26 pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
28 self.claims.get_custom(key)
29 }
30
31 pub fn has_role(&self, role: &str) -> bool {
33 if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
34 user_role == role
35 } else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
36 roles.iter().any(|r| {
37 if let serde_json::Value::String(r_str) = r {
38 r_str == role
39 } else {
40 false
41 }
42 })
43 } else {
44 false
45 }
46 }
47}
48
49pub struct AuthMiddleware {
51 validator: Arc<JwtValidator>,
52 _session_store: Arc<dyn SessionStore>,
53 public_key: Vec<u8>,
54 _optional: bool,
55}
56
57impl AuthMiddleware {
58 pub fn new(
66 validator: Arc<JwtValidator>,
67 session_store: Arc<dyn SessionStore>,
68 public_key: Vec<u8>,
69 optional: bool,
70 ) -> Self {
71 Self {
72 validator,
73 _session_store: session_store,
74 public_key,
75 _optional: optional,
76 }
77 }
78
79 pub async fn validate_token(&self, token: &str) -> Result<Claims> {
88 self.validator.validate(token, &self.public_key)
89 }
90}
91
92impl AuthError {
93 #[allow(clippy::cognitive_complexity)] fn response_parts(&self) -> (StatusCode, &'static str, String) {
98 match self {
99 Self::TokenExpired => {
100 (StatusCode::UNAUTHORIZED, "token_expired", "Authentication failed".to_string())
101 },
102 Self::InvalidSignature => (
103 StatusCode::UNAUTHORIZED,
104 "invalid_signature",
105 "Authentication failed".to_string(),
106 ),
107 Self::InvalidToken { .. }
108 | Self::MissingClaim { .. }
109 | Self::InvalidClaimValue { .. }
110 | Self::MissingNonce
113 | Self::NonceMismatch
114 | Self::MissingAuthTime
115 | Self::SessionTooOld { .. } => {
116 (StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
117 },
118 Self::TokenNotFound => {
119 (StatusCode::UNAUTHORIZED, "token_not_found", "Authentication failed".to_string())
120 },
121 Self::SessionRevoked => {
122 (StatusCode::UNAUTHORIZED, "session_revoked", "Authentication failed".to_string())
123 },
124 Self::InvalidState => {
125 (StatusCode::BAD_REQUEST, "invalid_state", "Authentication failed".to_string())
126 },
127 Self::Forbidden { .. } => {
128 (StatusCode::FORBIDDEN, "forbidden", "Permission denied".to_string())
129 },
130 Self::OAuthError { .. } => {
131 (StatusCode::UNAUTHORIZED, "oauth_error", "Authentication failed".to_string())
132 },
133 Self::SessionError { .. } => {
134 (StatusCode::UNAUTHORIZED, "session_error", "Authentication failed".to_string())
135 },
136 Self::DatabaseError { .. }
137 | Self::ConfigError { .. }
138 | Self::OidcMetadataError { .. }
139 | Self::Internal { .. }
140 | Self::SystemTimeError { .. } => (
141 StatusCode::INTERNAL_SERVER_ERROR,
142 "server_error",
143 "Service temporarily unavailable".to_string(),
144 ),
145 Self::PkceError { .. } => {
146 (StatusCode::BAD_REQUEST, "pkce_error", "Authentication failed".to_string())
147 },
148 Self::RateLimited { retry_after_secs } => (
149 StatusCode::TOO_MANY_REQUESTS,
150 "rate_limited",
151 format!("Too many requests. Retry after {retry_after_secs} seconds"),
152 ),
153 }
154 }
155
156 #[allow(clippy::cognitive_complexity)] fn log_security_details(&self) {
159 use tracing::warn;
160
161 match self {
162 Self::InvalidToken { reason } => warn!("Invalid token error: {reason}"),
163 Self::MissingClaim { claim } => warn!("Missing required claim: {claim}"),
164 Self::InvalidClaimValue { claim, reason } => {
165 warn!("Invalid claim value for '{claim}': {reason}");
166 },
167 Self::Forbidden { message } => warn!("Authorization denied: {message}"),
168 Self::OAuthError { message } => warn!("OAuth provider error: {message}"),
169 Self::SessionError { message } => warn!("Session error: {message}"),
170 Self::DatabaseError { message } => {
171 warn!("Database error (should not reach client): {message}");
172 },
173 Self::ConfigError { message } => {
174 warn!("Configuration error (should not reach client): {message}");
175 },
176 Self::OidcMetadataError { message } => warn!("OIDC metadata error: {message}"),
177 Self::PkceError { message } => warn!("PKCE error: {message}"),
178 Self::Internal { message } => {
179 warn!("Internal error (should not reach client): {message}");
180 },
181 Self::SystemTimeError { message } => {
182 warn!("System time error (should not reach client): {message}");
183 },
184 Self::MissingNonce | Self::NonceMismatch => {
185 warn!("OIDC nonce validation failed: {self}");
186 },
187 Self::MissingAuthTime | Self::SessionTooOld { .. } => {
188 warn!("OIDC auth_time validation failed: {self}");
189 },
190 Self::TokenExpired
192 | Self::InvalidSignature
193 | Self::TokenNotFound
194 | Self::SessionRevoked
195 | Self::InvalidState
196 | Self::RateLimited { .. } => {},
197 }
198 }
199}
200
201impl IntoResponse for AuthError {
202 fn into_response(self) -> Response {
203 self.log_security_details();
204 let (status, error_code, sanitized_message) = self.response_parts();
205
206 let body = serde_json::json!({
207 "errors": [{
208 "message": sanitized_message,
209 "extensions": {
210 "code": error_code
211 }
212 }]
213 });
214
215 (status, axum::Json(body)).into_response()
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 #[allow(clippy::wildcard_imports)]
222 use super::*;
224
225 #[test]
226 fn test_authenticated_user_clone() {
227 use std::collections::HashMap;
228
229 use crate::Claims;
230
231 let claims = Claims {
232 sub: "user123".to_string(),
233 iat: 1000,
234 exp: 2000,
235 iss: "https://example.com".to_string(),
236 aud: vec!["api".to_string()],
237 extra: HashMap::new(),
238 };
239
240 let user = AuthenticatedUser {
241 user_id: "user123".to_string(),
242 claims,
243 };
244
245 let _cloned = user.clone();
246 assert_eq!(user.user_id, "user123");
247 }
248
249 #[test]
250 fn test_has_role_single_string() {
251 use std::collections::HashMap;
252
253 use crate::Claims;
254
255 let mut claims = Claims {
256 sub: "user123".to_string(),
257 iat: 1000,
258 exp: 2000,
259 iss: "https://example.com".to_string(),
260 aud: vec!["api".to_string()],
261 extra: HashMap::new(),
262 };
263
264 claims.extra.insert("role".to_string(), serde_json::json!("admin"));
265
266 let user = AuthenticatedUser {
267 user_id: "user123".to_string(),
268 claims,
269 };
270
271 assert!(user.has_role("admin"));
272 assert!(!user.has_role("user"));
273 }
274
275 #[test]
276 fn test_has_role_array() {
277 use std::collections::HashMap;
278
279 use crate::Claims;
280
281 let mut claims = Claims {
282 sub: "user123".to_string(),
283 iat: 1000,
284 exp: 2000,
285 iss: "https://example.com".to_string(),
286 aud: vec!["api".to_string()],
287 extra: HashMap::new(),
288 };
289
290 claims
291 .extra
292 .insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
293
294 let user = AuthenticatedUser {
295 user_id: "user123".to_string(),
296 claims,
297 };
298
299 assert!(user.has_role("admin"));
300 assert!(user.has_role("user"));
301 assert!(user.has_role("editor"));
302 assert!(!user.has_role("moderator"));
303 }
304
305 #[test]
306 fn test_get_custom_claim() {
307 use std::collections::HashMap;
308
309 use crate::Claims;
310
311 let mut claims = Claims {
312 sub: "user123".to_string(),
313 iat: 1000,
314 exp: 2000,
315 iss: "https://example.com".to_string(),
316 aud: vec!["api".to_string()],
317 extra: HashMap::new(),
318 };
319
320 claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
321
322 let user = AuthenticatedUser {
323 user_id: "user123".to_string(),
324 claims,
325 };
326
327 assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
328 assert_eq!(user.get_custom_claim("nonexistent"), None);
329 }
330
331 #[test]
334 fn test_invalid_token_sanitized() {
335 let error = AuthError::InvalidToken {
337 reason: "RS256 signature mismatch at offset 512 bytes".to_string(),
338 };
339 let response = error.into_response();
341 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
342 }
343
344 #[test]
345 fn test_missing_claim_sanitized() {
346 let error = AuthError::MissingClaim {
348 claim: "sensitive_user_id".to_string(),
349 };
350 let response = error.into_response();
352 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
353 }
354
355 #[test]
356 fn test_invalid_claim_value_sanitized() {
357 let error = AuthError::InvalidClaimValue {
359 claim: "exp".to_string(),
360 reason: "Must match pattern: ^[0-9]{10,}$".to_string(),
361 };
362 let response = error.into_response();
363 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
364 }
365
366 #[test]
367 fn test_database_error_sanitized() {
368 let error = AuthError::DatabaseError {
370 message: "Connection to 192.168.1.100:5432 failed: timeout".to_string(),
371 };
372 let response = error.into_response();
373 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
375 }
376
377 #[test]
378 fn test_config_error_sanitized() {
379 let error = AuthError::ConfigError {
381 message: "Secret key missing in /etc/fraiseql/config.toml".to_string(),
382 };
383 let response = error.into_response();
384 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
386 }
387
388 #[test]
389 fn test_oauth_error_sanitized() {
390 let error = AuthError::OAuthError {
392 message: "GitHub API returned 500 from https://api.github.com/user (rate limited)"
393 .to_string(),
394 };
395 let response = error.into_response();
396 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
398 }
399
400 #[test]
401 fn test_session_error_sanitized() {
402 let error = AuthError::SessionError {
404 message: "Redis connection pool exhausted: 0/10 available".to_string(),
405 };
406 let response = error.into_response();
407 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
409 }
410
411 #[test]
412 fn test_forbidden_error_sanitized() {
413 let error = AuthError::Forbidden {
415 message: "User lacks role=admin AND permission=write:config for operation".to_string(),
416 };
417 let response = error.into_response();
418 assert_eq!(response.status(), StatusCode::FORBIDDEN);
420 }
421
422 #[test]
423 fn test_internal_error_sanitized() {
424 let error = AuthError::Internal {
426 message: "Panic in JWT validation thread: index out of bounds".to_string(),
427 };
428 let response = error.into_response();
429 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
431 }
432
433 #[test]
434 fn test_system_time_error_sanitized() {
435 let error = AuthError::SystemTimeError {
437 message: "System clock jumped backward by 3600 seconds".to_string(),
438 };
439 let response = error.into_response();
440 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
442 }
443
444 #[test]
445 fn test_rate_limited_error_message() {
446 let error = AuthError::RateLimited {
448 retry_after_secs: 60,
449 };
450 let response = error.into_response();
451 assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
453 }
454
455 #[test]
456 fn test_token_expired_returns_generic_message() {
457 let error = AuthError::TokenExpired;
458 let response = error.into_response();
459 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
460 }
461
462 #[test]
463 fn test_invalid_signature_returns_generic_message() {
464 let error = AuthError::InvalidSignature;
465 let response = error.into_response();
466 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
467 }
468
469 #[test]
470 fn test_invalid_state_error() {
471 let error = AuthError::InvalidState;
472 let response = error.into_response();
473 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
475 }
476
477 #[test]
478 fn test_pkce_error_returns_bad_request() {
479 let error = AuthError::PkceError {
480 message: "Challenge verification failed".to_string(),
481 };
482 let response = error.into_response();
483 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
485 }
486
487 #[test]
488 fn test_oidc_metadata_error_returns_server_error() {
489 let error = AuthError::OidcMetadataError {
490 message: "Failed to fetch metadata".to_string(),
491 };
492 let response = error.into_response();
493 assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
495 }
496
497 #[test]
498 fn test_all_errors_have_status_codes() {
499 let errors = vec![
501 AuthError::TokenExpired,
502 AuthError::InvalidSignature,
503 AuthError::InvalidState,
504 AuthError::TokenNotFound,
505 AuthError::SessionRevoked,
506 AuthError::InvalidToken {
507 reason: "test".to_string(),
508 },
509 AuthError::MissingClaim {
510 claim: "test".to_string(),
511 },
512 AuthError::InvalidClaimValue {
513 claim: "test".to_string(),
514 reason: "test".to_string(),
515 },
516 AuthError::OAuthError {
517 message: "test".to_string(),
518 },
519 AuthError::SessionError {
520 message: "test".to_string(),
521 },
522 AuthError::DatabaseError {
523 message: "test".to_string(),
524 },
525 AuthError::ConfigError {
526 message: "test".to_string(),
527 },
528 AuthError::OidcMetadataError {
529 message: "test".to_string(),
530 },
531 AuthError::PkceError {
532 message: "test".to_string(),
533 },
534 AuthError::Forbidden {
535 message: "test".to_string(),
536 },
537 AuthError::Internal {
538 message: "test".to_string(),
539 },
540 AuthError::SystemTimeError {
541 message: "test".to_string(),
542 },
543 AuthError::RateLimited {
544 retry_after_secs: 60,
545 },
546 ];
547
548 for error in errors {
549 let response = error.into_response();
550 let status = response.status();
552 assert!(
553 status == StatusCode::UNAUTHORIZED
554 || status == StatusCode::FORBIDDEN
555 || status == StatusCode::BAD_REQUEST
556 || status == StatusCode::INTERNAL_SERVER_ERROR
557 || status == StatusCode::TOO_MANY_REQUESTS,
558 "Unexpected status code: {}",
559 status
560 );
561 }
562 }
563}