use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
type CacheKey = (String, u64);
#[derive(Clone, Debug)]
pub struct SubAgentCache {
inner: Arc<Mutex<CacheInner>>,
}
#[derive(Debug)]
struct CacheInner {
entries: HashMap<CacheKey, CachedResult>,
generation: u64,
}
#[derive(Debug, Clone)]
struct CachedResult {
response: String,
generation: u64,
}
impl SubAgentCache {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(CacheInner {
entries: HashMap::new(),
generation: 0,
})),
}
}
pub fn get(&self, agent_name: &str, prompt: &str) -> Option<String> {
let key = make_key(agent_name, prompt);
let inner = self.inner.lock().ok()?;
let entry = inner.entries.get(&key)?;
if entry.generation == inner.generation {
Some(entry.response.clone())
} else {
None
}
}
pub fn put(&self, agent_name: &str, prompt: &str, response: &str) {
let key = make_key(agent_name, prompt);
if let Ok(mut inner) = self.inner.lock() {
let current_gen = inner.generation;
inner.entries.insert(
key,
CachedResult {
response: response.to_string(),
generation: current_gen,
},
);
}
}
pub fn invalidate(&self) {
if let Ok(mut inner) = self.inner.lock() {
inner.generation += 1;
}
}
pub fn len(&self) -> usize {
self.inner
.lock()
.map(|inner| inner.entries.len())
.unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for SubAgentCache {
fn default() -> Self {
Self::new()
}
}
fn make_key(agent_name: &str, prompt: &str) -> CacheKey {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
prompt.hash(&mut hasher);
(agent_name.to_string(), hasher.finish())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cache_hit_after_put() {
let cache = SubAgentCache::new();
cache.put("reviewer", "review this code", "looks good!");
assert_eq!(
cache.get("reviewer", "review this code"),
Some("looks good!".to_string())
);
}
#[test]
fn cache_miss_different_prompt() {
let cache = SubAgentCache::new();
cache.put("reviewer", "review this code", "looks good!");
assert_eq!(cache.get("reviewer", "review OTHER code"), None);
}
#[test]
fn cache_miss_different_agent() {
let cache = SubAgentCache::new();
cache.put("reviewer", "review this", "looks good!");
assert_eq!(cache.get("testgen", "review this"), None);
}
#[test]
fn invalidation_clears_stale_entries() {
let cache = SubAgentCache::new();
cache.put("reviewer", "prompt", "result");
assert!(cache.get("reviewer", "prompt").is_some());
cache.invalidate();
assert_eq!(cache.get("reviewer", "prompt"), None);
}
#[test]
fn entries_after_invalidation_are_fresh() {
let cache = SubAgentCache::new();
cache.put("reviewer", "old prompt", "old result");
cache.invalidate();
cache.put("reviewer", "new prompt", "new result");
assert_eq!(cache.get("reviewer", "old prompt"), None);
assert_eq!(
cache.get("reviewer", "new prompt"),
Some("new result".to_string())
);
}
#[test]
fn len_tracks_entries() {
let cache = SubAgentCache::new();
assert!(cache.is_empty());
cache.put("a", "p1", "r1");
cache.put("b", "p2", "r2");
assert_eq!(cache.len(), 2);
}
#[test]
fn shared_across_clones() {
let cache = SubAgentCache::new();
let clone = cache.clone();
cache.put("agent", "prompt", "result");
assert_eq!(clone.get("agent", "prompt"), Some("result".to_string()));
}
}