1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6
7use crate::error::{ModeError, Result};
8use crate::mode::Mode;
9use crate::models::ModeContext;
10
11pub struct ModeSwitcher {
19 modes: HashMap<String, Arc<dyn Mode>>,
21 current_mode: Arc<RwLock<Option<String>>>,
23 saved_contexts: Arc<RwLock<HashMap<String, ModeContext>>>,
25 context: Arc<RwLock<ModeContext>>,
27}
28
29impl ModeSwitcher {
30 pub fn new(context: ModeContext) -> Self {
32 Self {
33 modes: HashMap::new(),
34 current_mode: Arc::new(RwLock::new(None)),
35 saved_contexts: Arc::new(RwLock::new(HashMap::new())),
36 context: Arc::new(RwLock::new(context)),
37 }
38 }
39
40 pub fn register_mode(&mut self, mode: Arc<dyn Mode>) {
42 self.modes.insert(mode.id().to_string(), mode);
43 }
44
45 pub fn get_mode(&self, id: &str) -> Result<Arc<dyn Mode>> {
47 self.modes
48 .get(id)
49 .cloned()
50 .ok_or_else(|| ModeError::NotFound(id.to_string()))
51 }
52
53 pub async fn current_mode(&self) -> Result<Option<Arc<dyn Mode>>> {
55 let mode_id = self.current_mode.read().await;
56 match mode_id.as_ref() {
57 Some(id) => Ok(Some(self.get_mode(id)?)),
58 None => Ok(None),
59 }
60 }
61
62 pub async fn current_mode_id(&self) -> Option<String> {
64 self.current_mode.read().await.clone()
65 }
66
67 pub async fn switch_mode(&self, mode_id: &str) -> Result<Arc<dyn Mode>> {
76 let target_mode = self.get_mode(mode_id)?;
78
79 if let Some(current_id) = self.current_mode.read().await.as_ref() {
81 let ctx = self.context.read().await.clone();
82 let mut saved = self.saved_contexts.write().await;
83 saved.insert(current_id.clone(), ctx);
84 }
85
86 let mut saved = self.saved_contexts.write().await;
88 if let Some(saved_ctx) = saved.remove(mode_id) {
89 let mut ctx = self.context.write().await;
90 *ctx = saved_ctx;
91 } else {
92 let mut ctx = self.context.write().await;
94 let session_id = ctx.session_id.clone();
95 let project_path = ctx.project_path.clone();
96 *ctx = ModeContext::new(session_id);
97 ctx.project_path = project_path;
98 }
99
100 let mut current = self.current_mode.write().await;
102 *current = Some(mode_id.to_string());
103
104 Ok(target_mode)
105 }
106
107 pub async fn context(&self) -> ModeContext {
109 self.context.read().await.clone()
110 }
111
112 pub async fn update_context<F>(&self, f: F) -> Result<()>
114 where
115 F: FnOnce(&mut ModeContext),
116 {
117 let mut ctx = self.context.write().await;
118 f(&mut ctx);
119 Ok(())
120 }
121
122 pub async fn save_context_for_mode(&self, mode_id: &str) -> Result<()> {
124 let ctx = self.context.read().await.clone();
125 let mut saved = self.saved_contexts.write().await;
126 saved.insert(mode_id.to_string(), ctx);
127 Ok(())
128 }
129
130 pub async fn restore_context_for_mode(&self, mode_id: &str) -> Result<()> {
132 let mut saved = self.saved_contexts.write().await;
133 if let Some(saved_ctx) = saved.remove(mode_id) {
134 let mut ctx = self.context.write().await;
135 *ctx = saved_ctx;
136 Ok(())
137 } else {
138 Err(ModeError::ContextError(format!(
139 "No saved context for mode: {}",
140 mode_id
141 )))
142 }
143 }
144
145 pub fn has_mode(&self, id: &str) -> bool {
147 self.modes.contains_key(id)
148 }
149
150 pub fn mode_count(&self) -> usize {
152 self.modes.len()
153 }
154
155 pub fn list_mode_ids(&self) -> Vec<String> {
157 self.modes.keys().cloned().collect()
158 }
159
160 pub async fn has_saved_context(&self, mode_id: &str) -> bool {
162 self.saved_contexts.read().await.contains_key(mode_id)
163 }
164
165 pub async fn saved_context_count(&self) -> usize {
167 self.saved_contexts.read().await.len()
168 }
169
170 pub async fn clear_saved_contexts(&self) {
172 self.saved_contexts.write().await.clear();
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::models::{Capability, ModeConfig, ModeConstraints, Operation};
180
181 struct TestMode {
182 id: String,
183 config: ModeConfig,
184 }
185
186 #[async_trait::async_trait]
187 impl Mode for TestMode {
188 fn id(&self) -> &str {
189 &self.id
190 }
191
192 fn name(&self) -> &str {
193 "Test Mode"
194 }
195
196 fn description(&self) -> &str {
197 "A test mode"
198 }
199
200 fn system_prompt(&self) -> &str {
201 "You are a test mode"
202 }
203
204 async fn process(
205 &self,
206 _input: &str,
207 _context: &ModeContext,
208 ) -> Result<crate::models::ModeResponse> {
209 Ok(crate::models::ModeResponse::new(
210 "Test response".to_string(),
211 self.id.clone(),
212 ))
213 }
214
215 fn capabilities(&self) -> Vec<Capability> {
216 vec![Capability::QuestionAnswering]
217 }
218
219 fn config(&self) -> &ModeConfig {
220 &self.config
221 }
222
223 fn can_execute(&self, _operation: &Operation) -> bool {
224 true
225 }
226
227 fn constraints(&self) -> ModeConstraints {
228 ModeConstraints {
229 allow_file_operations: false,
230 allow_command_execution: false,
231 allow_code_generation: false,
232 require_specs: false,
233 auto_think_more_threshold: None,
234 }
235 }
236 }
237
238 #[test]
239 fn test_mode_switcher_creation() {
240 let context = ModeContext::new("test-session".to_string());
241 let switcher = ModeSwitcher::new(context);
242 assert_eq!(switcher.mode_count(), 0);
243 }
244
245 #[test]
246 fn test_register_mode() {
247 let context = ModeContext::new("test-session".to_string());
248 let mut switcher = ModeSwitcher::new(context);
249
250 let mode = Arc::new(TestMode {
251 id: "test".to_string(),
252 config: ModeConfig {
253 temperature: 0.7,
254 max_tokens: 1000,
255 system_prompt: "Test".to_string(),
256 capabilities: vec![Capability::QuestionAnswering],
257 constraints: ModeConstraints {
258 allow_file_operations: false,
259 allow_command_execution: false,
260 allow_code_generation: false,
261 require_specs: false,
262 auto_think_more_threshold: None,
263 },
264 },
265 });
266
267 switcher.register_mode(mode);
268 assert_eq!(switcher.mode_count(), 1);
269 assert!(switcher.has_mode("test"));
270 }
271
272 #[tokio::test]
273 async fn test_switch_mode() {
274 let context = ModeContext::new("test-session".to_string());
275 let mut switcher = ModeSwitcher::new(context);
276
277 let mode = Arc::new(TestMode {
278 id: "test".to_string(),
279 config: ModeConfig {
280 temperature: 0.7,
281 max_tokens: 1000,
282 system_prompt: "Test".to_string(),
283 capabilities: vec![Capability::QuestionAnswering],
284 constraints: ModeConstraints {
285 allow_file_operations: false,
286 allow_command_execution: false,
287 allow_code_generation: false,
288 require_specs: false,
289 auto_think_more_threshold: None,
290 },
291 },
292 });
293
294 switcher.register_mode(mode);
295 let result = switcher.switch_mode("test").await;
296 assert!(result.is_ok());
297
298 let current = switcher.current_mode().await;
299 assert!(current.is_ok());
300 assert!(current.unwrap().is_some());
301 }
302
303 #[tokio::test]
304 async fn test_switch_nonexistent_mode() {
305 let context = ModeContext::new("test-session".to_string());
306 let switcher = ModeSwitcher::new(context);
307 let result = switcher.switch_mode("nonexistent").await;
308 assert!(result.is_err());
309 }
310
311 #[tokio::test]
312 async fn test_context_preservation_on_switch() {
313 let context = ModeContext::new("test-session".to_string());
314 let mut switcher = ModeSwitcher::new(context);
315
316 let mode1 = Arc::new(TestMode {
317 id: "mode1".to_string(),
318 config: ModeConfig {
319 temperature: 0.7,
320 max_tokens: 1000,
321 system_prompt: "Mode 1".to_string(),
322 capabilities: vec![Capability::QuestionAnswering],
323 constraints: ModeConstraints {
324 allow_file_operations: false,
325 allow_command_execution: false,
326 allow_code_generation: false,
327 require_specs: false,
328 auto_think_more_threshold: None,
329 },
330 },
331 });
332
333 let mode2 = Arc::new(TestMode {
334 id: "mode2".to_string(),
335 config: ModeConfig {
336 temperature: 0.7,
337 max_tokens: 1000,
338 system_prompt: "Mode 2".to_string(),
339 capabilities: vec![Capability::QuestionAnswering],
340 constraints: ModeConstraints {
341 allow_file_operations: false,
342 allow_command_execution: false,
343 allow_code_generation: false,
344 require_specs: false,
345 auto_think_more_threshold: None,
346 },
347 },
348 });
349
350 switcher.register_mode(mode1);
351 switcher.register_mode(mode2);
352
353 switcher.switch_mode("mode1").await.unwrap();
355
356 switcher
358 .update_context(|ctx| {
359 ctx.add_message(
360 crate::models::MessageRole::User,
361 "Hello from mode1".to_string(),
362 );
363 })
364 .await
365 .unwrap();
366
367 let ctx1 = switcher.context().await;
368 assert_eq!(ctx1.conversation_history.len(), 1);
369
370 switcher.switch_mode("mode2").await.unwrap();
372
373 let ctx2 = switcher.context().await;
375 assert_eq!(ctx2.conversation_history.len(), 0);
376
377 switcher.switch_mode("mode1").await.unwrap();
379
380 let ctx1_restored = switcher.context().await;
382 assert_eq!(ctx1_restored.conversation_history.len(), 1);
383 assert_eq!(
384 ctx1_restored.conversation_history[0].content,
385 "Hello from mode1"
386 );
387 }
388
389 #[tokio::test]
390 async fn test_save_and_restore_context() {
391 let context = ModeContext::new("test-session".to_string());
392 let switcher = ModeSwitcher::new(context);
393
394 switcher
396 .update_context(|ctx| {
397 ctx.add_message(crate::models::MessageRole::User, "Test message".to_string());
398 })
399 .await
400 .unwrap();
401
402 switcher.save_context_for_mode("test-mode").await.unwrap();
404 assert!(switcher.has_saved_context("test-mode").await);
405
406 switcher
408 .update_context(|ctx| {
409 ctx.conversation_history.clear();
410 })
411 .await
412 .unwrap();
413
414 let ctx = switcher.context().await;
415 assert_eq!(ctx.conversation_history.len(), 0);
416
417 switcher
419 .restore_context_for_mode("test-mode")
420 .await
421 .unwrap();
422
423 let restored_ctx = switcher.context().await;
424 assert_eq!(restored_ctx.conversation_history.len(), 1);
425 assert_eq!(restored_ctx.conversation_history[0].content, "Test message");
426 }
427
428 #[tokio::test]
429 async fn test_current_mode_id() {
430 let context = ModeContext::new("test-session".to_string());
431 let mut switcher = ModeSwitcher::new(context);
432
433 let mode = Arc::new(TestMode {
434 id: "test".to_string(),
435 config: ModeConfig {
436 temperature: 0.7,
437 max_tokens: 1000,
438 system_prompt: "Test".to_string(),
439 capabilities: vec![Capability::QuestionAnswering],
440 constraints: ModeConstraints {
441 allow_file_operations: false,
442 allow_command_execution: false,
443 allow_code_generation: false,
444 require_specs: false,
445 auto_think_more_threshold: None,
446 },
447 },
448 });
449
450 switcher.register_mode(mode);
451
452 assert!(switcher.current_mode_id().await.is_none());
453
454 switcher.switch_mode("test").await.unwrap();
455
456 let mode_id = switcher.current_mode_id().await;
457 assert_eq!(mode_id, Some("test".to_string()));
458 }
459
460 #[tokio::test]
461 async fn test_list_mode_ids() {
462 let context = ModeContext::new("test-session".to_string());
463 let mut switcher = ModeSwitcher::new(context);
464
465 let mode1 = Arc::new(TestMode {
466 id: "mode1".to_string(),
467 config: ModeConfig {
468 temperature: 0.7,
469 max_tokens: 1000,
470 system_prompt: "Mode 1".to_string(),
471 capabilities: vec![Capability::QuestionAnswering],
472 constraints: ModeConstraints {
473 allow_file_operations: false,
474 allow_command_execution: false,
475 allow_code_generation: false,
476 require_specs: false,
477 auto_think_more_threshold: None,
478 },
479 },
480 });
481
482 let mode2 = Arc::new(TestMode {
483 id: "mode2".to_string(),
484 config: ModeConfig {
485 temperature: 0.7,
486 max_tokens: 1000,
487 system_prompt: "Mode 2".to_string(),
488 capabilities: vec![Capability::QuestionAnswering],
489 constraints: ModeConstraints {
490 allow_file_operations: false,
491 allow_command_execution: false,
492 allow_code_generation: false,
493 require_specs: false,
494 auto_think_more_threshold: None,
495 },
496 },
497 });
498
499 switcher.register_mode(mode1);
500 switcher.register_mode(mode2);
501
502 let ids = switcher.list_mode_ids();
503 assert_eq!(ids.len(), 2);
504 assert!(ids.contains(&"mode1".to_string()));
505 assert!(ids.contains(&"mode2".to_string()));
506 }
507
508 #[tokio::test]
509 async fn test_clear_saved_contexts() {
510 let context = ModeContext::new("test-session".to_string());
511 let switcher = ModeSwitcher::new(context);
512
513 switcher.save_context_for_mode("mode1").await.unwrap();
514 switcher.save_context_for_mode("mode2").await.unwrap();
515
516 assert_eq!(switcher.saved_context_count().await, 2);
517
518 switcher.clear_saved_contexts().await;
519
520 assert_eq!(switcher.saved_context_count().await, 0);
521 }
522}