use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use sha2::{Digest, Sha256};
pub struct LlmCache {
entries: Arc<RwLock<HashMap<u64, CacheEntry>>>,
max_entries: usize,
}
#[derive(Clone)]
struct CacheEntry {
content: String,
access_count: u64,
}
impl LlmCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: Arc::new(RwLock::new(HashMap::new())),
max_entries,
}
}
pub fn cache_key(prompt: &str, system: Option<&str>, seed: Option<u64>) -> u64 {
let mut hasher = Sha256::new();
hasher.update(prompt.as_bytes());
if let Some(sys) = system {
hasher.update(sys.as_bytes());
}
if let Some(s) = seed {
hasher.update(s.to_le_bytes());
}
let hash = hasher.finalize();
u64::from_le_bytes(hash[..8].try_into().unwrap_or([0u8; 8]))
}
pub fn get(&self, key: u64) -> Option<String> {
let mut entries = self.entries.write().ok()?;
if let Some(entry) = entries.get_mut(&key) {
entry.access_count += 1;
Some(entry.content.clone())
} else {
None
}
}
pub fn insert(&self, key: u64, content: String) {
if let Ok(mut entries) = self.entries.write() {
if entries.len() >= self.max_entries {
if let Some((&evict_key, _)) = entries.iter().min_by_key(|(_, v)| v.access_count) {
entries.remove(&evict_key);
}
}
entries.insert(
key,
CacheEntry {
content,
access_count: 1,
},
);
}
}
pub fn len(&self) -> usize {
self.entries.read().map(|e| e.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&self) {
if let Ok(mut entries) = self.entries.write() {
entries.clear();
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_cache_insert_and_get() {
let cache = LlmCache::new(100);
let key = LlmCache::cache_key("test", None, Some(42));
cache.insert(key, "response".to_string());
assert_eq!(cache.get(key), Some("response".to_string()));
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_miss() {
let cache = LlmCache::new(100);
assert_eq!(cache.get(12345), None);
}
#[test]
fn test_cache_eviction() {
let cache = LlmCache::new(2);
cache.insert(1, "a".to_string());
cache.insert(2, "b".to_string());
cache.insert(3, "c".to_string()); assert_eq!(cache.len(), 2);
}
#[test]
fn test_cache_key_deterministic() {
let k1 = LlmCache::cache_key("prompt", Some("system"), Some(42));
let k2 = LlmCache::cache_key("prompt", Some("system"), Some(42));
assert_eq!(k1, k2);
}
#[test]
fn test_cache_key_differs() {
let k1 = LlmCache::cache_key("prompt1", None, None);
let k2 = LlmCache::cache_key("prompt2", None, None);
assert_ne!(k1, k2);
}
#[test]
fn test_cache_clear() {
let cache = LlmCache::new(100);
cache.insert(1, "a".to_string());
cache.insert(2, "b".to_string());
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
}