kotoba_security/
session.rs

1//! Session management for stateless authentication
2
3use crate::error::{SecurityError, Result};
4use crate::config::{SessionConfig, SessionStoreType, SameSitePolicy};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Session data structure
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SessionData {
13    pub session_id: String,
14    pub user_id: String,
15    pub roles: Vec<String>,
16    pub permissions: Vec<String>,
17    pub attributes: HashMap<String, serde_json::Value>,
18    pub created_at: chrono::DateTime<chrono::Utc>,
19    pub expires_at: chrono::DateTime<chrono::Utc>,
20    pub last_accessed_at: chrono::DateTime<chrono::Utc>,
21    pub ip_address: Option<String>,
22    pub user_agent: Option<String>,
23}
24
25impl SessionData {
26    /// Create new session data
27    pub fn new(
28        session_id: String,
29        user_id: String,
30        roles: Vec<String>,
31        permissions: Vec<String>,
32        max_age_seconds: Option<u64>,
33        ip_address: Option<String>,
34        user_agent: Option<String>,
35    ) -> Self {
36        let now = chrono::Utc::now();
37        let expires_at = max_age_seconds
38            .map(|secs| now + chrono::Duration::seconds(secs as i64))
39            .unwrap_or_else(|| now + chrono::Duration::hours(24)); // Default 24 hours
40
41        Self {
42            session_id,
43            user_id,
44            roles,
45            permissions,
46            attributes: HashMap::new(),
47            created_at: now,
48            expires_at,
49            last_accessed_at: now,
50            ip_address,
51            user_agent,
52        }
53    }
54
55    /// Check if session is expired
56    pub fn is_expired(&self) -> bool {
57        chrono::Utc::now() > self.expires_at
58    }
59
60    /// Update last accessed time
61    pub fn touch(&mut self) {
62        self.last_accessed_at = chrono::Utc::now();
63    }
64
65    /// Extend session expiration
66    pub fn extend(&mut self, additional_seconds: i64) {
67        self.expires_at = chrono::Utc::now() + chrono::Duration::seconds(additional_seconds);
68    }
69
70    /// Get remaining time until expiration in seconds
71    pub fn time_until_expiry(&self) -> i64 {
72        let now = chrono::Utc::now();
73        (self.expires_at - now).num_seconds()
74    }
75
76    /// Check if user has specific role
77    pub fn has_role(&self, role: &str) -> bool {
78        self.roles.contains(&role.to_string())
79    }
80
81    /// Check if user has specific permission
82    pub fn has_permission(&self, permission: &str) -> bool {
83        self.permissions.contains(&permission.to_string())
84    }
85
86    /// Add custom attribute
87    pub fn set_attribute(&mut self, key: String, value: serde_json::Value) {
88        self.attributes.insert(key, value);
89    }
90
91    /// Get custom attribute
92    pub fn get_attribute(&self, key: &str) -> Option<&serde_json::Value> {
93        self.attributes.get(key)
94    }
95
96    /// Remove custom attribute
97    pub fn remove_attribute(&mut self, key: &str) -> Option<serde_json::Value> {
98        self.attributes.remove(key)
99    }
100}
101
102/// Session store trait for different storage backends
103#[async_trait::async_trait]
104pub trait SessionStore: Send + Sync {
105    /// Store session data
106    async fn store(&self, session: SessionData) -> Result<()>;
107
108    /// Retrieve session data by ID
109    async fn get(&self, session_id: &str) -> Result<Option<SessionData>>;
110
111    /// Update session data
112    async fn update(&self, session: SessionData) -> Result<()>;
113
114    /// Delete session by ID
115    async fn delete(&self, session_id: &str) -> Result<()>;
116
117    /// Delete all sessions for a user
118    async fn delete_user_sessions(&self, user_id: &str) -> Result<usize>;
119
120    /// Clean up expired sessions
121    async fn cleanup_expired(&self) -> Result<usize>;
122
123    /// Get session count
124    async fn count(&self) -> Result<usize>;
125}
126
127/// In-memory session store
128pub struct MemorySessionStore {
129    sessions: Arc<RwLock<HashMap<String, SessionData>>>,
130}
131
132impl MemorySessionStore {
133    pub fn new() -> Self {
134        Self {
135            sessions: Arc::new(RwLock::new(HashMap::new())),
136        }
137    }
138}
139
140#[async_trait::async_trait]
141impl SessionStore for MemorySessionStore {
142    async fn store(&self, session: SessionData) -> Result<()> {
143        let mut sessions = self.sessions.write().await;
144        sessions.insert(session.session_id.clone(), session);
145        Ok(())
146    }
147
148    async fn get(&self, session_id: &str) -> Result<Option<SessionData>> {
149        let sessions = self.sessions.read().await;
150        Ok(sessions.get(session_id).cloned())
151    }
152
153    async fn update(&self, session: SessionData) -> Result<()> {
154        let mut sessions = self.sessions.write().await;
155        sessions.insert(session.session_id.clone(), session);
156        Ok(())
157    }
158
159    async fn delete(&self, session_id: &str) -> Result<()> {
160        let mut sessions = self.sessions.write().await;
161        sessions.remove(session_id);
162        Ok(())
163    }
164
165    async fn delete_user_sessions(&self, user_id: &str) -> Result<usize> {
166        let mut sessions = self.sessions.write().await;
167        let keys_to_remove: Vec<String> = sessions
168            .iter()
169            .filter(|(_, session)| session.user_id == user_id)
170            .map(|(key, _)| key.clone())
171            .collect();
172
173        let count = keys_to_remove.len();
174        for key in keys_to_remove {
175            sessions.remove(&key);
176        }
177
178        Ok(count)
179    }
180
181    async fn cleanup_expired(&self) -> Result<usize> {
182        let mut sessions = self.sessions.write().await;
183        let expired_keys: Vec<String> = sessions
184            .iter()
185            .filter(|(_, session)| session.is_expired())
186            .map(|(key, _)| key.clone())
187            .collect();
188
189        let count = expired_keys.len();
190        for key in expired_keys {
191            sessions.remove(&key);
192        }
193
194        Ok(count)
195    }
196
197    async fn count(&self) -> Result<usize> {
198        let sessions = self.sessions.read().await;
199        Ok(sessions.len())
200    }
201}
202
203/// Session manager for handling session lifecycle
204pub struct SessionManager {
205    config: SessionConfig,
206    store: Box<dyn SessionStore>,
207}
208
209impl SessionManager {
210    /// Create new session manager
211    pub fn new(config: SessionConfig) -> Self {
212        let store: Box<dyn SessionStore> = match config.store_type {
213            SessionStoreType::Memory => Box::new(MemorySessionStore::new()),
214            SessionStoreType::Redis => {
215                // TODO: Implement Redis session store
216                Box::new(MemorySessionStore::new())
217            }
218            SessionStoreType::Database => {
219                // TODO: Implement database session store
220                Box::new(MemorySessionStore::new())
221            }
222        };
223
224        Self { config, store }
225    }
226
227    /// Create new session
228    pub async fn create_session(
229        &self,
230        user_id: &str,
231        roles: Vec<String>,
232        permissions: Vec<String>,
233        ip_address: Option<String>,
234        user_agent: Option<String>,
235    ) -> Result<SessionData> {
236        let session_id = self.generate_session_id();
237        let session = SessionData::new(
238            session_id,
239            user_id.to_string(),
240            roles,
241            permissions,
242            self.config.max_age_seconds,
243            ip_address,
244            user_agent,
245        );
246
247        self.store.store(session.clone()).await?;
248        Ok(session)
249    }
250
251    /// Get session by ID
252    pub async fn get_session(&self, session_id: &str) -> Result<Option<SessionData>> {
253        let mut session = self.store.get(session_id).await?;
254
255        if let Some(ref mut session) = session {
256            // Check if session is expired
257            if session.is_expired() {
258                // Clean up expired session
259                self.store.delete(session_id).await?;
260                return Ok(None);
261            }
262
263            // Update last accessed time
264            session.touch();
265            self.store.update(session.clone()).await?;
266        }
267
268        Ok(session)
269    }
270
271    /// Update session
272    pub async fn update_session(&self, session: SessionData) -> Result<()> {
273        if session.is_expired() {
274            return Err(SecurityError::SessionExpired);
275        }
276
277        self.store.update(session).await
278    }
279
280    /// Delete session
281    pub async fn delete_session(&self, session_id: &str) -> Result<()> {
282        self.store.delete(session_id).await
283    }
284
285    /// Delete all sessions for a user
286    pub async fn delete_user_sessions(&self, user_id: &str) -> Result<usize> {
287        self.store.delete_user_sessions(user_id).await
288    }
289
290    /// Extend session expiration
291    pub async fn extend_session(&self, session_id: &str, additional_seconds: i64) -> Result<()> {
292        let mut session = self.store.get(session_id).await?
293            .ok_or_else(|| SecurityError::SessionInvalid)?;
294
295        if session.is_expired() {
296            return Err(SecurityError::SessionExpired);
297        }
298
299        session.extend(additional_seconds);
300        self.store.update(session).await
301    }
302
303    /// Validate session and return user information
304    pub async fn validate_session(&self, session_id: &str) -> Result<Option<SessionData>> {
305        self.get_session(session_id).await
306    }
307
308    /// Clean up expired sessions
309    pub async fn cleanup_expired_sessions(&self) -> Result<usize> {
310        self.store.cleanup_expired().await
311    }
312
313    /// Get session statistics
314    pub async fn get_stats(&self) -> Result<SessionStats> {
315        let count = self.store.count().await?;
316        Ok(SessionStats { total_sessions: count })
317    }
318
319    /// Generate unique session ID
320    fn generate_session_id(&self) -> String {
321        use uuid::Uuid;
322        Uuid::new_v4().to_string()
323    }
324}
325
326/// Session statistics
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct SessionStats {
329    pub total_sessions: usize,
330}
331
332/// Cookie configuration for session management
333#[derive(Debug, Clone)]
334pub struct CookieConfig {
335    pub name: String,
336    pub secure: bool,
337    pub http_only: bool,
338    pub same_site: SameSitePolicy,
339    pub domain: Option<String>,
340    pub path: Option<String>,
341}
342
343impl Default for CookieConfig {
344    fn default() -> Self {
345        Self {
346            name: "session_id".to_string(),
347            secure: true,
348            http_only: true,
349            same_site: SameSitePolicy::Lax,
350            domain: None,
351            path: Some("/".to_string()),
352        }
353    }
354}
355
356impl From<&SessionConfig> for CookieConfig {
357    fn from(config: &SessionConfig) -> Self {
358        Self {
359            name: config.cookie_name.clone(),
360            secure: config.cookie_secure,
361            http_only: config.cookie_http_only,
362            same_site: config.cookie_same_site.clone(),
363            domain: None,
364            path: Some("/".to_string()),
365        }
366    }
367}
368
369/// Session cookie utilities
370pub struct SessionCookie;
371
372impl SessionCookie {
373    /// Generate session cookie header value
374    pub fn generate_cookie_header(session_id: &str, config: &CookieConfig, max_age: Option<u64>) -> String {
375        let mut cookie = format!("{}={}", config.name, session_id);
376
377        if config.http_only {
378            cookie.push_str("; HttpOnly");
379        }
380
381        if config.secure {
382            cookie.push_str("; Secure");
383        }
384
385        match config.same_site {
386            SameSitePolicy::Strict => cookie.push_str("; SameSite=Strict"),
387            SameSitePolicy::Lax => cookie.push_str("; SameSite=Lax"),
388            SameSitePolicy::None => cookie.push_str("; SameSite=None"),
389        }
390
391        if let Some(domain) = &config.domain {
392            cookie.push_str(&format!("; Domain={}", domain));
393        }
394
395        if let Some(path) = &config.path {
396            cookie.push_str(&format!("; Path={}", path));
397        }
398
399        if let Some(max_age) = max_age {
400            cookie.push_str(&format!("; Max-Age={}", max_age));
401        }
402
403        cookie
404    }
405
406    /// Parse session ID from cookie header
407    pub fn parse_session_id(cookie_header: &str, cookie_name: &str) -> Option<String> {
408        for cookie in cookie_header.split(';') {
409            let cookie = cookie.trim();
410            if let Some(value) = cookie.strip_prefix(&format!("{}=", cookie_name)) {
411                return Some(value.to_string());
412            }
413        }
414        None
415    }
416
417    /// Generate delete cookie header
418    pub fn generate_delete_cookie_header(config: &CookieConfig) -> String {
419        let mut cookie = format!("{}=; Max-Age=0", config.name);
420
421        if config.http_only {
422            cookie.push_str("; HttpOnly");
423        }
424
425        if config.secure {
426            cookie.push_str("; Secure");
427        }
428
429        match config.same_site {
430            SameSitePolicy::Strict => cookie.push_str("; SameSite=Strict"),
431            SameSitePolicy::Lax => cookie.push_str("; SameSite=Lax"),
432            SameSitePolicy::None => cookie.push_str("; SameSite=None"),
433        }
434
435        if let Some(domain) = &config.domain {
436            cookie.push_str(&format!("; Domain={}", domain));
437        }
438
439        if let Some(path) = &config.path {
440            cookie.push_str(&format!("; Path={}", path));
441        }
442
443        cookie
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use std::time::Duration;
451    use tokio::time::sleep;
452
453    async fn create_test_manager() -> SessionManager {
454        let config = SessionConfig::default();
455        SessionManager::new(config)
456    }
457
458    #[tokio::test]
459    async fn test_session_creation() {
460        let manager = create_test_manager().await;
461
462        let session = manager.create_session(
463            "user123",
464            vec!["admin".to_string()],
465            vec!["read".to_string(), "write".to_string()],
466            Some("127.0.0.1".to_string()),
467            Some("Test Browser".to_string()),
468        ).await.unwrap();
469
470        assert_eq!(session.user_id, "user123");
471        assert!(session.has_role("admin"));
472        assert!(session.has_permission("read"));
473        assert!(!session.is_expired());
474        assert!(session.time_until_expiry() > 0);
475    }
476
477    #[tokio::test]
478    async fn test_session_retrieval() {
479        let manager = create_test_manager().await;
480
481        let session = manager.create_session(
482            "user123",
483            vec!["user".to_string()],
484            vec![],
485            None,
486            None,
487        ).await.unwrap();
488
489        let retrieved = manager.get_session(&session.session_id).await.unwrap().unwrap();
490        assert_eq!(retrieved.user_id, "user123");
491        assert_eq!(retrieved.session_id, session.session_id);
492    }
493
494    #[tokio::test]
495    async fn test_session_update() {
496        let manager = create_test_manager().await;
497
498        let mut session = manager.create_session(
499            "user123",
500            vec!["user".to_string()],
501            vec![],
502            None,
503            None,
504        ).await.unwrap();
505
506        session.set_attribute("theme".to_string(), serde_json::Value::String("dark".to_string()));
507        manager.update_session(session.clone()).await.unwrap();
508
509        let updated = manager.get_session(&session.session_id).await.unwrap().unwrap();
510        assert_eq!(updated.get_attribute("theme").unwrap().as_str().unwrap(), "dark");
511    }
512
513    #[tokio::test]
514    async fn test_session_deletion() {
515        let manager = create_test_manager().await;
516
517        let session = manager.create_session(
518            "user123",
519            vec!["user".to_string()],
520            vec![],
521            None,
522            None,
523        ).await.unwrap();
524
525        // Verify session exists
526        let retrieved = manager.get_session(&session.session_id).await.unwrap();
527        assert!(retrieved.is_some());
528
529        // Delete session
530        manager.delete_session(&session.session_id).await.unwrap();
531
532        // Verify session is deleted
533        let retrieved = manager.get_session(&session.session_id).await.unwrap();
534        assert!(retrieved.is_none());
535    }
536
537    #[tokio::test]
538    async fn test_user_session_deletion() {
539        let manager = create_test_manager().await;
540
541        // Create multiple sessions for the same user
542        let session1 = manager.create_session(
543            "user123",
544            vec!["user".to_string()],
545            vec![],
546            None,
547            None,
548        ).await.unwrap();
549
550        let session2 = manager.create_session(
551            "user123",
552            vec!["user".to_string()],
553            vec![],
554            None,
555            None,
556        ).await.unwrap();
557
558        // Delete all sessions for user
559        let deleted_count = manager.delete_user_sessions("user123").await.unwrap();
560        assert_eq!(deleted_count, 2);
561
562        // Verify sessions are deleted
563        let retrieved1 = manager.get_session(&session1.session_id).await.unwrap();
564        let retrieved2 = manager.get_session(&session2.session_id).await.unwrap();
565        assert!(retrieved1.is_none());
566        assert!(retrieved2.is_none());
567    }
568
569    #[tokio::test]
570    async fn test_session_extension() {
571        let manager = create_test_manager().await;
572
573        let session = manager.create_session(
574            "user123",
575            vec!["user".to_string()],
576            vec![],
577            None,
578            None,
579        ).await.unwrap();
580
581        let original_expiry = session.time_until_expiry();
582
583        // Extend session by 1 hour
584        manager.extend_session(&session.session_id, 3600).await.unwrap();
585
586        let updated = manager.get_session(&session.session_id).await.unwrap().unwrap();
587        let new_expiry = updated.time_until_expiry();
588
589        // New expiry should be longer than original
590        assert!(new_expiry > original_expiry);
591    }
592
593    #[tokio::test]
594    async fn test_cookie_header_generation() {
595        let config = CookieConfig::default();
596        let session_id = "session123";
597
598        let cookie_header = SessionCookie::generate_cookie_header(session_id, &config, Some(3600));
599        assert!(cookie_header.contains("session_id=session123"));
600        assert!(cookie_header.contains("HttpOnly"));
601        assert!(cookie_header.contains("Secure"));
602        assert!(cookie_header.contains("SameSite=Lax"));
603        assert!(cookie_header.contains("Max-Age=3600"));
604    }
605
606    #[tokio::test]
607    async fn test_cookie_parsing() {
608        let cookie_header = "session_id=abc123; other=value; session_id=def456";
609        let session_id = SessionCookie::parse_session_id(cookie_header, "session_id");
610
611        assert_eq!(session_id, Some("abc123".to_string()));
612    }
613
614    #[tokio::test]
615    async fn test_delete_cookie_generation() {
616        let config = CookieConfig::default();
617        let delete_header = SessionCookie::generate_delete_cookie_header(&config);
618
619        assert!(delete_header.contains("session_id="));
620        assert!(delete_header.contains("Max-Age=0"));
621        assert!(delete_header.contains("HttpOnly"));
622        assert!(delete_header.contains("Secure"));
623    }
624
625    #[tokio::test]
626    async fn test_session_attributes() {
627        let session = SessionData::new(
628            "session123".to_string(),
629            "user123".to_string(),
630            vec!["user".to_string()],
631            vec!["read".to_string()],
632            Some(3600),
633            None,
634            None,
635        );
636
637        // Test has_role and has_permission
638        assert!(session.has_role("user"));
639        assert!(!session.has_role("admin"));
640        assert!(session.has_permission("read"));
641        assert!(!session.has_permission("write"));
642
643        // Test custom attributes
644        let mut session = session;
645        session.set_attribute("theme".to_string(), serde_json::Value::String("dark".to_string()));
646        session.set_attribute("locale".to_string(), serde_json::Value::String("en".to_string()));
647
648        assert_eq!(session.get_attribute("theme").unwrap().as_str().unwrap(), "dark");
649        assert_eq!(session.get_attribute("locale").unwrap().as_str().unwrap(), "en");
650
651        // Test attribute removal
652        let removed = session.remove_attribute("theme");
653        assert_eq!(removed.unwrap().as_str().unwrap(), "dark");
654        assert!(session.get_attribute("theme").is_none());
655    }
656}