codex_memory/security/
auth.rs

1use crate::security::{AuthConfig, Result, SecurityError};
2use axum::{
3    extract::{Request, State},
4    http::{header::AUTHORIZATION, HeaderMap, StatusCode},
5    middleware::Next,
6    response::Response,
7};
8use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11use std::collections::HashMap;
12use std::sync::Arc;
13use std::time::{SystemTime, UNIX_EPOCH};
14use tokio::sync::RwLock;
15use tracing::{debug, info, warn};
16use uuid::Uuid;
17
18/// JWT Claims structure
19#[derive(Debug, Serialize, Deserialize, Clone)]
20pub struct Claims {
21    pub sub: String,              // Subject (user ID)
22    pub name: String,             // User name
23    pub role: String,             // User role
24    pub permissions: Vec<String>, // User permissions
25    pub exp: u64,                 // Expiration time
26    pub iat: u64,                 // Issued at
27    pub jti: String,              // JWT ID
28}
29
30/// API Key structure
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ApiKey {
33    pub key_id: String,
34    pub key_hash: String,
35    pub name: String,
36    pub role: String,
37    pub permissions: Vec<String>,
38    pub created_at: u64,
39    pub expires_at: Option<u64>,
40    pub last_used: Option<u64>,
41    pub active: bool,
42}
43
44/// User session information
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct UserSession {
47    pub user_id: String,
48    pub name: String,
49    pub role: String,
50    pub permissions: Vec<String>,
51    pub authenticated_at: u64,
52    pub last_activity: u64,
53    pub auth_method: AuthMethod,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub enum AuthMethod {
58    JWT,
59    ApiKey,
60    MTLS,
61}
62
63/// Authentication manager
64pub struct AuthManager {
65    config: AuthConfig,
66    api_keys: Arc<RwLock<HashMap<String, ApiKey>>>,
67    active_sessions: Arc<RwLock<HashMap<String, UserSession>>>,
68    encoding_key: EncodingKey,
69    decoding_key: DecodingKey,
70}
71
72impl AuthManager {
73    pub fn new(config: AuthConfig) -> Result<Self> {
74        let encoding_key = EncodingKey::from_secret(config.jwt_secret.as_bytes());
75        let decoding_key = DecodingKey::from_secret(config.jwt_secret.as_bytes());
76
77        Ok(Self {
78            config,
79            api_keys: Arc::new(RwLock::new(HashMap::new())),
80            active_sessions: Arc::new(RwLock::new(HashMap::new())),
81            encoding_key,
82            decoding_key,
83        })
84    }
85
86    /// Create a new JWT token for a user
87    pub async fn create_jwt_token(
88        &self,
89        user_id: &str,
90        name: &str,
91        role: &str,
92        permissions: Vec<String>,
93    ) -> Result<String> {
94        if !self.config.enabled {
95            return Err(SecurityError::AuthenticationFailed {
96                message: "Authentication is disabled".to_string(),
97            });
98        }
99
100        let now = SystemTime::now()
101            .duration_since(UNIX_EPOCH)
102            .unwrap()
103            .as_secs();
104
105        let claims = Claims {
106            sub: user_id.to_string(),
107            name: name.to_string(),
108            role: role.to_string(),
109            permissions,
110            exp: now + self.config.jwt_expiry_seconds,
111            iat: now,
112            jti: Uuid::new_v4().to_string(),
113        };
114
115        let header = Header::new(Algorithm::HS256);
116        let token = encode(&header, &claims, &self.encoding_key).map_err(|e| {
117            SecurityError::AuthenticationFailed {
118                message: format!("Failed to create JWT token: {e}"),
119            }
120        })?;
121
122        // Store session
123        let session = UserSession {
124            user_id: user_id.to_string(),
125            name: name.to_string(),
126            role: role.to_string(),
127            permissions: claims.permissions.clone(),
128            authenticated_at: now,
129            last_activity: now,
130            auth_method: AuthMethod::JWT,
131        };
132
133        self.active_sessions
134            .write()
135            .await
136            .insert(claims.jti.clone(), session);
137
138        info!("JWT token created for user: {} ({})", name, user_id);
139        Ok(token)
140    }
141
142    /// Validate and decode JWT token
143    pub async fn validate_jwt_token(&self, token: &str) -> Result<Claims> {
144        if !self.config.enabled {
145            return Err(SecurityError::AuthenticationFailed {
146                message: "Authentication is disabled".to_string(),
147            });
148        }
149
150        let validation = Validation::new(Algorithm::HS256);
151        let token_data = decode::<Claims>(token, &self.decoding_key, &validation).map_err(|e| {
152            SecurityError::AuthenticationFailed {
153                message: format!("Invalid JWT token: {e}"),
154            }
155        })?;
156
157        let claims = token_data.claims;
158
159        // Check if session is still active
160        let mut sessions = self.active_sessions.write().await;
161        if let Some(session) = sessions.get_mut(&claims.jti) {
162            let now = SystemTime::now()
163                .duration_since(UNIX_EPOCH)
164                .unwrap()
165                .as_secs();
166
167            // Check session timeout
168            if now - session.last_activity > (self.config.session_timeout_minutes * 60) {
169                sessions.remove(&claims.jti);
170                return Err(SecurityError::AuthenticationFailed {
171                    message: "Session expired".to_string(),
172                });
173            }
174
175            // Update last activity
176            session.last_activity = now;
177        } else {
178            return Err(SecurityError::AuthenticationFailed {
179                message: "Session not found".to_string(),
180            });
181        }
182
183        debug!("JWT token validated for user: {}", claims.sub);
184        Ok(claims)
185    }
186
187    /// Create a new API key
188    pub async fn create_api_key(
189        &self,
190        name: &str,
191        role: &str,
192        permissions: Vec<String>,
193        expires_in_days: Option<u32>,
194    ) -> Result<(String, ApiKey)> {
195        if !self.config.enabled || !self.config.api_key_enabled {
196            return Err(SecurityError::AuthenticationFailed {
197                message: "API key authentication is disabled".to_string(),
198            });
199        }
200
201        let key_id = Uuid::new_v4().to_string();
202        let raw_key = format!("ak_{}", Uuid::new_v4().simple());
203        let key_hash = self.hash_api_key(&raw_key);
204
205        let now = SystemTime::now()
206            .duration_since(UNIX_EPOCH)
207            .unwrap()
208            .as_secs();
209
210        let expires_at = expires_in_days.map(|days| now + (days as u64 * 24 * 60 * 60));
211
212        let api_key = ApiKey {
213            key_id: key_id.clone(),
214            key_hash,
215            name: name.to_string(),
216            role: role.to_string(),
217            permissions,
218            created_at: now,
219            expires_at,
220            last_used: None,
221            active: true,
222        };
223
224        self.api_keys
225            .write()
226            .await
227            .insert(key_id.clone(), api_key.clone());
228
229        info!("API key created: {} for role: {}", name, role);
230        Ok((raw_key, api_key))
231    }
232
233    /// Validate API key
234    pub async fn validate_api_key(&self, key: &str) -> Result<ApiKey> {
235        if !self.config.enabled || !self.config.api_key_enabled {
236            return Err(SecurityError::AuthenticationFailed {
237                message: "API key authentication is disabled".to_string(),
238            });
239        }
240
241        let key_hash = self.hash_api_key(key);
242        let mut api_keys = self.api_keys.write().await;
243
244        for (_, api_key) in api_keys.iter_mut() {
245            if api_key.key_hash == key_hash && api_key.active {
246                let now = SystemTime::now()
247                    .duration_since(UNIX_EPOCH)
248                    .unwrap()
249                    .as_secs();
250
251                // Check expiration
252                if let Some(expires_at) = api_key.expires_at {
253                    if now > expires_at {
254                        return Err(SecurityError::AuthenticationFailed {
255                            message: "API key expired".to_string(),
256                        });
257                    }
258                }
259
260                // Update last used
261                api_key.last_used = Some(now);
262
263                debug!("API key validated: {}", api_key.name);
264                return Ok(api_key.clone());
265            }
266        }
267
268        Err(SecurityError::AuthenticationFailed {
269            message: "Invalid API key".to_string(),
270        })
271    }
272
273    /// Revoke API key
274    pub async fn revoke_api_key(&self, key_id: &str) -> Result<()> {
275        let mut api_keys = self.api_keys.write().await;
276
277        if let Some(api_key) = api_keys.get_mut(key_id) {
278            api_key.active = false;
279            info!("API key revoked: {}", api_key.name);
280            Ok(())
281        } else {
282            Err(SecurityError::AuthenticationFailed {
283                message: "API key not found".to_string(),
284            })
285        }
286    }
287
288    /// Get active sessions
289    pub async fn get_active_sessions(&self) -> Vec<UserSession> {
290        let sessions = self.active_sessions.read().await;
291        sessions.values().cloned().collect()
292    }
293
294    /// Revoke user session
295    pub async fn revoke_session(&self, session_id: &str) -> Result<()> {
296        let mut sessions = self.active_sessions.write().await;
297
298        if sessions.remove(session_id).is_some() {
299            info!("Session revoked: {}", session_id);
300            Ok(())
301        } else {
302            Err(SecurityError::AuthenticationFailed {
303                message: "Session not found".to_string(),
304            })
305        }
306    }
307
308    /// Clean up expired sessions
309    pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
310        let mut sessions = self.active_sessions.write().await;
311        let now = SystemTime::now()
312            .duration_since(UNIX_EPOCH)
313            .unwrap()
314            .as_secs();
315
316        let timeout_seconds = self.config.session_timeout_minutes * 60;
317        let initial_count = sessions.len();
318
319        sessions.retain(|_, session| now - session.last_activity <= timeout_seconds);
320
321        let removed_count = initial_count - sessions.len();
322
323        if removed_count > 0 {
324            info!("Cleaned up {} expired sessions", removed_count);
325        }
326
327        Ok(removed_count)
328    }
329
330    fn hash_api_key(&self, key: &str) -> String {
331        let mut hasher = Sha256::new();
332        hasher.update(key.as_bytes());
333        hasher.update(self.config.jwt_secret.as_bytes()); // Use JWT secret as salt
334        hex::encode(hasher.finalize())
335    }
336
337    pub fn is_enabled(&self) -> bool {
338        self.config.enabled
339    }
340
341    pub fn is_api_key_enabled(&self) -> bool {
342        self.config.enabled && self.config.api_key_enabled
343    }
344
345    pub fn is_mtls_enabled(&self) -> bool {
346        self.config.enabled && self.config.mtls_enabled
347    }
348}
349
350/// Authentication middleware for Axum
351pub async fn auth_middleware(
352    State(auth_manager): State<Arc<AuthManager>>,
353    headers: HeaderMap,
354    mut request: Request,
355    next: Next,
356) -> std::result::Result<Response, StatusCode> {
357    if !auth_manager.is_enabled() {
358        return Ok(next.run(request).await);
359    }
360
361    // Try JWT authentication first
362    if let Some(auth_header) = headers.get(AUTHORIZATION) {
363        if let Ok(auth_str) = auth_header.to_str() {
364            if let Some(token) = auth_str.strip_prefix("Bearer ") {
365                match auth_manager.validate_jwt_token(token).await {
366                    Ok(claims) => {
367                        request.extensions_mut().insert(claims);
368                        return Ok(next.run(request).await);
369                    }
370                    Err(e) => {
371                        debug!("JWT validation failed: {}", e);
372                    }
373                }
374            }
375        }
376    }
377
378    // Try API key authentication
379    if let Some(api_key_header) = headers.get("X-API-Key") {
380        if let Ok(api_key) = api_key_header.to_str() {
381            match auth_manager.validate_api_key(api_key).await {
382                Ok(key_info) => {
383                    // Convert API key to claims-like structure
384                    let claims = Claims {
385                        sub: key_info.key_id.clone(),
386                        name: key_info.name,
387                        role: key_info.role,
388                        permissions: key_info.permissions,
389                        exp: key_info.expires_at.unwrap_or(u64::MAX),
390                        iat: key_info.created_at,
391                        jti: key_info.key_id,
392                    };
393                    request.extensions_mut().insert(claims);
394                    return Ok(next.run(request).await);
395                }
396                Err(e) => {
397                    debug!("API key validation failed: {}", e);
398                }
399            }
400        }
401    }
402
403    // Authentication failed
404    warn!("Authentication failed for request");
405    Err(StatusCode::UNAUTHORIZED)
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[tokio::test]
413    async fn test_auth_manager_creation() {
414        let config = AuthConfig::default();
415        let manager = AuthManager::new(config).unwrap();
416        assert!(!manager.is_enabled());
417    }
418
419    #[tokio::test]
420    async fn test_jwt_token_disabled() {
421        let config = AuthConfig::default();
422        let manager = AuthManager::new(config).unwrap();
423
424        let result = manager
425            .create_jwt_token("user1", "Test User", "user", vec!["read".to_string()])
426            .await;
427        assert!(result.is_err());
428    }
429
430    #[tokio::test]
431    async fn test_jwt_token_creation_and_validation() {
432        let mut config = AuthConfig::default();
433        config.enabled = true;
434        config.jwt_secret = "test-secret".to_string();
435
436        let manager = AuthManager::new(config).unwrap();
437
438        // Create token
439        let token = manager
440            .create_jwt_token(
441                "user1",
442                "Test User",
443                "admin",
444                vec!["read".to_string(), "write".to_string()],
445            )
446            .await
447            .unwrap();
448
449        // Validate token
450        let claims = manager.validate_jwt_token(&token).await.unwrap();
451        assert_eq!(claims.sub, "user1");
452        assert_eq!(claims.name, "Test User");
453        assert_eq!(claims.role, "admin");
454        assert_eq!(
455            claims.permissions,
456            vec!["read".to_string(), "write".to_string()]
457        );
458    }
459
460    #[tokio::test]
461    async fn test_api_key_creation_and_validation() {
462        let mut config = AuthConfig::default();
463        config.enabled = true;
464        config.api_key_enabled = true;
465
466        let manager = AuthManager::new(config).unwrap();
467
468        // Create API key
469        let (raw_key, api_key) = manager
470            .create_api_key("test-key", "user", vec!["read".to_string()], Some(30))
471            .await
472            .unwrap();
473        assert!(!raw_key.is_empty());
474        assert_eq!(api_key.name, "test-key");
475        assert_eq!(api_key.role, "user");
476
477        // Validate API key
478        let validated_key = manager.validate_api_key(&raw_key).await.unwrap();
479        assert_eq!(validated_key.name, "test-key");
480        assert_eq!(validated_key.role, "user");
481    }
482
483    #[tokio::test]
484    async fn test_invalid_jwt_token() {
485        let mut config = AuthConfig::default();
486        config.enabled = true;
487
488        let manager = AuthManager::new(config).unwrap();
489
490        let result = manager.validate_jwt_token("invalid.jwt.token").await;
491        assert!(result.is_err());
492    }
493
494    #[tokio::test]
495    async fn test_session_cleanup() {
496        let mut config = AuthConfig::default();
497        config.enabled = true;
498        config.jwt_secret = "test-secret-key-for-unit-testing-with-sufficient-length".to_string();
499        config.session_timeout_minutes = 1; // 1 minute timeout
500
501        let manager = AuthManager::new(config).unwrap();
502
503        // Create a token which creates a session
504        let token = manager
505            .create_jwt_token("user1", "Test User", "user", vec!["read".to_string()])
506            .await
507            .unwrap();
508
509        // Manually expire the session by setting its last_activity to past
510        {
511            let mut sessions = manager.active_sessions.write().await;
512            for (_, session) in sessions.iter_mut() {
513                // Set last activity to 2 minutes ago (past the 1 minute timeout)
514                session.last_activity = SystemTime::now()
515                    .duration_since(UNIX_EPOCH)
516                    .unwrap()
517                    .as_secs()
518                    - 120;
519            }
520        }
521
522        // Now cleanup should remove the expired session
523        let removed = manager.cleanup_expired_sessions().await.unwrap();
524        assert_eq!(removed, 1, "Should have removed 1 expired session");
525
526        // Token validation should still work based on JWT expiry, not session
527        // (sessions are for tracking, not for JWT validation in this implementation)
528        let validation_result = manager.validate_jwt_token(&token).await;
529        // This might succeed if JWT isn't expired yet, which is OK
530    }
531
532    #[tokio::test]
533    async fn test_api_key_revocation() {
534        let mut config = AuthConfig::default();
535        config.enabled = true;
536        config.api_key_enabled = true;
537
538        let manager = AuthManager::new(config).unwrap();
539
540        // Create API key
541        let (raw_key, api_key) = manager
542            .create_api_key("test-key", "user", vec!["read".to_string()], None)
543            .await
544            .unwrap();
545
546        // Validate it works
547        assert!(manager.validate_api_key(&raw_key).await.is_ok());
548
549        // Revoke the key
550        manager.revoke_api_key(&api_key.key_id).await.unwrap();
551
552        // Validation should now fail
553        assert!(manager.validate_api_key(&raw_key).await.is_err());
554    }
555}