1use crate::security::{AuthConfig, Result, SecurityError};
2use axum::{
3 extract::{Request, State},
4 http::{header::AUTHORIZATION, HeaderMap, StatusCode},
5 middleware::Next,
6 response::Response,
7};
8use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tokio::sync::RwLock;
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct Claims {
21 pub sub: String, pub name: String, pub role: String, pub permissions: Vec<String>, pub exp: u64, pub iat: u64, pub jti: String, }
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ApiKey {
33 pub key_id: String,
34 pub key_hash: String,
35 pub name: String,
36 pub role: String,
37 pub permissions: Vec<String>,
38 pub created_at: u64,
39 pub expires_at: Option<u64>,
40 pub last_used: Option<u64>,
41 pub active: bool,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct UserSession {
47 pub user_id: String,
48 pub name: String,
49 pub role: String,
50 pub permissions: Vec<String>,
51 pub authenticated_at: u64,
52 pub last_activity: u64,
53 pub auth_method: AuthMethod,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum AuthMethod {
58 JWT,
59 ApiKey,
60 MTLS,
61}
62
63pub struct AuthManager {
65 config: AuthConfig,
66 api_keys: Arc<RwLock<HashMap<String, ApiKey>>>,
67 active_sessions: Arc<RwLock<HashMap<String, UserSession>>>,
68 encoding_key: EncodingKey,
69 decoding_key: DecodingKey,
70}
71
72impl AuthManager {
73 pub fn new(config: AuthConfig) -> Result<Self> {
74 let encoding_key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
75 let decoding_key = DecodingKey::from_secret(config.jwt_secret.as_bytes());
76
77 Ok(Self {
78 config,
79 api_keys: Arc::new(RwLock::new(HashMap::new())),
80 active_sessions: Arc::new(RwLock::new(HashMap::new())),
81 encoding_key,
82 decoding_key,
83 })
84 }
85
86 pub async fn create_jwt_token(
88 &self,
89 user_id: &str,
90 name: &str,
91 role: &str,
92 permissions: Vec<String>,
93 ) -> Result<String> {
94 if !self.config.enabled {
95 return Err(SecurityError::AuthenticationFailed {
96 message: "Authentication is disabled".to_string(),
97 });
98 }
99
100 let now = SystemTime::now()
101 .duration_since(UNIX_EPOCH)
102 .unwrap()
103 .as_secs();
104
105 let claims = Claims {
106 sub: user_id.to_string(),
107 name: name.to_string(),
108 role: role.to_string(),
109 permissions,
110 exp: now + self.config.jwt_expiry_seconds,
111 iat: now,
112 jti: Uuid::new_v4().to_string(),
113 };
114
115 let header = Header::new(Algorithm::HS256);
116 let token = encode(&header, &claims, &self.encoding_key).map_err(|e| {
117 SecurityError::AuthenticationFailed {
118 message: format!("Failed to create JWT token: {e}"),
119 }
120 })?;
121
122 let session = UserSession {
124 user_id: user_id.to_string(),
125 name: name.to_string(),
126 role: role.to_string(),
127 permissions: claims.permissions.clone(),
128 authenticated_at: now,
129 last_activity: now,
130 auth_method: AuthMethod::JWT,
131 };
132
133 self.active_sessions
134 .write()
135 .await
136 .insert(claims.jti.clone(), session);
137
138 info!("JWT token created for user: {} ({})", name, user_id);
139 Ok(token)
140 }
141
142 pub async fn validate_jwt_token(&self, token: &str) -> Result<Claims> {
144 if !self.config.enabled {
145 return Err(SecurityError::AuthenticationFailed {
146 message: "Authentication is disabled".to_string(),
147 });
148 }
149
150 let validation = Validation::new(Algorithm::HS256);
151 let token_data = decode::<Claims>(token, &self.decoding_key, &validation).map_err(|e| {
152 SecurityError::AuthenticationFailed {
153 message: format!("Invalid JWT token: {e}"),
154 }
155 })?;
156
157 let claims = token_data.claims;
158
159 let mut sessions = self.active_sessions.write().await;
161 if let Some(session) = sessions.get_mut(&claims.jti) {
162 let now = SystemTime::now()
163 .duration_since(UNIX_EPOCH)
164 .unwrap()
165 .as_secs();
166
167 if now - session.last_activity > (self.config.session_timeout_minutes * 60) {
169 sessions.remove(&claims.jti);
170 return Err(SecurityError::AuthenticationFailed {
171 message: "Session expired".to_string(),
172 });
173 }
174
175 session.last_activity = now;
177 } else {
178 return Err(SecurityError::AuthenticationFailed {
179 message: "Session not found".to_string(),
180 });
181 }
182
183 debug!("JWT token validated for user: {}", claims.sub);
184 Ok(claims)
185 }
186
187 pub async fn create_api_key(
189 &self,
190 name: &str,
191 role: &str,
192 permissions: Vec<String>,
193 expires_in_days: Option<u32>,
194 ) -> Result<(String, ApiKey)> {
195 if !self.config.enabled || !self.config.api_key_enabled {
196 return Err(SecurityError::AuthenticationFailed {
197 message: "API key authentication is disabled".to_string(),
198 });
199 }
200
201 let key_id = Uuid::new_v4().to_string();
202 let raw_key = format!("ak_{}", Uuid::new_v4().simple());
203 let key_hash = self.hash_api_key(&raw_key);
204
205 let now = SystemTime::now()
206 .duration_since(UNIX_EPOCH)
207 .unwrap()
208 .as_secs();
209
210 let expires_at = expires_in_days.map(|days| now + (days as u64 * 24 * 60 * 60));
211
212 let api_key = ApiKey {
213 key_id: key_id.clone(),
214 key_hash,
215 name: name.to_string(),
216 role: role.to_string(),
217 permissions,
218 created_at: now,
219 expires_at,
220 last_used: None,
221 active: true,
222 };
223
224 self.api_keys
225 .write()
226 .await
227 .insert(key_id.clone(), api_key.clone());
228
229 info!("API key created: {} for role: {}", name, role);
230 Ok((raw_key, api_key))
231 }
232
233 pub async fn validate_api_key(&self, key: &str) -> Result<ApiKey> {
235 if !self.config.enabled || !self.config.api_key_enabled {
236 return Err(SecurityError::AuthenticationFailed {
237 message: "API key authentication is disabled".to_string(),
238 });
239 }
240
241 let key_hash = self.hash_api_key(key);
242 let mut api_keys = self.api_keys.write().await;
243
244 for (_, api_key) in api_keys.iter_mut() {
245 if api_key.key_hash == key_hash && api_key.active {
246 let now = SystemTime::now()
247 .duration_since(UNIX_EPOCH)
248 .unwrap()
249 .as_secs();
250
251 if let Some(expires_at) = api_key.expires_at {
253 if now > expires_at {
254 return Err(SecurityError::AuthenticationFailed {
255 message: "API key expired".to_string(),
256 });
257 }
258 }
259
260 api_key.last_used = Some(now);
262
263 debug!("API key validated: {}", api_key.name);
264 return Ok(api_key.clone());
265 }
266 }
267
268 Err(SecurityError::AuthenticationFailed {
269 message: "Invalid API key".to_string(),
270 })
271 }
272
273 pub async fn revoke_api_key(&self, key_id: &str) -> Result<()> {
275 let mut api_keys = self.api_keys.write().await;
276
277 if let Some(api_key) = api_keys.get_mut(key_id) {
278 api_key.active = false;
279 info!("API key revoked: {}", api_key.name);
280 Ok(())
281 } else {
282 Err(SecurityError::AuthenticationFailed {
283 message: "API key not found".to_string(),
284 })
285 }
286 }
287
288 pub async fn get_active_sessions(&self) -> Vec<UserSession> {
290 let sessions = self.active_sessions.read().await;
291 sessions.values().cloned().collect()
292 }
293
294 pub async fn revoke_session(&self, session_id: &str) -> Result<()> {
296 let mut sessions = self.active_sessions.write().await;
297
298 if sessions.remove(session_id).is_some() {
299 info!("Session revoked: {}", session_id);
300 Ok(())
301 } else {
302 Err(SecurityError::AuthenticationFailed {
303 message: "Session not found".to_string(),
304 })
305 }
306 }
307
308 pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
310 let mut sessions = self.active_sessions.write().await;
311 let now = SystemTime::now()
312 .duration_since(UNIX_EPOCH)
313 .unwrap()
314 .as_secs();
315
316 let timeout_seconds = self.config.session_timeout_minutes * 60;
317 let initial_count = sessions.len();
318
319 sessions.retain(|_, session| now - session.last_activity <= timeout_seconds);
320
321 let removed_count = initial_count - sessions.len();
322
323 if removed_count > 0 {
324 info!("Cleaned up {} expired sessions", removed_count);
325 }
326
327 Ok(removed_count)
328 }
329
330 fn hash_api_key(&self, key: &str) -> String {
331 let mut hasher = Sha256::new();
332 hasher.update(key.as_bytes());
333 hasher.update(self.config.jwt_secret.as_bytes()); hex::encode(hasher.finalize())
335 }
336
337 pub fn is_enabled(&self) -> bool {
338 self.config.enabled
339 }
340
341 pub fn is_api_key_enabled(&self) -> bool {
342 self.config.enabled && self.config.api_key_enabled
343 }
344
345 pub fn is_mtls_enabled(&self) -> bool {
346 self.config.enabled && self.config.mtls_enabled
347 }
348}
349
350pub async fn auth_middleware(
352 State(auth_manager): State<Arc<AuthManager>>,
353 headers: HeaderMap,
354 mut request: Request,
355 next: Next,
356) -> std::result::Result<Response, StatusCode> {
357 if !auth_manager.is_enabled() {
358 return Ok(next.run(request).await);
359 }
360
361 if let Some(auth_header) = headers.get(AUTHORIZATION) {
363 if let Ok(auth_str) = auth_header.to_str() {
364 if let Some(token) = auth_str.strip_prefix("Bearer ") {
365 match auth_manager.validate_jwt_token(token).await {
366 Ok(claims) => {
367 request.extensions_mut().insert(claims);
368 return Ok(next.run(request).await);
369 }
370 Err(e) => {
371 debug!("JWT validation failed: {}", e);
372 }
373 }
374 }
375 }
376 }
377
378 if let Some(api_key_header) = headers.get("X-API-Key") {
380 if let Ok(api_key) = api_key_header.to_str() {
381 match auth_manager.validate_api_key(api_key).await {
382 Ok(key_info) => {
383 let claims = Claims {
385 sub: key_info.key_id.clone(),
386 name: key_info.name,
387 role: key_info.role,
388 permissions: key_info.permissions,
389 exp: key_info.expires_at.unwrap_or(u64::MAX),
390 iat: key_info.created_at,
391 jti: key_info.key_id,
392 };
393 request.extensions_mut().insert(claims);
394 return Ok(next.run(request).await);
395 }
396 Err(e) => {
397 debug!("API key validation failed: {}", e);
398 }
399 }
400 }
401 }
402
403 warn!("Authentication failed for request");
405 Err(StatusCode::UNAUTHORIZED)
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[tokio::test]
413 async fn test_auth_manager_creation() {
414 let config = AuthConfig::default();
415 let manager = AuthManager::new(config).unwrap();
416 assert!(!manager.is_enabled());
417 }
418
419 #[tokio::test]
420 async fn test_jwt_token_disabled() {
421 let config = AuthConfig::default();
422 let manager = AuthManager::new(config).unwrap();
423
424 let result = manager
425 .create_jwt_token("user1", "Test User", "user", vec!["read".to_string()])
426 .await;
427 assert!(result.is_err());
428 }
429
430 #[tokio::test]
431 async fn test_jwt_token_creation_and_validation() {
432 let mut config = AuthConfig::default();
433 config.enabled = true;
434 config.jwt_secret = "test-secret".to_string();
435
436 let manager = AuthManager::new(config).unwrap();
437
438 let token = manager
440 .create_jwt_token(
441 "user1",
442 "Test User",
443 "admin",
444 vec!["read".to_string(), "write".to_string()],
445 )
446 .await
447 .unwrap();
448
449 let claims = manager.validate_jwt_token(&token).await.unwrap();
451 assert_eq!(claims.sub, "user1");
452 assert_eq!(claims.name, "Test User");
453 assert_eq!(claims.role, "admin");
454 assert_eq!(
455 claims.permissions,
456 vec!["read".to_string(), "write".to_string()]
457 );
458 }
459
460 #[tokio::test]
461 async fn test_api_key_creation_and_validation() {
462 let mut config = AuthConfig::default();
463 config.enabled = true;
464 config.api_key_enabled = true;
465
466 let manager = AuthManager::new(config).unwrap();
467
468 let (raw_key, api_key) = manager
470 .create_api_key("test-key", "user", vec!["read".to_string()], Some(30))
471 .await
472 .unwrap();
473 assert!(!raw_key.is_empty());
474 assert_eq!(api_key.name, "test-key");
475 assert_eq!(api_key.role, "user");
476
477 let validated_key = manager.validate_api_key(&raw_key).await.unwrap();
479 assert_eq!(validated_key.name, "test-key");
480 assert_eq!(validated_key.role, "user");
481 }
482
483 #[tokio::test]
484 async fn test_invalid_jwt_token() {
485 let mut config = AuthConfig::default();
486 config.enabled = true;
487
488 let manager = AuthManager::new(config).unwrap();
489
490 let result = manager.validate_jwt_token("invalid.jwt.token").await;
491 assert!(result.is_err());
492 }
493
494 #[tokio::test]
495 async fn test_session_cleanup() {
496 let mut config = AuthConfig::default();
497 config.enabled = true;
498 config.jwt_secret = "test-secret-key-for-unit-testing-with-sufficient-length".to_string();
499 config.session_timeout_minutes = 1; let manager = AuthManager::new(config).unwrap();
502
503 let token = manager
505 .create_jwt_token("user1", "Test User", "user", vec!["read".to_string()])
506 .await
507 .unwrap();
508
509 {
511 let mut sessions = manager.active_sessions.write().await;
512 for (_, session) in sessions.iter_mut() {
513 session.last_activity = SystemTime::now()
515 .duration_since(UNIX_EPOCH)
516 .unwrap()
517 .as_secs()
518 - 120;
519 }
520 }
521
522 let removed = manager.cleanup_expired_sessions().await.unwrap();
524 assert_eq!(removed, 1, "Should have removed 1 expired session");
525
526 let validation_result = manager.validate_jwt_token(&token).await;
529 }
531
532 #[tokio::test]
533 async fn test_api_key_revocation() {
534 let mut config = AuthConfig::default();
535 config.enabled = true;
536 config.api_key_enabled = true;
537
538 let manager = AuthManager::new(config).unwrap();
539
540 let (raw_key, api_key) = manager
542 .create_api_key("test-key", "user", vec!["read".to_string()], None)
543 .await
544 .unwrap();
545
546 assert!(manager.validate_api_key(&raw_key).await.is_ok());
548
549 manager.revoke_api_key(&api_key.key_id).await.unwrap();
551
552 assert!(manager.validate_api_key(&raw_key).await.is_err());
554 }
555}