use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::time::{Duration, Instant};
use crate::lobe::LobeOutput;
pub struct LobeCache {
entries: HashMap<(String, u64), CacheEntry>,
ttl: Duration,
}
struct CacheEntry {
output: LobeOutput,
input: String, created_at: Instant,
}
impl LobeCache {
pub fn new(ttl: Duration) -> Self {
Self {
entries: HashMap::new(),
ttl,
}
}
pub fn get(&self, lobe_name: &str, input: &str) -> Option<&LobeOutput> {
let key = (lobe_name.to_string(), hash_input(input));
self.entries.get(&key).and_then(|entry| {
if entry.input != input {
return None;
}
if entry.created_at.elapsed() < self.ttl {
Some(&entry.output)
} else {
None
}
})
}
pub fn put(&mut self, lobe_name: &str, input: &str, output: LobeOutput) {
let key = (lobe_name.to_string(), hash_input(input));
self.entries.insert(
key,
CacheEntry {
output,
input: input.to_string(),
created_at: Instant::now(),
},
);
}
pub fn evict_expired(&mut self) {
self.entries
.retain(|_, entry| entry.created_at.elapsed() < self.ttl);
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
fn hash_input(input: &str) -> u64 {
let mut hasher = DefaultHasher::new();
input.hash(&mut hasher);
hasher.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_hit() {
let mut cache = LobeCache::new(Duration::from_secs(60));
let output = LobeOutput::new("analysis result", 0.9);
cache.put("analyst", "test input", output);
let cached = cache.get("analyst", "test input");
assert!(cached.is_some());
assert_eq!(cached.unwrap().content, "analysis result");
}
#[test]
fn test_cache_miss_different_input() {
let mut cache = LobeCache::new(Duration::from_secs(60));
cache.put("analyst", "input A", LobeOutput::new("result A", 0.9));
assert!(cache.get("analyst", "input B").is_none());
}
#[test]
fn test_cache_miss_different_lobe() {
let mut cache = LobeCache::new(Duration::from_secs(60));
cache.put("analyst", "input", LobeOutput::new("result", 0.9));
assert!(cache.get("critic", "input").is_none());
}
#[test]
fn test_cache_expiry() {
let mut cache = LobeCache::new(Duration::ZERO);
cache.put("analyst", "input", LobeOutput::new("result", 0.9));
assert!(cache.get("analyst", "input").is_none());
}
#[test]
fn test_cache_overwrite() {
let mut cache = LobeCache::new(Duration::from_secs(60));
cache.put("analyst", "input", LobeOutput::new("old", 0.5));
cache.put("analyst", "input", LobeOutput::new("new", 0.9));
let cached = cache.get("analyst", "input").unwrap();
assert_eq!(cached.content, "new");
assert_eq!(cache.len(), 1);
}
#[test]
fn test_evict_expired() {
let mut cache = LobeCache::new(Duration::ZERO);
cache.put("a", "input", LobeOutput::new("a", 0.5));
cache.put("b", "input", LobeOutput::new("b", 0.5));
assert_eq!(cache.len(), 2);
cache.evict_expired();
assert!(cache.is_empty());
}
#[test]
fn test_clear() {
let mut cache = LobeCache::new(Duration::from_secs(60));
cache.put("a", "input", LobeOutput::new("a", 0.5));
cache.put("b", "input", LobeOutput::new("b", 0.5));
cache.clear();
assert!(cache.is_empty());
}
}