Skip to main content

codetether_agent/swarm/
cache.rs

1//! Swarm result caching for avoiding duplicate task execution
2//!
3//! Uses content-based hashing to identify identical tasks and cache
4//! their results to disk for reuse across executions.
5
6use super::{SubTask, SubTaskResult};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::time::{Duration, SystemTime};
12use tokio::fs;
13use tracing;
14
15/// Cache entry storing a subtask result with metadata
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CacheEntry {
18    /// The cached result
19    pub result: SubTaskResult,
20    /// When this entry was created
21    pub created_at: SystemTime,
22    /// Hash of the task content (for verification)
23    pub content_hash: String,
24    /// Task name for debugging
25    pub task_name: String,
26}
27
28impl CacheEntry {
29    /// Check if this entry has expired
30    pub fn is_expired(&self, ttl: Duration) -> bool {
31        match self.created_at.elapsed() {
32            Ok(elapsed) => elapsed > ttl,
33            Err(_) => {
34                // Clock went backwards, treat as expired
35                true
36            }
37        }
38    }
39}
40
41/// Cache statistics for monitoring
42#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct CacheStats {
44    /// Number of cache hits
45    pub hits: u64,
46    /// Number of cache misses
47    pub misses: u64,
48    /// Number of entries evicted due to size limits
49    pub evictions: u64,
50    /// Number of expired entries removed
51    pub expired_removed: u64,
52    /// Current number of entries in cache
53    pub current_entries: usize,
54}
55
56impl CacheStats {
57    /// Total number of cache lookups
58    pub fn total_lookups(&self) -> u64 {
59        self.hits + self.misses
60    }
61
62    /// Cache hit rate (0.0 to 1.0)
63    pub fn hit_rate(&self) -> f64 {
64        let total = self.total_lookups();
65        if total == 0 {
66            0.0
67        } else {
68            self.hits as f64 / total as f64
69        }
70    }
71
72    /// Record a cache hit
73    pub fn record_hit(&mut self) {
74        self.hits += 1;
75    }
76
77    /// Record a cache miss
78    pub fn record_miss(&mut self) {
79        self.misses += 1;
80    }
81}
82
83/// Configuration for the swarm cache
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct CacheConfig {
86    /// Whether caching is enabled
87    pub enabled: bool,
88    /// Time-to-live for cache entries (seconds)
89    pub ttl_secs: u64,
90    /// Maximum number of entries in cache
91    pub max_entries: usize,
92    /// Maximum size of cache directory in MB
93    pub max_size_mb: u64,
94    /// Cache directory path (None = use default)
95    pub cache_dir: Option<PathBuf>,
96    /// Whether to bypass cache for this execution
97    pub bypass: bool,
98}
99
100impl Default for CacheConfig {
101    fn default() -> Self {
102        Self {
103            enabled: true,
104            ttl_secs: 86400, // 24 hours
105            max_entries: 1000,
106            max_size_mb: 100,
107            cache_dir: None,
108            bypass: false,
109        }
110    }
111}
112
113/// Swarm result cache using content-based hashing
114pub struct SwarmCache {
115    config: CacheConfig,
116    cache_dir: PathBuf,
117    stats: CacheStats,
118    /// In-memory index of cache entries
119    index: HashMap<String, CacheEntry>,
120}
121
122impl SwarmCache {
123    /// Create a new cache with the given configuration
124    pub async fn new(config: CacheConfig) -> Result<Self> {
125        let cache_dir = config
126            .cache_dir
127            .clone()
128            .unwrap_or_else(Self::default_cache_dir);
129
130        // Ensure cache directory exists
131        fs::create_dir_all(&cache_dir).await?;
132
133        let mut cache = Self {
134            config,
135            cache_dir,
136            stats: CacheStats::default(),
137            index: HashMap::new(),
138        };
139
140        // Load existing index
141        cache.load_index().await?;
142
143        tracing::info!(
144            cache_dir = %cache.cache_dir.display(),
145            entries = cache.index.len(),
146            "Swarm cache initialized"
147        );
148
149        Ok(cache)
150    }
151
152    /// Get the default cache directory
153    fn default_cache_dir() -> PathBuf {
154        directories::ProjectDirs::from("com", "codetether", "agent")
155            .map(|dirs| dirs.cache_dir().join("swarm"))
156            .unwrap_or_else(|| PathBuf::from(".cache/swarm"))
157    }
158
159    /// Generate a cache key from task content using SHA-256
160    pub fn generate_key(task: &SubTask) -> String {
161        use std::collections::hash_map::DefaultHasher;
162        use std::hash::{Hash, Hasher};
163
164        // Create a deterministic hash from task content
165        let mut hasher = DefaultHasher::new();
166        task.name.hash(&mut hasher);
167        task.instruction.hash(&mut hasher);
168        task.specialty.hash(&mut hasher);
169        task.max_steps.hash(&mut hasher);
170
171        // Include context that affects execution
172        if let Some(parent) = &task.context.parent_task {
173            parent.hash(&mut hasher);
174        }
175
176        // Hash dependency results that are part of the context
177        let mut dep_keys: Vec<_> = task.context.dependency_results.keys().collect();
178        dep_keys.sort(); // Ensure deterministic ordering
179        for key in dep_keys {
180            key.hash(&mut hasher);
181            task.context.dependency_results[key].hash(&mut hasher);
182        }
183
184        format!("{:016x}", hasher.finish())
185    }
186
187    /// Get a cached result if available and not expired
188    pub async fn get(&mut self, task: &SubTask) -> Option<SubTaskResult> {
189        if !self.config.enabled || self.config.bypass {
190            return None;
191        }
192
193        let key = Self::generate_key(task);
194
195        // Check in-memory index first
196        if let Some(entry) = self.index.get(&key) {
197            let ttl = Duration::from_secs(self.config.ttl_secs);
198
199            if entry.is_expired(ttl) {
200                tracing::debug!(key = %key, "Cache entry expired");
201                self.stats.expired_removed += 1;
202                self.index.remove(&key);
203                let _ = self.remove_from_disk(&key);
204                self.stats.record_miss();
205                return None;
206            }
207
208            // Verify content hash matches
209            let current_hash = Self::generate_content_hash(task);
210            if entry.content_hash != current_hash {
211                tracing::debug!(key = %key, "Content hash mismatch, cache invalid");
212                self.index.remove(&key);
213                let _ = self.remove_from_disk(&key);
214                self.stats.record_miss();
215                return None;
216            }
217
218            tracing::info!(key = %key, task_name = %entry.task_name, "Cache hit");
219            self.stats.record_hit();
220            return Some(entry.result.clone());
221        }
222
223        self.stats.record_miss();
224        None
225    }
226
227    /// Store a result in the cache
228    pub async fn put(&mut self, task: &SubTask, result: &SubTaskResult) -> Result<()> {
229        if !self.config.enabled || self.config.bypass {
230            return Ok(());
231        }
232
233        // Only cache successful results
234        if !result.success {
235            tracing::debug!(task_id = %task.id, "Not caching failed result");
236            return Ok(());
237        }
238
239        // Check if we need to evict entries
240        self.enforce_size_limits().await?;
241
242        let key = Self::generate_key(task);
243        let content_hash = Self::generate_content_hash(task);
244
245        let entry = CacheEntry {
246            result: result.clone(),
247            created_at: SystemTime::now(),
248            content_hash,
249            task_name: task.name.clone(),
250        };
251
252        // Store on disk
253        self.save_to_disk(&key, &entry).await?;
254
255        // Update in-memory index
256        self.index.insert(key.clone(), entry);
257        self.stats.current_entries = self.index.len();
258
259        tracing::info!(key = %key, task_name = %task.name, "Cached result");
260
261        Ok(())
262    }
263
264    /// Generate a content hash for verification
265    fn generate_content_hash(task: &SubTask) -> String {
266        use std::collections::hash_map::DefaultHasher;
267        use std::hash::{Hash, Hasher};
268
269        let mut hasher = DefaultHasher::new();
270        task.instruction.hash(&mut hasher);
271        format!("{:016x}", hasher.finish())
272    }
273
274    /// Enforce size limits by evicting oldest entries
275    async fn enforce_size_limits(&mut self) -> Result<()> {
276        if self.index.len() < self.config.max_entries {
277            return Ok(());
278        }
279
280        // Sort by creation time and remove oldest
281        let mut entries: Vec<_> = self
282            .index
283            .iter()
284            .map(|(k, v)| (k.clone(), v.created_at))
285            .collect();
286        entries.sort_by(|a, b| a.1.cmp(&b.1));
287
288        let to_remove = self.index.len() - self.config.max_entries + 1;
289        for (key, _) in entries.into_iter().take(to_remove) {
290            self.index.remove(&key);
291            let _ = self.remove_from_disk(&key);
292            self.stats.evictions += 1;
293        }
294
295        self.stats.current_entries = self.index.len();
296
297        Ok(())
298    }
299
300    /// Get cache statistics
301    pub fn stats(&self) -> &CacheStats {
302        &self.stats
303    }
304
305    /// Get mutable reference to stats
306    pub fn stats_mut(&mut self) -> &mut CacheStats {
307        &mut self.stats
308    }
309
310    /// Clear all cache entries
311    pub async fn clear(&mut self) -> Result<()> {
312        self.index.clear();
313        self.stats.current_entries = 0;
314
315        // Remove all files in cache directory
316        let mut entries = fs::read_dir(&self.cache_dir).await?;
317        while let Some(entry) = entries.next_entry().await? {
318            let path = entry.path();
319            if path.extension().map_or(false, |e| e == "json") {
320                let _ = fs::remove_file(&path).await;
321            }
322        }
323
324        tracing::info!("Cache cleared");
325        Ok(())
326    }
327
328    /// Get the path for a cache entry file
329    fn entry_path(&self, key: &str) -> PathBuf {
330        self.cache_dir.join(format!("{}.json", key))
331    }
332
333    /// Get the cache directory path.
334    pub fn cache_dir(&self) -> &std::path::Path {
335        &self.cache_dir
336    }
337
338    /// Save an entry to disk
339    async fn save_to_disk(&self, key: &str, entry: &CacheEntry) -> Result<()> {
340        let path = self.entry_path(key);
341        let json = serde_json::to_string_pretty(entry)?;
342        fs::write(&path, json).await?;
343        Ok(())
344    }
345
346    /// Remove an entry from disk
347    async fn remove_from_disk(&self, key: &str) -> Result<()> {
348        let path = self.entry_path(key);
349        if path.exists() {
350            fs::remove_file(&path).await?;
351        }
352        Ok(())
353    }
354
355    /// Load index from disk
356    async fn load_index(&mut self) -> Result<()> {
357        let ttl = Duration::from_secs(self.config.ttl_secs);
358
359        let mut entries = match fs::read_dir(&self.cache_dir).await {
360            Ok(entries) => entries,
361            Err(_) => return Ok(()),
362        };
363
364        while let Some(entry) = entries.next_entry().await? {
365            let path = entry.path();
366            if path.extension().map_or(false, |e| e == "json") {
367                if let Some(key) = path.file_stem().and_then(|s| s.to_str()) {
368                    match fs::read_to_string(&path).await {
369                        Ok(json) => {
370                            if let Ok(cache_entry) = serde_json::from_str::<CacheEntry>(&json) {
371                                if !cache_entry.is_expired(ttl) {
372                                    self.index.insert(key.to_string(), cache_entry);
373                                } else {
374                                    self.stats.expired_removed += 1;
375                                    let _ = fs::remove_file(&path).await;
376                                }
377                            }
378                        }
379                        Err(e) => {
380                            tracing::warn!(path = %path.display(), error = %e, "Failed to read cache entry");
381                        }
382                    }
383                }
384            }
385        }
386
387        self.stats.current_entries = self.index.len();
388        Ok(())
389    }
390
391    /// Set bypass mode for current execution
392    pub fn set_bypass(&mut self, bypass: bool) {
393        self.config.bypass = bypass;
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use tempfile::tempdir;
401
402    fn create_test_task(name: &str, instruction: &str) -> SubTask {
403        SubTask::new(name, instruction)
404    }
405
406    fn create_test_result(success: bool) -> SubTaskResult {
407        SubTaskResult {
408            subtask_id: "test-123".to_string(),
409            subagent_id: "agent-123".to_string(),
410            success,
411            result: "test result".to_string(),
412            steps: 5,
413            tool_calls: 3,
414            execution_time_ms: 1000,
415            error: None,
416            artifacts: vec![],
417        }
418    }
419
420    #[tokio::test]
421    async fn test_cache_basic_operations() {
422        let temp_dir = tempdir().unwrap();
423        let config = CacheConfig {
424            enabled: true,
425            ttl_secs: 3600,
426            max_entries: 100,
427            max_size_mb: 10,
428            cache_dir: Some(temp_dir.path().to_path_buf()),
429            bypass: false,
430        };
431
432        let mut cache = SwarmCache::new(config).await.unwrap();
433
434        let task = create_test_task("test task", "do something");
435        let result = create_test_result(true);
436
437        // Initially should be a miss
438        assert!(cache.get(&task).await.is_none());
439        assert_eq!(cache.stats().misses, 1);
440
441        // Store the result
442        cache.put(&task, &result).await.unwrap();
443
444        // Now should be a hit
445        let cached = cache.get(&task).await;
446        assert!(cached.is_some());
447        assert_eq!(cache.stats().hits, 1);
448        assert_eq!(cached.unwrap().result, result.result);
449    }
450
451    #[tokio::test]
452    async fn test_cache_different_tasks() {
453        let temp_dir = tempdir().unwrap();
454        let config = CacheConfig {
455            enabled: true,
456            ttl_secs: 3600,
457            max_entries: 100,
458            max_size_mb: 10,
459            cache_dir: Some(temp_dir.path().to_path_buf()),
460            bypass: false,
461        };
462
463        let mut cache = SwarmCache::new(config).await.unwrap();
464
465        let task1 = create_test_task("task 1", "do something");
466        let task2 = create_test_task("task 2", "do something else");
467        let result = create_test_result(true);
468
469        // Store only task1
470        cache.put(&task1, &result).await.unwrap();
471
472        // task1 should hit, task2 should miss
473        assert!(cache.get(&task1).await.is_some());
474        assert!(cache.get(&task2).await.is_none());
475    }
476
477    #[tokio::test]
478    async fn test_cache_bypass() {
479        let temp_dir = tempdir().unwrap();
480        let config = CacheConfig {
481            enabled: true,
482            ttl_secs: 3600,
483            max_entries: 100,
484            max_size_mb: 10,
485            cache_dir: Some(temp_dir.path().to_path_buf()),
486            bypass: true, // Bypass enabled
487        };
488
489        let mut cache = SwarmCache::new(config).await.unwrap();
490
491        let task = create_test_task("test", "instruction");
492        let result = create_test_result(true);
493
494        // Store the result
495        cache.put(&task, &result).await.unwrap();
496
497        // Should still be a miss due to bypass
498        assert!(cache.get(&task).await.is_none());
499    }
500
501    #[tokio::test]
502    async fn test_cache_failed_results_not_cached() {
503        let temp_dir = tempdir().unwrap();
504        let config = CacheConfig {
505            enabled: true,
506            ttl_secs: 3600,
507            max_entries: 100,
508            max_size_mb: 10,
509            cache_dir: Some(temp_dir.path().to_path_buf()),
510            bypass: false,
511        };
512
513        let mut cache = SwarmCache::new(config).await.unwrap();
514
515        let task = create_test_task("test", "instruction");
516        let failed_result = create_test_result(false);
517
518        // Try to store failed result
519        cache.put(&task, &failed_result).await.unwrap();
520
521        // Should not be cached
522        assert!(cache.get(&task).await.is_none());
523    }
524
525    #[tokio::test]
526    async fn test_cache_clear() {
527        let temp_dir = tempdir().unwrap();
528        let config = CacheConfig {
529            enabled: true,
530            ttl_secs: 3600,
531            max_entries: 100,
532            max_size_mb: 10,
533            cache_dir: Some(temp_dir.path().to_path_buf()),
534            bypass: false,
535        };
536
537        let mut cache = SwarmCache::new(config).await.unwrap();
538
539        let task = create_test_task("test", "instruction");
540        let result = create_test_result(true);
541
542        cache.put(&task, &result).await.unwrap();
543        assert!(cache.get(&task).await.is_some());
544
545        cache.clear().await.unwrap();
546        assert!(cache.get(&task).await.is_none());
547        assert_eq!(cache.stats().current_entries, 0);
548    }
549}