cognis_rag/retrievers/
caching.rs1use std::sync::Arc;
4
5use async_trait::async_trait;
6
7use cognis_core::wrappers::{CacheBackend, MemoryCache};
8use cognis_core::{Result, Runnable, RunnableConfig};
9
10use crate::document::Document;
11
12pub struct CachingRetriever {
15 inner: Arc<dyn Runnable<String, Vec<Document>>>,
16 backend: Arc<dyn CacheBackend<String, Vec<Document>>>,
17}
18
19impl CachingRetriever {
20 pub fn new(inner: Arc<dyn Runnable<String, Vec<Document>>>) -> Self {
22 Self {
23 inner,
24 backend: Arc::new(MemoryCache::<String, Vec<Document>>::new()),
25 }
26 }
27
28 pub fn with_backend(mut self, b: Arc<dyn CacheBackend<String, Vec<Document>>>) -> Self {
30 self.backend = b;
31 self
32 }
33}
34
35#[async_trait]
36impl Runnable<String, Vec<Document>> for CachingRetriever {
37 async fn invoke(&self, query: String, config: RunnableConfig) -> Result<Vec<Document>> {
38 if let Some(hit) = self.backend.get(&query).await {
39 return Ok(hit);
40 }
41 let out = self.inner.invoke(query.clone(), config).await?;
42 self.backend.set(query, out.clone()).await;
43 Ok(out)
44 }
45 fn name(&self) -> &str {
46 "CachingRetriever"
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53 use std::sync::atomic::{AtomicUsize, Ordering};
54
55 struct Counter {
56 calls: Arc<AtomicUsize>,
57 }
58 #[async_trait]
59 impl Runnable<String, Vec<Document>> for Counter {
60 async fn invoke(&self, q: String, _: RunnableConfig) -> Result<Vec<Document>> {
61 self.calls.fetch_add(1, Ordering::SeqCst);
62 Ok(vec![Document::new(q)])
63 }
64 }
65
66 #[tokio::test]
67 async fn second_identical_query_hits_cache() {
68 let calls = Arc::new(AtomicUsize::new(0));
69 let inner: Arc<dyn Runnable<String, Vec<Document>>> = Arc::new(Counter {
70 calls: calls.clone(),
71 });
72 let r = CachingRetriever::new(inner);
73 let _ = r
74 .invoke("a".into(), RunnableConfig::default())
75 .await
76 .unwrap();
77 let _ = r
78 .invoke("a".into(), RunnableConfig::default())
79 .await
80 .unwrap();
81 let _ = r
82 .invoke("b".into(), RunnableConfig::default())
83 .await
84 .unwrap();
85 assert_eq!(calls.load(Ordering::SeqCst), 2);
86 }
87}