icarus_core/
session.rs

1//! Session management for stateful MCP interactions
2
3use crate::error::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// A user session
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Session {
11    /// Unique session identifier
12    pub id: String,
13    /// User principal (ICP identity)
14    pub principal: Option<String>,
15    /// Session creation timestamp
16    pub created_at: u64,
17    /// Last activity timestamp
18    pub last_activity: u64,
19    /// Session metadata
20    pub metadata: HashMap<String, serde_json::Value>,
21    /// Whether the session is active
22    pub active: bool,
23}
24
25/// Trait for managing sessions
26#[async_trait]
27pub trait SessionManager: Send + Sync {
28    /// Create a new session
29    async fn create_session(&mut self, principal: Option<String>) -> Result<Session>;
30
31    /// Get a session by ID
32    async fn get_session(&self, session_id: &str) -> Result<Option<Session>>;
33
34    /// Update a session
35    async fn update_session(&mut self, session: Session) -> Result<()>;
36
37    /// Delete a session
38    async fn delete_session(&mut self, session_id: &str) -> Result<()>;
39
40    /// List active sessions
41    async fn list_active_sessions(&self) -> Result<Vec<Session>>;
42
43    /// Clean up expired sessions
44    async fn cleanup_expired(&mut self, max_age_secs: u64) -> Result<u32>;
45
46    /// Touch a session to update last activity
47    async fn touch_session(&mut self, session_id: &str) -> Result<()> {
48        if let Some(mut session) = self.get_session(session_id).await? {
49            session.last_activity = ic_cdk::api::time();
50            self.update_session(session).await?;
51        }
52        Ok(())
53    }
54}
55
56/// Session configuration
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SessionConfig {
59    /// Maximum session duration in seconds
60    pub max_duration: u64,
61    /// Session timeout (inactivity) in seconds
62    pub timeout: u64,
63    /// Maximum number of concurrent sessions per principal
64    pub max_per_principal: u32,
65    /// Whether to require authentication
66    pub require_auth: bool,
67}
68
69impl Default for SessionConfig {
70    fn default() -> Self {
71        Self {
72            max_duration: 86400, // 24 hours
73            timeout: 3600,       // 1 hour
74            max_per_principal: 10,
75            require_auth: false,
76        }
77    }
78}
79
80/// In-memory session manager for testing
81pub struct MemorySessionManager {
82    sessions: HashMap<String, Session>,
83    #[cfg(test)]
84    mock_time: Option<u64>,
85}
86
87impl MemorySessionManager {
88    /// Create a new memory session manager
89    pub fn new(_config: SessionConfig) -> Self {
90        Self {
91            sessions: HashMap::new(),
92            #[cfg(test)]
93            mock_time: None,
94        }
95    }
96
97    #[cfg(test)]
98    fn get_time(&self) -> u64 {
99        self.mock_time.unwrap_or_else(|| {
100            std::time::SystemTime::now()
101                .duration_since(std::time::UNIX_EPOCH)
102                .unwrap()
103                .as_nanos() as u64
104        })
105    }
106
107    #[cfg(not(test))]
108    fn get_time(&self) -> u64 {
109        ic_cdk::api::time()
110    }
111
112    fn generate_session_id() -> String {
113        use std::time::{SystemTime, UNIX_EPOCH};
114        let timestamp = SystemTime::now()
115            .duration_since(UNIX_EPOCH)
116            .unwrap()
117            .as_nanos();
118        format!("session_{}", timestamp)
119    }
120}
121
122#[async_trait]
123impl SessionManager for MemorySessionManager {
124    async fn create_session(&mut self, principal: Option<String>) -> Result<Session> {
125        let now = self.get_time();
126        let session = Session {
127            id: Self::generate_session_id(),
128            principal,
129            created_at: now,
130            last_activity: now,
131            metadata: HashMap::new(),
132            active: true,
133        };
134
135        self.sessions.insert(session.id.clone(), session.clone());
136        Ok(session)
137    }
138
139    async fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
140        Ok(self.sessions.get(session_id).cloned())
141    }
142
143    async fn update_session(&mut self, session: Session) -> Result<()> {
144        self.sessions.insert(session.id.clone(), session);
145        Ok(())
146    }
147
148    async fn delete_session(&mut self, session_id: &str) -> Result<()> {
149        self.sessions.remove(session_id);
150        Ok(())
151    }
152
153    async fn list_active_sessions(&self) -> Result<Vec<Session>> {
154        Ok(self
155            .sessions
156            .values()
157            .filter(|s| s.active)
158            .cloned()
159            .collect())
160    }
161
162    async fn cleanup_expired(&mut self, max_age_secs: u64) -> Result<u32> {
163        let now = self.get_time();
164        let cutoff = now.saturating_sub(max_age_secs * 1_000_000_000); // Convert to nanos
165
166        let expired: Vec<String> = self
167            .sessions
168            .iter()
169            .filter(|(_, session)| session.last_activity < cutoff)
170            .map(|(id, _)| id.clone())
171            .collect();
172
173        let count = expired.len() as u32;
174        for id in expired {
175            self.sessions.remove(&id);
176        }
177
178        Ok(count)
179    }
180}
181
182/// Session context for request handling
183#[derive(Debug, Clone)]
184pub struct SessionContext {
185    pub session: Session,
186    pub authenticated: bool,
187}
188
189impl SessionContext {
190    /// Create a new session context
191    pub fn new(session: Session, authenticated: bool) -> Self {
192        Self {
193            session,
194            authenticated,
195        }
196    }
197
198    /// Get session metadata value
199    pub fn get_metadata<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
200        self.session
201            .metadata
202            .get(key)
203            .and_then(|v| serde_json::from_value(v.clone()).ok())
204    }
205
206    /// Set session metadata value
207    pub fn set_metadata<T: Serialize>(&mut self, key: String, value: T) -> Result<()> {
208        let json_value =
209            serde_json::to_value(value).map_err(crate::error::IcarusError::Serialization)?;
210        self.session.metadata.insert(key, json_value);
211        Ok(())
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[tokio::test]
220    async fn test_memory_session_manager() {
221        let config = SessionConfig::default();
222        let mut manager = MemorySessionManager::new(config);
223
224        // Create session
225        let session = manager
226            .create_session(Some("test-principal".to_string()))
227            .await
228            .unwrap();
229        assert!(session.active);
230        assert_eq!(session.principal, Some("test-principal".to_string()));
231
232        // Get session
233        let retrieved = manager.get_session(&session.id).await.unwrap();
234        assert!(retrieved.is_some());
235        assert_eq!(retrieved.unwrap().id, session.id);
236
237        // List active sessions
238        let active = manager.list_active_sessions().await.unwrap();
239        assert_eq!(active.len(), 1);
240
241        // Delete session
242        manager.delete_session(&session.id).await.unwrap();
243        let deleted = manager.get_session(&session.id).await.unwrap();
244        assert!(deleted.is_none());
245    }
246}