mockforge_foundation/intelligent_behavior/
session.rs1use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::session_state::SessionState;
10use crate::Result;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
15#[serde(rename_all = "lowercase")]
16#[derive(Default)]
17pub enum SessionTrackingMethod {
18 #[default]
20 Cookie,
21 Header,
23 QueryParam,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
30pub struct SessionTracking {
31 #[serde(default)]
33 pub method: SessionTrackingMethod,
34
35 #[serde(default = "default_cookie_name")]
37 pub cookie_name: String,
38
39 #[serde(default = "default_header_name")]
41 pub header_name: String,
42
43 #[serde(default = "default_query_param")]
45 pub query_param: String,
46
47 #[serde(default = "default_true")]
49 pub auto_create: bool,
50}
51
52impl Default for SessionTracking {
53 fn default() -> Self {
54 Self {
55 method: SessionTrackingMethod::Cookie,
56 cookie_name: default_cookie_name(),
57 header_name: default_header_name(),
58 query_param: default_query_param(),
59 auto_create: true,
60 }
61 }
62}
63
64fn default_cookie_name() -> String {
65 "mockforge_session".to_string()
66}
67
68fn default_header_name() -> String {
69 "X-Session-ID".to_string()
70}
71
72fn default_query_param() -> String {
73 "session_id".to_string()
74}
75
76fn default_true() -> bool {
77 true
78}
79
80pub struct SessionManager {
82 sessions: Arc<RwLock<HashMap<String, SessionState>>>,
84
85 config: SessionTracking,
87
88 timeout_seconds: u64,
90}
91
92impl SessionManager {
93 pub fn new(config: SessionTracking, timeout_seconds: u64) -> Self {
95 Self {
96 sessions: Arc::new(RwLock::new(HashMap::new())),
97 config,
98 timeout_seconds,
99 }
100 }
101
102 pub fn generate_session_id() -> String {
104 Uuid::new_v4().to_string()
105 }
106
107 pub async fn get_or_create_session(&self, session_id: Option<String>) -> Result<String> {
109 let session_id = match session_id {
110 Some(id) => {
111 let sessions = self.sessions.read().await;
113 if sessions.contains_key(&id) {
114 id
115 } else if self.config.auto_create {
116 drop(sessions); let new_id = id.clone();
118 self.create_session(new_id.clone()).await?;
119 new_id
120 } else {
121 return Err(crate::Error::internal(format!("Session '{}' not found", id)));
122 }
123 }
124 None => {
125 if self.config.auto_create {
126 let new_id = Self::generate_session_id();
127 self.create_session(new_id.clone()).await?;
128 new_id
129 } else {
130 return Err(crate::Error::internal(
131 "No session ID provided and auto-create is disabled",
132 ));
133 }
134 }
135 };
136
137 Ok(session_id)
138 }
139
140 pub async fn create_session(&self, session_id: String) -> Result<String> {
142 let mut sessions = self.sessions.write().await;
143
144 if sessions.contains_key(&session_id) {
145 return Err(crate::Error::internal(format!("Session '{}' already exists", session_id)));
146 }
147
148 let state = SessionState::new(session_id.clone());
149 sessions.insert(session_id.clone(), state);
150
151 Ok(session_id)
152 }
153
154 pub async fn get_session(&self, session_id: &str) -> Option<SessionState> {
156 let sessions = self.sessions.read().await;
157 sessions.get(session_id).cloned()
158 }
159
160 pub async fn update_session(&self, session_id: &str, state: SessionState) -> Result<()> {
162 let mut sessions = self.sessions.write().await;
163
164 if !sessions.contains_key(session_id) {
165 return Err(crate::Error::internal(format!("Session '{}' not found", session_id)));
166 }
167
168 sessions.insert(session_id.to_string(), state);
169 Ok(())
170 }
171
172 pub async fn delete_session(&self, session_id: &str) -> Result<()> {
174 let mut sessions = self.sessions.write().await;
175 sessions.remove(session_id);
176 Ok(())
177 }
178
179 pub async fn list_sessions(&self) -> Vec<String> {
181 let sessions = self.sessions.read().await;
182 sessions.keys().cloned().collect()
183 }
184
185 pub async fn cleanup_expired_sessions(&self) -> usize {
187 let timeout = chrono::Duration::seconds(self.timeout_seconds as i64);
188 let mut sessions = self.sessions.write().await;
189
190 let expired: Vec<String> = sessions
191 .iter()
192 .filter(|(_, state)| state.is_inactive(timeout))
193 .map(|(id, _)| id.clone())
194 .collect();
195
196 let count = expired.len();
197 for id in expired {
198 sessions.remove(&id);
199 }
200
201 count
202 }
203
204 pub async fn session_count(&self) -> usize {
206 let sessions = self.sessions.read().await;
207 sessions.len()
208 }
209
210 pub async fn clear_all(&self) {
212 let mut sessions = self.sessions.write().await;
213 sessions.clear();
214 }
215
216 pub fn config(&self) -> &SessionTracking {
218 &self.config
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[tokio::test]
227 async fn test_session_manager_create_session() {
228 let config = SessionTracking::default();
229 let manager = SessionManager::new(config, 3600);
230
231 let session_id = manager.create_session("test_session".to_string()).await.unwrap();
232 assert_eq!(session_id, "test_session");
233
234 let state = manager.get_session(&session_id).await;
235 assert!(state.is_some());
236 }
237
238 #[tokio::test]
239 async fn test_session_manager_get_or_create() {
240 let config = SessionTracking::default();
241 let manager = SessionManager::new(config, 3600);
242
243 let session_id = manager.get_or_create_session(None).await.unwrap();
245 assert!(!session_id.is_empty());
246
247 let same_id = manager.get_or_create_session(Some(session_id.clone())).await.unwrap();
249 assert_eq!(session_id, same_id);
250 }
251
252 #[tokio::test]
253 async fn test_session_manager_delete_session() {
254 let config = SessionTracking::default();
255 let manager = SessionManager::new(config, 3600);
256
257 let session_id = manager.create_session("test_delete".to_string()).await.unwrap();
258 assert!(manager.get_session(&session_id).await.is_some());
259
260 manager.delete_session(&session_id).await.unwrap();
261 assert!(manager.get_session(&session_id).await.is_none());
262 }
263
264 #[tokio::test]
265 async fn test_session_manager_list_sessions() {
266 let config = SessionTracking::default();
267 let manager = SessionManager::new(config, 3600);
268
269 manager.create_session("session1".to_string()).await.unwrap();
270 manager.create_session("session2".to_string()).await.unwrap();
271
272 let sessions = manager.list_sessions().await;
273 assert_eq!(sessions.len(), 2);
274 assert!(sessions.contains(&"session1".to_string()));
275 assert!(sessions.contains(&"session2".to_string()));
276 }
277
278 #[tokio::test]
279 async fn test_session_manager_clear_all() {
280 let config = SessionTracking::default();
281 let manager = SessionManager::new(config, 3600);
282
283 manager.create_session("session1".to_string()).await.unwrap();
284 manager.create_session("session2".to_string()).await.unwrap();
285
286 assert_eq!(manager.session_count().await, 2);
287
288 manager.clear_all().await;
289 assert_eq!(manager.session_count().await, 0);
290 }
291
292 #[tokio::test]
293 async fn test_session_cleanup_expired() {
294 let config = SessionTracking::default();
295 let manager = SessionManager::new(config, 1); let session_id = manager.create_session("test_expire".to_string()).await.unwrap();
298
299 tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
301
302 let cleaned = manager.cleanup_expired_sessions().await;
303 assert_eq!(cleaned, 1);
304 assert!(manager.get_session(&session_id).await.is_none());
305 }
306}