1use crate::error::Result;
4use async_trait::async_trait;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Session {
11 pub id: String,
13 pub principal: Option<String>,
15 pub created_at: u64,
17 pub last_activity: u64,
19 pub metadata: HashMap<String, serde_json::Value>,
21 pub active: bool,
23}
24
25#[async_trait]
27pub trait SessionManager: Send + Sync {
28 async fn create_session(&mut self, principal: Option<String>) -> Result<Session>;
30
31 async fn get_session(&self, session_id: &str) -> Result<Option<Session>>;
33
34 async fn update_session(&mut self, session: Session) -> Result<()>;
36
37 async fn delete_session(&mut self, session_id: &str) -> Result<()>;
39
40 async fn list_active_sessions(&self) -> Result<Vec<Session>>;
42
43 async fn cleanup_expired(&mut self, max_age_secs: u64) -> Result<u32>;
45
46 async fn touch_session(&mut self, session_id: &str) -> Result<()> {
48 if let Some(mut session) = self.get_session(session_id).await? {
49 session.last_activity = ic_cdk::api::time();
50 self.update_session(session).await?;
51 }
52 Ok(())
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SessionConfig {
59 pub max_duration: u64,
61 pub timeout: u64,
63 pub max_per_principal: u32,
65 pub require_auth: bool,
67}
68
69impl Default for SessionConfig {
70 fn default() -> Self {
71 Self {
72 max_duration: 86400, timeout: 3600, max_per_principal: 10,
75 require_auth: false,
76 }
77 }
78}
79
80pub struct MemorySessionManager {
82 sessions: HashMap<String, Session>,
83 #[cfg(test)]
84 mock_time: Option<u64>,
85}
86
87impl MemorySessionManager {
88 pub fn new(_config: SessionConfig) -> Self {
90 Self {
91 sessions: HashMap::new(),
92 #[cfg(test)]
93 mock_time: None,
94 }
95 }
96
97 #[cfg(test)]
98 fn get_time(&self) -> u64 {
99 self.mock_time.unwrap_or_else(|| {
100 std::time::SystemTime::now()
101 .duration_since(std::time::UNIX_EPOCH)
102 .unwrap()
103 .as_nanos() as u64
104 })
105 }
106
107 #[cfg(not(test))]
108 fn get_time(&self) -> u64 {
109 ic_cdk::api::time()
110 }
111
112 fn generate_session_id() -> String {
113 use std::time::{SystemTime, UNIX_EPOCH};
114 let timestamp = SystemTime::now()
115 .duration_since(UNIX_EPOCH)
116 .unwrap()
117 .as_nanos();
118 format!("session_{}", timestamp)
119 }
120}
121
122#[async_trait]
123impl SessionManager for MemorySessionManager {
124 async fn create_session(&mut self, principal: Option<String>) -> Result<Session> {
125 let now = self.get_time();
126 let session = Session {
127 id: Self::generate_session_id(),
128 principal,
129 created_at: now,
130 last_activity: now,
131 metadata: HashMap::new(),
132 active: true,
133 };
134
135 self.sessions.insert(session.id.clone(), session.clone());
136 Ok(session)
137 }
138
139 async fn get_session(&self, session_id: &str) -> Result<Option<Session>> {
140 Ok(self.sessions.get(session_id).cloned())
141 }
142
143 async fn update_session(&mut self, session: Session) -> Result<()> {
144 self.sessions.insert(session.id.clone(), session);
145 Ok(())
146 }
147
148 async fn delete_session(&mut self, session_id: &str) -> Result<()> {
149 self.sessions.remove(session_id);
150 Ok(())
151 }
152
153 async fn list_active_sessions(&self) -> Result<Vec<Session>> {
154 Ok(self
155 .sessions
156 .values()
157 .filter(|s| s.active)
158 .cloned()
159 .collect())
160 }
161
162 async fn cleanup_expired(&mut self, max_age_secs: u64) -> Result<u32> {
163 let now = self.get_time();
164 let cutoff = now.saturating_sub(max_age_secs * 1_000_000_000); let expired: Vec<String> = self
167 .sessions
168 .iter()
169 .filter(|(_, session)| session.last_activity < cutoff)
170 .map(|(id, _)| id.clone())
171 .collect();
172
173 let count = expired.len() as u32;
174 for id in expired {
175 self.sessions.remove(&id);
176 }
177
178 Ok(count)
179 }
180}
181
182#[derive(Debug, Clone)]
184pub struct SessionContext {
185 pub session: Session,
186 pub authenticated: bool,
187}
188
189impl SessionContext {
190 pub fn new(session: Session, authenticated: bool) -> Self {
192 Self {
193 session,
194 authenticated,
195 }
196 }
197
198 pub fn get_metadata<T: for<'de> Deserialize<'de>>(&self, key: &str) -> Option<T> {
200 self.session
201 .metadata
202 .get(key)
203 .and_then(|v| serde_json::from_value(v.clone()).ok())
204 }
205
206 pub fn set_metadata<T: Serialize>(&mut self, key: String, value: T) -> Result<()> {
208 let json_value =
209 serde_json::to_value(value).map_err(crate::error::IcarusError::Serialization)?;
210 self.session.metadata.insert(key, json_value);
211 Ok(())
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[tokio::test]
220 async fn test_memory_session_manager() {
221 let config = SessionConfig::default();
222 let mut manager = MemorySessionManager::new(config);
223
224 let session = manager
226 .create_session(Some("test-principal".to_string()))
227 .await
228 .unwrap();
229 assert!(session.active);
230 assert_eq!(session.principal, Some("test-principal".to_string()));
231
232 let retrieved = manager.get_session(&session.id).await.unwrap();
234 assert!(retrieved.is_some());
235 assert_eq!(retrieved.unwrap().id, session.id);
236
237 let active = manager.list_active_sessions().await.unwrap();
239 assert_eq!(active.len(), 1);
240
241 manager.delete_session(&session.id).await.unwrap();
243 let deleted = manager.get_session(&session.id).await.unwrap();
244 assert!(deleted.is_none());
245 }
246}