Skip to main content

fraiseql_server/auth/
middleware.rs

1// Authentication middleware for Axum
2use std::sync::Arc;
3
4use axum::{
5    http::StatusCode,
6    response::{IntoResponse, Response},
7};
8use serde::{Deserialize, Serialize};
9
10use crate::auth::{
11    error::{AuthError, Result},
12    jwt::{Claims, JwtValidator},
13    session::SessionStore,
14};
15
16/// Authenticated user extracted from JWT token
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AuthenticatedUser {
19    /// User ID from token claims
20    pub user_id: String,
21    /// Full JWT claims
22    pub claims:  Claims,
23}
24
25impl AuthenticatedUser {
26    /// Get a custom claim from the JWT
27    pub fn get_custom_claim(&self, key: &str) -> Option<&serde_json::Value> {
28        self.claims.get_custom(key)
29    }
30
31    /// Check if user has a specific role
32    pub fn has_role(&self, role: &str) -> bool {
33        if let Some(serde_json::Value::String(user_role)) = self.claims.get_custom("role") {
34            user_role == role
35        } else if let Some(serde_json::Value::Array(roles)) = self.claims.get_custom("roles") {
36            roles.iter().any(|r| {
37                if let serde_json::Value::String(r_str) = r {
38                    r_str == role
39                } else {
40                    false
41                }
42            })
43        } else {
44            false
45        }
46    }
47}
48
49/// Authentication middleware configuration
50pub struct AuthMiddleware {
51    validator:      Arc<JwtValidator>,
52    _session_store: Arc<dyn SessionStore>,
53    public_key:     Vec<u8>,
54    _optional:      bool,
55}
56
57impl AuthMiddleware {
58    /// Create a new authentication middleware
59    ///
60    /// # Arguments
61    /// * `validator` - JWT validator
62    /// * `session_store` - Session storage backend
63    /// * `public_key` - Public key for JWT signature verification
64    /// * `optional` - If true, missing auth is not an error
65    pub fn new(
66        validator: Arc<JwtValidator>,
67        session_store: Arc<dyn SessionStore>,
68        public_key: Vec<u8>,
69        optional: bool,
70    ) -> Self {
71        Self {
72            validator,
73            _session_store: session_store,
74            public_key,
75            _optional: optional,
76        }
77    }
78
79    /// Validate a Bearer token and extract claims
80    pub async fn validate_token(&self, token: &str) -> Result<Claims> {
81        self.validator.validate(token, &self.public_key)
82    }
83}
84
85impl IntoResponse for AuthError {
86    fn into_response(self) -> Response {
87        let (status, error, message) = match self {
88            AuthError::TokenExpired => {
89                (StatusCode::UNAUTHORIZED, "token_expired", "Authentication token has expired")
90            },
91            AuthError::InvalidSignature => {
92                (StatusCode::UNAUTHORIZED, "invalid_signature", "Token signature is invalid")
93            },
94            AuthError::InvalidToken { ref reason } => {
95                (StatusCode::UNAUTHORIZED, "invalid_token", reason.as_str())
96            },
97            AuthError::TokenNotFound => {
98                (StatusCode::UNAUTHORIZED, "token_not_found", "Authentication token not found")
99            },
100            AuthError::SessionRevoked => {
101                (StatusCode::UNAUTHORIZED, "session_revoked", "Session has been revoked")
102            },
103            AuthError::Forbidden { ref message } => {
104                (StatusCode::FORBIDDEN, "forbidden", message.as_str())
105            },
106            _ => (
107                StatusCode::INTERNAL_SERVER_ERROR,
108                "auth_error",
109                "An authentication error occurred",
110            ),
111        };
112
113        let body = serde_json::json!({
114            "errors": [{
115                "message": message,
116                "extensions": {
117                    "code": error
118                }
119            }]
120        });
121
122        (status, axum::Json(body)).into_response()
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    #[test]
131    fn test_authenticated_user_clone() {
132        use std::collections::HashMap;
133
134        use crate::auth::Claims;
135
136        let claims = Claims {
137            sub:   "user123".to_string(),
138            iat:   1000,
139            exp:   2000,
140            iss:   "https://example.com".to_string(),
141            aud:   vec!["api".to_string()],
142            extra: HashMap::new(),
143        };
144
145        let user = AuthenticatedUser {
146            user_id: "user123".to_string(),
147            claims,
148        };
149
150        let _cloned = user.clone();
151        assert_eq!(user.user_id, "user123");
152    }
153
154    #[test]
155    fn test_has_role_single_string() {
156        use std::collections::HashMap;
157
158        use crate::auth::Claims;
159
160        let mut claims = Claims {
161            sub:   "user123".to_string(),
162            iat:   1000,
163            exp:   2000,
164            iss:   "https://example.com".to_string(),
165            aud:   vec!["api".to_string()],
166            extra: HashMap::new(),
167        };
168
169        claims.extra.insert("role".to_string(), serde_json::json!("admin"));
170
171        let user = AuthenticatedUser {
172            user_id: "user123".to_string(),
173            claims,
174        };
175
176        assert!(user.has_role("admin"));
177        assert!(!user.has_role("user"));
178    }
179
180    #[test]
181    fn test_has_role_array() {
182        use std::collections::HashMap;
183
184        use crate::auth::Claims;
185
186        let mut claims = Claims {
187            sub:   "user123".to_string(),
188            iat:   1000,
189            exp:   2000,
190            iss:   "https://example.com".to_string(),
191            aud:   vec!["api".to_string()],
192            extra: HashMap::new(),
193        };
194
195        claims
196            .extra
197            .insert("roles".to_string(), serde_json::json!(["admin", "user", "editor"]));
198
199        let user = AuthenticatedUser {
200            user_id: "user123".to_string(),
201            claims,
202        };
203
204        assert!(user.has_role("admin"));
205        assert!(user.has_role("user"));
206        assert!(user.has_role("editor"));
207        assert!(!user.has_role("moderator"));
208    }
209
210    #[test]
211    fn test_get_custom_claim() {
212        use std::collections::HashMap;
213
214        use crate::auth::Claims;
215
216        let mut claims = Claims {
217            sub:   "user123".to_string(),
218            iat:   1000,
219            exp:   2000,
220            iss:   "https://example.com".to_string(),
221            aud:   vec!["api".to_string()],
222            extra: HashMap::new(),
223        };
224
225        claims.extra.insert("org_id".to_string(), serde_json::json!("org_456"));
226
227        let user = AuthenticatedUser {
228            user_id: "user123".to_string(),
229            claims,
230        };
231
232        assert_eq!(user.get_custom_claim("org_id"), Some(&serde_json::json!("org_456")));
233        assert_eq!(user.get_custom_claim("nonexistent"), None);
234    }
235}