ricecoder_modes/
task_config.rs

1use crate::error::{ModeError, Result};
2use crate::models::ThinkMoreConfig;
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5
6/// Manages per-task Think More configuration with context storage
7#[derive(Debug, Clone)]
8pub struct TaskConfigManager {
9    /// Per-task configurations
10    configs: Arc<Mutex<HashMap<String, TaskConfig>>>,
11}
12
13/// Configuration for a specific task
14#[derive(Debug, Clone)]
15pub struct TaskConfig {
16    /// Task identifier
17    pub task_id: String,
18    /// Think More configuration for this task
19    pub think_more_config: ThinkMoreConfig,
20    /// Custom context data for this task
21    pub context: HashMap<String, serde_json::Value>,
22}
23
24impl TaskConfig {
25    /// Create a new task configuration
26    pub fn new(task_id: String, think_more_config: ThinkMoreConfig) -> Self {
27        Self {
28            task_id,
29            think_more_config,
30            context: HashMap::new(),
31        }
32    }
33
34    /// Add context data to the task
35    pub fn add_context(&mut self, key: String, value: serde_json::Value) {
36        self.context.insert(key, value);
37    }
38
39    /// Get context data from the task
40    pub fn get_context(&self, key: &str) -> Option<&serde_json::Value> {
41        self.context.get(key)
42    }
43
44    /// Remove context data from the task
45    pub fn remove_context(&mut self, key: &str) -> Option<serde_json::Value> {
46        self.context.remove(key)
47    }
48
49    /// Clear all context data
50    pub fn clear_context(&mut self) {
51        self.context.clear();
52    }
53}
54
55impl TaskConfigManager {
56    /// Create a new task configuration manager
57    pub fn new() -> Self {
58        Self {
59            configs: Arc::new(Mutex::new(HashMap::new())),
60        }
61    }
62
63    /// Register a task with its configuration
64    pub fn register_task(&self, task_id: String, config: ThinkMoreConfig) -> Result<()> {
65        let mut configs = self.configs.lock().map_err(|_| {
66            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
67        })?;
68        let task_config = TaskConfig::new(task_id.clone(), config);
69        configs.insert(task_id, task_config);
70        Ok(())
71    }
72
73    /// Get the configuration for a task
74    pub fn get_task_config(&self, task_id: &str) -> Result<Option<ThinkMoreConfig>> {
75        let configs = self.configs.lock().map_err(|_| {
76            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
77        })?;
78        Ok(configs.get(task_id).map(|tc| tc.think_more_config.clone()))
79    }
80
81    /// Update the configuration for a task
82    pub fn update_task_config(&self, task_id: &str, config: ThinkMoreConfig) -> Result<()> {
83        let mut configs = self.configs.lock().map_err(|_| {
84            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
85        })?;
86
87        if let Some(task_config) = configs.get_mut(task_id) {
88            task_config.think_more_config = config;
89            Ok(())
90        } else {
91            Err(ModeError::NotFound(format!("Task {} not found", task_id)))
92        }
93    }
94
95    /// Unregister a task
96    pub fn unregister_task(&self, task_id: &str) -> Result<()> {
97        let mut configs = self.configs.lock().map_err(|_| {
98            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
99        })?;
100        configs.remove(task_id);
101        Ok(())
102    }
103
104    /// Check if a task is registered
105    pub fn has_task(&self, task_id: &str) -> Result<bool> {
106        let configs = self.configs.lock().map_err(|_| {
107            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
108        })?;
109        Ok(configs.contains_key(task_id))
110    }
111
112    /// Add context data to a task
113    pub fn add_task_context(
114        &self,
115        task_id: &str,
116        key: String,
117        value: serde_json::Value,
118    ) -> Result<()> {
119        let mut configs = self.configs.lock().map_err(|_| {
120            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
121        })?;
122
123        if let Some(task_config) = configs.get_mut(task_id) {
124            task_config.add_context(key, value);
125            Ok(())
126        } else {
127            Err(ModeError::NotFound(format!("Task {} not found", task_id)))
128        }
129    }
130
131    /// Get context data from a task
132    pub fn get_task_context(&self, task_id: &str, key: &str) -> Result<Option<serde_json::Value>> {
133        let configs = self.configs.lock().map_err(|_| {
134            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
135        })?;
136
137        Ok(configs
138            .get(task_id)
139            .and_then(|tc| tc.get_context(key).cloned()))
140    }
141
142    /// Remove context data from a task
143    pub fn remove_task_context(
144        &self,
145        task_id: &str,
146        key: &str,
147    ) -> Result<Option<serde_json::Value>> {
148        let mut configs = self.configs.lock().map_err(|_| {
149            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
150        })?;
151
152        if let Some(task_config) = configs.get_mut(task_id) {
153            Ok(task_config.remove_context(key))
154        } else {
155            Err(ModeError::NotFound(format!("Task {} not found", task_id)))
156        }
157    }
158
159    /// Clear all context data for a task
160    pub fn clear_task_context(&self, task_id: &str) -> Result<()> {
161        let mut configs = self.configs.lock().map_err(|_| {
162            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
163        })?;
164
165        if let Some(task_config) = configs.get_mut(task_id) {
166            task_config.clear_context();
167            Ok(())
168        } else {
169            Err(ModeError::NotFound(format!("Task {} not found", task_id)))
170        }
171    }
172
173    /// Get all registered task IDs
174    pub fn get_all_task_ids(&self) -> Result<Vec<String>> {
175        let configs = self.configs.lock().map_err(|_| {
176            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
177        })?;
178        Ok(configs.keys().cloned().collect())
179    }
180
181    /// Get the number of registered tasks
182    pub fn task_count(&self) -> Result<usize> {
183        let configs = self.configs.lock().map_err(|_| {
184            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
185        })?;
186        Ok(configs.len())
187    }
188
189    /// Clear all tasks
190    pub fn clear_all_tasks(&self) -> Result<()> {
191        let mut configs = self.configs.lock().map_err(|_| {
192            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
193        })?;
194        configs.clear();
195        Ok(())
196    }
197
198    /// Get all task configurations
199    pub fn get_all_configs(&self) -> Result<Vec<TaskConfig>> {
200        let configs = self.configs.lock().map_err(|_| {
201            ModeError::ConfigError("Failed to acquire lock on task configs".to_string())
202        })?;
203        Ok(configs.values().cloned().collect())
204    }
205}
206
207impl Default for TaskConfigManager {
208    fn default() -> Self {
209        Self::new()
210    }
211}
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use crate::models::ThinkingDepth;
217    use std::time::Duration;
218
219    fn create_test_config() -> ThinkMoreConfig {
220        ThinkMoreConfig {
221            enabled: true,
222            depth: crate::models::ThinkingDepth::Medium,
223            timeout: Duration::from_secs(30),
224            auto_enable: false,
225        }
226    }
227
228    #[test]
229    fn test_task_config_creation() {
230        let config = create_test_config();
231        let task_config = TaskConfig::new("task1".to_string(), config.clone());
232        assert_eq!(task_config.task_id, "task1");
233        assert_eq!(task_config.think_more_config.depth, ThinkingDepth::Medium);
234    }
235
236    #[test]
237    fn test_task_config_context() {
238        let config = create_test_config();
239        let mut task_config = TaskConfig::new("task1".to_string(), config);
240
241        task_config.add_context("key1".to_string(), serde_json::json!("value1"));
242        assert!(task_config.get_context("key1").is_some());
243
244        task_config.remove_context("key1");
245        assert!(task_config.get_context("key1").is_none());
246    }
247
248    #[test]
249    fn test_manager_register_task() {
250        let manager = TaskConfigManager::new();
251        let config = create_test_config();
252
253        manager.register_task("task1".to_string(), config).unwrap();
254        assert!(manager.has_task("task1").unwrap());
255    }
256
257    #[test]
258    fn test_manager_get_task_config() {
259        let manager = TaskConfigManager::new();
260        let config = create_test_config();
261
262        manager
263            .register_task("task1".to_string(), config.clone())
264            .unwrap();
265        let retrieved = manager.get_task_config("task1").unwrap();
266        assert!(retrieved.is_some());
267        assert_eq!(retrieved.unwrap().depth, ThinkingDepth::Medium);
268    }
269
270    #[test]
271    fn test_manager_update_task_config() {
272        let manager = TaskConfigManager::new();
273        let config = create_test_config();
274
275        manager.register_task("task1".to_string(), config).unwrap();
276
277        let mut new_config = create_test_config();
278        new_config.depth = ThinkingDepth::Deep;
279        manager.update_task_config("task1", new_config).unwrap();
280
281        let retrieved = manager.get_task_config("task1").unwrap();
282        assert_eq!(retrieved.unwrap().depth, ThinkingDepth::Deep);
283    }
284
285    #[test]
286    fn test_manager_unregister_task() {
287        let manager = TaskConfigManager::new();
288        let config = create_test_config();
289
290        manager.register_task("task1".to_string(), config).unwrap();
291        assert!(manager.has_task("task1").unwrap());
292
293        manager.unregister_task("task1").unwrap();
294        assert!(!manager.has_task("task1").unwrap());
295    }
296
297    #[test]
298    fn test_manager_add_task_context() {
299        let manager = TaskConfigManager::new();
300        let config = create_test_config();
301
302        manager.register_task("task1".to_string(), config).unwrap();
303        manager
304            .add_task_context("task1", "key1".to_string(), serde_json::json!("value1"))
305            .unwrap();
306
307        let value = manager.get_task_context("task1", "key1").unwrap();
308        assert!(value.is_some());
309    }
310
311    #[test]
312    fn test_manager_remove_task_context() {
313        let manager = TaskConfigManager::new();
314        let config = create_test_config();
315
316        manager.register_task("task1".to_string(), config).unwrap();
317        manager
318            .add_task_context("task1", "key1".to_string(), serde_json::json!("value1"))
319            .unwrap();
320        manager.remove_task_context("task1", "key1").unwrap();
321
322        let value = manager.get_task_context("task1", "key1").unwrap();
323        assert!(value.is_none());
324    }
325
326    #[test]
327    fn test_manager_clear_task_context() {
328        let manager = TaskConfigManager::new();
329        let config = create_test_config();
330
331        manager.register_task("task1".to_string(), config).unwrap();
332        manager
333            .add_task_context("task1", "key1".to_string(), serde_json::json!("value1"))
334            .unwrap();
335        manager
336            .add_task_context("task1", "key2".to_string(), serde_json::json!("value2"))
337            .unwrap();
338
339        manager.clear_task_context("task1").unwrap();
340
341        assert!(manager.get_task_context("task1", "key1").unwrap().is_none());
342        assert!(manager.get_task_context("task1", "key2").unwrap().is_none());
343    }
344
345    #[test]
346    fn test_manager_get_all_task_ids() {
347        let manager = TaskConfigManager::new();
348        let config = create_test_config();
349
350        manager
351            .register_task("task1".to_string(), config.clone())
352            .unwrap();
353        manager
354            .register_task("task2".to_string(), config.clone())
355            .unwrap();
356        manager.register_task("task3".to_string(), config).unwrap();
357
358        let ids = manager.get_all_task_ids().unwrap();
359        assert_eq!(ids.len(), 3);
360    }
361
362    #[test]
363    fn test_manager_task_count() {
364        let manager = TaskConfigManager::new();
365        let config = create_test_config();
366
367        manager
368            .register_task("task1".to_string(), config.clone())
369            .unwrap();
370        manager.register_task("task2".to_string(), config).unwrap();
371
372        assert_eq!(manager.task_count().unwrap(), 2);
373    }
374
375    #[test]
376    fn test_manager_clear_all_tasks() {
377        let manager = TaskConfigManager::new();
378        let config = create_test_config();
379
380        manager
381            .register_task("task1".to_string(), config.clone())
382            .unwrap();
383        manager.register_task("task2".to_string(), config).unwrap();
384
385        manager.clear_all_tasks().unwrap();
386        assert_eq!(manager.task_count().unwrap(), 0);
387    }
388
389    #[test]
390    fn test_manager_get_all_configs() {
391        let manager = TaskConfigManager::new();
392        let config = create_test_config();
393
394        manager
395            .register_task("task1".to_string(), config.clone())
396            .unwrap();
397        manager.register_task("task2".to_string(), config).unwrap();
398
399        let configs = manager.get_all_configs().unwrap();
400        assert_eq!(configs.len(), 2);
401    }
402
403    #[test]
404    fn test_manager_default() {
405        let manager = TaskConfigManager::default();
406        assert_eq!(manager.task_count().unwrap(), 0);
407    }
408
409    #[test]
410    fn test_manager_error_on_nonexistent_task() {
411        let manager = TaskConfigManager::new();
412        let result = manager.get_task_config("nonexistent");
413        assert!(result.is_ok());
414        assert!(result.unwrap().is_none());
415    }
416
417    #[test]
418    fn test_manager_error_on_update_nonexistent_task() {
419        let manager = TaskConfigManager::new();
420        let config = create_test_config();
421        let result = manager.update_task_config("nonexistent", config);
422        assert!(result.is_err());
423    }
424}