Skip to main content

codetether_agent/swarm/
cache.rs

1//! Swarm result caching for avoiding duplicate task execution
2//!
3//! Uses content-based hashing to identify identical tasks and cache
4//! their results to disk for reuse across executions.
5
6use super::{SubTask, SubTaskResult};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::time::{Duration, SystemTime};
12use tokio::fs;
13use tracing;
14
15/// Cache entry storing a subtask result with metadata
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CacheEntry {
18    /// The cached result
19    pub result: SubTaskResult,
20    /// When this entry was created
21    pub created_at: SystemTime,
22    /// Hash of the task content (for verification)
23    pub content_hash: String,
24    /// Task name for debugging
25    pub task_name: String,
26}
27
28impl CacheEntry {
29    /// Check if this entry has expired
30    pub fn is_expired(&self, ttl: Duration) -> bool {
31        match self.created_at.elapsed() {
32            Ok(elapsed) => elapsed > ttl,
33            Err(_) => {
34                // Clock went backwards, treat as expired
35                true
36            }
37        }
38    }
39}
40
41/// Cache statistics for monitoring
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct CacheStats {
44    /// Number of cache hits
45    pub hits: u64,
46    /// Number of cache misses
47    pub misses: u64,
48    /// Number of entries evicted due to size limits
49    pub evictions: u64,
50    /// Number of expired entries removed
51    pub expired_removed: u64,
52    /// Current number of entries in cache
53    pub current_entries: usize,
54}
55
56impl CacheStats {
57    /// Total number of cache lookups
58    pub fn total_lookups(&self) -> u64 {
59        self.hits + self.misses
60    }
61
62    /// Cache hit rate (0.0 to 1.0)
63    pub fn hit_rate(&self) -> f64 {
64        let total = self.total_lookups();
65        if total == 0 {
66            0.0
67        } else {
68            self.hits as f64 / total as f64
69        }
70    }
71
72    /// Record a cache hit
73    pub fn record_hit(&mut self) {
74        self.hits += 1;
75    }
76
77    /// Record a cache miss
78    pub fn record_miss(&mut self) {
79        self.misses += 1;
80    }
81}
82
83/// Configuration for the swarm cache
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct CacheConfig {
86    /// Whether caching is enabled
87    pub enabled: bool,
88    /// Time-to-live for cache entries (seconds)
89    pub ttl_secs: u64,
90    /// Maximum number of entries in cache
91    pub max_entries: usize,
92    /// Maximum size of cache directory in MB
93    pub max_size_mb: u64,
94    /// Cache directory path (None = use default)
95    pub cache_dir: Option<PathBuf>,
96    /// Whether to bypass cache for this execution
97    pub bypass: bool,
98}
99
100impl Default for CacheConfig {
101    fn default() -> Self {
102        Self {
103            enabled: true,
104            ttl_secs: 86400, // 24 hours
105            max_entries: 1000,
106            max_size_mb: 100,
107            cache_dir: None,
108            bypass: false,
109        }
110    }
111}
112
113/// Swarm result cache using content-based hashing
114pub struct SwarmCache {
115    config: CacheConfig,
116    cache_dir: PathBuf,
117    stats: CacheStats,
118    /// In-memory index of cache entries
119    index: HashMap<String, CacheEntry>,
120}
121
122impl SwarmCache {
123    /// Create a new cache with the given configuration
124    pub async fn new(config: CacheConfig) -> Result<Self> {
125        let cache_dir = config
126            .cache_dir
127            .clone()
128            .unwrap_or_else(Self::default_cache_dir);
129
130        // Ensure cache directory exists
131        fs::create_dir_all(&cache_dir).await?;
132
133        let mut cache = Self {
134            config,
135            cache_dir,
136            stats: CacheStats::default(),
137            index: HashMap::new(),
138        };
139
140        // Load existing index
141        cache.load_index().await?;
142
143        tracing::info!(
144            cache_dir = %cache.cache_dir.display(),
145            entries = cache.index.len(),
146            "Swarm cache initialized"
147        );
148
149        Ok(cache)
150    }
151
152    /// Get the default cache directory
153    fn default_cache_dir() -> PathBuf {
154        crate::config::Config::data_dir()
155            .map(|dirs| dirs.join("cache").join("swarm"))
156            .unwrap_or_else(|| PathBuf::from(".codetether-agent/cache/swarm"))
157    }
158
159    /// Generate a cache key from task content using SHA-256
160    pub fn generate_key(task: &SubTask) -> String {
161        use std::collections::hash_map::DefaultHasher;
162        use std::hash::{Hash, Hasher};
163
164        // Create a deterministic hash from task content
165        let mut hasher = DefaultHasher::new();
166        task.name.hash(&mut hasher);
167        task.instruction.hash(&mut hasher);
168        task.specialty.hash(&mut hasher);
169        task.max_steps.hash(&mut hasher);
170
171        // Include context that affects execution
172        if let Some(parent) = &task.context.parent_task {
173            parent.hash(&mut hasher);
174        }
175
176        // Hash dependency results that are part of the context
177        let mut dep_keys: Vec<_> = task.context.dependency_results.keys().collect();
178        dep_keys.sort(); // Ensure deterministic ordering
179        for key in dep_keys {
180            key.hash(&mut hasher);
181            task.context.dependency_results[key].hash(&mut hasher);
182        }
183
184        format!("{:016x}", hasher.finish())
185    }
186
187    /// Get a cached result if available and not expired
188    pub async fn get(&mut self, task: &SubTask) -> Option<SubTaskResult> {
189        if !self.config.enabled || self.config.bypass {
190            return None;
191        }
192
193        let key = Self::generate_key(task);
194
195        // Check in-memory index first
196        if let Some(entry) = self.index.get(&key) {
197            let ttl = Duration::from_secs(self.config.ttl_secs);
198
199            if entry.is_expired(ttl) {
200                tracing::debug!(key = %key, "Cache entry expired");
201                self.stats.expired_removed += 1;
202                self.index.remove(&key);
203                let _ = self.remove_from_disk(&key);
204                self.stats.record_miss();
205                return None;
206            }
207
208            // Verify content hash matches
209            let current_hash = Self::generate_content_hash(task);
210            if entry.content_hash != current_hash {
211                tracing::debug!(key = %key, "Content hash mismatch, cache invalid");
212                self.index.remove(&key);
213                let _ = self.remove_from_disk(&key);
214                self.stats.record_miss();
215                return None;
216            }
217
218            tracing::info!(key = %key, task_name = %entry.task_name, "Cache hit");
219            self.stats.record_hit();
220            return Some(entry.result.clone());
221        }
222
223        self.stats.record_miss();
224        None
225    }
226
227    /// Store a result in the cache
228    pub async fn put(&mut self, task: &SubTask, result: &SubTaskResult) -> Result<()> {
229        if !self.config.enabled || self.config.bypass {
230            return Ok(());
231        }
232
233        // Only cache successful results
234        if !result.success {
235            tracing::debug!(task_id = %task.id, "Not caching failed result");
236            return Ok(());
237        }
238
239        // Check if we need to evict entries
240        self.enforce_size_limits().await?;
241
242        let key = Self::generate_key(task);
243        let content_hash = Self::generate_content_hash(task);
244
245        let entry = CacheEntry {
246            result: result.clone(),
247            created_at: SystemTime::now(),
248            content_hash,
249            task_name: task.name.clone(),
250        };
251
252        // Store on disk
253        self.save_to_disk(&key, &entry).await?;
254
255        // Update in-memory index
256        self.index.insert(key.clone(), entry);
257        self.stats.current_entries = self.index.len();
258
259        tracing::info!(key = %key, task_name = %task.name, "Cached result");
260
261        Ok(())
262    }
263
264    /// Generate a content hash for verification
265    fn generate_content_hash(task: &SubTask) -> String {
266        use std::collections::hash_map::DefaultHasher;
267        use std::hash::{Hash, Hasher};
268
269        let mut hasher = DefaultHasher::new();
270        task.instruction.hash(&mut hasher);
271        format!("{:016x}", hasher.finish())
272    }
273
274    /// Enforce size limits by evicting oldest entries
275    async fn enforce_size_limits(&mut self) -> Result<()> {
276        if self.index.len() < self.config.max_entries {
277            return Ok(());
278        }
279
280        // Sort by creation time and remove oldest
281        let mut entries: Vec<_> = self
282            .index
283            .iter()
284            .map(|(k, v)| (k.clone(), v.created_at))
285            .collect();
286        entries.sort_by(|a, b| a.1.cmp(&b.1));
287
288        let to_remove = self.index.len() - self.config.max_entries + 1;
289        for (key, _) in entries.into_iter().take(to_remove) {
290            self.index.remove(&key);
291            let _ = self.remove_from_disk(&key);
292            self.stats.evictions += 1;
293        }
294
295        self.stats.current_entries = self.index.len();
296
297        Ok(())
298    }
299
300    /// Get cache statistics
301    pub fn stats(&self) -> &CacheStats {
302        &self.stats
303    }
304
305    /// Get mutable reference to stats
306    pub fn stats_mut(&mut self) -> &mut CacheStats {
307        &mut self.stats
308    }
309
310    /// Clear all cache entries
311    pub async fn clear(&mut self) -> Result<()> {
312        self.index.clear();
313        self.stats.current_entries = 0;
314
315        // Remove all files in cache directory
316        let mut entries = fs::read_dir(&self.cache_dir).await?;
317        while let Some(entry) = entries.next_entry().await? {
318            let path = entry.path();
319            if path.extension().is_some_and(|e| e == "json") {
320                let _ = fs::remove_file(&path).await;
321            }
322        }
323
324        tracing::info!("Cache cleared");
325        Ok(())
326    }
327
328    /// Get the path for a cache entry file
329    fn entry_path(&self, key: &str) -> PathBuf {
330        self.cache_dir.join(format!("{}.json", key))
331    }
332
333    /// Get the cache directory path.
334    pub fn cache_dir(&self) -> &std::path::Path {
335        &self.cache_dir
336    }
337
338    /// Save an entry to disk
339    async fn save_to_disk(&self, key: &str, entry: &CacheEntry) -> Result<()> {
340        let path = self.entry_path(key);
341        let json = serde_json::to_string_pretty(entry)?;
342        fs::write(&path, json).await?;
343        Ok(())
344    }
345
346    /// Remove an entry from disk
347    async fn remove_from_disk(&self, key: &str) -> Result<()> {
348        let path = self.entry_path(key);
349        if path.exists() {
350            fs::remove_file(&path).await?;
351        }
352        Ok(())
353    }
354
355    /// Load index from disk
356    async fn load_index(&mut self) -> Result<()> {
357        let ttl = Duration::from_secs(self.config.ttl_secs);
358
359        let mut entries = match fs::read_dir(&self.cache_dir).await {
360            Ok(entries) => entries,
361            Err(_) => return Ok(()),
362        };
363
364        while let Some(entry) = entries.next_entry().await? {
365            let path = entry.path();
366            if path.extension().is_some_and(|e| e == "json")
367                && let Some(key) = path.file_stem().and_then(|s| s.to_str())
368            {
369                match fs::read_to_string(&path).await {
370                    Ok(json) => {
371                        if let Ok(cache_entry) = serde_json::from_str::<CacheEntry>(&json) {
372                            if !cache_entry.is_expired(ttl) {
373                                self.index.insert(key.to_string(), cache_entry);
374                            } else {
375                                self.stats.expired_removed += 1;
376                                let _ = fs::remove_file(&path).await;
377                            }
378                        }
379                    }
380                    Err(e) => {
381                        tracing::warn!(path = %path.display(), error = %e, "Failed to read cache entry");
382                    }
383                }
384            }
385        }
386
387        self.stats.current_entries = self.index.len();
388        Ok(())
389    }
390
391    /// Set bypass mode for current execution
392    pub fn set_bypass(&mut self, bypass: bool) {
393        self.config.bypass = bypass;
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use tempfile::tempdir;
401
402    fn create_test_task(name: &str, instruction: &str) -> SubTask {
403        SubTask::new(name, instruction)
404    }
405
406    fn create_test_result(success: bool) -> SubTaskResult {
407        SubTaskResult {
408            subtask_id: "test-123".to_string(),
409            subagent_id: "agent-123".to_string(),
410            success,
411            result: "test result".to_string(),
412            steps: 5,
413            tool_calls: 3,
414            execution_time_ms: 1000,
415            error: None,
416            artifacts: vec![],
417            retry_count: 0,
418        }
419    }
420
421    #[tokio::test]
422    async fn test_cache_basic_operations() {
423        let temp_dir = tempdir().unwrap();
424        let config = CacheConfig {
425            enabled: true,
426            ttl_secs: 3600,
427            max_entries: 100,
428            max_size_mb: 10,
429            cache_dir: Some(temp_dir.path().to_path_buf()),
430            bypass: false,
431        };
432
433        let mut cache = SwarmCache::new(config).await.unwrap();
434
435        let task = create_test_task("test task", "do something");
436        let result = create_test_result(true);
437
438        // Initially should be a miss
439        assert!(cache.get(&task).await.is_none());
440        assert_eq!(cache.stats().misses, 1);
441
442        // Store the result
443        cache.put(&task, &result).await.unwrap();
444
445        // Now should be a hit
446        let cached = cache.get(&task).await;
447        assert!(cached.is_some());
448        assert_eq!(cache.stats().hits, 1);
449        assert_eq!(cached.unwrap().result, result.result);
450    }
451
452    #[tokio::test]
453    async fn test_cache_different_tasks() {
454        let temp_dir = tempdir().unwrap();
455        let config = CacheConfig {
456            enabled: true,
457            ttl_secs: 3600,
458            max_entries: 100,
459            max_size_mb: 10,
460            cache_dir: Some(temp_dir.path().to_path_buf()),
461            bypass: false,
462        };
463
464        let mut cache = SwarmCache::new(config).await.unwrap();
465
466        let task1 = create_test_task("task 1", "do something");
467        let task2 = create_test_task("task 2", "do something else");
468        let result = create_test_result(true);
469
470        // Store only task1
471        cache.put(&task1, &result).await.unwrap();
472
473        // task1 should hit, task2 should miss
474        assert!(cache.get(&task1).await.is_some());
475        assert!(cache.get(&task2).await.is_none());
476    }
477
478    #[tokio::test]
479    async fn test_cache_bypass() {
480        let temp_dir = tempdir().unwrap();
481        let config = CacheConfig {
482            enabled: true,
483            ttl_secs: 3600,
484            max_entries: 100,
485            max_size_mb: 10,
486            cache_dir: Some(temp_dir.path().to_path_buf()),
487            bypass: true, // Bypass enabled
488        };
489
490        let mut cache = SwarmCache::new(config).await.unwrap();
491
492        let task = create_test_task("test", "instruction");
493        let result = create_test_result(true);
494
495        // Store the result
496        cache.put(&task, &result).await.unwrap();
497
498        // Should still be a miss due to bypass
499        assert!(cache.get(&task).await.is_none());
500    }
501
502    #[tokio::test]
503    async fn test_cache_failed_results_not_cached() {
504        let temp_dir = tempdir().unwrap();
505        let config = CacheConfig {
506            enabled: true,
507            ttl_secs: 3600,
508            max_entries: 100,
509            max_size_mb: 10,
510            cache_dir: Some(temp_dir.path().to_path_buf()),
511            bypass: false,
512        };
513
514        let mut cache = SwarmCache::new(config).await.unwrap();
515
516        let task = create_test_task("test", "instruction");
517        let failed_result = create_test_result(false);
518
519        // Try to store failed result
520        cache.put(&task, &failed_result).await.unwrap();
521
522        // Should not be cached
523        assert!(cache.get(&task).await.is_none());
524    }
525
526    #[tokio::test]
527    async fn test_cache_clear() {
528        let temp_dir = tempdir().unwrap();
529        let config = CacheConfig {
530            enabled: true,
531            ttl_secs: 3600,
532            max_entries: 100,
533            max_size_mb: 10,
534            cache_dir: Some(temp_dir.path().to_path_buf()),
535            bypass: false,
536        };
537
538        let mut cache = SwarmCache::new(config).await.unwrap();
539
540        let task = create_test_task("test", "instruction");
541        let result = create_test_result(true);
542
543        cache.put(&task, &result).await.unwrap();
544        assert!(cache.get(&task).await.is_some());
545
546        cache.clear().await.unwrap();
547        assert!(cache.get(&task).await.is_none());
548        assert_eq!(cache.stats().current_entries, 0);
549    }
550}