Skip to main content

heliosdb_proxy/auth/
session.rs

1//! Session Management
2//!
3//! Manages authenticated sessions with token generation, validation, and lifecycle.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, SessionConfig};
13
14/// Session errors
15#[derive(Debug, Error)]
16pub enum SessionError {
17    #[error("Session not found")]
18    NotFound,
19
20    #[error("Session expired")]
21    Expired,
22
23    #[error("Session invalidated")]
24    Invalidated,
25
26    #[error("Session limit exceeded")]
27    LimitExceeded,
28
29    #[error("Token generation failed")]
30    TokenGenerationFailed,
31
32    #[error("Invalid token format")]
33    InvalidTokenFormat,
34}
35
36/// Session information
37#[derive(Debug, Clone)]
38pub struct Session {
39    /// Session ID
40    pub id: String,
41
42    /// Session token
43    pub token: String,
44
45    /// Associated identity
46    pub identity: Identity,
47
48    /// Creation time
49    pub created_at: chrono::DateTime<chrono::Utc>,
50
51    /// Last activity time
52    pub last_activity: chrono::DateTime<chrono::Utc>,
53
54    /// Expiration time
55    pub expires_at: chrono::DateTime<chrono::Utc>,
56
57    /// Absolute expiration (max session lifetime)
58    pub absolute_expires_at: chrono::DateTime<chrono::Utc>,
59
60    /// Client IP address
61    pub client_ip: Option<std::net::IpAddr>,
62
63    /// User agent
64    pub user_agent: Option<String>,
65
66    /// Session metadata
67    pub metadata: HashMap<String, String>,
68
69    /// Whether session is active
70    pub active: bool,
71}
72
73impl Session {
74    /// Check if session is expired
75    pub fn is_expired(&self) -> bool {
76        let now = chrono::Utc::now();
77        now > self.expires_at || now > self.absolute_expires_at
78    }
79
80    /// Check if session is valid
81    pub fn is_valid(&self) -> bool {
82        self.active && !self.is_expired()
83    }
84
85    /// Get remaining time
86    pub fn remaining_time(&self) -> Option<Duration> {
87        let now = chrono::Utc::now();
88        let expires = self.expires_at.min(self.absolute_expires_at);
89
90        if expires > now {
91            (expires - now).to_std().ok()
92        } else {
93            None
94        }
95    }
96
97    /// Get session duration
98    pub fn duration(&self) -> Duration {
99        let now = chrono::Utc::now();
100        (now - self.created_at).to_std().unwrap_or(Duration::ZERO)
101    }
102}
103
104/// Session manager
105pub struct SessionManager {
106    /// Configuration
107    config: SessionConfig,
108
109    /// Active sessions by ID
110    sessions: Arc<RwLock<HashMap<String, Session>>>,
111
112    /// Session lookup by token
113    tokens: Arc<RwLock<HashMap<String, String>>>,
114
115    /// Sessions by user
116    user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
117
118    /// Last cleanup time
119    last_cleanup: Arc<RwLock<Instant>>,
120}
121
122impl SessionManager {
123    /// Create a new session manager
124    pub fn new(config: SessionConfig) -> Self {
125        Self {
126            config,
127            sessions: Arc::new(RwLock::new(HashMap::new())),
128            tokens: Arc::new(RwLock::new(HashMap::new())),
129            user_sessions: Arc::new(RwLock::new(HashMap::new())),
130            last_cleanup: Arc::new(RwLock::new(Instant::now())),
131        }
132    }
133
134    /// Create a builder
135    pub fn builder() -> SessionManagerBuilder {
136        SessionManagerBuilder::new()
137    }
138
139    /// Create a new session
140    pub fn create_session(
141        &self,
142        identity: Identity,
143        client_ip: Option<std::net::IpAddr>,
144        user_agent: Option<String>,
145    ) -> Result<Session, SessionError> {
146        // Check session limit
147        if self.config.max_sessions_per_user > 0 {
148            let user_sessions = self.user_sessions.read();
149            if let Some(sessions) = user_sessions.get(&identity.user_id) {
150                if sessions.len() >= self.config.max_sessions_per_user {
151                    return Err(SessionError::LimitExceeded);
152                }
153            }
154        }
155
156        // Generate session ID and token
157        let session_id = self.generate_session_id();
158        let token = self.generate_token();
159
160        let now = chrono::Utc::now();
161        let expires_at = now + chrono::Duration::from_std(self.config.idle_timeout)
162            .unwrap_or(chrono::Duration::hours(1));
163        let absolute_expires_at = now + chrono::Duration::from_std(self.config.absolute_timeout)
164            .unwrap_or(chrono::Duration::hours(24));
165
166        let session = Session {
167            id: session_id.clone(),
168            token: token.clone(),
169            identity: identity.clone(),
170            created_at: now,
171            last_activity: now,
172            expires_at,
173            absolute_expires_at,
174            client_ip,
175            user_agent,
176            metadata: HashMap::new(),
177            active: true,
178        };
179
180        // Store session
181        self.sessions.write().insert(session_id.clone(), session.clone());
182        self.tokens.write().insert(token.clone(), session_id.clone());
183        self.user_sessions.write()
184            .entry(identity.user_id.clone())
185            .or_insert_with(Vec::new)
186            .push(session_id);
187
188        // Cleanup old sessions periodically
189        self.maybe_cleanup();
190
191        Ok(session)
192    }
193
194    /// Get session by token
195    pub fn get_session(&self, token: &str) -> Result<Session, SessionError> {
196        let session_id = self.tokens.read()
197            .get(token)
198            .cloned()
199            .ok_or(SessionError::NotFound)?;
200
201        let session = self.sessions.read()
202            .get(&session_id)
203            .cloned()
204            .ok_or(SessionError::NotFound)?;
205
206        if !session.active {
207            return Err(SessionError::Invalidated);
208        }
209
210        if session.is_expired() {
211            return Err(SessionError::Expired);
212        }
213
214        Ok(session)
215    }
216
217    /// Get session by ID
218    pub fn get_session_by_id(&self, session_id: &str) -> Result<Session, SessionError> {
219        let session = self.sessions.read()
220            .get(session_id)
221            .cloned()
222            .ok_or(SessionError::NotFound)?;
223
224        if !session.active {
225            return Err(SessionError::Invalidated);
226        }
227
228        if session.is_expired() {
229            return Err(SessionError::Expired);
230        }
231
232        Ok(session)
233    }
234
235    /// Validate token and return identity
236    pub fn validate_token(&self, token: &str) -> Result<Identity, SessionError> {
237        let session = self.get_session(token)?;
238        Ok(session.identity)
239    }
240
241    /// Refresh session (extend expiration)
242    pub fn refresh_session(&self, token: &str) -> Result<Session, SessionError> {
243        let session_id = self.tokens.read()
244            .get(token)
245            .cloned()
246            .ok_or(SessionError::NotFound)?;
247
248        let mut sessions = self.sessions.write();
249        let session = sessions.get_mut(&session_id)
250            .ok_or(SessionError::NotFound)?;
251
252        if !session.active {
253            return Err(SessionError::Invalidated);
254        }
255
256        if session.is_expired() {
257            return Err(SessionError::Expired);
258        }
259
260        // Update activity time and expiration
261        let now = chrono::Utc::now();
262        session.last_activity = now;
263
264        // Extend idle timeout but not beyond absolute timeout
265        let new_expires = now + chrono::Duration::from_std(self.config.idle_timeout)
266            .unwrap_or(chrono::Duration::hours(1));
267        session.expires_at = new_expires.min(session.absolute_expires_at);
268
269        Ok(session.clone())
270    }
271
272    /// Invalidate a session
273    pub fn invalidate_session(&self, token: &str) -> Result<(), SessionError> {
274        let session_id = self.tokens.read()
275            .get(token)
276            .cloned()
277            .ok_or(SessionError::NotFound)?;
278
279        self.invalidate_session_by_id(&session_id)
280    }
281
282    /// Invalidate session by ID
283    pub fn invalidate_session_by_id(&self, session_id: &str) -> Result<(), SessionError> {
284        let mut sessions = self.sessions.write();
285        let session = sessions.get_mut(session_id)
286            .ok_or(SessionError::NotFound)?;
287
288        session.active = false;
289
290        // Remove from token lookup
291        self.tokens.write().remove(&session.token);
292
293        // Remove from user sessions
294        let user_id = session.identity.user_id.clone();
295        let mut user_sessions = self.user_sessions.write();
296        if let Some(sessions) = user_sessions.get_mut(&user_id) {
297            sessions.retain(|id| id != session_id);
298        }
299
300        Ok(())
301    }
302
303    /// Invalidate all sessions for a user
304    pub fn invalidate_user_sessions(&self, user_id: &str) {
305        let session_ids: Vec<String> = self.user_sessions.read()
306            .get(user_id)
307            .cloned()
308            .unwrap_or_default();
309
310        for session_id in session_ids {
311            let _ = self.invalidate_session_by_id(&session_id);
312        }
313    }
314
315    /// List all sessions for a user
316    pub fn list_user_sessions(&self, user_id: &str) -> Vec<Session> {
317        let session_ids: Vec<String> = self.user_sessions.read()
318            .get(user_id)
319            .cloned()
320            .unwrap_or_default();
321
322        let sessions = self.sessions.read();
323        session_ids.iter()
324            .filter_map(|id| sessions.get(id).cloned())
325            .filter(|s| s.is_valid())
326            .collect()
327    }
328
329    /// Update session metadata
330    pub fn update_metadata(
331        &self,
332        token: &str,
333        key: impl Into<String>,
334        value: impl Into<String>,
335    ) -> Result<(), SessionError> {
336        let session_id = self.tokens.read()
337            .get(token)
338            .cloned()
339            .ok_or(SessionError::NotFound)?;
340
341        let mut sessions = self.sessions.write();
342        let session = sessions.get_mut(&session_id)
343            .ok_or(SessionError::NotFound)?;
344
345        session.metadata.insert(key.into(), value.into());
346        Ok(())
347    }
348
349    /// Get session statistics
350    pub fn stats(&self) -> SessionStats {
351        let sessions = self.sessions.read();
352        let active = sessions.values().filter(|s| s.is_valid()).count();
353        let expired = sessions.values().filter(|s| s.is_expired()).count();
354        let invalidated = sessions.values().filter(|s| !s.active).count();
355
356        SessionStats {
357            total: sessions.len(),
358            active,
359            expired,
360            invalidated,
361        }
362    }
363
364    /// Cleanup expired sessions
365    pub fn cleanup(&self) {
366        let expired_ids: Vec<String> = {
367            let sessions = self.sessions.read();
368            sessions.iter()
369                .filter(|(_, s)| s.is_expired() || !s.active)
370                .map(|(id, _)| id.clone())
371                .collect()
372        };
373
374        for id in expired_ids {
375            let _ = self.invalidate_session_by_id(&id);
376            self.sessions.write().remove(&id);
377        }
378
379        *self.last_cleanup.write() = Instant::now();
380    }
381
382    /// Maybe run cleanup if enough time has passed
383    fn maybe_cleanup(&self) {
384        let should_cleanup = {
385            let last = self.last_cleanup.read();
386            last.elapsed() > Duration::from_secs(60)
387        };
388
389        if should_cleanup {
390            self.cleanup();
391        }
392    }
393
394    /// Generate a session ID
395    fn generate_session_id(&self) -> String {
396        use std::collections::hash_map::RandomState;
397        use std::hash::{BuildHasher, Hasher};
398
399        let mut hasher = RandomState::new().build_hasher();
400        hasher.write_u128(std::time::SystemTime::now()
401            .duration_since(std::time::UNIX_EPOCH)
402            .unwrap()
403            .as_nanos());
404        hasher.write_usize(std::process::id() as usize);
405
406        let hash1 = hasher.finish();
407        hasher.write_u64(hash1);
408        let hash2 = hasher.finish();
409
410        format!("sess_{:016x}{:016x}", hash1, hash2)
411    }
412
413    /// Generate a session token
414    fn generate_token(&self) -> String {
415        use std::collections::hash_map::RandomState;
416        use std::hash::{BuildHasher, Hasher};
417
418        let mut hasher = RandomState::new().build_hasher();
419        hasher.write_u128(std::time::SystemTime::now()
420            .duration_since(std::time::UNIX_EPOCH)
421            .unwrap()
422            .as_nanos());
423
424        let mut token_bytes = Vec::new();
425        for _ in 0..4 {
426            hasher.write_u64(hasher.finish());
427            token_bytes.extend_from_slice(&hasher.finish().to_le_bytes());
428        }
429
430        // Encode as URL-safe base64
431        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
432        URL_SAFE_NO_PAD.encode(&token_bytes)
433    }
434}
435
436/// Session statistics
437#[derive(Debug, Clone)]
438pub struct SessionStats {
439    /// Total sessions
440    pub total: usize,
441
442    /// Active sessions
443    pub active: usize,
444
445    /// Expired sessions
446    pub expired: usize,
447
448    /// Invalidated sessions
449    pub invalidated: usize,
450}
451
452/// Session manager builder
453pub struct SessionManagerBuilder {
454    config: SessionConfig,
455}
456
457impl SessionManagerBuilder {
458    /// Create a new builder
459    pub fn new() -> Self {
460        Self {
461            config: SessionConfig::default(),
462        }
463    }
464
465    /// Set idle timeout
466    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
467        self.config.idle_timeout = timeout;
468        self
469    }
470
471    /// Set absolute timeout
472    pub fn absolute_timeout(mut self, timeout: Duration) -> Self {
473        self.config.absolute_timeout = timeout;
474        self
475    }
476
477    /// Set max sessions per user
478    pub fn max_sessions_per_user(mut self, max: usize) -> Self {
479        self.config.max_sessions_per_user = max;
480        self
481    }
482
483    /// Enable secure cookies
484    pub fn secure_cookies(mut self, secure: bool) -> Self {
485        self.config.secure_cookies = secure;
486        self
487    }
488
489    /// Build the session manager
490    pub fn build(self) -> SessionManager {
491        SessionManager::new(self.config)
492    }
493}
494
495impl Default for SessionManagerBuilder {
496    fn default() -> Self {
497        Self::new()
498    }
499}
500
501/// Session cookie options
502#[derive(Debug, Clone)]
503pub struct CookieOptions {
504    /// Cookie name
505    pub name: String,
506
507    /// Cookie path
508    pub path: String,
509
510    /// Cookie domain
511    pub domain: Option<String>,
512
513    /// Secure flag
514    pub secure: bool,
515
516    /// HttpOnly flag
517    pub http_only: bool,
518
519    /// SameSite attribute
520    pub same_site: SameSite,
521
522    /// Max age
523    pub max_age: Option<Duration>,
524}
525
526/// SameSite attribute
527#[derive(Debug, Clone, PartialEq, Eq)]
528pub enum SameSite {
529    Strict,
530    Lax,
531    None,
532}
533
534impl Default for CookieOptions {
535    fn default() -> Self {
536        Self {
537            name: "session".to_string(),
538            path: "/".to_string(),
539            domain: None,
540            secure: true,
541            http_only: true,
542            same_site: SameSite::Lax,
543            max_age: None,
544        }
545    }
546}
547
548impl CookieOptions {
549    /// Build Set-Cookie header value
550    pub fn to_set_cookie_header(&self, token: &str) -> String {
551        let mut parts = vec![
552            format!("{}={}", self.name, token),
553            format!("Path={}", self.path),
554        ];
555
556        if let Some(domain) = &self.domain {
557            parts.push(format!("Domain={}", domain));
558        }
559
560        if self.secure {
561            parts.push("Secure".to_string());
562        }
563
564        if self.http_only {
565            parts.push("HttpOnly".to_string());
566        }
567
568        parts.push(match self.same_site {
569            SameSite::Strict => "SameSite=Strict".to_string(),
570            SameSite::Lax => "SameSite=Lax".to_string(),
571            SameSite::None => "SameSite=None".to_string(),
572        });
573
574        if let Some(max_age) = self.max_age {
575            parts.push(format!("Max-Age={}", max_age.as_secs()));
576        }
577
578        parts.join("; ")
579    }
580
581    /// Build deletion cookie header
582    pub fn to_delete_cookie_header(&self) -> String {
583        let mut parts = vec![
584            format!("{}=", self.name),
585            format!("Path={}", self.path),
586            "Max-Age=0".to_string(),
587            "Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
588        ];
589
590        if let Some(domain) = &self.domain {
591            parts.push(format!("Domain={}", domain));
592        }
593
594        parts.join("; ")
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    fn test_identity() -> Identity {
603        Identity {
604            user_id: "user123".to_string(),
605            name: Some("Test User".to_string()),
606            email: Some("test@example.com".to_string()),
607            roles: vec!["user".to_string()],
608            groups: Vec::new(),
609            tenant_id: None,
610            claims: HashMap::new(),
611            auth_method: "test".to_string(),
612            authenticated_at: chrono::Utc::now(),
613        }
614    }
615
616    #[test]
617    fn test_create_session() {
618        let manager = SessionManager::builder()
619            .idle_timeout(Duration::from_secs(3600))
620            .absolute_timeout(Duration::from_secs(86400))
621            .build();
622
623        let session = manager.create_session(
624            test_identity(),
625            None,
626            Some("Test Agent".to_string()),
627        ).unwrap();
628
629        assert!(session.is_valid());
630        assert!(session.active);
631        assert!(!session.is_expired());
632    }
633
634    #[test]
635    fn test_get_session() {
636        let manager = SessionManager::new(SessionConfig::default());
637
638        let session = manager.create_session(test_identity(), None, None).unwrap();
639        let token = session.token.clone();
640
641        let retrieved = manager.get_session(&token).unwrap();
642        assert_eq!(retrieved.id, session.id);
643    }
644
645    #[test]
646    fn test_validate_token() {
647        let manager = SessionManager::new(SessionConfig::default());
648
649        let session = manager.create_session(test_identity(), None, None).unwrap();
650        let identity = manager.validate_token(&session.token).unwrap();
651
652        assert_eq!(identity.user_id, "user123");
653    }
654
655    #[test]
656    fn test_refresh_session() {
657        let manager = SessionManager::new(SessionConfig::default());
658
659        let session = manager.create_session(test_identity(), None, None).unwrap();
660        let original_expires = session.expires_at;
661
662        std::thread::sleep(Duration::from_millis(10));
663
664        let refreshed = manager.refresh_session(&session.token).unwrap();
665        assert!(refreshed.last_activity > session.last_activity);
666    }
667
668    #[test]
669    fn test_invalidate_session() {
670        let manager = SessionManager::new(SessionConfig::default());
671
672        let session = manager.create_session(test_identity(), None, None).unwrap();
673        manager.invalidate_session(&session.token).unwrap();
674
675        assert!(manager.get_session(&session.token).is_err());
676    }
677
678    #[test]
679    fn test_session_limit() {
680        let manager = SessionManager::builder()
681            .max_sessions_per_user(2)
682            .build();
683
684        let _ = manager.create_session(test_identity(), None, None).unwrap();
685        let _ = manager.create_session(test_identity(), None, None).unwrap();
686
687        let result = manager.create_session(test_identity(), None, None);
688        assert!(matches!(result, Err(SessionError::LimitExceeded)));
689    }
690
691    #[test]
692    fn test_list_user_sessions() {
693        let manager = SessionManager::new(SessionConfig::default());
694
695        let _ = manager.create_session(test_identity(), None, None).unwrap();
696        let _ = manager.create_session(test_identity(), None, None).unwrap();
697
698        let sessions = manager.list_user_sessions("user123");
699        assert_eq!(sessions.len(), 2);
700    }
701
702    #[test]
703    fn test_invalidate_user_sessions() {
704        let manager = SessionManager::new(SessionConfig::default());
705
706        let s1 = manager.create_session(test_identity(), None, None).unwrap();
707        let s2 = manager.create_session(test_identity(), None, None).unwrap();
708
709        manager.invalidate_user_sessions("user123");
710
711        assert!(manager.get_session(&s1.token).is_err());
712        assert!(manager.get_session(&s2.token).is_err());
713    }
714
715    #[test]
716    fn test_session_stats() {
717        let manager = SessionManager::new(SessionConfig::default());
718
719        let _ = manager.create_session(test_identity(), None, None).unwrap();
720        let s2 = manager.create_session(test_identity(), None, None).unwrap();
721        manager.invalidate_session(&s2.token).unwrap();
722
723        let stats = manager.stats();
724        assert_eq!(stats.total, 2);
725        assert_eq!(stats.active, 1);
726    }
727
728    #[test]
729    fn test_update_metadata() {
730        let manager = SessionManager::new(SessionConfig::default());
731
732        let session = manager.create_session(test_identity(), None, None).unwrap();
733        manager.update_metadata(&session.token, "key", "value").unwrap();
734
735        let updated = manager.get_session(&session.token).unwrap();
736        assert_eq!(updated.metadata.get("key"), Some(&"value".to_string()));
737    }
738
739    #[test]
740    fn test_cookie_options() {
741        let options = CookieOptions {
742            name: "session".to_string(),
743            path: "/".to_string(),
744            domain: Some("example.com".to_string()),
745            secure: true,
746            http_only: true,
747            same_site: SameSite::Strict,
748            max_age: Some(Duration::from_secs(3600)),
749        };
750
751        let header = options.to_set_cookie_header("token123");
752
753        assert!(header.contains("session=token123"));
754        assert!(header.contains("Path=/"));
755        assert!(header.contains("Domain=example.com"));
756        assert!(header.contains("Secure"));
757        assert!(header.contains("HttpOnly"));
758        assert!(header.contains("SameSite=Strict"));
759        assert!(header.contains("Max-Age=3600"));
760    }
761
762    #[test]
763    fn test_delete_cookie() {
764        let options = CookieOptions::default();
765        let header = options.to_delete_cookie_header();
766
767        assert!(header.contains("session="));
768        assert!(header.contains("Max-Age=0"));
769        assert!(header.contains("Expires=Thu, 01 Jan 1970"));
770    }
771
772    #[test]
773    fn test_session_remaining_time() {
774        let manager = SessionManager::builder()
775            .idle_timeout(Duration::from_secs(3600))
776            .build();
777
778        let session = manager.create_session(test_identity(), None, None).unwrap();
779
780        let remaining = session.remaining_time().unwrap();
781        assert!(remaining > Duration::from_secs(3500)); // Should be close to 1 hour
782        assert!(remaining <= Duration::from_secs(3600));
783    }
784}