1use super::{SubTask, SubTaskResult};
7use anyhow::Result;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::PathBuf;
11use std::time::{Duration, SystemTime};
12use tokio::fs;
13use tracing;
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct CacheEntry {
18 pub result: SubTaskResult,
20 pub created_at: SystemTime,
22 pub content_hash: String,
24 pub task_name: String,
26}
27
28impl CacheEntry {
29 pub fn is_expired(&self, ttl: Duration) -> bool {
31 match self.created_at.elapsed() {
32 Ok(elapsed) => elapsed > ttl,
33 Err(_) => {
34 true
36 }
37 }
38 }
39}
40
41#[derive(Debug, Clone, Default, Serialize, Deserialize)]
43pub struct CacheStats {
44 pub hits: u64,
46 pub misses: u64,
48 pub evictions: u64,
50 pub expired_removed: u64,
52 pub current_entries: usize,
54}
55
56impl CacheStats {
57 pub fn total_lookups(&self) -> u64 {
59 self.hits + self.misses
60 }
61
62 pub fn hit_rate(&self) -> f64 {
64 let total = self.total_lookups();
65 if total == 0 {
66 0.0
67 } else {
68 self.hits as f64 / total as f64
69 }
70 }
71
72 pub fn record_hit(&mut self) {
74 self.hits += 1;
75 }
76
77 pub fn record_miss(&mut self) {
79 self.misses += 1;
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct CacheConfig {
86 pub enabled: bool,
88 pub ttl_secs: u64,
90 pub max_entries: usize,
92 pub max_size_mb: u64,
94 pub cache_dir: Option<PathBuf>,
96 pub bypass: bool,
98}
99
100impl Default for CacheConfig {
101 fn default() -> Self {
102 Self {
103 enabled: true,
104 ttl_secs: 86400, max_entries: 1000,
106 max_size_mb: 100,
107 cache_dir: None,
108 bypass: false,
109 }
110 }
111}
112
113pub struct SwarmCache {
115 config: CacheConfig,
116 cache_dir: PathBuf,
117 stats: CacheStats,
118 index: HashMap<String, CacheEntry>,
120}
121
122impl SwarmCache {
123 pub async fn new(config: CacheConfig) -> Result<Self> {
125 let cache_dir = config
126 .cache_dir
127 .clone()
128 .unwrap_or_else(Self::default_cache_dir);
129
130 fs::create_dir_all(&cache_dir).await?;
132
133 let mut cache = Self {
134 config,
135 cache_dir,
136 stats: CacheStats::default(),
137 index: HashMap::new(),
138 };
139
140 cache.load_index().await?;
142
143 tracing::info!(
144 cache_dir = %cache.cache_dir.display(),
145 entries = cache.index.len(),
146 "Swarm cache initialized"
147 );
148
149 Ok(cache)
150 }
151
152 fn default_cache_dir() -> PathBuf {
154 directories::ProjectDirs::from("com", "codetether", "agent")
155 .map(|dirs| dirs.cache_dir().join("swarm"))
156 .unwrap_or_else(|| PathBuf::from(".cache/swarm"))
157 }
158
159 pub fn generate_key(task: &SubTask) -> String {
161 use std::collections::hash_map::DefaultHasher;
162 use std::hash::{Hash, Hasher};
163
164 let mut hasher = DefaultHasher::new();
166 task.name.hash(&mut hasher);
167 task.instruction.hash(&mut hasher);
168 task.specialty.hash(&mut hasher);
169 task.max_steps.hash(&mut hasher);
170
171 if let Some(parent) = &task.context.parent_task {
173 parent.hash(&mut hasher);
174 }
175
176 let mut dep_keys: Vec<_> = task.context.dependency_results.keys().collect();
178 dep_keys.sort(); for key in dep_keys {
180 key.hash(&mut hasher);
181 task.context.dependency_results[key].hash(&mut hasher);
182 }
183
184 format!("{:016x}", hasher.finish())
185 }
186
187 pub async fn get(&mut self, task: &SubTask) -> Option<SubTaskResult> {
189 if !self.config.enabled || self.config.bypass {
190 return None;
191 }
192
193 let key = Self::generate_key(task);
194
195 if let Some(entry) = self.index.get(&key) {
197 let ttl = Duration::from_secs(self.config.ttl_secs);
198
199 if entry.is_expired(ttl) {
200 tracing::debug!(key = %key, "Cache entry expired");
201 self.stats.expired_removed += 1;
202 self.index.remove(&key);
203 let _ = self.remove_from_disk(&key);
204 self.stats.record_miss();
205 return None;
206 }
207
208 let current_hash = Self::generate_content_hash(task);
210 if entry.content_hash != current_hash {
211 tracing::debug!(key = %key, "Content hash mismatch, cache invalid");
212 self.index.remove(&key);
213 let _ = self.remove_from_disk(&key);
214 self.stats.record_miss();
215 return None;
216 }
217
218 tracing::info!(key = %key, task_name = %entry.task_name, "Cache hit");
219 self.stats.record_hit();
220 return Some(entry.result.clone());
221 }
222
223 self.stats.record_miss();
224 None
225 }
226
227 pub async fn put(&mut self, task: &SubTask, result: &SubTaskResult) -> Result<()> {
229 if !self.config.enabled || self.config.bypass {
230 return Ok(());
231 }
232
233 if !result.success {
235 tracing::debug!(task_id = %task.id, "Not caching failed result");
236 return Ok(());
237 }
238
239 self.enforce_size_limits().await?;
241
242 let key = Self::generate_key(task);
243 let content_hash = Self::generate_content_hash(task);
244
245 let entry = CacheEntry {
246 result: result.clone(),
247 created_at: SystemTime::now(),
248 content_hash,
249 task_name: task.name.clone(),
250 };
251
252 self.save_to_disk(&key, &entry).await?;
254
255 self.index.insert(key.clone(), entry);
257 self.stats.current_entries = self.index.len();
258
259 tracing::info!(key = %key, task_name = %task.name, "Cached result");
260
261 Ok(())
262 }
263
264 fn generate_content_hash(task: &SubTask) -> String {
266 use std::collections::hash_map::DefaultHasher;
267 use std::hash::{Hash, Hasher};
268
269 let mut hasher = DefaultHasher::new();
270 task.instruction.hash(&mut hasher);
271 format!("{:016x}", hasher.finish())
272 }
273
274 async fn enforce_size_limits(&mut self) -> Result<()> {
276 if self.index.len() < self.config.max_entries {
277 return Ok(());
278 }
279
280 let mut entries: Vec<_> = self
282 .index
283 .iter()
284 .map(|(k, v)| (k.clone(), v.created_at))
285 .collect();
286 entries.sort_by(|a, b| a.1.cmp(&b.1));
287
288 let to_remove = self.index.len() - self.config.max_entries + 1;
289 for (key, _) in entries.into_iter().take(to_remove) {
290 self.index.remove(&key);
291 let _ = self.remove_from_disk(&key);
292 self.stats.evictions += 1;
293 }
294
295 self.stats.current_entries = self.index.len();
296
297 Ok(())
298 }
299
300 pub fn stats(&self) -> &CacheStats {
302 &self.stats
303 }
304
305 pub fn stats_mut(&mut self) -> &mut CacheStats {
307 &mut self.stats
308 }
309
310 pub async fn clear(&mut self) -> Result<()> {
312 self.index.clear();
313 self.stats.current_entries = 0;
314
315 let mut entries = fs::read_dir(&self.cache_dir).await?;
317 while let Some(entry) = entries.next_entry().await? {
318 let path = entry.path();
319 if path.extension().map_or(false, |e| e == "json") {
320 let _ = fs::remove_file(&path).await;
321 }
322 }
323
324 tracing::info!("Cache cleared");
325 Ok(())
326 }
327
328 fn entry_path(&self, key: &str) -> PathBuf {
330 self.cache_dir.join(format!("{}.json", key))
331 }
332
333 pub fn cache_dir(&self) -> &std::path::Path {
335 &self.cache_dir
336 }
337
338 async fn save_to_disk(&self, key: &str, entry: &CacheEntry) -> Result<()> {
340 let path = self.entry_path(key);
341 let json = serde_json::to_string_pretty(entry)?;
342 fs::write(&path, json).await?;
343 Ok(())
344 }
345
346 async fn remove_from_disk(&self, key: &str) -> Result<()> {
348 let path = self.entry_path(key);
349 if path.exists() {
350 fs::remove_file(&path).await?;
351 }
352 Ok(())
353 }
354
355 async fn load_index(&mut self) -> Result<()> {
357 let ttl = Duration::from_secs(self.config.ttl_secs);
358
359 let mut entries = match fs::read_dir(&self.cache_dir).await {
360 Ok(entries) => entries,
361 Err(_) => return Ok(()),
362 };
363
364 while let Some(entry) = entries.next_entry().await? {
365 let path = entry.path();
366 if path.extension().map_or(false, |e| e == "json") {
367 if let Some(key) = path.file_stem().and_then(|s| s.to_str()) {
368 match fs::read_to_string(&path).await {
369 Ok(json) => {
370 if let Ok(cache_entry) = serde_json::from_str::<CacheEntry>(&json) {
371 if !cache_entry.is_expired(ttl) {
372 self.index.insert(key.to_string(), cache_entry);
373 } else {
374 self.stats.expired_removed += 1;
375 let _ = fs::remove_file(&path).await;
376 }
377 }
378 }
379 Err(e) => {
380 tracing::warn!(path = %path.display(), error = %e, "Failed to read cache entry");
381 }
382 }
383 }
384 }
385 }
386
387 self.stats.current_entries = self.index.len();
388 Ok(())
389 }
390
391 pub fn set_bypass(&mut self, bypass: bool) {
393 self.config.bypass = bypass;
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use tempfile::tempdir;
401
402 fn create_test_task(name: &str, instruction: &str) -> SubTask {
403 SubTask::new(name, instruction)
404 }
405
406 fn create_test_result(success: bool) -> SubTaskResult {
407 SubTaskResult {
408 subtask_id: "test-123".to_string(),
409 subagent_id: "agent-123".to_string(),
410 success,
411 result: "test result".to_string(),
412 steps: 5,
413 tool_calls: 3,
414 execution_time_ms: 1000,
415 error: None,
416 artifacts: vec![],
417 }
418 }
419
420 #[tokio::test]
421 async fn test_cache_basic_operations() {
422 let temp_dir = tempdir().unwrap();
423 let config = CacheConfig {
424 enabled: true,
425 ttl_secs: 3600,
426 max_entries: 100,
427 max_size_mb: 10,
428 cache_dir: Some(temp_dir.path().to_path_buf()),
429 bypass: false,
430 };
431
432 let mut cache = SwarmCache::new(config).await.unwrap();
433
434 let task = create_test_task("test task", "do something");
435 let result = create_test_result(true);
436
437 assert!(cache.get(&task).await.is_none());
439 assert_eq!(cache.stats().misses, 1);
440
441 cache.put(&task, &result).await.unwrap();
443
444 let cached = cache.get(&task).await;
446 assert!(cached.is_some());
447 assert_eq!(cache.stats().hits, 1);
448 assert_eq!(cached.unwrap().result, result.result);
449 }
450
451 #[tokio::test]
452 async fn test_cache_different_tasks() {
453 let temp_dir = tempdir().unwrap();
454 let config = CacheConfig {
455 enabled: true,
456 ttl_secs: 3600,
457 max_entries: 100,
458 max_size_mb: 10,
459 cache_dir: Some(temp_dir.path().to_path_buf()),
460 bypass: false,
461 };
462
463 let mut cache = SwarmCache::new(config).await.unwrap();
464
465 let task1 = create_test_task("task 1", "do something");
466 let task2 = create_test_task("task 2", "do something else");
467 let result = create_test_result(true);
468
469 cache.put(&task1, &result).await.unwrap();
471
472 assert!(cache.get(&task1).await.is_some());
474 assert!(cache.get(&task2).await.is_none());
475 }
476
477 #[tokio::test]
478 async fn test_cache_bypass() {
479 let temp_dir = tempdir().unwrap();
480 let config = CacheConfig {
481 enabled: true,
482 ttl_secs: 3600,
483 max_entries: 100,
484 max_size_mb: 10,
485 cache_dir: Some(temp_dir.path().to_path_buf()),
486 bypass: true, };
488
489 let mut cache = SwarmCache::new(config).await.unwrap();
490
491 let task = create_test_task("test", "instruction");
492 let result = create_test_result(true);
493
494 cache.put(&task, &result).await.unwrap();
496
497 assert!(cache.get(&task).await.is_none());
499 }
500
501 #[tokio::test]
502 async fn test_cache_failed_results_not_cached() {
503 let temp_dir = tempdir().unwrap();
504 let config = CacheConfig {
505 enabled: true,
506 ttl_secs: 3600,
507 max_entries: 100,
508 max_size_mb: 10,
509 cache_dir: Some(temp_dir.path().to_path_buf()),
510 bypass: false,
511 };
512
513 let mut cache = SwarmCache::new(config).await.unwrap();
514
515 let task = create_test_task("test", "instruction");
516 let failed_result = create_test_result(false);
517
518 cache.put(&task, &failed_result).await.unwrap();
520
521 assert!(cache.get(&task).await.is_none());
523 }
524
525 #[tokio::test]
526 async fn test_cache_clear() {
527 let temp_dir = tempdir().unwrap();
528 let config = CacheConfig {
529 enabled: true,
530 ttl_secs: 3600,
531 max_entries: 100,
532 max_size_mb: 10,
533 cache_dir: Some(temp_dir.path().to_path_buf()),
534 bypass: false,
535 };
536
537 let mut cache = SwarmCache::new(config).await.unwrap();
538
539 let task = create_test_task("test", "instruction");
540 let result = create_test_result(true);
541
542 cache.put(&task, &result).await.unwrap();
543 assert!(cache.get(&task).await.is_some());
544
545 cache.clear().await.unwrap();
546 assert!(cache.get(&task).await.is_none());
547 assert_eq!(cache.stats().current_entries, 0);
548 }
549}