ricecoder_sessions/
router.rs

1//! Session routing for message handling
2
3use crate::error::{SessionError, SessionResult};
4use crate::models::{Message, MessageRole, Session, SessionContext};
5use std::collections::HashMap;
6
7/// Routes messages to the appropriate session
8/// Manages active session state and ensures messages are routed to the correct session
9#[derive(Debug, Clone)]
10pub struct SessionRouter {
11    /// All sessions indexed by ID
12    sessions: HashMap<String, Session>,
13    /// Currently active session ID
14    active_session_id: Option<String>,
15    /// Tracks which session each message belongs to
16    message_session_map: HashMap<String, String>, // message_id -> session_id
17}
18
19impl SessionRouter {
20    /// Create a new session router
21    pub fn new() -> Self {
22        Self {
23            sessions: HashMap::new(),
24            active_session_id: None,
25            message_session_map: HashMap::new(),
26        }
27    }
28
29    /// Create a new session and set it as active
30    pub fn create_session(
31        &mut self,
32        name: String,
33        context: SessionContext,
34    ) -> SessionResult<Session> {
35        let session = Session::new(name, context);
36        let session_id = session.id.clone();
37
38        self.sessions.insert(session_id.clone(), session.clone());
39
40        // Set as active if it's the first session
41        if self.active_session_id.is_none() {
42            self.active_session_id = Some(session_id);
43        }
44
45        Ok(session)
46    }
47
48    /// Route a message to the active session
49    /// Returns the session ID the message was routed to
50    pub fn route_to_active_session(&mut self, message_content: &str) -> SessionResult<String> {
51        let session_id = self
52            .active_session_id
53            .as_ref()
54            .ok_or(SessionError::Invalid("No active session".to_string()))?
55            .clone();
56
57        let session = self
58            .sessions
59            .get_mut(&session_id)
60            .ok_or(SessionError::NotFound(session_id.clone()))?;
61
62        // Create a message and add it to the session history
63        let message = Message::new(MessageRole::User, message_content.to_string());
64        let message_id = message.id.clone();
65
66        session.history.push(message);
67        session.updated_at = chrono::Utc::now();
68
69        // Track which session this message belongs to
70        self.message_session_map
71            .insert(message_id, session_id.clone());
72
73        Ok(session_id)
74    }
75
76    /// Route a message to a specific session
77    /// Returns the session ID the message was routed to
78    pub fn route_to_session(
79        &mut self,
80        session_id: &str,
81        message_content: &str,
82    ) -> SessionResult<String> {
83        let session = self
84            .sessions
85            .get_mut(session_id)
86            .ok_or(SessionError::NotFound(session_id.to_string()))?;
87
88        // Create a message and add it to the session history
89        let message = Message::new(MessageRole::User, message_content.to_string());
90        let message_id = message.id.clone();
91
92        session.history.push(message);
93        session.updated_at = chrono::Utc::now();
94
95        // Track which session this message belongs to
96        self.message_session_map
97            .insert(message_id, session_id.to_string());
98
99        Ok(session_id.to_string())
100    }
101
102    /// Get the active session
103    pub fn get_active_session(&self) -> SessionResult<Session> {
104        let session_id = self
105            .active_session_id
106            .as_ref()
107            .ok_or(SessionError::Invalid("No active session".to_string()))?;
108
109        self.sessions
110            .get(session_id)
111            .cloned()
112            .ok_or_else(|| SessionError::NotFound(session_id.clone()))
113    }
114
115    /// Get a session by ID
116    pub fn get_session(&self, session_id: &str) -> SessionResult<Session> {
117        self.sessions
118            .get(session_id)
119            .cloned()
120            .ok_or_else(|| SessionError::NotFound(session_id.to_string()))
121    }
122
123    /// Switch to a different session
124    pub fn switch_session(&mut self, session_id: &str) -> SessionResult<Session> {
125        // Verify the session exists
126        let session = self.get_session(session_id)?;
127
128        self.active_session_id = Some(session_id.to_string());
129
130        Ok(session)
131    }
132
133    /// Get the ID of the active session
134    pub fn active_session_id(&self) -> Option<&str> {
135        self.active_session_id.as_deref()
136    }
137
138    /// List all sessions
139    pub fn list_sessions(&self) -> Vec<Session> {
140        self.sessions.values().cloned().collect()
141    }
142
143    /// Get which session a message belongs to
144    pub fn get_message_session(&self, message_id: &str) -> Option<String> {
145        self.message_session_map.get(message_id).cloned()
146    }
147
148    /// Verify that a message belongs to a specific session
149    pub fn verify_message_in_session(&self, message_id: &str, session_id: &str) -> bool {
150        self.message_session_map
151            .get(message_id)
152            .map(|id| id == session_id)
153            .unwrap_or(false)
154    }
155
156    /// Delete a session
157    pub fn delete_session(&mut self, session_id: &str) -> SessionResult<()> {
158        if !self.sessions.contains_key(session_id) {
159            return Err(SessionError::NotFound(session_id.to_string()));
160        }
161
162        // Remove all messages from this session from the tracking map
163        self.message_session_map.retain(|_, sid| sid != session_id);
164
165        self.sessions.remove(session_id);
166
167        // If the deleted session was active, switch to another session
168        if self.active_session_id.as_deref() == Some(session_id) {
169            self.active_session_id = self.sessions.keys().next().cloned();
170        }
171
172        Ok(())
173    }
174
175    /// Update a session
176    pub fn update_session(&mut self, session: Session) -> SessionResult<()> {
177        if !self.sessions.contains_key(&session.id) {
178            return Err(SessionError::NotFound(session.id.clone()));
179        }
180
181        self.sessions.insert(session.id.clone(), session);
182        Ok(())
183    }
184
185    /// Get the number of sessions
186    pub fn session_count(&self) -> usize {
187        self.sessions.len()
188    }
189}
190
191impl Default for SessionRouter {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::models::SessionMode;
201
202    fn create_test_context() -> SessionContext {
203        SessionContext::new("openai".to_string(), "gpt-4".to_string(), SessionMode::Chat)
204    }
205
206    #[test]
207    fn test_create_session() {
208        let mut router = SessionRouter::new();
209        let context = create_test_context();
210
211        let session = router
212            .create_session("Test Session".to_string(), context)
213            .unwrap();
214
215        assert_eq!(session.name, "Test Session");
216        assert_eq!(router.session_count(), 1);
217    }
218
219    #[test]
220    fn test_route_to_active_session() {
221        let mut router = SessionRouter::new();
222        let context = create_test_context();
223
224        router
225            .create_session("Test Session".to_string(), context)
226            .unwrap();
227
228        let session_id = router.route_to_active_session("Hello").unwrap();
229
230        let session = router.get_session(&session_id).unwrap();
231        assert_eq!(session.history.len(), 1);
232        assert_eq!(session.history[0].content, "Hello");
233    }
234
235    #[test]
236    fn test_route_to_specific_session() {
237        let mut router = SessionRouter::new();
238        let context = create_test_context();
239
240        let session1 = router
241            .create_session("Session 1".to_string(), context.clone())
242            .unwrap();
243        let session2 = router
244            .create_session("Session 2".to_string(), context)
245            .unwrap();
246
247        // Route message to session 2
248        let routed_session_id = router
249            .route_to_session(&session2.id, "Message to session 2")
250            .unwrap();
251
252        assert_eq!(routed_session_id, session2.id);
253
254        // Verify message is in session 2, not session 1
255        let s1 = router.get_session(&session1.id).unwrap();
256        let s2 = router.get_session(&session2.id).unwrap();
257
258        assert_eq!(s1.history.len(), 0);
259        assert_eq!(s2.history.len(), 1);
260    }
261
262    #[test]
263    fn test_switch_session() {
264        let mut router = SessionRouter::new();
265        let context = create_test_context();
266
267        let session1 = router
268            .create_session("Session 1".to_string(), context.clone())
269            .unwrap();
270        let session2 = router
271            .create_session("Session 2".to_string(), context)
272            .unwrap();
273
274        // Initially session1 is active
275        assert_eq!(router.active_session_id(), Some(session1.id.as_str()));
276
277        // Switch to session 2
278        router.switch_session(&session2.id).unwrap();
279
280        assert_eq!(router.active_session_id(), Some(session2.id.as_str()));
281    }
282
283    #[test]
284    fn test_message_isolation() {
285        let mut router = SessionRouter::new();
286        let context = create_test_context();
287
288        let session1 = router
289            .create_session("Session 1".to_string(), context.clone())
290            .unwrap();
291        let session2 = router
292            .create_session("Session 2".to_string(), context)
293            .unwrap();
294
295        // Route message to session 1
296        router.route_to_session(&session1.id, "Message 1").unwrap();
297
298        // Switch to session 2 and route message
299        router.switch_session(&session2.id).unwrap();
300        router.route_to_active_session("Message 2").unwrap();
301
302        // Verify messages are isolated
303        let s1 = router.get_session(&session1.id).unwrap();
304        let s2 = router.get_session(&session2.id).unwrap();
305
306        assert_eq!(s1.history.len(), 1);
307        assert_eq!(s2.history.len(), 1);
308        assert_eq!(s1.history[0].content, "Message 1");
309        assert_eq!(s2.history[0].content, "Message 2");
310    }
311
312    #[test]
313    fn test_delete_session() {
314        let mut router = SessionRouter::new();
315        let context = create_test_context();
316
317        let session = router
318            .create_session("Test Session".to_string(), context)
319            .unwrap();
320
321        router.delete_session(&session.id).unwrap();
322
323        assert_eq!(router.session_count(), 0);
324        assert!(router.get_session(&session.id).is_err());
325    }
326
327    #[test]
328    fn test_get_message_session() {
329        let mut router = SessionRouter::new();
330        let context = create_test_context();
331
332        let session = router
333            .create_session("Test Session".to_string(), context)
334            .unwrap();
335
336        let session_id = router.route_to_active_session("Hello").unwrap();
337        let message_id = router.get_session(&session_id).unwrap().history[0]
338            .id
339            .clone();
340
341        assert_eq!(router.get_message_session(&message_id), Some(session.id));
342    }
343
344    #[test]
345    fn test_verify_message_in_session() {
346        let mut router = SessionRouter::new();
347        let context = create_test_context();
348
349        let session1 = router
350            .create_session("Session 1".to_string(), context.clone())
351            .unwrap();
352        let session2 = router
353            .create_session("Session 2".to_string(), context)
354            .unwrap();
355
356        router.route_to_session(&session1.id, "Message").unwrap();
357        let message_id = router.get_session(&session1.id).unwrap().history[0]
358            .id
359            .clone();
360
361        assert!(router.verify_message_in_session(&message_id, &session1.id));
362        assert!(!router.verify_message_in_session(&message_id, &session2.id));
363    }
364}