1use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::RwLock;
17
18use crate::outputs::Generation;
19
20pub type CacheReturnValue = Vec<Generation>;
22
23#[async_trait]
37pub trait BaseCache: Send + Sync {
38 fn lookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue>;
58
59 fn update(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue);
76
77 fn clear(&self);
79
80 async fn alookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
100 self.lookup(prompt, llm_string)
101 }
102
103 async fn aupdate(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
120 self.update(prompt, llm_string, return_val);
121 }
122
123 async fn aclear(&self) {
125 self.clear();
126 }
127}
128
129#[derive(Debug)]
131pub struct InMemoryCache {
132 cache: RwLock<HashMap<(String, String), CacheReturnValue>>,
134 maxsize: Option<usize>,
137 key_order: RwLock<Vec<(String, String)>>,
139}
140
141impl InMemoryCache {
142 pub fn new(maxsize: Option<usize>) -> Self {
154 if let Some(size) = maxsize
155 && size == 0
156 {
157 panic!("maxsize must be greater than 0");
158 }
159 Self {
160 cache: RwLock::new(HashMap::new()),
161 maxsize,
162 key_order: RwLock::new(Vec::new()),
163 }
164 }
165
166 pub fn unbounded() -> Self {
168 Self::new(None)
169 }
170}
171
172impl Default for InMemoryCache {
173 fn default() -> Self {
174 Self::unbounded()
175 }
176}
177
178#[async_trait]
179impl BaseCache for InMemoryCache {
180 fn lookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
181 let cache = self.cache.read().expect("Lock poisoned");
182 cache
183 .get(&(prompt.to_string(), llm_string.to_string()))
184 .cloned()
185 }
186
187 fn update(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
188 let key = (prompt.to_string(), llm_string.to_string());
189 let mut cache = self.cache.write().expect("Lock poisoned");
190 let mut key_order = self.key_order.write().expect("Lock poisoned");
191
192 if cache.contains_key(&key) {
194 key_order.retain(|k| k != &key);
195 } else if let Some(maxsize) = self.maxsize {
196 if cache.len() >= maxsize
198 && let Some(oldest_key) = key_order.first().cloned()
199 {
200 cache.remove(&oldest_key);
201 key_order.remove(0);
202 }
203 }
204
205 cache.insert(key.clone(), return_val);
206 key_order.push(key);
207 }
208
209 fn clear(&self) {
210 let mut cache = self.cache.write().expect("Lock poisoned");
211 let mut key_order = self.key_order.write().expect("Lock poisoned");
212 cache.clear();
213 key_order.clear();
214 }
215
216 async fn alookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
217 self.lookup(prompt, llm_string)
218 }
219
220 async fn aupdate(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
221 self.update(prompt, llm_string, return_val);
222 }
223
224 async fn aclear(&self) {
225 self.clear();
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::outputs::Generation;
233
234 #[test]
235 fn test_in_memory_cache_new() {
236 let cache = InMemoryCache::new(None);
237 assert!(cache.lookup("prompt", "llm").is_none());
238 }
239
240 #[test]
241 fn test_in_memory_cache_unbounded() {
242 let cache = InMemoryCache::unbounded();
243 assert!(cache.lookup("prompt", "llm").is_none());
244 }
245
246 #[test]
247 fn test_in_memory_cache_default() {
248 let cache = InMemoryCache::default();
249 assert!(cache.lookup("prompt", "llm").is_none());
250 }
251
252 #[test]
253 #[should_panic(expected = "maxsize must be greater than 0")]
254 fn test_in_memory_cache_zero_maxsize() {
255 InMemoryCache::new(Some(0));
256 }
257
258 #[test]
259 fn test_in_memory_cache_lookup_miss() {
260 let cache = InMemoryCache::new(None);
261 let result = cache.lookup("prompt", "llm_string");
262 assert!(result.is_none());
263 }
264
265 #[test]
266 fn test_in_memory_cache_update_and_lookup() {
267 let cache = InMemoryCache::new(None);
268 let generations = vec![Generation::new("Hello, world!")];
269
270 cache.update("prompt", "llm_string", generations.clone());
271
272 let result = cache.lookup("prompt", "llm_string");
273 assert!(result.is_some());
274 let cached = result.unwrap();
275 assert_eq!(cached.len(), 1);
276 assert_eq!(cached[0].text, "Hello, world!");
277 }
278
279 #[test]
280 fn test_in_memory_cache_clear() {
281 let cache = InMemoryCache::new(None);
282 let generations = vec![Generation::new("Hello")];
283
284 cache.update("prompt1", "llm", generations.clone());
285 cache.update("prompt2", "llm", generations.clone());
286
287 assert!(cache.lookup("prompt1", "llm").is_some());
288 assert!(cache.lookup("prompt2", "llm").is_some());
289
290 cache.clear();
291
292 assert!(cache.lookup("prompt1", "llm").is_none());
293 assert!(cache.lookup("prompt2", "llm").is_none());
294 }
295
296 #[test]
297 fn test_in_memory_cache_maxsize() {
298 let cache = InMemoryCache::new(Some(2));
299
300 cache.update("prompt1", "llm", vec![Generation::new("1")]);
301 cache.update("prompt2", "llm", vec![Generation::new("2")]);
302
303 assert!(cache.lookup("prompt1", "llm").is_some());
304 assert!(cache.lookup("prompt2", "llm").is_some());
305
306 cache.update("prompt3", "llm", vec![Generation::new("3")]);
308
309 assert!(cache.lookup("prompt1", "llm").is_none()); assert!(cache.lookup("prompt2", "llm").is_some());
311 assert!(cache.lookup("prompt3", "llm").is_some());
312 }
313
314 #[test]
315 fn test_in_memory_cache_update_existing_key() {
316 let cache = InMemoryCache::new(None);
317
318 cache.update("prompt", "llm", vec![Generation::new("first")]);
319 let result = cache.lookup("prompt", "llm").unwrap();
320 assert_eq!(result[0].text, "first");
321
322 cache.update("prompt", "llm", vec![Generation::new("second")]);
323 let result = cache.lookup("prompt", "llm").unwrap();
324 assert_eq!(result[0].text, "second");
325 }
326
327 #[test]
328 fn test_in_memory_cache_different_llm_strings() {
329 let cache = InMemoryCache::new(None);
330
331 cache.update("prompt", "llm1", vec![Generation::new("from llm1")]);
332 cache.update("prompt", "llm2", vec![Generation::new("from llm2")]);
333
334 let result1 = cache.lookup("prompt", "llm1").unwrap();
335 assert_eq!(result1[0].text, "from llm1");
336
337 let result2 = cache.lookup("prompt", "llm2").unwrap();
338 assert_eq!(result2[0].text, "from llm2");
339 }
340
341 #[tokio::test]
342 async fn test_in_memory_cache_alookup() {
343 let cache = InMemoryCache::new(None);
344 let generations = vec![Generation::new("async test")];
345
346 cache.update("prompt", "llm", generations);
347
348 let result = cache.alookup("prompt", "llm").await;
349 assert!(result.is_some());
350 assert_eq!(result.unwrap()[0].text, "async test");
351 }
352
353 #[tokio::test]
354 async fn test_in_memory_cache_aupdate() {
355 let cache = InMemoryCache::new(None);
356 let generations = vec![Generation::new("async update")];
357
358 cache.aupdate("prompt", "llm", generations).await;
359
360 let result = cache.lookup("prompt", "llm");
361 assert!(result.is_some());
362 assert_eq!(result.unwrap()[0].text, "async update");
363 }
364
365 #[tokio::test]
366 async fn test_in_memory_cache_aclear() {
367 let cache = InMemoryCache::new(None);
368
369 cache.update("prompt", "llm", vec![Generation::new("test")]);
370 assert!(cache.lookup("prompt", "llm").is_some());
371
372 cache.aclear().await;
373 assert!(cache.lookup("prompt", "llm").is_none());
374 }
375
376 #[test]
377 fn test_in_memory_cache_maxsize_update_refreshes_position() {
378 let cache = InMemoryCache::new(Some(2));
379
380 cache.update("prompt1", "llm", vec![Generation::new("1")]);
381 cache.update("prompt2", "llm", vec![Generation::new("2")]);
382
383 cache.update("prompt1", "llm", vec![Generation::new("1 updated")]);
385
386 cache.update("prompt3", "llm", vec![Generation::new("3")]);
388
389 assert!(cache.lookup("prompt1", "llm").is_some()); assert!(cache.lookup("prompt2", "llm").is_none()); assert!(cache.lookup("prompt3", "llm").is_some()); }
393
394 #[test]
395 fn test_in_memory_cache_multiple_generations() {
396 let cache = InMemoryCache::new(None);
397 let generations = vec![
398 Generation::new("First generation"),
399 Generation::new("Second generation"),
400 Generation::new("Third generation"),
401 ];
402
403 cache.update("prompt", "llm", generations);
404
405 let result = cache.lookup("prompt", "llm").unwrap();
406 assert_eq!(result.len(), 3);
407 assert_eq!(result[0].text, "First generation");
408 assert_eq!(result[1].text, "Second generation");
409 assert_eq!(result[2].text, "Third generation");
410 }
411}