pulseengine_mcp_auth/session/
session_manager.rs

1//! Session Management System for MCP Authentication
2//!
3//! This module provides comprehensive session management including JWT tokens,
4//! session storage, lifecycle management, and security features.
5
6use crate::{
7    AuthContext,
8    jwt::{JwtConfig, JwtError, JwtManager},
9};
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::sync::Arc;
13use thiserror::Error;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info};
16use uuid::Uuid;
17
18/// Errors that can occur during session management
19#[derive(Debug, Error)]
20pub enum SessionError {
21    #[error("Session not found: {session_id}")]
22    SessionNotFound { session_id: String },
23
24    #[error("Session expired: {session_id}")]
25    SessionExpired { session_id: String },
26
27    #[error("Session invalid: {reason}")]
28    SessionInvalid { reason: String },
29
30    #[error("Maximum sessions exceeded for user: {user_id}")]
31    MaxSessionsExceeded { user_id: String },
32
33    #[error("Session creation failed: {reason}")]
34    CreationFailed { reason: String },
35
36    #[error("JWT error: {0}")]
37    JwtError(#[from] JwtError),
38
39    #[error("Storage error: {0}")]
40    StorageError(String),
41
42    #[error("Invalid session token")]
43    InvalidToken,
44}
45
46/// Session information
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct Session {
49    /// Unique session identifier
50    pub session_id: String,
51
52    /// User/API key identifier
53    pub user_id: String,
54
55    /// Authentication context
56    pub auth_context: AuthContext,
57
58    /// Session creation timestamp
59    pub created_at: chrono::DateTime<chrono::Utc>,
60
61    /// Session last accessed timestamp
62    pub last_accessed: chrono::DateTime<chrono::Utc>,
63
64    /// Session expiration timestamp
65    pub expires_at: chrono::DateTime<chrono::Utc>,
66
67    /// Client IP address
68    pub client_ip: Option<String>,
69
70    /// User agent string
71    pub user_agent: Option<String>,
72
73    /// Session metadata
74    pub metadata: HashMap<String, String>,
75
76    /// Whether session is active
77    pub is_active: bool,
78
79    /// JWT refresh token (if applicable)
80    pub refresh_token: Option<String>,
81}
82
83impl Session {
84    /// Create a new session
85    pub fn new(user_id: String, auth_context: AuthContext, duration: chrono::Duration) -> Self {
86        let now = chrono::Utc::now();
87        let session_id = Uuid::new_v4().to_string();
88
89        Self {
90            session_id,
91            user_id,
92            auth_context,
93            created_at: now,
94            last_accessed: now,
95            expires_at: now + duration,
96            client_ip: None,
97            user_agent: None,
98            metadata: HashMap::new(),
99            is_active: true,
100            refresh_token: None,
101        }
102    }
103
104    /// Check if session is expired
105    pub fn is_expired(&self) -> bool {
106        chrono::Utc::now() > self.expires_at
107    }
108
109    /// Update last accessed timestamp
110    pub fn touch(&mut self) {
111        self.last_accessed = chrono::Utc::now();
112    }
113
114    /// Add metadata to session
115    pub fn with_metadata(mut self, key: String, value: String) -> Self {
116        self.metadata.insert(key, value);
117        self
118    }
119
120    /// Add client information
121    pub fn with_client_info(
122        mut self,
123        client_ip: Option<String>,
124        user_agent: Option<String>,
125    ) -> Self {
126        self.client_ip = client_ip;
127        self.user_agent = user_agent;
128        self
129    }
130}
131
132/// Session storage trait for different backends
133#[async_trait::async_trait]
134pub trait SessionStorage: Send + Sync {
135    /// Store a session
136    async fn store_session(&self, session: &Session) -> Result<(), SessionError>;
137
138    /// Retrieve a session by ID
139    async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionError>;
140
141    /// Update an existing session
142    async fn update_session(&self, session: &Session) -> Result<(), SessionError>;
143
144    /// Delete a session
145    async fn delete_session(&self, session_id: &str) -> Result<(), SessionError>;
146
147    /// Get all sessions for a user
148    async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError>;
149
150    /// Clean up expired sessions
151    async fn cleanup_expired(&self) -> Result<u64, SessionError>;
152
153    /// Get session count for a user
154    async fn get_session_count(&self, user_id: &str) -> Result<usize, SessionError>;
155}
156
157/// In-memory session storage implementation
158pub struct MemorySessionStorage {
159    sessions: Arc<RwLock<HashMap<String, Session>>>,
160    user_sessions: Arc<RwLock<HashMap<String, Vec<String>>>>,
161}
162
163impl MemorySessionStorage {
164    pub fn new() -> Self {
165        Self {
166            sessions: Arc::new(RwLock::new(HashMap::new())),
167            user_sessions: Arc::new(RwLock::new(HashMap::new())),
168        }
169    }
170}
171
172impl Default for MemorySessionStorage {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178#[async_trait::async_trait]
179impl SessionStorage for MemorySessionStorage {
180    async fn store_session(&self, session: &Session) -> Result<(), SessionError> {
181        let mut sessions = self.sessions.write().await;
182        let mut user_sessions = self.user_sessions.write().await;
183
184        sessions.insert(session.session_id.clone(), session.clone());
185
186        user_sessions
187            .entry(session.user_id.clone())
188            .or_insert_with(Vec::new)
189            .push(session.session_id.clone());
190
191        debug!(
192            "Stored session {} for user {}",
193            session.session_id, session.user_id
194        );
195        Ok(())
196    }
197
198    async fn get_session(&self, session_id: &str) -> Result<Option<Session>, SessionError> {
199        let sessions = self.sessions.read().await;
200        Ok(sessions.get(session_id).cloned())
201    }
202
203    async fn update_session(&self, session: &Session) -> Result<(), SessionError> {
204        let mut sessions = self.sessions.write().await;
205        if sessions.contains_key(&session.session_id) {
206            sessions.insert(session.session_id.clone(), session.clone());
207            debug!("Updated session {}", session.session_id);
208            Ok(())
209        } else {
210            Err(SessionError::SessionNotFound {
211                session_id: session.session_id.clone(),
212            })
213        }
214    }
215
216    async fn delete_session(&self, session_id: &str) -> Result<(), SessionError> {
217        let mut sessions = self.sessions.write().await;
218        let mut user_sessions = self.user_sessions.write().await;
219
220        if let Some(session) = sessions.remove(session_id) {
221            if let Some(user_session_list) = user_sessions.get_mut(&session.user_id) {
222                user_session_list.retain(|id| id != session_id);
223                if user_session_list.is_empty() {
224                    user_sessions.remove(&session.user_id);
225                }
226            }
227            debug!("Deleted session {}", session_id);
228            Ok(())
229        } else {
230            Err(SessionError::SessionNotFound {
231                session_id: session_id.to_string(),
232            })
233        }
234    }
235
236    async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError> {
237        let sessions = self.sessions.read().await;
238        let user_sessions = self.user_sessions.read().await;
239
240        let mut result = Vec::new();
241        if let Some(session_ids) = user_sessions.get(user_id) {
242            for session_id in session_ids {
243                if let Some(session) = sessions.get(session_id) {
244                    result.push(session.clone());
245                }
246            }
247        }
248
249        Ok(result)
250    }
251
252    async fn cleanup_expired(&self) -> Result<u64, SessionError> {
253        let mut sessions = self.sessions.write().await;
254        let mut user_sessions = self.user_sessions.write().await;
255        let mut removed_count = 0u64;
256
257        let now = chrono::Utc::now();
258        let expired_sessions: Vec<String> = sessions
259            .iter()
260            .filter(|(_, session)| session.expires_at < now)
261            .map(|(id, _)| id.clone())
262            .collect();
263
264        for session_id in expired_sessions {
265            if let Some(session) = sessions.remove(&session_id) {
266                if let Some(user_session_list) = user_sessions.get_mut(&session.user_id) {
267                    user_session_list.retain(|id| id != &session_id);
268                    if user_session_list.is_empty() {
269                        user_sessions.remove(&session.user_id);
270                    }
271                }
272                removed_count += 1;
273            }
274        }
275
276        if removed_count > 0 {
277            info!("Cleaned up {} expired sessions", removed_count);
278        }
279
280        Ok(removed_count)
281    }
282
283    async fn get_session_count(&self, user_id: &str) -> Result<usize, SessionError> {
284        let user_sessions = self.user_sessions.read().await;
285        Ok(user_sessions.get(user_id).map(|v| v.len()).unwrap_or(0))
286    }
287}
288
289/// Configuration for session management
290#[derive(Debug, Clone)]
291pub struct SessionConfig {
292    /// Default session duration
293    pub default_duration: chrono::Duration,
294
295    /// Maximum session duration
296    pub max_duration: chrono::Duration,
297
298    /// Maximum sessions per user
299    pub max_sessions_per_user: usize,
300
301    /// Enable JWT tokens for sessions
302    pub enable_jwt: bool,
303
304    /// JWT configuration
305    pub jwt_config: JwtConfig,
306
307    /// Enable session refresh
308    pub enable_refresh: bool,
309
310    /// Refresh token duration
311    pub refresh_duration: chrono::Duration,
312
313    /// Cleanup interval for expired sessions
314    pub cleanup_interval: chrono::Duration,
315
316    /// Enable session extension on access
317    pub extend_on_access: bool,
318
319    /// Session extension duration
320    pub extension_duration: chrono::Duration,
321}
322
323impl Default for SessionConfig {
324    fn default() -> Self {
325        Self {
326            default_duration: chrono::Duration::hours(24),
327            max_duration: chrono::Duration::days(7),
328            max_sessions_per_user: 10,
329            enable_jwt: true,
330            jwt_config: JwtConfig::default(),
331            enable_refresh: true,
332            refresh_duration: chrono::Duration::days(30),
333            cleanup_interval: chrono::Duration::hours(1),
334            extend_on_access: true,
335            extension_duration: chrono::Duration::hours(1),
336        }
337    }
338}
339
340/// Session manager for handling session lifecycle
341pub struct SessionManager {
342    config: SessionConfig,
343    storage: Arc<dyn SessionStorage>,
344    jwt_manager: Option<Arc<JwtManager>>,
345}
346
347impl SessionManager {
348    /// Create a new session manager
349    pub fn new(config: SessionConfig, storage: Arc<dyn SessionStorage>) -> Self {
350        let jwt_manager = if config.enable_jwt {
351            match JwtManager::new(config.jwt_config.clone()) {
352                Ok(manager) => Some(Arc::new(manager)),
353                Err(e) => {
354                    error!("Failed to create JWT manager: {}", e);
355                    None
356                }
357            }
358        } else {
359            None
360        };
361
362        Self {
363            config,
364            storage,
365            jwt_manager,
366        }
367    }
368
369    /// Create with default configuration and memory storage
370    pub fn with_default_config() -> Self {
371        Self::new(
372            SessionConfig::default(),
373            Arc::new(MemorySessionStorage::new()),
374        )
375    }
376
377    /// Create a new session for a user
378    pub async fn create_session(
379        &self,
380        user_id: String,
381        auth_context: AuthContext,
382        duration: Option<chrono::Duration>,
383        client_ip: Option<String>,
384        user_agent: Option<String>,
385    ) -> Result<(Session, Option<String>), SessionError> {
386        // Check session limits
387        let session_count = self.storage.get_session_count(&user_id).await?;
388        if session_count >= self.config.max_sessions_per_user {
389            return Err(SessionError::MaxSessionsExceeded { user_id });
390        }
391
392        // Use provided duration or default
393        let session_duration = duration.unwrap_or(self.config.default_duration);
394
395        // Ensure duration doesn't exceed maximum
396        let final_duration = std::cmp::min(session_duration, self.config.max_duration);
397
398        // Create session
399        let mut session = Session::new(user_id.clone(), auth_context, final_duration)
400            .with_client_info(client_ip, user_agent);
401
402        // Generate JWT token if enabled
403        let jwt_token = if let Some(jwt_manager) = &self.jwt_manager {
404            let token = jwt_manager
405                .generate_access_token(
406                    session
407                        .auth_context
408                        .user_id
409                        .clone()
410                        .unwrap_or_else(|| user_id.clone()),
411                    session.auth_context.roles.clone(),
412                    session.auth_context.api_key_id.clone(),
413                    session.client_ip.clone(),
414                    Some(session.session_id.clone()),
415                    vec!["api".to_string()],
416                )
417                .await?;
418            Some(token)
419        } else {
420            None
421        };
422
423        // Generate refresh token if enabled
424        if self.config.enable_refresh {
425            session.refresh_token = Some(Uuid::new_v4().to_string());
426        }
427
428        // Store session
429        self.storage.store_session(&session).await?;
430
431        info!(
432            "Created session {} for user {} (duration: {} hours)",
433            session.session_id,
434            user_id,
435            final_duration.num_hours()
436        );
437
438        Ok((session, jwt_token))
439    }
440
441    /// Get a session by ID
442    pub async fn get_session(&self, session_id: &str) -> Result<Session, SessionError> {
443        let session = self.storage.get_session(session_id).await?.ok_or_else(|| {
444            SessionError::SessionNotFound {
445                session_id: session_id.to_string(),
446            }
447        })?;
448
449        if session.is_expired() {
450            // Clean up expired session
451            let _ = self.storage.delete_session(session_id).await;
452            return Err(SessionError::SessionExpired {
453                session_id: session_id.to_string(),
454            });
455        }
456
457        if !session.is_active {
458            return Err(SessionError::SessionInvalid {
459                reason: "Session is inactive".to_string(),
460            });
461        }
462
463        Ok(session)
464    }
465
466    /// Validate and refresh a session
467    pub async fn validate_session(&self, session_id: &str) -> Result<Session, SessionError> {
468        let mut session = self.get_session(session_id).await?;
469
470        // Update last accessed time
471        session.touch();
472
473        // Extend session if configured
474        if self.config.extend_on_access {
475            let new_expiry = chrono::Utc::now() + self.config.extension_duration;
476            if new_expiry < session.expires_at + self.config.max_duration {
477                session.expires_at = new_expiry;
478            }
479        }
480
481        // Update session in storage
482        self.storage.update_session(&session).await?;
483
484        debug!("Validated and updated session {}", session_id);
485        Ok(session)
486    }
487
488    /// Validate a JWT token and return session
489    pub async fn validate_jwt_token(&self, token: &str) -> Result<AuthContext, SessionError> {
490        let jwt_manager =
491            self.jwt_manager
492                .as_ref()
493                .ok_or_else(|| SessionError::SessionInvalid {
494                    reason: "JWT not enabled".to_string(),
495                })?;
496
497        let auth_context = jwt_manager.token_to_auth_context(token).await?;
498        Ok(auth_context)
499    }
500
501    /// Refresh a session using refresh token
502    pub async fn refresh_session(
503        &self,
504        session_id: &str,
505        refresh_token: &str,
506    ) -> Result<(Session, Option<String>), SessionError> {
507        let session = self.get_session(session_id).await?;
508
509        // Validate refresh token
510        if !self.config.enable_refresh {
511            return Err(SessionError::SessionInvalid {
512                reason: "Session refresh not enabled".to_string(),
513            });
514        }
515
516        let stored_refresh_token =
517            session
518                .refresh_token
519                .as_ref()
520                .ok_or_else(|| SessionError::SessionInvalid {
521                    reason: "No refresh token available".to_string(),
522                })?;
523
524        if stored_refresh_token != refresh_token {
525            return Err(SessionError::InvalidToken);
526        }
527
528        // Create new session
529        self.create_session(
530            session.user_id.clone(),
531            session.auth_context.clone(),
532            Some(self.config.default_duration),
533            session.client_ip.clone(),
534            session.user_agent.clone(),
535        )
536        .await
537    }
538
539    /// Terminate a session
540    pub async fn terminate_session(&self, session_id: &str) -> Result<(), SessionError> {
541        self.storage.delete_session(session_id).await?;
542        info!("Terminated session {}", session_id);
543        Ok(())
544    }
545
546    /// Terminate all sessions for a user
547    pub async fn terminate_user_sessions(&self, user_id: &str) -> Result<u64, SessionError> {
548        let sessions = self.storage.get_user_sessions(user_id).await?;
549        let mut terminated_count = 0u64;
550
551        for session in sessions {
552            if self
553                .storage
554                .delete_session(&session.session_id)
555                .await
556                .is_ok()
557            {
558                terminated_count += 1;
559            }
560        }
561
562        info!(
563            "Terminated {} sessions for user {}",
564            terminated_count, user_id
565        );
566        Ok(terminated_count)
567    }
568
569    /// Get all active sessions for a user
570    pub async fn get_user_sessions(&self, user_id: &str) -> Result<Vec<Session>, SessionError> {
571        let sessions = self.storage.get_user_sessions(user_id).await?;
572        let active_sessions = sessions
573            .into_iter()
574            .filter(|s| !s.is_expired() && s.is_active)
575            .collect();
576
577        Ok(active_sessions)
578    }
579
580    /// Clean up expired sessions
581    pub async fn cleanup_expired_sessions(&self) -> Result<u64, SessionError> {
582        self.storage.cleanup_expired().await
583    }
584
585    /// Start background cleanup task
586    pub async fn start_cleanup_task(&self) -> tokio::task::JoinHandle<()> {
587        let storage = Arc::clone(&self.storage);
588        let interval = self.config.cleanup_interval;
589
590        tokio::spawn(async move {
591            let mut cleanup_interval = tokio::time::interval(
592                interval
593                    .to_std()
594                    .unwrap_or(std::time::Duration::from_secs(3600)),
595            );
596
597            loop {
598                cleanup_interval.tick().await;
599
600                match storage.cleanup_expired().await {
601                    Ok(count) => {
602                        if count > 0 {
603                            debug!("Cleanup task removed {} expired sessions", count);
604                        }
605                    }
606                    Err(e) => {
607                        error!("Session cleanup failed: {}", e);
608                    }
609                }
610            }
611        })
612    }
613
614    /// Get session statistics
615    pub async fn get_session_stats(&self) -> Result<SessionStats, SessionError> {
616        // This is a simplified implementation for memory storage
617        // Real implementations would query the storage backend
618        Ok(SessionStats {
619            total_sessions: 0,   // Would count all sessions
620            active_sessions: 0,  // Would count active sessions
621            expired_sessions: 0, // Would count expired sessions
622        })
623    }
624}
625
626/// Session statistics
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct SessionStats {
629    pub total_sessions: u64,
630    pub active_sessions: u64,
631    pub expired_sessions: u64,
632}
633
634#[cfg(test)]
635mod tests {
636    use super::*;
637    use crate::models::Role;
638
639    fn create_test_auth_context() -> AuthContext {
640        AuthContext {
641            user_id: Some("test_user".to_string()),
642            roles: vec![Role::Operator],
643            api_key_id: Some("test_key".to_string()),
644            permissions: vec!["read".to_string(), "write".to_string()],
645        }
646    }
647
648    #[tokio::test]
649    async fn test_session_creation() {
650        let manager = SessionManager::with_default_config();
651        let auth_context = create_test_auth_context();
652
653        let result = manager
654            .create_session(
655                "test_user".to_string(),
656                auth_context,
657                None,
658                Some("127.0.0.1".to_string()),
659                Some("TestAgent/1.0".to_string()),
660            )
661            .await;
662
663        assert!(result.is_ok());
664        let (session, jwt_token) = result.unwrap();
665        assert_eq!(session.user_id, "test_user");
666        assert!(!session.is_expired());
667        assert!(jwt_token.is_some()); // JWT is enabled by default
668    }
669
670    #[tokio::test]
671    async fn test_session_validation() {
672        let manager = SessionManager::with_default_config();
673        let auth_context = create_test_auth_context();
674
675        let (session, _) = manager
676            .create_session("test_user".to_string(), auth_context, None, None, None)
677            .await
678            .unwrap();
679
680        let validated_session = manager.validate_session(&session.session_id).await;
681        assert!(validated_session.is_ok());
682
683        let validated = validated_session.unwrap();
684        assert!(validated.last_accessed > session.last_accessed);
685    }
686
687    #[tokio::test]
688    async fn test_session_expiration() {
689        let manager = SessionManager::with_default_config();
690        let auth_context = create_test_auth_context();
691
692        // Create session with very short duration
693        let (session, _) = manager
694            .create_session(
695                "test_user".to_string(),
696                auth_context,
697                Some(chrono::Duration::milliseconds(1)),
698                None,
699                None,
700            )
701            .await
702            .unwrap();
703
704        // Wait for expiration
705        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
706
707        let result = manager.get_session(&session.session_id).await;
708        assert!(matches!(result, Err(SessionError::SessionExpired { .. })));
709    }
710
711    #[tokio::test]
712    async fn test_session_limits() {
713        let config = SessionConfig {
714            max_sessions_per_user: 2,
715            ..Default::default()
716        };
717        let manager = SessionManager::new(config, Arc::new(MemorySessionStorage::new()));
718        let auth_context = create_test_auth_context();
719
720        // Create first session
721        let result1 = manager
722            .create_session(
723                "test_user".to_string(),
724                auth_context.clone(),
725                None,
726                None,
727                None,
728            )
729            .await;
730        assert!(result1.is_ok());
731
732        // Create second session
733        let result2 = manager
734            .create_session(
735                "test_user".to_string(),
736                auth_context.clone(),
737                None,
738                None,
739                None,
740            )
741            .await;
742        assert!(result2.is_ok());
743
744        // Third session should fail
745        let result3 = manager
746            .create_session("test_user".to_string(), auth_context, None, None, None)
747            .await;
748        assert!(matches!(
749            result3,
750            Err(SessionError::MaxSessionsExceeded { .. })
751        ));
752    }
753
754    #[tokio::test]
755    async fn test_session_termination() {
756        let manager = SessionManager::with_default_config();
757        let auth_context = create_test_auth_context();
758
759        let (session, _) = manager
760            .create_session("test_user".to_string(), auth_context, None, None, None)
761            .await
762            .unwrap();
763
764        // Session should exist
765        assert!(manager.get_session(&session.session_id).await.is_ok());
766
767        // Terminate session
768        assert!(manager.terminate_session(&session.session_id).await.is_ok());
769
770        // Session should no longer exist
771        assert!(matches!(
772            manager.get_session(&session.session_id).await,
773            Err(SessionError::SessionNotFound { .. })
774        ));
775    }
776
777    #[tokio::test]
778    async fn test_cleanup_expired_sessions() {
779        let manager = SessionManager::with_default_config();
780        let auth_context = create_test_auth_context();
781
782        // Create expired session
783        let (_, _) = manager
784            .create_session(
785                "test_user".to_string(),
786                auth_context,
787                Some(chrono::Duration::milliseconds(1)),
788                None,
789                None,
790            )
791            .await
792            .unwrap();
793
794        // Wait for expiration
795        tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
796
797        // Cleanup should remove the expired session
798        let cleanup_result = manager.cleanup_expired_sessions().await;
799        assert!(cleanup_result.is_ok());
800        assert!(cleanup_result.unwrap() > 0);
801    }
802}