Skip to main content

mockforge_foundation/intelligent_behavior/
session.rs

1//! Session management for tracking state across requests
2
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::session_state::SessionState;
10use crate::Result;
11
12/// Session tracking method
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[serde(rename_all = "lowercase")]
16#[derive(Default)]
17pub enum SessionTrackingMethod {
18    /// Track via cookie
19    #[default]
20    Cookie,
21    /// Track via HTTP header
22    Header,
23    /// Track via query parameter
24    QueryParam,
25}
26
27/// Session tracking configuration
28#[derive(Debug, Clone, Serialize, Deserialize)]
29#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
30pub struct SessionTracking {
31    /// Tracking method
32    #[serde(default)]
33    pub method: SessionTrackingMethod,
34
35    /// Cookie name (if method is Cookie)
36    #[serde(default = "default_cookie_name")]
37    pub cookie_name: String,
38
39    /// Header name (if method is Header)
40    #[serde(default = "default_header_name")]
41    pub header_name: String,
42
43    /// Query parameter name (if method is QueryParam)
44    #[serde(default = "default_query_param")]
45    pub query_param: String,
46
47    /// Automatically create sessions if not present
48    #[serde(default = "default_true")]
49    pub auto_create: bool,
50}
51
52impl Default for SessionTracking {
53    fn default() -> Self {
54        Self {
55            method: SessionTrackingMethod::Cookie,
56            cookie_name: default_cookie_name(),
57            header_name: default_header_name(),
58            query_param: default_query_param(),
59            auto_create: true,
60        }
61    }
62}
63
64fn default_cookie_name() -> String {
65    "mockforge_session".to_string()
66}
67
68fn default_header_name() -> String {
69    "X-Session-ID".to_string()
70}
71
72fn default_query_param() -> String {
73    "session_id".to_string()
74}
75
76fn default_true() -> bool {
77    true
78}
79
80/// Session manager for tracking and managing sessions
81pub struct SessionManager {
82    /// Active sessions
83    sessions: Arc<RwLock<HashMap<String, SessionState>>>,
84
85    /// Session tracking configuration
86    config: SessionTracking,
87
88    /// Session timeout in seconds
89    timeout_seconds: u64,
90}
91
92impl SessionManager {
93    /// Create a new session manager
94    pub fn new(config: SessionTracking, timeout_seconds: u64) -> Self {
95        Self {
96            sessions: Arc::new(RwLock::new(HashMap::new())),
97            config,
98            timeout_seconds,
99        }
100    }
101
102    /// Generate a new session ID
103    pub fn generate_session_id() -> String {
104        Uuid::new_v4().to_string()
105    }
106
107    /// Get or create a session
108    pub async fn get_or_create_session(&self, session_id: Option<String>) -> Result<String> {
109        let session_id = match session_id {
110            Some(id) => {
111                // Check if session exists
112                let sessions = self.sessions.read().await;
113                if sessions.contains_key(&id) {
114                    id
115                } else if self.config.auto_create {
116                    drop(sessions); // Release read lock
117                    let new_id = id.clone();
118                    self.create_session(new_id.clone()).await?;
119                    new_id
120                } else {
121                    return Err(crate::Error::internal(format!("Session '{}' not found", id)));
122                }
123            }
124            None => {
125                if self.config.auto_create {
126                    let new_id = Self::generate_session_id();
127                    self.create_session(new_id.clone()).await?;
128                    new_id
129                } else {
130                    return Err(crate::Error::internal(
131                        "No session ID provided and auto-create is disabled",
132                    ));
133                }
134            }
135        };
136
137        Ok(session_id)
138    }
139
140    /// Create a new session
141    pub async fn create_session(&self, session_id: String) -> Result<String> {
142        let mut sessions = self.sessions.write().await;
143
144        if sessions.contains_key(&session_id) {
145            return Err(crate::Error::internal(format!("Session '{}' already exists", session_id)));
146        }
147
148        let state = SessionState::new(session_id.clone());
149        sessions.insert(session_id.clone(), state);
150
151        Ok(session_id)
152    }
153
154    /// Get a session by ID
155    pub async fn get_session(&self, session_id: &str) -> Option<SessionState> {
156        let sessions = self.sessions.read().await;
157        sessions.get(session_id).cloned()
158    }
159
160    /// Update a session
161    pub async fn update_session(&self, session_id: &str, state: SessionState) -> Result<()> {
162        let mut sessions = self.sessions.write().await;
163
164        if !sessions.contains_key(session_id) {
165            return Err(crate::Error::internal(format!("Session '{}' not found", session_id)));
166        }
167
168        sessions.insert(session_id.to_string(), state);
169        Ok(())
170    }
171
172    /// Delete a session
173    pub async fn delete_session(&self, session_id: &str) -> Result<()> {
174        let mut sessions = self.sessions.write().await;
175        sessions.remove(session_id);
176        Ok(())
177    }
178
179    /// List all active session IDs
180    pub async fn list_sessions(&self) -> Vec<String> {
181        let sessions = self.sessions.read().await;
182        sessions.keys().cloned().collect()
183    }
184
185    /// Clean up expired sessions
186    pub async fn cleanup_expired_sessions(&self) -> usize {
187        let timeout = chrono::Duration::seconds(self.timeout_seconds as i64);
188        let mut sessions = self.sessions.write().await;
189
190        let expired: Vec<String> = sessions
191            .iter()
192            .filter(|(_, state)| state.is_inactive(timeout))
193            .map(|(id, _)| id.clone())
194            .collect();
195
196        let count = expired.len();
197        for id in expired {
198            sessions.remove(&id);
199        }
200
201        count
202    }
203
204    /// Get the number of active sessions
205    pub async fn session_count(&self) -> usize {
206        let sessions = self.sessions.read().await;
207        sessions.len()
208    }
209
210    /// Clear all sessions
211    pub async fn clear_all(&self) {
212        let mut sessions = self.sessions.write().await;
213        sessions.clear();
214    }
215
216    /// Get session tracking configuration
217    pub fn config(&self) -> &SessionTracking {
218        &self.config
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[tokio::test]
227    async fn test_session_manager_create_session() {
228        let config = SessionTracking::default();
229        let manager = SessionManager::new(config, 3600);
230
231        let session_id = manager.create_session("test_session".to_string()).await.unwrap();
232        assert_eq!(session_id, "test_session");
233
234        let state = manager.get_session(&session_id).await;
235        assert!(state.is_some());
236    }
237
238    #[tokio::test]
239    async fn test_session_manager_get_or_create() {
240        let config = SessionTracking::default();
241        let manager = SessionManager::new(config, 3600);
242
243        // Create with auto-create
244        let session_id = manager.get_or_create_session(None).await.unwrap();
245        assert!(!session_id.is_empty());
246
247        // Get existing
248        let same_id = manager.get_or_create_session(Some(session_id.clone())).await.unwrap();
249        assert_eq!(session_id, same_id);
250    }
251
252    #[tokio::test]
253    async fn test_session_manager_delete_session() {
254        let config = SessionTracking::default();
255        let manager = SessionManager::new(config, 3600);
256
257        let session_id = manager.create_session("test_delete".to_string()).await.unwrap();
258        assert!(manager.get_session(&session_id).await.is_some());
259
260        manager.delete_session(&session_id).await.unwrap();
261        assert!(manager.get_session(&session_id).await.is_none());
262    }
263
264    #[tokio::test]
265    async fn test_session_manager_list_sessions() {
266        let config = SessionTracking::default();
267        let manager = SessionManager::new(config, 3600);
268
269        manager.create_session("session1".to_string()).await.unwrap();
270        manager.create_session("session2".to_string()).await.unwrap();
271
272        let sessions = manager.list_sessions().await;
273        assert_eq!(sessions.len(), 2);
274        assert!(sessions.contains(&"session1".to_string()));
275        assert!(sessions.contains(&"session2".to_string()));
276    }
277
278    #[tokio::test]
279    async fn test_session_manager_clear_all() {
280        let config = SessionTracking::default();
281        let manager = SessionManager::new(config, 3600);
282
283        manager.create_session("session1".to_string()).await.unwrap();
284        manager.create_session("session2".to_string()).await.unwrap();
285
286        assert_eq!(manager.session_count().await, 2);
287
288        manager.clear_all().await;
289        assert_eq!(manager.session_count().await, 0);
290    }
291
292    #[tokio::test]
293    async fn test_session_cleanup_expired() {
294        let config = SessionTracking::default();
295        let manager = SessionManager::new(config, 1); // 1 second timeout
296
297        let session_id = manager.create_session("test_expire".to_string()).await.unwrap();
298
299        // Wait for session to expire
300        tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
301
302        let cleaned = manager.cleanup_expired_sessions().await;
303        assert_eq!(cleaned, 1);
304        assert!(manager.get_session(&session_id).await.is_none());
305    }
306}