Skip to main content

datasynth_core/llm/
cache.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use sha2::{Digest, Sha256};
5
6/// In-memory LRU cache for LLM responses.
7pub struct LlmCache {
8    entries: Arc<RwLock<HashMap<u64, CacheEntry>>>,
9    max_entries: usize,
10}
11
12#[derive(Clone)]
13struct CacheEntry {
14    content: String,
15    access_count: u64,
16}
17
18impl LlmCache {
19    /// Create a new cache with maximum entry count.
20    pub fn new(max_entries: usize) -> Self {
21        Self {
22            entries: Arc::new(RwLock::new(HashMap::new())),
23            max_entries,
24        }
25    }
26
27    /// Compute cache key from prompt + system + seed.
28    pub fn cache_key(prompt: &str, system: Option<&str>, seed: Option<u64>) -> u64 {
29        let mut hasher = Sha256::new();
30        hasher.update(prompt.as_bytes());
31        if let Some(sys) = system {
32            hasher.update(sys.as_bytes());
33        }
34        if let Some(s) = seed {
35            hasher.update(s.to_le_bytes());
36        }
37        let hash = hasher.finalize();
38        u64::from_le_bytes(hash[..8].try_into().unwrap_or([0u8; 8]))
39    }
40
41    /// Get a cached response.
42    pub fn get(&self, key: u64) -> Option<String> {
43        let mut entries = self.entries.write().ok()?;
44        if let Some(entry) = entries.get_mut(&key) {
45            entry.access_count += 1;
46            Some(entry.content.clone())
47        } else {
48            None
49        }
50    }
51
52    /// Insert a response into the cache.
53    pub fn insert(&self, key: u64, content: String) {
54        if let Ok(mut entries) = self.entries.write() {
55            // Evict least-accessed entry if at capacity
56            if entries.len() >= self.max_entries {
57                if let Some((&evict_key, _)) = entries.iter().min_by_key(|(_, v)| v.access_count) {
58                    entries.remove(&evict_key);
59                }
60            }
61            entries.insert(
62                key,
63                CacheEntry {
64                    content,
65                    access_count: 1,
66                },
67            );
68        }
69    }
70
71    /// Number of entries in cache.
72    pub fn len(&self) -> usize {
73        self.entries.read().map(|e| e.len()).unwrap_or(0)
74    }
75
76    /// Whether the cache is empty.
77    pub fn is_empty(&self) -> bool {
78        self.len() == 0
79    }
80
81    /// Clear all entries.
82    pub fn clear(&self) {
83        if let Ok(mut entries) = self.entries.write() {
84            entries.clear();
85        }
86    }
87}
88
89#[cfg(test)]
90#[allow(clippy::unwrap_used)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_cache_insert_and_get() {
96        let cache = LlmCache::new(100);
97        let key = LlmCache::cache_key("test", None, Some(42));
98        cache.insert(key, "response".to_string());
99        assert_eq!(cache.get(key), Some("response".to_string()));
100        assert_eq!(cache.len(), 1);
101    }
102
103    #[test]
104    fn test_cache_miss() {
105        let cache = LlmCache::new(100);
106        assert_eq!(cache.get(12345), None);
107    }
108
109    #[test]
110    fn test_cache_eviction() {
111        let cache = LlmCache::new(2);
112        cache.insert(1, "a".to_string());
113        cache.insert(2, "b".to_string());
114        cache.insert(3, "c".to_string()); // Should evict one entry
115        assert_eq!(cache.len(), 2);
116    }
117
118    #[test]
119    fn test_cache_key_deterministic() {
120        let k1 = LlmCache::cache_key("prompt", Some("system"), Some(42));
121        let k2 = LlmCache::cache_key("prompt", Some("system"), Some(42));
122        assert_eq!(k1, k2);
123    }
124
125    #[test]
126    fn test_cache_key_differs() {
127        let k1 = LlmCache::cache_key("prompt1", None, None);
128        let k2 = LlmCache::cache_key("prompt2", None, None);
129        assert_ne!(k1, k2);
130    }
131
132    #[test]
133    fn test_cache_clear() {
134        let cache = LlmCache::new(100);
135        cache.insert(1, "a".to_string());
136        cache.insert(2, "b".to_string());
137        assert_eq!(cache.len(), 2);
138        cache.clear();
139        assert!(cache.is_empty());
140    }
141}