agent_chain_core/
caches.rs

1//! Optional caching layer for language models.
2//!
3//! Distinct from provider-based prompt caching.
4//!
5//! A cache is useful for two reasons:
6//!
7//! 1. It can save you money by reducing the number of API calls you make to the LLM
8//!    provider if you're often requesting the same completion multiple times.
9//! 2. It can speed up your application by reducing the number of API calls you make to the
10//!    LLM provider.
11//!
12//! Mirrors `langchain_core.caches`.
13
14use async_trait::async_trait;
15use std::collections::HashMap;
16use std::sync::RwLock;
17
18use crate::outputs::Generation;
19
20/// The return type for cache operations - a sequence of Generations.
21pub type CacheReturnValue = Vec<Generation>;
22
23/// Interface for a caching layer for LLMs and Chat models.
24///
25/// The cache interface consists of the following methods:
26///
27/// - `lookup`: Look up a value based on a prompt and `llm_string`.
28/// - `update`: Update the cache based on a prompt and `llm_string`.
29/// - `clear`: Clear the cache.
30///
31/// In addition, the cache interface provides an async version of each method.
32///
33/// The default implementation of the async methods is to run the synchronous
34/// method directly. It's recommended to override the async methods
35/// and provide async implementations to avoid unnecessary overhead.
36#[async_trait]
37pub trait BaseCache: Send + Sync {
38    /// Look up based on `prompt` and `llm_string`.
39    ///
40    /// A cache implementation is expected to generate a key from the 2-tuple
41    /// of `prompt` and `llm_string` (e.g., by concatenating them with a delimiter).
42    ///
43    /// # Arguments
44    ///
45    /// * `prompt` - A string representation of the prompt.
46    ///   In the case of a chat model, the prompt is a non-trivial
47    ///   serialization of the prompt into the language model.
48    /// * `llm_string` - A string representation of the LLM configuration.
49    ///   This is used to capture the invocation parameters of the LLM
50    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
51    ///   These invocation parameters are serialized into a string representation.
52    ///
53    /// # Returns
54    ///
55    /// On a cache miss, return `None`. On a cache hit, return the cached value.
56    /// The cached value is a list of `Generation` (or subclasses).
57    fn lookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue>;
58
59    /// Update cache based on `prompt` and `llm_string`.
60    ///
61    /// The prompt and llm_string are used to generate a key for the cache.
62    /// The key should match that of the lookup method.
63    ///
64    /// # Arguments
65    ///
66    /// * `prompt` - A string representation of the prompt.
67    ///   In the case of a chat model, the prompt is a non-trivial
68    ///   serialization of the prompt into the language model.
69    /// * `llm_string` - A string representation of the LLM configuration.
70    ///   This is used to capture the invocation parameters of the LLM
71    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
72    ///   These invocation parameters are serialized into a string representation.
73    /// * `return_val` - The value to be cached. The value is a list of `Generation`
74    ///   (or subclasses).
75    fn update(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue);
76
77    /// Clear cache that can take additional keyword arguments.
78    fn clear(&self);
79
80    /// Async look up based on `prompt` and `llm_string`.
81    ///
82    /// A cache implementation is expected to generate a key from the 2-tuple
83    /// of `prompt` and `llm_string` (e.g., by concatenating them with a delimiter).
84    ///
85    /// # Arguments
86    ///
87    /// * `prompt` - A string representation of the prompt.
88    ///   In the case of a chat model, the prompt is a non-trivial
89    ///   serialization of the prompt into the language model.
90    /// * `llm_string` - A string representation of the LLM configuration.
91    ///   This is used to capture the invocation parameters of the LLM
92    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
93    ///   These invocation parameters are serialized into a string representation.
94    ///
95    /// # Returns
96    ///
97    /// On a cache miss, return `None`. On a cache hit, return the cached value.
98    /// The cached value is a list of `Generation` (or subclasses).
99    async fn alookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
100        self.lookup(prompt, llm_string)
101    }
102
103    /// Async update cache based on `prompt` and `llm_string`.
104    ///
105    /// The prompt and llm_string are used to generate a key for the cache.
106    /// The key should match that of the look up method.
107    ///
108    /// # Arguments
109    ///
110    /// * `prompt` - A string representation of the prompt.
111    ///   In the case of a chat model, the prompt is a non-trivial
112    ///   serialization of the prompt into the language model.
113    /// * `llm_string` - A string representation of the LLM configuration.
114    ///   This is used to capture the invocation parameters of the LLM
115    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
116    ///   These invocation parameters are serialized into a string representation.
117    /// * `return_val` - The value to be cached. The value is a list of `Generation`
118    ///   (or subclasses).
119    async fn aupdate(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
120        self.update(prompt, llm_string, return_val);
121    }
122
123    /// Async clear cache.
124    async fn aclear(&self) {
125        self.clear();
126    }
127}
128
129/// Cache that stores things in memory.
130#[derive(Debug)]
131pub struct InMemoryCache {
132    /// The internal cache storage using (prompt, llm_string) as key.
133    cache: RwLock<HashMap<(String, String), CacheReturnValue>>,
134    /// The maximum number of items to store in the cache.
135    /// If `None`, the cache has no maximum size.
136    maxsize: Option<usize>,
137    /// Order of keys for LRU-style eviction (stores keys in insertion order).
138    key_order: RwLock<Vec<(String, String)>>,
139}
140
141impl InMemoryCache {
142    /// Initialize with empty cache.
143    ///
144    /// # Arguments
145    ///
146    /// * `maxsize` - The maximum number of items to store in the cache.
147    ///   If `None`, the cache has no maximum size.
148    ///   If the cache exceeds the maximum size, the oldest items are removed.
149    ///
150    /// # Panics
151    ///
152    /// Panics if `maxsize` is less than or equal to `0`.
153    pub fn new(maxsize: Option<usize>) -> Self {
154        if let Some(size) = maxsize
155            && size == 0
156        {
157            panic!("maxsize must be greater than 0");
158        }
159        Self {
160            cache: RwLock::new(HashMap::new()),
161            maxsize,
162            key_order: RwLock::new(Vec::new()),
163        }
164    }
165
166    /// Create a new InMemoryCache with no maximum size.
167    pub fn unbounded() -> Self {
168        Self::new(None)
169    }
170}
171
172impl Default for InMemoryCache {
173    fn default() -> Self {
174        Self::unbounded()
175    }
176}
177
178#[async_trait]
179impl BaseCache for InMemoryCache {
180    fn lookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
181        let cache = self.cache.read().expect("Lock poisoned");
182        cache
183            .get(&(prompt.to_string(), llm_string.to_string()))
184            .cloned()
185    }
186
187    fn update(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
188        let key = (prompt.to_string(), llm_string.to_string());
189        let mut cache = self.cache.write().expect("Lock poisoned");
190        let mut key_order = self.key_order.write().expect("Lock poisoned");
191
192        // If key already exists, remove it from the order list (it will be added at the end)
193        if cache.contains_key(&key) {
194            key_order.retain(|k| k != &key);
195        } else if let Some(maxsize) = self.maxsize {
196            // If at capacity, remove the oldest item
197            if cache.len() >= maxsize
198                && let Some(oldest_key) = key_order.first().cloned()
199            {
200                cache.remove(&oldest_key);
201                key_order.remove(0);
202            }
203        }
204
205        cache.insert(key.clone(), return_val);
206        key_order.push(key);
207    }
208
209    fn clear(&self) {
210        let mut cache = self.cache.write().expect("Lock poisoned");
211        let mut key_order = self.key_order.write().expect("Lock poisoned");
212        cache.clear();
213        key_order.clear();
214    }
215
216    async fn alookup(&self, prompt: &str, llm_string: &str) -> Option<CacheReturnValue> {
217        self.lookup(prompt, llm_string)
218    }
219
220    async fn aupdate(&self, prompt: &str, llm_string: &str, return_val: CacheReturnValue) {
221        self.update(prompt, llm_string, return_val);
222    }
223
224    async fn aclear(&self) {
225        self.clear();
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use crate::outputs::Generation;
233
234    #[test]
235    fn test_in_memory_cache_new() {
236        let cache = InMemoryCache::new(None);
237        assert!(cache.lookup("prompt", "llm").is_none());
238    }
239
240    #[test]
241    fn test_in_memory_cache_unbounded() {
242        let cache = InMemoryCache::unbounded();
243        assert!(cache.lookup("prompt", "llm").is_none());
244    }
245
246    #[test]
247    fn test_in_memory_cache_default() {
248        let cache = InMemoryCache::default();
249        assert!(cache.lookup("prompt", "llm").is_none());
250    }
251
252    #[test]
253    #[should_panic(expected = "maxsize must be greater than 0")]
254    fn test_in_memory_cache_zero_maxsize() {
255        InMemoryCache::new(Some(0));
256    }
257
258    #[test]
259    fn test_in_memory_cache_lookup_miss() {
260        let cache = InMemoryCache::new(None);
261        let result = cache.lookup("prompt", "llm_string");
262        assert!(result.is_none());
263    }
264
265    #[test]
266    fn test_in_memory_cache_update_and_lookup() {
267        let cache = InMemoryCache::new(None);
268        let generations = vec![Generation::new("Hello, world!")];
269
270        cache.update("prompt", "llm_string", generations.clone());
271
272        let result = cache.lookup("prompt", "llm_string");
273        assert!(result.is_some());
274        let cached = result.unwrap();
275        assert_eq!(cached.len(), 1);
276        assert_eq!(cached[0].text, "Hello, world!");
277    }
278
279    #[test]
280    fn test_in_memory_cache_clear() {
281        let cache = InMemoryCache::new(None);
282        let generations = vec![Generation::new("Hello")];
283
284        cache.update("prompt1", "llm", generations.clone());
285        cache.update("prompt2", "llm", generations.clone());
286
287        assert!(cache.lookup("prompt1", "llm").is_some());
288        assert!(cache.lookup("prompt2", "llm").is_some());
289
290        cache.clear();
291
292        assert!(cache.lookup("prompt1", "llm").is_none());
293        assert!(cache.lookup("prompt2", "llm").is_none());
294    }
295
296    #[test]
297    fn test_in_memory_cache_maxsize() {
298        let cache = InMemoryCache::new(Some(2));
299
300        cache.update("prompt1", "llm", vec![Generation::new("1")]);
301        cache.update("prompt2", "llm", vec![Generation::new("2")]);
302
303        assert!(cache.lookup("prompt1", "llm").is_some());
304        assert!(cache.lookup("prompt2", "llm").is_some());
305
306        // Adding third item should evict the first (oldest)
307        cache.update("prompt3", "llm", vec![Generation::new("3")]);
308
309        assert!(cache.lookup("prompt1", "llm").is_none()); // Evicted
310        assert!(cache.lookup("prompt2", "llm").is_some());
311        assert!(cache.lookup("prompt3", "llm").is_some());
312    }
313
314    #[test]
315    fn test_in_memory_cache_update_existing_key() {
316        let cache = InMemoryCache::new(None);
317
318        cache.update("prompt", "llm", vec![Generation::new("first")]);
319        let result = cache.lookup("prompt", "llm").unwrap();
320        assert_eq!(result[0].text, "first");
321
322        cache.update("prompt", "llm", vec![Generation::new("second")]);
323        let result = cache.lookup("prompt", "llm").unwrap();
324        assert_eq!(result[0].text, "second");
325    }
326
327    #[test]
328    fn test_in_memory_cache_different_llm_strings() {
329        let cache = InMemoryCache::new(None);
330
331        cache.update("prompt", "llm1", vec![Generation::new("from llm1")]);
332        cache.update("prompt", "llm2", vec![Generation::new("from llm2")]);
333
334        let result1 = cache.lookup("prompt", "llm1").unwrap();
335        assert_eq!(result1[0].text, "from llm1");
336
337        let result2 = cache.lookup("prompt", "llm2").unwrap();
338        assert_eq!(result2[0].text, "from llm2");
339    }
340
341    #[tokio::test]
342    async fn test_in_memory_cache_alookup() {
343        let cache = InMemoryCache::new(None);
344        let generations = vec![Generation::new("async test")];
345
346        cache.update("prompt", "llm", generations);
347
348        let result = cache.alookup("prompt", "llm").await;
349        assert!(result.is_some());
350        assert_eq!(result.unwrap()[0].text, "async test");
351    }
352
353    #[tokio::test]
354    async fn test_in_memory_cache_aupdate() {
355        let cache = InMemoryCache::new(None);
356        let generations = vec![Generation::new("async update")];
357
358        cache.aupdate("prompt", "llm", generations).await;
359
360        let result = cache.lookup("prompt", "llm");
361        assert!(result.is_some());
362        assert_eq!(result.unwrap()[0].text, "async update");
363    }
364
365    #[tokio::test]
366    async fn test_in_memory_cache_aclear() {
367        let cache = InMemoryCache::new(None);
368
369        cache.update("prompt", "llm", vec![Generation::new("test")]);
370        assert!(cache.lookup("prompt", "llm").is_some());
371
372        cache.aclear().await;
373        assert!(cache.lookup("prompt", "llm").is_none());
374    }
375
376    #[test]
377    fn test_in_memory_cache_maxsize_update_refreshes_position() {
378        let cache = InMemoryCache::new(Some(2));
379
380        cache.update("prompt1", "llm", vec![Generation::new("1")]);
381        cache.update("prompt2", "llm", vec![Generation::new("2")]);
382
383        // Update prompt1 - should move it to the end of the queue
384        cache.update("prompt1", "llm", vec![Generation::new("1 updated")]);
385
386        // Adding prompt3 should evict prompt2 (now oldest) instead of prompt1
387        cache.update("prompt3", "llm", vec![Generation::new("3")]);
388
389        assert!(cache.lookup("prompt1", "llm").is_some()); // Still present
390        assert!(cache.lookup("prompt2", "llm").is_none()); // Evicted
391        assert!(cache.lookup("prompt3", "llm").is_some()); // New
392    }
393
394    #[test]
395    fn test_in_memory_cache_multiple_generations() {
396        let cache = InMemoryCache::new(None);
397        let generations = vec![
398            Generation::new("First generation"),
399            Generation::new("Second generation"),
400            Generation::new("Third generation"),
401        ];
402
403        cache.update("prompt", "llm", generations);
404
405        let result = cache.lookup("prompt", "llm").unwrap();
406        assert_eq!(result.len(), 3);
407        assert_eq!(result[0].text, "First generation");
408        assert_eq!(result[1].text, "Second generation");
409        assert_eq!(result[2].text, "Third generation");
410    }
411}