Skip to main content

cognis_rag/retrievers/
caching.rs

1//! Cache retriever results by query string.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use cognis_core::wrappers::{CacheBackend, MemoryCache};
8use cognis_core::{Result, Runnable, RunnableConfig};
9
10use crate::document::Document;
11
12/// Wraps any `Runnable<String, Vec<Document>>` so identical query
13/// strings return cached document lists.
14pub struct CachingRetriever {
15    inner: Arc<dyn Runnable<String, Vec<Document>>>,
16    backend: Arc<dyn CacheBackend<String, Vec<Document>>>,
17}
18
19impl CachingRetriever {
20    /// Build with a memory cache.
21    pub fn new(inner: Arc<dyn Runnable<String, Vec<Document>>>) -> Self {
22        Self {
23            inner,
24            backend: Arc::new(MemoryCache::<String, Vec<Document>>::new()),
25        }
26    }
27
28    /// Override the cache backend.
29    pub fn with_backend(mut self, b: Arc<dyn CacheBackend<String, Vec<Document>>>) -> Self {
30        self.backend = b;
31        self
32    }
33}
34
35#[async_trait]
36impl Runnable<String, Vec<Document>> for CachingRetriever {
37    async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
38        if let Some(hit) = self.backend.get(&query).await {
39            return Ok(hit);
40        }
41        let out = self.inner.invoke(query.clone(), config).await?;
42        self.backend.set(query, out.clone()).await;
43        Ok(out)
44    }
45    fn name(&self) -> &str {
46        "CachingRetriever"
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53    use std::sync::atomic::{AtomicUsize, Ordering};
54
55    struct Counter {
56        calls: Arc<AtomicUsize>,
57    }
58    #[async_trait]
59    impl Runnable<String, Vec<Document>> for Counter {
60        async fn invoke(&self, q: String, _: RunnableConfig) -> Result<Vec<Document>> {
61            self.calls.fetch_add(1, Ordering::SeqCst);
62            Ok(vec![Document::new(q)])
63        }
64    }
65
66    #[tokio::test]
67    async fn second_identical_query_hits_cache() {
68        let calls = Arc::new(AtomicUsize::new(0));
69        let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(Counter {
70            calls: calls.clone(),
71        });
72        let r = CachingRetriever::new(inner);
73        let _ = r
74            .invoke("a".into(), RunnableConfig::default())
75            .await
76            .unwrap();
77        let _ = r
78            .invoke("a".into(), RunnableConfig::default())
79            .await
80            .unwrap();
81        let _ = r
82            .invoke("b".into(), RunnableConfig::default())
83            .await
84            .unwrap();
85        assert_eq!(calls.load(Ordering::SeqCst), 2);
86    }
87}