1use std::collections::HashMap;
4use std::hash::Hash;
5use std::marker::PhantomData;
6use std::sync::Arc;
7
8use async_trait::async_trait;
9use tokio::sync::RwLock;
10
11use crate::runnable::{Runnable, RunnableConfig};
12use crate::Result;
13
14#[async_trait]
19pub trait CacheBackend<K, V>: Send + Sync
20where
21 K: Send + Sync + 'static,
22 V: Send + Sync + 'static,
23{
24 async fn get(&self, key: &K) -> Option<V>;
26 async fn set(&self, key: K, value: V);
28}
29
30pub struct MemoryCache<K, V> {
32 inner: RwLock<HashMap<K, V>>,
33}
34
35impl<K, V> Default for MemoryCache<K, V> {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl<K, V> MemoryCache<K, V> {
42 pub fn new() -> Self {
44 Self {
45 inner: RwLock::new(HashMap::new()),
46 }
47 }
48}
49
50#[async_trait]
51impl<K, V> CacheBackend<K, V> for MemoryCache<K, V>
52where
53 K: Hash + Eq + Send + Sync + Clone + 'static,
54 V: Clone + Send + Sync + 'static,
55{
56 async fn get(&self, key: &K) -> Option<V> {
57 self.inner.read().await.get(key).cloned()
58 }
59 async fn set(&self, key: K, value: V) {
60 self.inner.write().await.insert(key, value);
61 }
62}
63
64type KeyFn<I, K> = dyn Fn(&I) -> K + Send + Sync;
65
66pub struct Cache<R, I, O, K, B> {
69 inner: R,
70 backend: Arc<B>,
71 key_fn: Arc<KeyFn<I, K>>,
72 _phantom: PhantomData<fn(I) -> O>,
73}
74
75impl<R, I, O, K, B> Cache<R, I, O, K, B>
76where
77 R: Runnable<I, O>,
78 I: Send + 'static,
79 O: Send + Sync + Clone + 'static,
80 K: Send + Sync + 'static,
81 B: CacheBackend<K, O>,
82{
83 pub fn new<F>(inner: R, backend: Arc<B>, key_fn: F) -> Self
85 where
86 F: Fn(&I) -> K + Send + Sync + 'static,
87 {
88 Self {
89 inner,
90 backend,
91 key_fn: Arc::new(key_fn),
92 _phantom: PhantomData,
93 }
94 }
95}
96
97#[async_trait]
98impl<R, I, O, K, B> Runnable<I, O> for Cache<R, I, O, K, B>
99where
100 R: Runnable<I, O>,
101 I: Send + 'static,
102 O: Clone + Send + Sync + 'static,
103 K: Send + Sync + 'static,
104 B: CacheBackend<K, O> + 'static,
105{
106 async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
107 let key = (self.key_fn)(&input);
108 if let Some(hit) = self.backend.get(&key).await {
109 return Ok(hit);
110 }
111 let out = self.inner.invoke(input, config).await?;
112 self.backend.set(key, out.clone()).await;
113 Ok(out)
114 }
115 fn name(&self) -> &str {
116 "Cache"
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123 use std::sync::atomic::{AtomicU32, Ordering};
124
125 struct Counter {
126 calls: Arc<AtomicU32>,
127 }
128
129 #[async_trait]
130 impl Runnable<u32, u32> for Counter {
131 async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
132 self.calls.fetch_add(1, Ordering::SeqCst);
133 Ok(input * 10)
134 }
135 }
136
137 #[tokio::test]
138 async fn caches_on_repeated_input() {
139 let calls = Arc::new(AtomicU32::new(0));
140 let backend = Arc::new(MemoryCache::<u32, u32>::new());
141 let cached = Cache::new(
142 Counter {
143 calls: calls.clone(),
144 },
145 backend,
146 |i: &u32| *i,
147 );
148 let cfg = RunnableConfig::default();
149 assert_eq!(cached.invoke(3, cfg.clone()).await.unwrap(), 30);
150 assert_eq!(cached.invoke(3, cfg.clone()).await.unwrap(), 30);
151 assert_eq!(cached.invoke(4, cfg.clone()).await.unwrap(), 40);
152 assert_eq!(calls.load(Ordering::SeqCst), 2); }
154}