ricecoder_modes/
mode_switcher.rs

1//! Mode switcher for handling mode transitions with context preservation
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use crate::error::{ModeError, Result};
8use crate::mode::Mode;
9use crate::models::ModeContext;
10
11/// Handles mode transitions with context preservation
12///
13/// The ModeSwitcher is responsible for:
14/// - Validating mode transitions
15/// - Preserving context across switches
16/// - Restoring context after switches
17/// - Managing mode-specific data
18pub struct ModeSwitcher {
19    /// Available modes
20    modes: HashMap<String, Arc<dyn Mode>>,
21    /// Current active mode
22    current_mode: Arc<RwLock<Option<String>>>,
23    /// Saved contexts for each mode
24    saved_contexts: Arc<RwLock<HashMap<String, ModeContext>>>,
25    /// Current execution context
26    context: Arc<RwLock<ModeContext>>,
27}
28
29impl ModeSwitcher {
30    /// Create a new mode switcher with the given context
31    pub fn new(context: ModeContext) -> Self {
32        Self {
33            modes: HashMap::new(),
34            current_mode: Arc::new(RwLock::new(None)),
35            saved_contexts: Arc::new(RwLock::new(HashMap::new())),
36            context: Arc::new(RwLock::new(context)),
37        }
38    }
39
40    /// Register a mode
41    pub fn register_mode(&mut self, mode: Arc<dyn Mode>) {
42        self.modes.insert(mode.id().to_string(), mode);
43    }
44
45    /// Get a registered mode by ID
46    pub fn get_mode(&self, id: &str) -> Result<Arc<dyn Mode>> {
47        self.modes
48            .get(id)
49            .cloned()
50            .ok_or_else(|| ModeError::NotFound(id.to_string()))
51    }
52
53    /// Get the current active mode
54    pub async fn current_mode(&self) -> Result<Option<Arc<dyn Mode>>> {
55        let mode_id = self.current_mode.read().await;
56        match mode_id.as_ref() {
57            Some(id) => Ok(Some(self.get_mode(id)?)),
58            None => Ok(None),
59        }
60    }
61
62    /// Get the current mode ID
63    pub async fn current_mode_id(&self) -> Option<String> {
64        self.current_mode.read().await.clone()
65    }
66
67    /// Switch to a different mode with context preservation
68    ///
69    /// This method:
70    /// 1. Validates the mode exists
71    /// 2. Saves the current context if switching from a mode
72    /// 3. Restores the saved context for the new mode if available
73    /// 4. Creates a fresh context if the mode hasn't been visited before
74    /// 5. Updates the current mode
75    pub async fn switch_mode(&self, mode_id: &str) -> Result<Arc<dyn Mode>> {
76        // Validate the target mode exists
77        let target_mode = self.get_mode(mode_id)?;
78
79        // Save current context if we're switching from a mode
80        if let Some(current_id) = self.current_mode.read().await.as_ref() {
81            let ctx = self.context.read().await.clone();
82            let mut saved = self.saved_contexts.write().await;
83            saved.insert(current_id.clone(), ctx);
84        }
85
86        // Restore context for the new mode if available, otherwise create fresh context
87        let mut saved = self.saved_contexts.write().await;
88        if let Some(saved_ctx) = saved.remove(mode_id) {
89            let mut ctx = self.context.write().await;
90            *ctx = saved_ctx;
91        } else {
92            // Create a fresh context for this mode (preserving session_id)
93            let mut ctx = self.context.write().await;
94            let session_id = ctx.session_id.clone();
95            let project_path = ctx.project_path.clone();
96            *ctx = ModeContext::new(session_id);
97            ctx.project_path = project_path;
98        }
99
100        // Update current mode
101        let mut current = self.current_mode.write().await;
102        *current = Some(mode_id.to_string());
103
104        Ok(target_mode)
105    }
106
107    /// Get the current context
108    pub async fn context(&self) -> ModeContext {
109        self.context.read().await.clone()
110    }
111
112    /// Update the context with a closure
113    pub async fn update_context<F>(&self, f: F) -> Result<()>
114    where
115        F: FnOnce(&mut ModeContext),
116    {
117        let mut ctx = self.context.write().await;
118        f(&mut ctx);
119        Ok(())
120    }
121
122    /// Save the current context for a specific mode
123    pub async fn save_context_for_mode(&self, mode_id: &str) -> Result<()> {
124        let ctx = self.context.read().await.clone();
125        let mut saved = self.saved_contexts.write().await;
126        saved.insert(mode_id.to_string(), ctx);
127        Ok(())
128    }
129
130    /// Restore context for a specific mode
131    pub async fn restore_context_for_mode(&self, mode_id: &str) -> Result<()> {
132        let mut saved = self.saved_contexts.write().await;
133        if let Some(saved_ctx) = saved.remove(mode_id) {
134            let mut ctx = self.context.write().await;
135            *ctx = saved_ctx;
136            Ok(())
137        } else {
138            Err(ModeError::ContextError(format!(
139                "No saved context for mode: {}",
140                mode_id
141            )))
142        }
143    }
144
145    /// Check if a mode is registered
146    pub fn has_mode(&self, id: &str) -> bool {
147        self.modes.contains_key(id)
148    }
149
150    /// Get the number of registered modes
151    pub fn mode_count(&self) -> usize {
152        self.modes.len()
153    }
154
155    /// Get all registered mode IDs
156    pub fn list_mode_ids(&self) -> Vec<String> {
157        self.modes.keys().cloned().collect()
158    }
159
160    /// Check if context is saved for a mode
161    pub async fn has_saved_context(&self, mode_id: &str) -> bool {
162        self.saved_contexts.read().await.contains_key(mode_id)
163    }
164
165    /// Get the number of saved contexts
166    pub async fn saved_context_count(&self) -> usize {
167        self.saved_contexts.read().await.len()
168    }
169
170    /// Clear all saved contexts
171    pub async fn clear_saved_contexts(&self) {
172        self.saved_contexts.write().await.clear();
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::models::{Capability, ModeConfig, ModeConstraints, Operation};
180
181    struct TestMode {
182        id: String,
183        config: ModeConfig,
184    }
185
186    #[async_trait::async_trait]
187    impl Mode for TestMode {
188        fn id(&self) -> &str {
189            &self.id
190        }
191
192        fn name(&self) -> &str {
193            "Test Mode"
194        }
195
196        fn description(&self) -> &str {
197            "A test mode"
198        }
199
200        fn system_prompt(&self) -> &str {
201            "You are a test mode"
202        }
203
204        async fn process(
205            &self,
206            _input: &str,
207            _context: &ModeContext,
208        ) -> Result<crate::models::ModeResponse> {
209            Ok(crate::models::ModeResponse::new(
210                "Test response".to_string(),
211                self.id.clone(),
212            ))
213        }
214
215        fn capabilities(&self) -> Vec<Capability> {
216            vec![Capability::QuestionAnswering]
217        }
218
219        fn config(&self) -> &ModeConfig {
220            &self.config
221        }
222
223        fn can_execute(&self, _operation: &Operation) -> bool {
224            true
225        }
226
227        fn constraints(&self) -> ModeConstraints {
228            ModeConstraints {
229                allow_file_operations: false,
230                allow_command_execution: false,
231                allow_code_generation: false,
232                require_specs: false,
233                auto_think_more_threshold: None,
234            }
235        }
236    }
237
238    #[test]
239    fn test_mode_switcher_creation() {
240        let context = ModeContext::new("test-session".to_string());
241        let switcher = ModeSwitcher::new(context);
242        assert_eq!(switcher.mode_count(), 0);
243    }
244
245    #[test]
246    fn test_register_mode() {
247        let context = ModeContext::new("test-session".to_string());
248        let mut switcher = ModeSwitcher::new(context);
249
250        let mode = Arc::new(TestMode {
251            id: "test".to_string(),
252            config: ModeConfig {
253                temperature: 0.7,
254                max_tokens: 1000,
255                system_prompt: "Test".to_string(),
256                capabilities: vec![Capability::QuestionAnswering],
257                constraints: ModeConstraints {
258                    allow_file_operations: false,
259                    allow_command_execution: false,
260                    allow_code_generation: false,
261                    require_specs: false,
262                    auto_think_more_threshold: None,
263                },
264            },
265        });
266
267        switcher.register_mode(mode);
268        assert_eq!(switcher.mode_count(), 1);
269        assert!(switcher.has_mode("test"));
270    }
271
272    #[tokio::test]
273    async fn test_switch_mode() {
274        let context = ModeContext::new("test-session".to_string());
275        let mut switcher = ModeSwitcher::new(context);
276
277        let mode = Arc::new(TestMode {
278            id: "test".to_string(),
279            config: ModeConfig {
280                temperature: 0.7,
281                max_tokens: 1000,
282                system_prompt: "Test".to_string(),
283                capabilities: vec![Capability::QuestionAnswering],
284                constraints: ModeConstraints {
285                    allow_file_operations: false,
286                    allow_command_execution: false,
287                    allow_code_generation: false,
288                    require_specs: false,
289                    auto_think_more_threshold: None,
290                },
291            },
292        });
293
294        switcher.register_mode(mode);
295        let result = switcher.switch_mode("test").await;
296        assert!(result.is_ok());
297
298        let current = switcher.current_mode().await;
299        assert!(current.is_ok());
300        assert!(current.unwrap().is_some());
301    }
302
303    #[tokio::test]
304    async fn test_switch_nonexistent_mode() {
305        let context = ModeContext::new("test-session".to_string());
306        let switcher = ModeSwitcher::new(context);
307        let result = switcher.switch_mode("nonexistent").await;
308        assert!(result.is_err());
309    }
310
311    #[tokio::test]
312    async fn test_context_preservation_on_switch() {
313        let context = ModeContext::new("test-session".to_string());
314        let mut switcher = ModeSwitcher::new(context);
315
316        let mode1 = Arc::new(TestMode {
317            id: "mode1".to_string(),
318            config: ModeConfig {
319                temperature: 0.7,
320                max_tokens: 1000,
321                system_prompt: "Mode 1".to_string(),
322                capabilities: vec![Capability::QuestionAnswering],
323                constraints: ModeConstraints {
324                    allow_file_operations: false,
325                    allow_command_execution: false,
326                    allow_code_generation: false,
327                    require_specs: false,
328                    auto_think_more_threshold: None,
329                },
330            },
331        });
332
333        let mode2 = Arc::new(TestMode {
334            id: "mode2".to_string(),
335            config: ModeConfig {
336                temperature: 0.7,
337                max_tokens: 1000,
338                system_prompt: "Mode 2".to_string(),
339                capabilities: vec![Capability::QuestionAnswering],
340                constraints: ModeConstraints {
341                    allow_file_operations: false,
342                    allow_command_execution: false,
343                    allow_code_generation: false,
344                    require_specs: false,
345                    auto_think_more_threshold: None,
346                },
347            },
348        });
349
350        switcher.register_mode(mode1);
351        switcher.register_mode(mode2);
352
353        // Switch to mode1
354        switcher.switch_mode("mode1").await.unwrap();
355
356        // Add a message to the context
357        switcher
358            .update_context(|ctx| {
359                ctx.add_message(
360                    crate::models::MessageRole::User,
361                    "Hello from mode1".to_string(),
362                );
363            })
364            .await
365            .unwrap();
366
367        let ctx1 = switcher.context().await;
368        assert_eq!(ctx1.conversation_history.len(), 1);
369
370        // Switch to mode2
371        switcher.switch_mode("mode2").await.unwrap();
372
373        // Context should be empty for mode2
374        let ctx2 = switcher.context().await;
375        assert_eq!(ctx2.conversation_history.len(), 0);
376
377        // Switch back to mode1
378        switcher.switch_mode("mode1").await.unwrap();
379
380        // Context should be restored
381        let ctx1_restored = switcher.context().await;
382        assert_eq!(ctx1_restored.conversation_history.len(), 1);
383        assert_eq!(
384            ctx1_restored.conversation_history[0].content,
385            "Hello from mode1"
386        );
387    }
388
389    #[tokio::test]
390    async fn test_save_and_restore_context() {
391        let context = ModeContext::new("test-session".to_string());
392        let switcher = ModeSwitcher::new(context);
393
394        // Add a message to the context
395        switcher
396            .update_context(|ctx| {
397                ctx.add_message(crate::models::MessageRole::User, "Test message".to_string());
398            })
399            .await
400            .unwrap();
401
402        // Save context for a mode
403        switcher.save_context_for_mode("test-mode").await.unwrap();
404        assert!(switcher.has_saved_context("test-mode").await);
405
406        // Clear the current context
407        switcher
408            .update_context(|ctx| {
409                ctx.conversation_history.clear();
410            })
411            .await
412            .unwrap();
413
414        let ctx = switcher.context().await;
415        assert_eq!(ctx.conversation_history.len(), 0);
416
417        // Restore context
418        switcher
419            .restore_context_for_mode("test-mode")
420            .await
421            .unwrap();
422
423        let restored_ctx = switcher.context().await;
424        assert_eq!(restored_ctx.conversation_history.len(), 1);
425        assert_eq!(restored_ctx.conversation_history[0].content, "Test message");
426    }
427
428    #[tokio::test]
429    async fn test_current_mode_id() {
430        let context = ModeContext::new("test-session".to_string());
431        let mut switcher = ModeSwitcher::new(context);
432
433        let mode = Arc::new(TestMode {
434            id: "test".to_string(),
435            config: ModeConfig {
436                temperature: 0.7,
437                max_tokens: 1000,
438                system_prompt: "Test".to_string(),
439                capabilities: vec![Capability::QuestionAnswering],
440                constraints: ModeConstraints {
441                    allow_file_operations: false,
442                    allow_command_execution: false,
443                    allow_code_generation: false,
444                    require_specs: false,
445                    auto_think_more_threshold: None,
446                },
447            },
448        });
449
450        switcher.register_mode(mode);
451
452        assert!(switcher.current_mode_id().await.is_none());
453
454        switcher.switch_mode("test").await.unwrap();
455
456        let mode_id = switcher.current_mode_id().await;
457        assert_eq!(mode_id, Some("test".to_string()));
458    }
459
460    #[tokio::test]
461    async fn test_list_mode_ids() {
462        let context = ModeContext::new("test-session".to_string());
463        let mut switcher = ModeSwitcher::new(context);
464
465        let mode1 = Arc::new(TestMode {
466            id: "mode1".to_string(),
467            config: ModeConfig {
468                temperature: 0.7,
469                max_tokens: 1000,
470                system_prompt: "Mode 1".to_string(),
471                capabilities: vec![Capability::QuestionAnswering],
472                constraints: ModeConstraints {
473                    allow_file_operations: false,
474                    allow_command_execution: false,
475                    allow_code_generation: false,
476                    require_specs: false,
477                    auto_think_more_threshold: None,
478                },
479            },
480        });
481
482        let mode2 = Arc::new(TestMode {
483            id: "mode2".to_string(),
484            config: ModeConfig {
485                temperature: 0.7,
486                max_tokens: 1000,
487                system_prompt: "Mode 2".to_string(),
488                capabilities: vec![Capability::QuestionAnswering],
489                constraints: ModeConstraints {
490                    allow_file_operations: false,
491                    allow_command_execution: false,
492                    allow_code_generation: false,
493                    require_specs: false,
494                    auto_think_more_threshold: None,
495                },
496            },
497        });
498
499        switcher.register_mode(mode1);
500        switcher.register_mode(mode2);
501
502        let ids = switcher.list_mode_ids();
503        assert_eq!(ids.len(), 2);
504        assert!(ids.contains(&"mode1".to_string()));
505        assert!(ids.contains(&"mode2".to_string()));
506    }
507
508    #[tokio::test]
509    async fn test_clear_saved_contexts() {
510        let context = ModeContext::new("test-session".to_string());
511        let switcher = ModeSwitcher::new(context);
512
513        switcher.save_context_for_mode("mode1").await.unwrap();
514        switcher.save_context_for_mode("mode2").await.unwrap();
515
516        assert_eq!(switcher.saved_context_count().await, 2);
517
518        switcher.clear_saved_contexts().await;
519
520        assert_eq!(switcher.saved_context_count().await, 0);
521    }
522}