Skip to main content

fraiseql_auth/
middleware.rs

1//! Authentication middleware for Axum request handlers.
2use std::sync::Arc;
3
4use axum::{
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9
10use crate::{
11    error::{AuthError, Result},
12    jwt::{Claims, JwtValidator},
13    session::SessionStore,
14};
15
16/// Authenticated user extracted from JWT token
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthenticatedUser {
19    /// User ID from token claims
20    pub user_id: String,
21    /// Full JWT claims
22    pub claims:  Claims,
23}
24
25impl AuthenticatedUser {
26    /// Get a custom claim from the JWT
27    pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
28        self.claims.get_custom(key)
29    }
30
31    /// Check if user has a specific role
32    pub fn has_role(&self, role: &str) -> bool {
33        if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
34            user_role == role
35        } else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
36            roles.iter().any(|r| {
37                if let serde_json::Value::String(r_str) = r {
38                    r_str == role
39                } else {
40                    false
41                }
42            })
43        } else {
44            false
45        }
46    }
47}
48
49/// Authentication middleware configuration
50pub struct AuthMiddleware {
51    validator:      Arc<JwtValidator>,
52    _session_store: Arc<dyn SessionStore>,
53    public_key:     Vec<u8>,
54    _optional:      bool,
55}
56
57impl AuthMiddleware {
58    /// Create a new authentication middleware
59    ///
60    /// # Arguments
61    /// * `validator` - JWT validator
62    /// * `session_store` - Session storage backend
63    /// * `public_key` - Public key for JWT signature verification
64    /// * `optional` - If true, missing auth is not an error
65    pub fn new(
66        validator: Arc<JwtValidator>,
67        session_store: Arc<dyn SessionStore>,
68        public_key: Vec<u8>,
69        optional: bool,
70    ) -> Self {
71        Self {
72            validator,
73            _session_store: session_store,
74            public_key,
75            _optional: optional,
76        }
77    }
78
79    /// Validate a Bearer token and extract claims.
80    ///
81    /// # Errors
82    ///
83    /// Returns `AuthError::InvalidToken` if the token signature is invalid,
84    /// expired, or does not match the expected issuer/audience.
85    /// Returns `AuthError::KeyError` if the public key cannot be used for
86    /// verification.
87    pub async fn validate_token(&self, token: &str) -> Result<Claims> {
88        self.validator.validate(token, &self.public_key)
89    }
90}
91
92impl AuthError {
93    /// Map each error variant to its HTTP response parts.
94    ///
95    /// SECURITY: Sanitized messages never expose internal details.
96    #[allow(clippy::cognitive_complexity)] // Reason: exhaustive 1:1 mapping of AuthError variants to HTTP response tuples
97    fn response_parts(&self) -> (StatusCode, &'static str, String) {
98        match self {
99            Self::TokenExpired => {
100                (StatusCode::UNAUTHORIZED, "token_expired", "Authentication failed".to_string())
101            },
102            Self::InvalidSignature => (
103                StatusCode::UNAUTHORIZED,
104                "invalid_signature",
105                "Authentication failed".to_string(),
106            ),
107            Self::InvalidToken { .. }
108            | Self::MissingClaim { .. }
109            | Self::InvalidClaimValue { .. }
110            // OIDC replay-protection errors: return 401 without revealing
111            // which specific claim was invalid to avoid oracle attacks.
112            | Self::MissingNonce
113            | Self::NonceMismatch
114            | Self::MissingAuthTime
115            | Self::SessionTooOld { .. } => {
116                (StatusCode::UNAUTHORIZED, "invalid_token", "Authentication failed".to_string())
117            },
118            Self::TokenNotFound => {
119                (StatusCode::UNAUTHORIZED, "token_not_found", "Authentication failed".to_string())
120            },
121            Self::SessionRevoked => {
122                (StatusCode::UNAUTHORIZED, "session_revoked", "Authentication failed".to_string())
123            },
124            Self::InvalidState => {
125                (StatusCode::BAD_REQUEST, "invalid_state", "Authentication failed".to_string())
126            },
127            Self::Forbidden { .. } => {
128                (StatusCode::FORBIDDEN, "forbidden", "Permission denied".to_string())
129            },
130            Self::OAuthError { .. } => {
131                (StatusCode::UNAUTHORIZED, "oauth_error", "Authentication failed".to_string())
132            },
133            Self::SessionError { .. } => {
134                (StatusCode::UNAUTHORIZED, "session_error", "Authentication failed".to_string())
135            },
136            Self::DatabaseError { .. }
137            | Self::ConfigError { .. }
138            | Self::OidcMetadataError { .. }
139            | Self::Internal { .. }
140            | Self::SystemTimeError { .. } => (
141                StatusCode::INTERNAL_SERVER_ERROR,
142                "server_error",
143                "Service temporarily unavailable".to_string(),
144            ),
145            Self::PkceError { .. } => {
146                (StatusCode::BAD_REQUEST, "pkce_error", "Authentication failed".to_string())
147            },
148            Self::RateLimited { retry_after_secs } => (
149                StatusCode::TOO_MANY_REQUESTS,
150                "rate_limited",
151                format!("Too many requests. Retry after {retry_after_secs} seconds"),
152            ),
153        }
154    }
155
156    /// Log security-sensitive error details server-side before returning a sanitized response.
157    #[allow(clippy::cognitive_complexity)] // Reason: exhaustive match logging security-sensitive details per AuthError variant
158    fn log_security_details(&self) {
159        use tracing::warn;
160
161        match self {
162            Self::InvalidToken { reason } => warn!("Invalid token error: {reason}"),
163            Self::MissingClaim { claim } => warn!("Missing required claim: {claim}"),
164            Self::InvalidClaimValue { claim, reason } => {
165                warn!("Invalid claim value for '{claim}': {reason}");
166            },
167            Self::Forbidden { message } => warn!("Authorization denied: {message}"),
168            Self::OAuthError { message } => warn!("OAuth provider error: {message}"),
169            Self::SessionError { message } => warn!("Session error: {message}"),
170            Self::DatabaseError { message } => {
171                warn!("Database error (should not reach client): {message}");
172            },
173            Self::ConfigError { message } => {
174                warn!("Configuration error (should not reach client): {message}");
175            },
176            Self::OidcMetadataError { message } => warn!("OIDC metadata error: {message}"),
177            Self::PkceError { message } => warn!("PKCE error: {message}"),
178            Self::Internal { message } => {
179                warn!("Internal error (should not reach client): {message}");
180            },
181            Self::SystemTimeError { message } => {
182                warn!("System time error (should not reach client): {message}");
183            },
184            Self::MissingNonce | Self::NonceMismatch => {
185                warn!("OIDC nonce validation failed: {self}");
186            },
187            Self::MissingAuthTime | Self::SessionTooOld { .. } => {
188                warn!("OIDC auth_time validation failed: {self}");
189            },
190            // No server-side logging needed for these variants
191            Self::TokenExpired
192            | Self::InvalidSignature
193            | Self::TokenNotFound
194            | Self::SessionRevoked
195            | Self::InvalidState
196            | Self::RateLimited { .. } => {},
197        }
198    }
199}
200
201impl IntoResponse for AuthError {
202    fn into_response(self) -> Response {
203        self.log_security_details();
204        let (status, error_code, sanitized_message) = self.response_parts();
205
206        let body = serde_json::json!({
207            "errors": [{
208                "message": sanitized_message,
209                "extensions": {
210                    "code": error_code
211                }
212            }]
213        });
214
215        (status, axum::Json(body)).into_response()
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    #[allow(clippy::wildcard_imports)]
222    // Reason: test module — wildcard keeps test boilerplate minimal
223    use super::*;
224
225    #[test]
226    fn test_authenticated_user_clone() {
227        use std::collections::HashMap;
228
229        use crate::Claims;
230
231        let claims = Claims {
232            sub:   "user123".to_string(),
233            iat:   1000,
234            exp:   2000,
235            iss:   "https://example.com".to_string(),
236            aud:   vec!["api".to_string()],
237            extra: HashMap::new(),
238        };
239
240        let user = AuthenticatedUser {
241            user_id: "user123".to_string(),
242            claims,
243        };
244
245        let _cloned = user.clone();
246        assert_eq!(user.user_id, "user123");
247    }
248
249    #[test]
250    fn test_has_role_single_string() {
251        use std::collections::HashMap;
252
253        use crate::Claims;
254
255        let mut claims = Claims {
256            sub:   "user123".to_string(),
257            iat:   1000,
258            exp:   2000,
259            iss:   "https://example.com".to_string(),
260            aud:   vec!["api".to_string()],
261            extra: HashMap::new(),
262        };
263
264        claims.extra.insert("role".to_string(), serde_json::json!("admin"));
265
266        let user = AuthenticatedUser {
267            user_id: "user123".to_string(),
268            claims,
269        };
270
271        assert!(user.has_role("admin"));
272        assert!(!user.has_role("user"));
273    }
274
275    #[test]
276    fn test_has_role_array() {
277        use std::collections::HashMap;
278
279        use crate::Claims;
280
281        let mut claims = Claims {
282            sub:   "user123".to_string(),
283            iat:   1000,
284            exp:   2000,
285            iss:   "https://example.com".to_string(),
286            aud:   vec!["api".to_string()],
287            extra: HashMap::new(),
288        };
289
290        claims
291            .extra
292            .insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
293
294        let user = AuthenticatedUser {
295            user_id: "user123".to_string(),
296            claims,
297        };
298
299        assert!(user.has_role("admin"));
300        assert!(user.has_role("user"));
301        assert!(user.has_role("editor"));
302        assert!(!user.has_role("moderator"));
303    }
304
305    #[test]
306    fn test_get_custom_claim() {
307        use std::collections::HashMap;
308
309        use crate::Claims;
310
311        let mut claims = Claims {
312            sub:   "user123".to_string(),
313            iat:   1000,
314            exp:   2000,
315            iss:   "https://example.com".to_string(),
316            aud:   vec!["api".to_string()],
317            extra: HashMap::new(),
318        };
319
320        claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
321
322        let user = AuthenticatedUser {
323            user_id: "user123".to_string(),
324            claims,
325        };
326
327        assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
328        assert_eq!(user.get_custom_claim("nonexistent"), None);
329    }
330
331    // SECURITY: Tests for error message sanitization to ensure internal details are never exposed
332
333    #[test]
334    fn test_invalid_token_sanitized() {
335        // SECURITY: Ensure cryptographic details are not exposed
336        let error = AuthError::InvalidToken {
337            reason: "RS256 signature mismatch at offset 512 bytes".to_string(),
338        };
339        // Verify it produces UNAUTHORIZED status by checking the status code mapping
340        let response = error.into_response();
341        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
342    }
343
344    #[test]
345    fn test_missing_claim_sanitized() {
346        // SECURITY: Ensure claim names are not exposed to attackers
347        let error = AuthError::MissingClaim {
348            claim: "sensitive_user_id".to_string(),
349        };
350        // Verify it produces UNAUTHORIZED status
351        let response = error.into_response();
352        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
353    }
354
355    #[test]
356    fn test_invalid_claim_value_sanitized() {
357        // SECURITY: Ensure claim validation rules are not exposed
358        let error = AuthError::InvalidClaimValue {
359            claim:  "exp".to_string(),
360            reason: "Must match pattern: ^[0-9]{10,}$".to_string(),
361        };
362        let response = error.into_response();
363        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
364    }
365
366    #[test]
367    fn test_database_error_sanitized() {
368        // SECURITY: NEVER expose database errors to clients
369        let error = AuthError::DatabaseError {
370            message: "Connection to 192.168.1.100:5432 failed: timeout".to_string(),
371        };
372        let response = error.into_response();
373        // Database errors should return INTERNAL_SERVER_ERROR
374        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
375    }
376
377    #[test]
378    fn test_config_error_sanitized() {
379        // SECURITY: NEVER expose configuration details to clients
380        let error = AuthError::ConfigError {
381            message: "Secret key missing in /etc/fraiseql/config.toml".to_string(),
382        };
383        let response = error.into_response();
384        // Config errors should return INTERNAL_SERVER_ERROR
385        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
386    }
387
388    #[test]
389    fn test_oauth_error_sanitized() {
390        // SECURITY: Don't expose OAuth provider details
391        let error = AuthError::OAuthError {
392            message: "GitHub API returned 500 from https://api.github.com/user (rate limited)"
393                .to_string(),
394        };
395        let response = error.into_response();
396        // OAuth errors should return UNAUTHORIZED
397        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
398    }
399
400    #[test]
401    fn test_session_error_sanitized() {
402        // SECURITY: Don't expose session implementation details
403        let error = AuthError::SessionError {
404            message: "Redis connection pool exhausted: 0/10 available".to_string(),
405        };
406        let response = error.into_response();
407        // Session errors should return UNAUTHORIZED
408        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
409    }
410
411    #[test]
412    fn test_forbidden_error_sanitized() {
413        // SECURITY: Don't expose permission logic details
414        let error = AuthError::Forbidden {
415            message: "User lacks role=admin AND permission=write:config for operation".to_string(),
416        };
417        let response = error.into_response();
418        // Forbidden errors should return FORBIDDEN
419        assert_eq!(response.status(), StatusCode::FORBIDDEN);
420    }
421
422    #[test]
423    fn test_internal_error_sanitized() {
424        // SECURITY: NEVER expose internal errors to clients
425        let error = AuthError::Internal {
426            message: "Panic in JWT validation thread: index out of bounds".to_string(),
427        };
428        let response = error.into_response();
429        // Internal errors should return INTERNAL_SERVER_ERROR
430        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
431    }
432
433    #[test]
434    fn test_system_time_error_sanitized() {
435        // SECURITY: Don't expose system errors to clients
436        let error = AuthError::SystemTimeError {
437            message: "System clock jumped backward by 3600 seconds".to_string(),
438        };
439        let response = error.into_response();
440        // System time errors should return INTERNAL_SERVER_ERROR
441        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
442    }
443
444    #[test]
445    fn test_rate_limited_error_message() {
446        // Rate limited errors CAN expose retry timing (it's not sensitive)
447        let error = AuthError::RateLimited {
448            retry_after_secs: 60,
449        };
450        let response = error.into_response();
451        // Rate limited should return TOO_MANY_REQUESTS
452        assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
453    }
454
455    #[test]
456    fn test_token_expired_returns_generic_message() {
457        let error = AuthError::TokenExpired;
458        let response = error.into_response();
459        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
460    }
461
462    #[test]
463    fn test_invalid_signature_returns_generic_message() {
464        let error = AuthError::InvalidSignature;
465        let response = error.into_response();
466        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
467    }
468
469    #[test]
470    fn test_invalid_state_error() {
471        let error = AuthError::InvalidState;
472        let response = error.into_response();
473        // Invalid state (CSRF token) should return BAD_REQUEST
474        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
475    }
476
477    #[test]
478    fn test_pkce_error_returns_bad_request() {
479        let error = AuthError::PkceError {
480            message: "Challenge verification failed".to_string(),
481        };
482        let response = error.into_response();
483        // PKCE errors should return BAD_REQUEST
484        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
485    }
486
487    #[test]
488    fn test_oidc_metadata_error_returns_server_error() {
489        let error = AuthError::OidcMetadataError {
490            message: "Failed to fetch metadata".to_string(),
491        };
492        let response = error.into_response();
493        // OIDC metadata errors should return INTERNAL_SERVER_ERROR
494        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
495    }
496
497    #[test]
498    fn test_all_errors_have_status_codes() {
499        // Verify that all error types have proper status codes
500        let errors = vec![
501            AuthError::TokenExpired,
502            AuthError::InvalidSignature,
503            AuthError::InvalidState,
504            AuthError::TokenNotFound,
505            AuthError::SessionRevoked,
506            AuthError::InvalidToken {
507                reason: "test".to_string(),
508            },
509            AuthError::MissingClaim {
510                claim: "test".to_string(),
511            },
512            AuthError::InvalidClaimValue {
513                claim:  "test".to_string(),
514                reason: "test".to_string(),
515            },
516            AuthError::OAuthError {
517                message: "test".to_string(),
518            },
519            AuthError::SessionError {
520                message: "test".to_string(),
521            },
522            AuthError::DatabaseError {
523                message: "test".to_string(),
524            },
525            AuthError::ConfigError {
526                message: "test".to_string(),
527            },
528            AuthError::OidcMetadataError {
529                message: "test".to_string(),
530            },
531            AuthError::PkceError {
532                message: "test".to_string(),
533            },
534            AuthError::Forbidden {
535                message: "test".to_string(),
536            },
537            AuthError::Internal {
538                message: "test".to_string(),
539            },
540            AuthError::SystemTimeError {
541                message: "test".to_string(),
542            },
543            AuthError::RateLimited {
544                retry_after_secs: 60,
545            },
546        ];
547
548        for error in errors {
549            let response = error.into_response();
550            // Every error should produce a valid status code
551            let status = response.status();
552            assert!(
553                status == StatusCode::UNAUTHORIZED
554                    || status == StatusCode::FORBIDDEN
555                    || status == StatusCode::BAD_REQUEST
556                    || status == StatusCode::INTERNAL_SERVER_ERROR
557                    || status == StatusCode::TOO_MANY_REQUESTS,
558                "Unexpected status code: {}",
559                status
560            );
561        }
562    }
563}