ricecoder_sessions/
router.rs1use crate::error::{SessionError, SessionResult};
4use crate::models::{Message, MessageRole, Session, SessionContext};
5use std::collections::HashMap;
6
7#[derive(Debug, Clone)]
10pub struct SessionRouter {
11 sessions: HashMap<String, Session>,
13 active_session_id: Option<String>,
15 message_session_map: HashMap<String, String>, }
18
19impl SessionRouter {
20 pub fn new() -> Self {
22 Self {
23 sessions: HashMap::new(),
24 active_session_id: None,
25 message_session_map: HashMap::new(),
26 }
27 }
28
29 pub fn create_session(
31 &mut self,
32 name: String,
33 context: SessionContext,
34 ) -> SessionResult<Session> {
35 let session = Session::new(name, context);
36 let session_id = session.id.clone();
37
38 self.sessions.insert(session_id.clone(), session.clone());
39
40 if self.active_session_id.is_none() {
42 self.active_session_id = Some(session_id);
43 }
44
45 Ok(session)
46 }
47
48 pub fn route_to_active_session(&mut self, message_content: &str) -> SessionResult<String> {
51 let session_id = self
52 .active_session_id
53 .as_ref()
54 .ok_or(SessionError::Invalid("No active session".to_string()))?
55 .clone();
56
57 let session = self
58 .sessions
59 .get_mut(&session_id)
60 .ok_or(SessionError::NotFound(session_id.clone()))?;
61
62 let message = Message::new(MessageRole::User, message_content.to_string());
64 let message_id = message.id.clone();
65
66 session.history.push(message);
67 session.updated_at = chrono::Utc::now();
68
69 self.message_session_map
71 .insert(message_id, session_id.clone());
72
73 Ok(session_id)
74 }
75
76 pub fn route_to_session(
79 &mut self,
80 session_id: &str,
81 message_content: &str,
82 ) -> SessionResult<String> {
83 let session = self
84 .sessions
85 .get_mut(session_id)
86 .ok_or(SessionError::NotFound(session_id.to_string()))?;
87
88 let message = Message::new(MessageRole::User, message_content.to_string());
90 let message_id = message.id.clone();
91
92 session.history.push(message);
93 session.updated_at = chrono::Utc::now();
94
95 self.message_session_map
97 .insert(message_id, session_id.to_string());
98
99 Ok(session_id.to_string())
100 }
101
102 pub fn get_active_session(&self) -> SessionResult<Session> {
104 let session_id = self
105 .active_session_id
106 .as_ref()
107 .ok_or(SessionError::Invalid("No active session".to_string()))?;
108
109 self.sessions
110 .get(session_id)
111 .cloned()
112 .ok_or_else(|| SessionError::NotFound(session_id.clone()))
113 }
114
115 pub fn get_session(&self, session_id: &str) -> SessionResult<Session> {
117 self.sessions
118 .get(session_id)
119 .cloned()
120 .ok_or_else(|| SessionError::NotFound(session_id.to_string()))
121 }
122
123 pub fn switch_session(&mut self, session_id: &str) -> SessionResult<Session> {
125 let session = self.get_session(session_id)?;
127
128 self.active_session_id = Some(session_id.to_string());
129
130 Ok(session)
131 }
132
133 pub fn active_session_id(&self) -> Option<&str> {
135 self.active_session_id.as_deref()
136 }
137
138 pub fn list_sessions(&self) -> Vec<Session> {
140 self.sessions.values().cloned().collect()
141 }
142
143 pub fn get_message_session(&self, message_id: &str) -> Option<String> {
145 self.message_session_map.get(message_id).cloned()
146 }
147
148 pub fn verify_message_in_session(&self, message_id: &str, session_id: &str) -> bool {
150 self.message_session_map
151 .get(message_id)
152 .map(|id| id == session_id)
153 .unwrap_or(false)
154 }
155
156 pub fn delete_session(&mut self, session_id: &str) -> SessionResult<()> {
158 if !self.sessions.contains_key(session_id) {
159 return Err(SessionError::NotFound(session_id.to_string()));
160 }
161
162 self.message_session_map.retain(|_, sid| sid != session_id);
164
165 self.sessions.remove(session_id);
166
167 if self.active_session_id.as_deref() == Some(session_id) {
169 self.active_session_id = self.sessions.keys().next().cloned();
170 }
171
172 Ok(())
173 }
174
175 pub fn update_session(&mut self, session: Session) -> SessionResult<()> {
177 if !self.sessions.contains_key(&session.id) {
178 return Err(SessionError::NotFound(session.id.clone()));
179 }
180
181 self.sessions.insert(session.id.clone(), session);
182 Ok(())
183 }
184
185 pub fn session_count(&self) -> usize {
187 self.sessions.len()
188 }
189}
190
191impl Default for SessionRouter {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::models::SessionMode;
201
202 fn create_test_context() -> SessionContext {
203 SessionContext::new("openai".to_string(), "gpt-4".to_string(), SessionMode::Chat)
204 }
205
206 #[test]
207 fn test_create_session() {
208 let mut router = SessionRouter::new();
209 let context = create_test_context();
210
211 let session = router
212 .create_session("Test Session".to_string(), context)
213 .unwrap();
214
215 assert_eq!(session.name, "Test Session");
216 assert_eq!(router.session_count(), 1);
217 }
218
219 #[test]
220 fn test_route_to_active_session() {
221 let mut router = SessionRouter::new();
222 let context = create_test_context();
223
224 router
225 .create_session("Test Session".to_string(), context)
226 .unwrap();
227
228 let session_id = router.route_to_active_session("Hello").unwrap();
229
230 let session = router.get_session(&session_id).unwrap();
231 assert_eq!(session.history.len(), 1);
232 assert_eq!(session.history[0].content, "Hello");
233 }
234
235 #[test]
236 fn test_route_to_specific_session() {
237 let mut router = SessionRouter::new();
238 let context = create_test_context();
239
240 let session1 = router
241 .create_session("Session 1".to_string(), context.clone())
242 .unwrap();
243 let session2 = router
244 .create_session("Session 2".to_string(), context)
245 .unwrap();
246
247 let routed_session_id = router
249 .route_to_session(&session2.id, "Message to session 2")
250 .unwrap();
251
252 assert_eq!(routed_session_id, session2.id);
253
254 let s1 = router.get_session(&session1.id).unwrap();
256 let s2 = router.get_session(&session2.id).unwrap();
257
258 assert_eq!(s1.history.len(), 0);
259 assert_eq!(s2.history.len(), 1);
260 }
261
262 #[test]
263 fn test_switch_session() {
264 let mut router = SessionRouter::new();
265 let context = create_test_context();
266
267 let session1 = router
268 .create_session("Session 1".to_string(), context.clone())
269 .unwrap();
270 let session2 = router
271 .create_session("Session 2".to_string(), context)
272 .unwrap();
273
274 assert_eq!(router.active_session_id(), Some(session1.id.as_str()));
276
277 router.switch_session(&session2.id).unwrap();
279
280 assert_eq!(router.active_session_id(), Some(session2.id.as_str()));
281 }
282
283 #[test]
284 fn test_message_isolation() {
285 let mut router = SessionRouter::new();
286 let context = create_test_context();
287
288 let session1 = router
289 .create_session("Session 1".to_string(), context.clone())
290 .unwrap();
291 let session2 = router
292 .create_session("Session 2".to_string(), context)
293 .unwrap();
294
295 router.route_to_session(&session1.id, "Message 1").unwrap();
297
298 router.switch_session(&session2.id).unwrap();
300 router.route_to_active_session("Message 2").unwrap();
301
302 let s1 = router.get_session(&session1.id).unwrap();
304 let s2 = router.get_session(&session2.id).unwrap();
305
306 assert_eq!(s1.history.len(), 1);
307 assert_eq!(s2.history.len(), 1);
308 assert_eq!(s1.history[0].content, "Message 1");
309 assert_eq!(s2.history[0].content, "Message 2");
310 }
311
312 #[test]
313 fn test_delete_session() {
314 let mut router = SessionRouter::new();
315 let context = create_test_context();
316
317 let session = router
318 .create_session("Test Session".to_string(), context)
319 .unwrap();
320
321 router.delete_session(&session.id).unwrap();
322
323 assert_eq!(router.session_count(), 0);
324 assert!(router.get_session(&session.id).is_err());
325 }
326
327 #[test]
328 fn test_get_message_session() {
329 let mut router = SessionRouter::new();
330 let context = create_test_context();
331
332 let session = router
333 .create_session("Test Session".to_string(), context)
334 .unwrap();
335
336 let session_id = router.route_to_active_session("Hello").unwrap();
337 let message_id = router.get_session(&session_id).unwrap().history[0]
338 .id
339 .clone();
340
341 assert_eq!(router.get_message_session(&message_id), Some(session.id));
342 }
343
344 #[test]
345 fn test_verify_message_in_session() {
346 let mut router = SessionRouter::new();
347 let context = create_test_context();
348
349 let session1 = router
350 .create_session("Session 1".to_string(), context.clone())
351 .unwrap();
352 let session2 = router
353 .create_session("Session 2".to_string(), context)
354 .unwrap();
355
356 router.route_to_session(&session1.id, "Message").unwrap();
357 let message_id = router.get_session(&session1.id).unwrap().history[0]
358 .id
359 .clone();
360
361 assert!(router.verify_message_in_session(&message_id, &session1.id));
362 assert!(!router.verify_message_in_session(&message_id, &session2.id));
363 }
364}