Skip to main content

heliosdb_proxy/auth/
session.rs

1//! Session Management
2//!
3//! Manages authenticated sessions with token generation, validation, and lifecycle.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, SessionConfig};
13
14/// Session errors
15#[derive(Debug, Error)]
16pub enum SessionError {
17    #[error("Session not found")]
18    NotFound,
19
20    #[error("Session expired")]
21    Expired,
22
23    #[error("Session invalidated")]
24    Invalidated,
25
26    #[error("Session limit exceeded")]
27    LimitExceeded,
28
29    #[error("Token generation failed")]
30    TokenGenerationFailed,
31
32    #[error("Invalid token format")]
33    InvalidTokenFormat,
34}
35
36/// Session information
37#[derive(Debug, Clone)]
38pub struct Session {
39    /// Session ID
40    pub id: String,
41
42    /// Session token
43    pub token: String,
44
45    /// Associated identity
46    pub identity: Identity,
47
48    /// Creation time
49    pub created_at: chrono::DateTime<chrono::Utc>,
50
51    /// Last activity time
52    pub last_activity: chrono::DateTime<chrono::Utc>,
53
54    /// Expiration time
55    pub expires_at: chrono::DateTime<chrono::Utc>,
56
57    /// Absolute expiration (max session lifetime)
58    pub absolute_expires_at: chrono::DateTime<chrono::Utc>,
59
60    /// Client IP address
61    pub client_ip: Option<std::net::IpAddr>,
62
63    /// User agent
64    pub user_agent: Option<String>,
65
66    /// Session metadata
67    pub metadata: HashMap<String, String>,
68
69    /// Whether session is active
70    pub active: bool,
71}
72
73impl Session {
74    /// Check if session is expired
75    pub fn is_expired(&self) -> bool {
76        let now = chrono::Utc::now();
77        now > self.expires_at || now > self.absolute_expires_at
78    }
79
80    /// Check if session is valid
81    pub fn is_valid(&self) -> bool {
82        self.active && !self.is_expired()
83    }
84
85    /// Get remaining time
86    pub fn remaining_time(&self) -> Option<Duration> {
87        let now = chrono::Utc::now();
88        let expires = self.expires_at.min(self.absolute_expires_at);
89
90        if expires > now {
91            (expires - now).to_std().ok()
92        } else {
93            None
94        }
95    }
96
97    /// Get session duration
98    pub fn duration(&self) -> Duration {
99        let now = chrono::Utc::now();
100        (now - self.created_at).to_std().unwrap_or(Duration::ZERO)
101    }
102}
103
104/// Session manager
105pub struct SessionManager {
106    /// Configuration
107    config: SessionConfig,
108
109    /// Active sessions by ID
110    sessions: Arc<RwLock<HashMap<String, Session>>>,
111
112    /// Session lookup by token
113    tokens: Arc<RwLock<HashMap<String, String>>>,
114
115    /// Sessions by user
116    user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
117
118    /// Last cleanup time
119    last_cleanup: Arc<RwLock<Instant>>,
120}
121
122impl SessionManager {
123    /// Create a new session manager
124    pub fn new(config: SessionConfig) -> Self {
125        Self {
126            config,
127            sessions: Arc::new(RwLock::new(HashMap::new())),
128            tokens: Arc::new(RwLock::new(HashMap::new())),
129            user_sessions: Arc::new(RwLock::new(HashMap::new())),
130            last_cleanup: Arc::new(RwLock::new(Instant::now())),
131        }
132    }
133
134    /// Create a builder
135    pub fn builder() -> SessionManagerBuilder {
136        SessionManagerBuilder::new()
137    }
138
139    /// Create a new session
140    pub fn create_session(
141        &self,
142        identity: Identity,
143        client_ip: Option<std::net::IpAddr>,
144        user_agent: Option<String>,
145    ) -> Result<Session, SessionError> {
146        // Check session limit
147        if self.config.max_sessions_per_user > 0 {
148            let user_sessions = self.user_sessions.read();
149            if let Some(sessions) = user_sessions.get(&identity.user_id) {
150                if sessions.len() >= self.config.max_sessions_per_user {
151                    return Err(SessionError::LimitExceeded);
152                }
153            }
154        }
155
156        // Generate session ID and token
157        let session_id = self.generate_session_id();
158        let token = self.generate_token();
159
160        let now = chrono::Utc::now();
161        let expires_at = now
162            + chrono::Duration::from_std(self.config.idle_timeout)
163                .unwrap_or(chrono::Duration::hours(1));
164        let absolute_expires_at = now
165            + chrono::Duration::from_std(self.config.absolute_timeout)
166                .unwrap_or(chrono::Duration::hours(24));
167
168        let session = Session {
169            id: session_id.clone(),
170            token: token.clone(),
171            identity: identity.clone(),
172            created_at: now,
173            last_activity: now,
174            expires_at,
175            absolute_expires_at,
176            client_ip,
177            user_agent,
178            metadata: HashMap::new(),
179            active: true,
180        };
181
182        // Store session
183        self.sessions
184            .write()
185            .insert(session_id.clone(), session.clone());
186        self.tokens
187            .write()
188            .insert(token.clone(), session_id.clone());
189        self.user_sessions
190            .write()
191            .entry(identity.user_id.clone())
192            .or_default()
193            .push(session_id);
194
195        // Cleanup old sessions periodically
196        self.maybe_cleanup();
197
198        Ok(session)
199    }
200
201    /// Get session by token
202    pub fn get_session(&self, token: &str) -> Result<Session, SessionError> {
203        let session_id = self
204            .tokens
205            .read()
206            .get(token)
207            .cloned()
208            .ok_or(SessionError::NotFound)?;
209
210        let session = self
211            .sessions
212            .read()
213            .get(&session_id)
214            .cloned()
215            .ok_or(SessionError::NotFound)?;
216
217        if !session.active {
218            return Err(SessionError::Invalidated);
219        }
220
221        if session.is_expired() {
222            return Err(SessionError::Expired);
223        }
224
225        Ok(session)
226    }
227
228    /// Get session by ID
229    pub fn get_session_by_id(&self, session_id: &str) -> Result<Session, SessionError> {
230        let session = self
231            .sessions
232            .read()
233            .get(session_id)
234            .cloned()
235            .ok_or(SessionError::NotFound)?;
236
237        if !session.active {
238            return Err(SessionError::Invalidated);
239        }
240
241        if session.is_expired() {
242            return Err(SessionError::Expired);
243        }
244
245        Ok(session)
246    }
247
248    /// Validate token and return identity
249    pub fn validate_token(&self, token: &str) -> Result<Identity, SessionError> {
250        let session = self.get_session(token)?;
251        Ok(session.identity)
252    }
253
254    /// Refresh session (extend expiration)
255    pub fn refresh_session(&self, token: &str) -> Result<Session, SessionError> {
256        let session_id = self
257            .tokens
258            .read()
259            .get(token)
260            .cloned()
261            .ok_or(SessionError::NotFound)?;
262
263        let mut sessions = self.sessions.write();
264        let session = sessions
265            .get_mut(&session_id)
266            .ok_or(SessionError::NotFound)?;
267
268        if !session.active {
269            return Err(SessionError::Invalidated);
270        }
271
272        if session.is_expired() {
273            return Err(SessionError::Expired);
274        }
275
276        // Update activity time and expiration
277        let now = chrono::Utc::now();
278        session.last_activity = now;
279
280        // Extend idle timeout but not beyond absolute timeout
281        let new_expires = now
282            + chrono::Duration::from_std(self.config.idle_timeout)
283                .unwrap_or(chrono::Duration::hours(1));
284        session.expires_at = new_expires.min(session.absolute_expires_at);
285
286        Ok(session.clone())
287    }
288
289    /// Invalidate a session
290    pub fn invalidate_session(&self, token: &str) -> Result<(), SessionError> {
291        let session_id = self
292            .tokens
293            .read()
294            .get(token)
295            .cloned()
296            .ok_or(SessionError::NotFound)?;
297
298        self.invalidate_session_by_id(&session_id)
299    }
300
301    /// Invalidate session by ID
302    pub fn invalidate_session_by_id(&self, session_id: &str) -> Result<(), SessionError> {
303        let mut sessions = self.sessions.write();
304        let session = sessions.get_mut(session_id).ok_or(SessionError::NotFound)?;
305
306        session.active = false;
307
308        // Remove from token lookup
309        self.tokens.write().remove(&session.token);
310
311        // Remove from user sessions
312        let user_id = session.identity.user_id.clone();
313        let mut user_sessions = self.user_sessions.write();
314        if let Some(sessions) = user_sessions.get_mut(&user_id) {
315            sessions.retain(|id| id != session_id);
316        }
317
318        Ok(())
319    }
320
321    /// Invalidate all sessions for a user
322    pub fn invalidate_user_sessions(&self, user_id: &str) {
323        let session_ids: Vec<String> = self
324            .user_sessions
325            .read()
326            .get(user_id)
327            .cloned()
328            .unwrap_or_default();
329
330        for session_id in session_ids {
331            let _ = self.invalidate_session_by_id(&session_id);
332        }
333    }
334
335    /// List all sessions for a user
336    pub fn list_user_sessions(&self, user_id: &str) -> Vec<Session> {
337        let session_ids: Vec<String> = self
338            .user_sessions
339            .read()
340            .get(user_id)
341            .cloned()
342            .unwrap_or_default();
343
344        let sessions = self.sessions.read();
345        session_ids
346            .iter()
347            .filter_map(|id| sessions.get(id).cloned())
348            .filter(|s| s.is_valid())
349            .collect()
350    }
351
352    /// Update session metadata
353    pub fn update_metadata(
354        &self,
355        token: &str,
356        key: impl Into<String>,
357        value: impl Into<String>,
358    ) -> Result<(), SessionError> {
359        let session_id = self
360            .tokens
361            .read()
362            .get(token)
363            .cloned()
364            .ok_or(SessionError::NotFound)?;
365
366        let mut sessions = self.sessions.write();
367        let session = sessions
368            .get_mut(&session_id)
369            .ok_or(SessionError::NotFound)?;
370
371        session.metadata.insert(key.into(), value.into());
372        Ok(())
373    }
374
375    /// Get session statistics
376    pub fn stats(&self) -> SessionStats {
377        let sessions = self.sessions.read();
378        let active = sessions.values().filter(|s| s.is_valid()).count();
379        let expired = sessions.values().filter(|s| s.is_expired()).count();
380        let invalidated = sessions.values().filter(|s| !s.active).count();
381
382        SessionStats {
383            total: sessions.len(),
384            active,
385            expired,
386            invalidated,
387        }
388    }
389
390    /// Cleanup expired sessions
391    pub fn cleanup(&self) {
392        let expired_ids: Vec<String> = {
393            let sessions = self.sessions.read();
394            sessions
395                .iter()
396                .filter(|(_, s)| s.is_expired() || !s.active)
397                .map(|(id, _)| id.clone())
398                .collect()
399        };
400
401        for id in expired_ids {
402            let _ = self.invalidate_session_by_id(&id);
403            self.sessions.write().remove(&id);
404        }
405
406        *self.last_cleanup.write() = Instant::now();
407    }
408
409    /// Maybe run cleanup if enough time has passed
410    fn maybe_cleanup(&self) {
411        let should_cleanup = {
412            let last = self.last_cleanup.read();
413            last.elapsed() > Duration::from_secs(60)
414        };
415
416        if should_cleanup {
417            self.cleanup();
418        }
419    }
420
421    /// Generate a session ID
422    fn generate_session_id(&self) -> String {
423        use std::collections::hash_map::RandomState;
424        use std::hash::{BuildHasher, Hasher};
425
426        let mut hasher = RandomState::new().build_hasher();
427        hasher.write_u128(
428            std::time::SystemTime::now()
429                .duration_since(std::time::UNIX_EPOCH)
430                .unwrap()
431                .as_nanos(),
432        );
433        hasher.write_usize(std::process::id() as usize);
434
435        let hash1 = hasher.finish();
436        hasher.write_u64(hash1);
437        let hash2 = hasher.finish();
438
439        format!("sess_{:016x}{:016x}", hash1, hash2)
440    }
441
442    /// Generate a session token
443    fn generate_token(&self) -> String {
444        use std::collections::hash_map::RandomState;
445        use std::hash::{BuildHasher, Hasher};
446
447        let mut hasher = RandomState::new().build_hasher();
448        hasher.write_u128(
449            std::time::SystemTime::now()
450                .duration_since(std::time::UNIX_EPOCH)
451                .unwrap()
452                .as_nanos(),
453        );
454
455        let mut token_bytes = Vec::new();
456        for _ in 0..4 {
457            hasher.write_u64(hasher.finish());
458            token_bytes.extend_from_slice(&hasher.finish().to_le_bytes());
459        }
460
461        // Encode as URL-safe base64
462        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
463        URL_SAFE_NO_PAD.encode(&token_bytes)
464    }
465}
466
467/// Session statistics
468#[derive(Debug, Clone)]
469pub struct SessionStats {
470    /// Total sessions
471    pub total: usize,
472
473    /// Active sessions
474    pub active: usize,
475
476    /// Expired sessions
477    pub expired: usize,
478
479    /// Invalidated sessions
480    pub invalidated: usize,
481}
482
483/// Session manager builder
484pub struct SessionManagerBuilder {
485    config: SessionConfig,
486}
487
488impl SessionManagerBuilder {
489    /// Create a new builder
490    pub fn new() -> Self {
491        Self {
492            config: SessionConfig::default(),
493        }
494    }
495
496    /// Set idle timeout
497    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
498        self.config.idle_timeout = timeout;
499        self
500    }
501
502    /// Set absolute timeout
503    pub fn absolute_timeout(mut self, timeout: Duration) -> Self {
504        self.config.absolute_timeout = timeout;
505        self
506    }
507
508    /// Set max sessions per user
509    pub fn max_sessions_per_user(mut self, max: usize) -> Self {
510        self.config.max_sessions_per_user = max;
511        self
512    }
513
514    /// Enable secure cookies
515    pub fn secure_cookies(mut self, secure: bool) -> Self {
516        self.config.secure_cookies = secure;
517        self
518    }
519
520    /// Build the session manager
521    pub fn build(self) -> SessionManager {
522        SessionManager::new(self.config)
523    }
524}
525
526impl Default for SessionManagerBuilder {
527    fn default() -> Self {
528        Self::new()
529    }
530}
531
532/// Session cookie options
533#[derive(Debug, Clone)]
534pub struct CookieOptions {
535    /// Cookie name
536    pub name: String,
537
538    /// Cookie path
539    pub path: String,
540
541    /// Cookie domain
542    pub domain: Option<String>,
543
544    /// Secure flag
545    pub secure: bool,
546
547    /// HttpOnly flag
548    pub http_only: bool,
549
550    /// SameSite attribute
551    pub same_site: SameSite,
552
553    /// Max age
554    pub max_age: Option<Duration>,
555}
556
557/// SameSite attribute
558#[derive(Debug, Clone, PartialEq, Eq)]
559pub enum SameSite {
560    Strict,
561    Lax,
562    None,
563}
564
565impl Default for CookieOptions {
566    fn default() -> Self {
567        Self {
568            name: "session".to_string(),
569            path: "/".to_string(),
570            domain: None,
571            secure: true,
572            http_only: true,
573            same_site: SameSite::Lax,
574            max_age: None,
575        }
576    }
577}
578
579impl CookieOptions {
580    /// Build Set-Cookie header value
581    pub fn to_set_cookie_header(&self, token: &str) -> String {
582        let mut parts = vec![
583            format!("{}={}", self.name, token),
584            format!("Path={}", self.path),
585        ];
586
587        if let Some(domain) = &self.domain {
588            parts.push(format!("Domain={}", domain));
589        }
590
591        if self.secure {
592            parts.push("Secure".to_string());
593        }
594
595        if self.http_only {
596            parts.push("HttpOnly".to_string());
597        }
598
599        parts.push(match self.same_site {
600            SameSite::Strict => "SameSite=Strict".to_string(),
601            SameSite::Lax => "SameSite=Lax".to_string(),
602            SameSite::None => "SameSite=None".to_string(),
603        });
604
605        if let Some(max_age) = self.max_age {
606            parts.push(format!("Max-Age={}", max_age.as_secs()));
607        }
608
609        parts.join("; ")
610    }
611
612    /// Build deletion cookie header
613    pub fn to_delete_cookie_header(&self) -> String {
614        let mut parts = vec![
615            format!("{}=", self.name),
616            format!("Path={}", self.path),
617            "Max-Age=0".to_string(),
618            "Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
619        ];
620
621        if let Some(domain) = &self.domain {
622            parts.push(format!("Domain={}", domain));
623        }
624
625        parts.join("; ")
626    }
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632
633    fn test_identity() -> Identity {
634        Identity {
635            user_id: "user123".to_string(),
636            name: Some("Test User".to_string()),
637            email: Some("test@example.com".to_string()),
638            roles: vec!["user".to_string()],
639            groups: Vec::new(),
640            tenant_id: None,
641            claims: HashMap::new(),
642            auth_method: "test".to_string(),
643            authenticated_at: chrono::Utc::now(),
644        }
645    }
646
647    #[test]
648    fn test_create_session() {
649        let manager = SessionManager::builder()
650            .idle_timeout(Duration::from_secs(3600))
651            .absolute_timeout(Duration::from_secs(86400))
652            .build();
653
654        let session = manager
655            .create_session(test_identity(), None, Some("Test Agent".to_string()))
656            .unwrap();
657
658        assert!(session.is_valid());
659        assert!(session.active);
660        assert!(!session.is_expired());
661    }
662
663    #[test]
664    fn test_get_session() {
665        let manager = SessionManager::new(SessionConfig::default());
666
667        let session = manager.create_session(test_identity(), None, None).unwrap();
668        let token = session.token.clone();
669
670        let retrieved = manager.get_session(&token).unwrap();
671        assert_eq!(retrieved.id, session.id);
672    }
673
674    #[test]
675    fn test_validate_token() {
676        let manager = SessionManager::new(SessionConfig::default());
677
678        let session = manager.create_session(test_identity(), None, None).unwrap();
679        let identity = manager.validate_token(&session.token).unwrap();
680
681        assert_eq!(identity.user_id, "user123");
682    }
683
684    #[test]
685    fn test_refresh_session() {
686        let manager = SessionManager::new(SessionConfig::default());
687
688        let session = manager.create_session(test_identity(), None, None).unwrap();
689        let original_expires = session.expires_at;
690
691        std::thread::sleep(Duration::from_millis(10));
692
693        let refreshed = manager.refresh_session(&session.token).unwrap();
694        assert!(refreshed.last_activity > session.last_activity);
695    }
696
697    #[test]
698    fn test_invalidate_session() {
699        let manager = SessionManager::new(SessionConfig::default());
700
701        let session = manager.create_session(test_identity(), None, None).unwrap();
702        manager.invalidate_session(&session.token).unwrap();
703
704        assert!(manager.get_session(&session.token).is_err());
705    }
706
707    #[test]
708    fn test_session_limit() {
709        let manager = SessionManager::builder().max_sessions_per_user(2).build();
710
711        let _ = manager.create_session(test_identity(), None, None).unwrap();
712        let _ = manager.create_session(test_identity(), None, None).unwrap();
713
714        let result = manager.create_session(test_identity(), None, None);
715        assert!(matches!(result, Err(SessionError::LimitExceeded)));
716    }
717
718    #[test]
719    fn test_list_user_sessions() {
720        let manager = SessionManager::new(SessionConfig::default());
721
722        let _ = manager.create_session(test_identity(), None, None).unwrap();
723        let _ = manager.create_session(test_identity(), None, None).unwrap();
724
725        let sessions = manager.list_user_sessions("user123");
726        assert_eq!(sessions.len(), 2);
727    }
728
729    #[test]
730    fn test_invalidate_user_sessions() {
731        let manager = SessionManager::new(SessionConfig::default());
732
733        let s1 = manager.create_session(test_identity(), None, None).unwrap();
734        let s2 = manager.create_session(test_identity(), None, None).unwrap();
735
736        manager.invalidate_user_sessions("user123");
737
738        assert!(manager.get_session(&s1.token).is_err());
739        assert!(manager.get_session(&s2.token).is_err());
740    }
741
742    #[test]
743    fn test_session_stats() {
744        let manager = SessionManager::new(SessionConfig::default());
745
746        let _ = manager.create_session(test_identity(), None, None).unwrap();
747        let s2 = manager.create_session(test_identity(), None, None).unwrap();
748        manager.invalidate_session(&s2.token).unwrap();
749
750        let stats = manager.stats();
751        assert_eq!(stats.total, 2);
752        assert_eq!(stats.active, 1);
753    }
754
755    #[test]
756    fn test_update_metadata() {
757        let manager = SessionManager::new(SessionConfig::default());
758
759        let session = manager.create_session(test_identity(), None, None).unwrap();
760        manager
761            .update_metadata(&session.token, "key", "value")
762            .unwrap();
763
764        let updated = manager.get_session(&session.token).unwrap();
765        assert_eq!(updated.metadata.get("key"), Some(&"value".to_string()));
766    }
767
768    #[test]
769    fn test_cookie_options() {
770        let options = CookieOptions {
771            name: "session".to_string(),
772            path: "/".to_string(),
773            domain: Some("example.com".to_string()),
774            secure: true,
775            http_only: true,
776            same_site: SameSite::Strict,
777            max_age: Some(Duration::from_secs(3600)),
778        };
779
780        let header = options.to_set_cookie_header("token123");
781
782        assert!(header.contains("session=token123"));
783        assert!(header.contains("Path=/"));
784        assert!(header.contains("Domain=example.com"));
785        assert!(header.contains("Secure"));
786        assert!(header.contains("HttpOnly"));
787        assert!(header.contains("SameSite=Strict"));
788        assert!(header.contains("Max-Age=3600"));
789    }
790
791    #[test]
792    fn test_delete_cookie() {
793        let options = CookieOptions::default();
794        let header = options.to_delete_cookie_header();
795
796        assert!(header.contains("session="));
797        assert!(header.contains("Max-Age=0"));
798        assert!(header.contains("Expires=Thu, 01 Jan 1970"));
799    }
800
801    #[test]
802    fn test_session_remaining_time() {
803        let manager = SessionManager::builder()
804            .idle_timeout(Duration::from_secs(3600))
805            .build();
806
807        let session = manager.create_session(test_identity(), None, None).unwrap();
808
809        let remaining = session.remaining_time().unwrap();
810        assert!(remaining > Duration::from_secs(3500)); // Should be close to 1 hour
811        assert!(remaining <= Duration::from_secs(3600));
812    }
813}