datasynth_core/llm/
cache.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use sha2::{Digest, Sha256};
5
6pub struct LlmCache {
8 entries: Arc<RwLock<HashMap<u64, CacheEntry>>>,
9 max_entries: usize,
10}
11
12#[derive(Clone)]
13struct CacheEntry {
14 content: String,
15 access_count: u64,
16}
17
18impl LlmCache {
19 pub fn new(max_entries: usize) -> Self {
21 Self {
22 entries: Arc::new(RwLock::new(HashMap::new())),
23 max_entries,
24 }
25 }
26
27 pub fn cache_key(prompt: &str, system: Option<&str>, seed: Option<u64>) -> u64 {
29 let mut hasher = Sha256::new();
30 hasher.update(prompt.as_bytes());
31 if let Some(sys) = system {
32 hasher.update(sys.as_bytes());
33 }
34 if let Some(s) = seed {
35 hasher.update(s.to_le_bytes());
36 }
37 let hash = hasher.finalize();
38 u64::from_le_bytes(hash[..8].try_into().unwrap_or([0u8; 8]))
39 }
40
41 pub fn get(&self, key: u64) -> Option<String> {
43 let mut entries = self.entries.write().ok()?;
44 if let Some(entry) = entries.get_mut(&key) {
45 entry.access_count += 1;
46 Some(entry.content.clone())
47 } else {
48 None
49 }
50 }
51
52 pub fn insert(&self, key: u64, content: String) {
54 if let Ok(mut entries) = self.entries.write() {
55 if entries.len() >= self.max_entries {
57 if let Some((&evict_key, _)) = entries.iter().min_by_key(|(_, v)| v.access_count) {
58 entries.remove(&evict_key);
59 }
60 }
61 entries.insert(
62 key,
63 CacheEntry {
64 content,
65 access_count: 1,
66 },
67 );
68 }
69 }
70
71 pub fn len(&self) -> usize {
73 self.entries.read().map(|e| e.len()).unwrap_or(0)
74 }
75
76 pub fn is_empty(&self) -> bool {
78 self.len() == 0
79 }
80
81 pub fn clear(&self) {
83 if let Ok(mut entries) = self.entries.write() {
84 entries.clear();
85 }
86 }
87}
88
89#[cfg(test)]
90#[allow(clippy::unwrap_used)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_cache_insert_and_get() {
96 let cache = LlmCache::new(100);
97 let key = LlmCache::cache_key("test", None, Some(42));
98 cache.insert(key, "response".to_string());
99 assert_eq!(cache.get(key), Some("response".to_string()));
100 assert_eq!(cache.len(), 1);
101 }
102
103 #[test]
104 fn test_cache_miss() {
105 let cache = LlmCache::new(100);
106 assert_eq!(cache.get(12345), None);
107 }
108
109 #[test]
110 fn test_cache_eviction() {
111 let cache = LlmCache::new(2);
112 cache.insert(1, "a".to_string());
113 cache.insert(2, "b".to_string());
114 cache.insert(3, "c".to_string()); assert_eq!(cache.len(), 2);
116 }
117
118 #[test]
119 fn test_cache_key_deterministic() {
120 let k1 = LlmCache::cache_key("prompt", Some("system"), Some(42));
121 let k2 = LlmCache::cache_key("prompt", Some("system"), Some(42));
122 assert_eq!(k1, k2);
123 }
124
125 #[test]
126 fn test_cache_key_differs() {
127 let k1 = LlmCache::cache_key("prompt1", None, None);
128 let k2 = LlmCache::cache_key("prompt2", None, None);
129 assert_ne!(k1, k2);
130 }
131
132 #[test]
133 fn test_cache_clear() {
134 let cache = LlmCache::new(100);
135 cache.insert(1, "a".to_string());
136 cache.insert(2, "b".to_string());
137 assert_eq!(cache.len(), 2);
138 cache.clear();
139 assert!(cache.is_empty());
140 }
141}