Skip to main content

cognis_core/wrappers/
cache.rs

1//! In-memory key-based cache wrapper.
2
3use std::collections::HashMap;
4use std::hash::Hash;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::RwLock;
10
11use crate::runnable::{Runnable, RunnableConfig};
12use crate::Result;
13
14/// Pluggable cache backend.
15///
16/// Implementations decide eviction, TTL, and persistence. The default
17/// [`MemoryCache`] is unbounded and lives only as long as the process.
18#[async_trait]
19pub trait CacheBackend<K, V>: Send + Sync
20where
21    K: Send + Sync + 'static,
22    V: Send + Sync + 'static,
23{
24    /// Look up a value. `None` if not present.
25    async fn get(&self, key: &K) -> Option<V>;
26    /// Store a value.
27    async fn set(&self, key: K, value: V);
28}
29
30/// Unbounded in-memory `HashMap`-backed cache.
31pub struct MemoryCache<K, V> {
32    inner: RwLock<HashMap<K, V>>,
33}
34
35impl<K, V> Default for MemoryCache<K, V> {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl<K, V> MemoryCache<K, V> {
42    /// Construct an empty `MemoryCache`.
43    pub fn new() -> Self {
44        Self {
45            inner: RwLock::new(HashMap::new()),
46        }
47    }
48}
49
50#[async_trait]
51impl<K, V> CacheBackend<K, V> for MemoryCache<K, V>
52where
53    K: Hash + Eq + Send + Sync + Clone + 'static,
54    V: Clone + Send + Sync + 'static,
55{
56    async fn get(&self, key: &K) -> Option<V> {
57        self.inner.read().await.get(key).cloned()
58    }
59    async fn set(&self, key: K, value: V) {
60        self.inner.write().await.insert(key, value);
61    }
62}
63
64type KeyFn<I, K> = dyn Fn(&I) -> K + Send + Sync;
65
66/// Caches results of the inner runnable, keyed by a user-supplied
67/// `key_fn(&I)`. On miss, runs the inner and stores the output.
68pub struct Cache<R, I, O, K, B> {
69    inner: R,
70    backend: Arc<B>,
71    key_fn: Arc<KeyFn<I, K>>,
72    _phantom: PhantomData<fn(I) -> O>,
73}
74
75impl<R, I, O, K, B> Cache<R, I, O, K, B>
76where
77    R: Runnable<I, O>,
78    I: Send + 'static,
79    O: Send + Sync + Clone + 'static,
80    K: Send + Sync + 'static,
81    B: CacheBackend<K, O>,
82{
83    /// Build a cache wrapper.
84    pub fn new<F>(inner: R, backend: Arc<B>, key_fn: F) -> Self
85    where
86        F: Fn(&I) -> K + Send + Sync + 'static,
87    {
88        Self {
89            inner,
90            backend,
91            key_fn: Arc::new(key_fn),
92            _phantom: PhantomData,
93        }
94    }
95}
96
97#[async_trait]
98impl<R, I, O, K, B> Runnable<I, O> for Cache<R, I, O, K, B>
99where
100    R: Runnable<I, O>,
101    I: Send + 'static,
102    O: Clone + Send + Sync + 'static,
103    K: Send + Sync + 'static,
104    B: CacheBackend<K, O> + 'static,
105{
106    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
107        let key = (self.key_fn)(&input);
108        if let Some(hit) = self.backend.get(&key).await {
109            return Ok(hit);
110        }
111        let out = self.inner.invoke(input, config).await?;
112        self.backend.set(key, out.clone()).await;
113        Ok(out)
114    }
115    fn name(&self) -> &str {
116        "Cache"
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use std::sync::atomic::{AtomicU32, Ordering};
124
125    struct Counter {
126        calls: Arc<AtomicU32>,
127    }
128
129    #[async_trait]
130    impl Runnable<u32, u32> for Counter {
131        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
132            self.calls.fetch_add(1, Ordering::SeqCst);
133            Ok(input * 10)
134        }
135    }
136
137    #[tokio::test]
138    async fn caches_on_repeated_input() {
139        let calls = Arc::new(AtomicU32::new(0));
140        let backend = Arc::new(MemoryCache::<u32, u32>::new());
141        let cached = Cache::new(
142            Counter {
143                calls: calls.clone(),
144            },
145            backend,
146            |i: &u32| *i,
147        );
148        let cfg = RunnableConfig::default();
149        assert_eq!(cached.invoke(3, cfg.clone()).await.unwrap(), 30);
150        assert_eq!(cached.invoke(3, cfg.clone()).await.unwrap(), 30);
151        assert_eq!(cached.invoke(4, cfg.clone()).await.unwrap(), 40);
152        assert_eq!(calls.load(Ordering::SeqCst), 2); // 3 hit cache once, 4 ran
153    }
154}