datasynth_core/llm/
cache.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use sha2::{Digest, Sha256};
5
6pub struct LlmCache {
8 entries: Arc<RwLock<HashMap<u64, CacheEntry>>>,
9 max_entries: usize,
10}
11
12#[derive(Clone)]
13struct CacheEntry {
14 content: String,
15 access_count: u64,
16}
17
18impl LlmCache {
19 pub fn new(max_entries: usize) -> Self {
21 Self {
22 entries: Arc::new(RwLock::new(HashMap::new())),
23 max_entries,
24 }
25 }
26
27 pub fn cache_key(prompt: &str, system: Option<&str>, seed: Option<u64>) -> u64 {
29 let mut hasher = Sha256::new();
30 hasher.update(prompt.as_bytes());
31 if let Some(sys) = system {
32 hasher.update(sys.as_bytes());
33 }
34 if let Some(s) = seed {
35 hasher.update(s.to_le_bytes());
36 }
37 let hash = hasher.finalize();
38 u64::from_le_bytes(hash[..8].try_into().unwrap_or([0u8; 8]))
39 }
40
41 pub fn get(&self, key: u64) -> Option<String> {
43 let mut entries = self.entries.write().ok()?;
44 if let Some(entry) = entries.get_mut(&key) {
45 entry.access_count += 1;
46 Some(entry.content.clone())
47 } else {
48 None
49 }
50 }
51
52 pub fn insert(&self, key: u64, content: String) {
54 if let Ok(mut entries) = self.entries.write() {
55 if entries.len() >= self.max_entries {
57 if let Some((&evict_key, _)) = entries.iter().min_by_key(|(_, v)| v.access_count) {
58 entries.remove(&evict_key);
59 }
60 }
61 entries.insert(
62 key,
63 CacheEntry {
64 content,
65 access_count: 1,
66 },
67 );
68 }
69 }
70
71 pub fn len(&self) -> usize {
73 self.entries.read().map(|e| e.len()).unwrap_or(0)
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.len() == 0
79 }
80
81 pub fn clear(&self) {
83 if let Ok(mut entries) = self.entries.write() {
84 entries.clear();
85 }
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92
93 #[test]
94 fn test_cache_insert_and_get() {
95 let cache = LlmCache::new(100);
96 let key = LlmCache::cache_key("test", None, Some(42));
97 cache.insert(key, "response".to_string());
98 assert_eq!(cache.get(key), Some("response".to_string()));
99 assert_eq!(cache.len(), 1);
100 }
101
102 #[test]
103 fn test_cache_miss() {
104 let cache = LlmCache::new(100);
105 assert_eq!(cache.get(12345), None);
106 }
107
108 #[test]
109 fn test_cache_eviction() {
110 let cache = LlmCache::new(2);
111 cache.insert(1, "a".to_string());
112 cache.insert(2, "b".to_string());
113 cache.insert(3, "c".to_string()); assert_eq!(cache.len(), 2);
115 }
116
117 #[test]
118 fn test_cache_key_deterministic() {
119 let k1 = LlmCache::cache_key("prompt", Some("system"), Some(42));
120 let k2 = LlmCache::cache_key("prompt", Some("system"), Some(42));
121 assert_eq!(k1, k2);
122 }
123
124 #[test]
125 fn test_cache_key_differs() {
126 let k1 = LlmCache::cache_key("prompt1", None, None);
127 let k2 = LlmCache::cache_key("prompt2", None, None);
128 assert_ne!(k1, k2);
129 }
130
131 #[test]
132 fn test_cache_clear() {
133 let cache = LlmCache::new(100);
134 cache.insert(1, "a".to_string());
135 cache.insert(2, "b".to_string());
136 assert_eq!(cache.len(), 2);
137 cache.clear();
138 assert!(cache.is_empty());
139 }
140}