use super::{SubTask, SubTaskResult};
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::{Duration, SystemTime};
use tokio::fs;
use tracing;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheEntry {
pub result: SubTaskResult,
pub created_at: SystemTime,
pub content_hash: String,
pub task_name: String,
}
impl CacheEntry {
pub fn is_expired(&self, ttl: Duration) -> bool {
match self.created_at.elapsed() {
Ok(elapsed) => elapsed > ttl,
Err(_) => {
true
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub expired_removed: u64,
pub current_entries: usize,
}
impl CacheStats {
pub fn total_lookups(&self) -> u64 {
self.hits + self.misses
}
pub fn hit_rate(&self) -> f64 {
let total = self.total_lookups();
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
pub fn record_hit(&mut self) {
self.hits += 1;
}
pub fn record_miss(&mut self) {
self.misses += 1;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub enabled: bool,
pub ttl_secs: u64,
pub max_entries: usize,
pub max_size_mb: u64,
pub cache_dir: Option<PathBuf>,
pub bypass: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enabled: true,
ttl_secs: 86400, max_entries: 1000,
max_size_mb: 100,
cache_dir: None,
bypass: false,
}
}
}
pub struct SwarmCache {
config: CacheConfig,
cache_dir: PathBuf,
stats: CacheStats,
index: HashMap<String, CacheEntry>,
}
impl SwarmCache {
pub async fn new(config: CacheConfig) -> Result<Self> {
let cache_dir = config
.cache_dir
.clone()
.unwrap_or_else(Self::default_cache_dir);
fs::create_dir_all(&cache_dir).await?;
let mut cache = Self {
config,
cache_dir,
stats: CacheStats::default(),
index: HashMap::new(),
};
cache.load_index().await?;
tracing::info!(
cache_dir = %cache.cache_dir.display(),
entries = cache.index.len(),
"Swarm cache initialized"
);
Ok(cache)
}
fn default_cache_dir() -> PathBuf {
crate::config::Config::data_dir()
.map(|dirs| dirs.join("cache").join("swarm"))
.unwrap_or_else(|| PathBuf::from(".codetether-agent/cache/swarm"))
}
pub fn generate_key(task: &SubTask) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
task.name.hash(&mut hasher);
task.instruction.hash(&mut hasher);
task.specialty.hash(&mut hasher);
task.max_steps.hash(&mut hasher);
if let Some(parent) = &task.context.parent_task {
parent.hash(&mut hasher);
}
let mut dep_keys: Vec<_> = task.context.dependency_results.keys().collect();
dep_keys.sort(); for key in dep_keys {
key.hash(&mut hasher);
task.context.dependency_results[key].hash(&mut hasher);
}
format!("{:016x}", hasher.finish())
}
pub async fn get(&mut self, task: &SubTask) -> Option<SubTaskResult> {
if !self.config.enabled || self.config.bypass {
return None;
}
let key = Self::generate_key(task);
if let Some(entry) = self.index.get(&key) {
let ttl = Duration::from_secs(self.config.ttl_secs);
if entry.is_expired(ttl) {
tracing::debug!(key = %key, "Cache entry expired");
self.stats.expired_removed += 1;
self.index.remove(&key);
let _ = self.remove_from_disk(&key);
self.stats.record_miss();
return None;
}
let current_hash = Self::generate_content_hash(task);
if entry.content_hash != current_hash {
tracing::debug!(key = %key, "Content hash mismatch, cache invalid");
self.index.remove(&key);
let _ = self.remove_from_disk(&key);
self.stats.record_miss();
return None;
}
tracing::info!(key = %key, task_name = %entry.task_name, "Cache hit");
self.stats.record_hit();
return Some(entry.result.clone());
}
self.stats.record_miss();
None
}
pub async fn put(&mut self, task: &SubTask, result: &SubTaskResult) -> Result<()> {
if !self.config.enabled || self.config.bypass {
return Ok(());
}
if !result.success {
tracing::debug!(task_id = %task.id, "Not caching failed result");
return Ok(());
}
self.enforce_size_limits().await?;
let key = Self::generate_key(task);
let content_hash = Self::generate_content_hash(task);
let entry = CacheEntry {
result: result.clone(),
created_at: SystemTime::now(),
content_hash,
task_name: task.name.clone(),
};
self.save_to_disk(&key, &entry).await?;
self.index.insert(key.clone(), entry);
self.stats.current_entries = self.index.len();
tracing::info!(key = %key, task_name = %task.name, "Cached result");
Ok(())
}
fn generate_content_hash(task: &SubTask) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
task.instruction.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
async fn enforce_size_limits(&mut self) -> Result<()> {
if self.index.len() < self.config.max_entries {
return Ok(());
}
let mut entries: Vec<_> = self
.index
.iter()
.map(|(k, v)| (k.clone(), v.created_at))
.collect();
entries.sort_by(|a, b| a.1.cmp(&b.1));
let to_remove = self.index.len() - self.config.max_entries + 1;
for (key, _) in entries.into_iter().take(to_remove) {
self.index.remove(&key);
let _ = self.remove_from_disk(&key);
self.stats.evictions += 1;
}
self.stats.current_entries = self.index.len();
Ok(())
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn stats_mut(&mut self) -> &mut CacheStats {
&mut self.stats
}
pub async fn clear(&mut self) -> Result<()> {
self.index.clear();
self.stats.current_entries = 0;
let mut entries = fs::read_dir(&self.cache_dir).await?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().is_some_and(|e| e == "json") {
let _ = fs::remove_file(&path).await;
}
}
tracing::info!("Cache cleared");
Ok(())
}
fn entry_path(&self, key: &str) -> PathBuf {
self.cache_dir.join(format!("{}.json", key))
}
pub fn cache_dir(&self) -> &std::path::Path {
&self.cache_dir
}
async fn save_to_disk(&self, key: &str, entry: &CacheEntry) -> Result<()> {
let path = self.entry_path(key);
let json = serde_json::to_string_pretty(entry)?;
fs::write(&path, json).await?;
Ok(())
}
async fn remove_from_disk(&self, key: &str) -> Result<()> {
let path = self.entry_path(key);
if path.exists() {
fs::remove_file(&path).await?;
}
Ok(())
}
async fn load_index(&mut self) -> Result<()> {
let ttl = Duration::from_secs(self.config.ttl_secs);
let mut entries = match fs::read_dir(&self.cache_dir).await {
Ok(entries) => entries,
Err(_) => return Ok(()),
};
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().is_some_and(|e| e == "json")
&& let Some(key) = path.file_stem().and_then(|s| s.to_str())
{
match fs::read_to_string(&path).await {
Ok(json) => {
if let Ok(cache_entry) = serde_json::from_str::<CacheEntry>(&json) {
if !cache_entry.is_expired(ttl) {
self.index.insert(key.to_string(), cache_entry);
} else {
self.stats.expired_removed += 1;
let _ = fs::remove_file(&path).await;
}
}
}
Err(e) => {
tracing::warn!(path = %path.display(), error = %e, "Failed to read cache entry");
}
}
}
}
self.stats.current_entries = self.index.len();
Ok(())
}
pub fn set_bypass(&mut self, bypass: bool) {
self.config.bypass = bypass;
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn create_test_task(name: &str, instruction: &str) -> SubTask {
SubTask::new(name, instruction)
}
fn create_test_result(success: bool) -> SubTaskResult {
SubTaskResult {
subtask_id: "test-123".to_string(),
subagent_id: "agent-123".to_string(),
success,
result: "test result".to_string(),
steps: 5,
tool_calls: 3,
execution_time_ms: 1000,
error: None,
artifacts: vec![],
retry_count: 0,
}
}
#[tokio::test]
async fn test_cache_basic_operations() {
let temp_dir = tempdir().unwrap();
let config = CacheConfig {
enabled: true,
ttl_secs: 3600,
max_entries: 100,
max_size_mb: 10,
cache_dir: Some(temp_dir.path().to_path_buf()),
bypass: false,
};
let mut cache = SwarmCache::new(config).await.unwrap();
let task = create_test_task("test task", "do something");
let result = create_test_result(true);
assert!(cache.get(&task).await.is_none());
assert_eq!(cache.stats().misses, 1);
cache.put(&task, &result).await.unwrap();
let cached = cache.get(&task).await;
assert!(cached.is_some());
assert_eq!(cache.stats().hits, 1);
assert_eq!(cached.unwrap().result, result.result);
}
#[tokio::test]
async fn test_cache_different_tasks() {
let temp_dir = tempdir().unwrap();
let config = CacheConfig {
enabled: true,
ttl_secs: 3600,
max_entries: 100,
max_size_mb: 10,
cache_dir: Some(temp_dir.path().to_path_buf()),
bypass: false,
};
let mut cache = SwarmCache::new(config).await.unwrap();
let task1 = create_test_task("task 1", "do something");
let task2 = create_test_task("task 2", "do something else");
let result = create_test_result(true);
cache.put(&task1, &result).await.unwrap();
assert!(cache.get(&task1).await.is_some());
assert!(cache.get(&task2).await.is_none());
}
#[tokio::test]
async fn test_cache_bypass() {
let temp_dir = tempdir().unwrap();
let config = CacheConfig {
enabled: true,
ttl_secs: 3600,
max_entries: 100,
max_size_mb: 10,
cache_dir: Some(temp_dir.path().to_path_buf()),
bypass: true, };
let mut cache = SwarmCache::new(config).await.unwrap();
let task = create_test_task("test", "instruction");
let result = create_test_result(true);
cache.put(&task, &result).await.unwrap();
assert!(cache.get(&task).await.is_none());
}
#[tokio::test]
async fn test_cache_failed_results_not_cached() {
let temp_dir = tempdir().unwrap();
let config = CacheConfig {
enabled: true,
ttl_secs: 3600,
max_entries: 100,
max_size_mb: 10,
cache_dir: Some(temp_dir.path().to_path_buf()),
bypass: false,
};
let mut cache = SwarmCache::new(config).await.unwrap();
let task = create_test_task("test", "instruction");
let failed_result = create_test_result(false);
cache.put(&task, &failed_result).await.unwrap();
assert!(cache.get(&task).await.is_none());
}
#[tokio::test]
async fn test_cache_clear() {
let temp_dir = tempdir().unwrap();
let config = CacheConfig {
enabled: true,
ttl_secs: 3600,
max_entries: 100,
max_size_mb: 10,
cache_dir: Some(temp_dir.path().to_path_buf()),
bypass: false,
};
let mut cache = SwarmCache::new(config).await.unwrap();
let task = create_test_task("test", "instruction");
let result = create_test_result(true);
cache.put(&task, &result).await.unwrap();
assert!(cache.get(&task).await.is_some());
cache.clear().await.unwrap();
assert!(cache.get(&task).await.is_none());
assert_eq!(cache.stats().current_entries, 0);
}
}