1use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9#[derive(Clone)]
11struct CacheEntry<T> {
12 value: T,
13 timestamp: Instant,
14 refreshing: bool,
15}
16
17pub struct MemoizedFunction<Args, Result> {
23 f: Arc<dyn Fn(Args) -> Result + Send + Sync>,
24 cache: Arc<Mutex<HashMap<Args, CacheEntry<Result>>>>,
25 cache_lifetime_ms: u64,
26}
27
28impl<Args, Result> MemoizedFunction<Args, Result>
29where
30 Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
31 Result: Clone + Send + 'static,
32{
33 pub fn new(
34 f: impl Fn(Args) -> Result + Send + Sync + 'static,
35 cache_lifetime_ms: u64,
36 ) -> Self {
37 Self {
38 f: Arc::new(f),
39 cache: Arc::new(Mutex::new(HashMap::new())),
40 cache_lifetime_ms,
41 }
42 }
43
44 pub fn call(&self, args: Args) -> Result {
45 let mut cache_guard = self.cache.lock().unwrap();
46 let now = Instant::now();
47
48 if let Some(cached) = cache_guard.get(&args) {
49 let age = now.duration_since(cached.timestamp).as_millis() as u64;
50
51 if age <= self.cache_lifetime_ms {
52 return cached.value.clone();
53 }
54 }
55
56 let f = Arc::clone(&self.f);
57 drop(cache_guard);
58
59 let new_value = f(args.clone());
60
61 let mut cache_guard = self.cache.lock().unwrap();
62 cache_guard.insert(
63 args,
64 CacheEntry {
65 value: new_value.clone(),
66 timestamp: now,
67 refreshing: false,
68 },
69 );
70
71 new_value
72 }
73
74 pub fn clear(&self) {
75 let mut cache_guard = self.cache.lock().unwrap();
76 cache_guard.clear();
77 }
78}
79
80pub fn memoize_with_ttl<Args, Result>(
82 f: impl Fn(Args) -> Result + Send + Sync + 'static,
83 cache_lifetime_ms: u64,
84) -> MemoizedFunction<Args, Result>
85where
86 Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
87 Result: Clone + Send + 'static,
88{
89 MemoizedFunction::new(f, cache_lifetime_ms)
90}
91
92struct AsyncCacheEntry<T> {
98 value: T,
99 timestamp: Instant,
100 refreshing: bool,
101 id: u64,
102}
103
104impl<T> AsyncCacheEntry<T> {
105 fn new(value: T, id: u64) -> Self {
106 Self {
107 value,
108 timestamp: Instant::now(),
109 refreshing: false,
110 id,
111 }
112 }
113}
114
115struct AsyncInner<Args, Result> {
116 f: Arc<
117 dyn Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
118 + Send
119 + Sync,
120 >,
121 cache: HashMap<Args, AsyncCacheEntry<Result>>,
122 in_flight:
124 HashMap<Args, (Arc<Mutex<Option<Result>>>, Arc<tokio::sync::Notify>)>,
125 cache_lifetime_ms: u64,
126 next_id: u64,
127}
128
129pub struct AsyncMemoized<Args, Result> {
131 inner: Arc<Mutex<AsyncInner<Args, Result>>>,
132}
133
134impl<Args, Result> AsyncMemoized<Args, Result>
135where
136 Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
137 Result: Clone + Send + 'static,
138{
139 pub fn new(
140 f: impl Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
141 + Send
142 + Sync
143 + 'static,
144 cache_lifetime_ms: u64,
145 ) -> Self {
146 Self {
147 inner: Arc::new(Mutex::new(AsyncInner {
148 f: Arc::new(f),
149 cache: HashMap::new(),
150 in_flight: HashMap::new(),
151 cache_lifetime_ms,
152 next_id: 1,
153 })),
154 }
155 }
156
157 pub async fn call(&self, args: Args) -> Result {
158 let now = Instant::now();
159
160 let maybe_slot_notify = {
162 let inner = self.inner.lock().unwrap();
163 inner.in_flight.get(&args).map(|(s, n)| (s.clone(), n.clone()))
164 };
165 if let Some((slot, notify)) = maybe_slot_notify {
166 notify.notified().await;
167 if let Some(ref result) = *slot.lock().unwrap() {
168 return result.clone();
169 }
170 }
171
172 {
174 let mut inner = self.inner.lock().unwrap();
175 if let Some(cached) = inner.cache.get(&args) {
176 let age = now.duration_since(cached.timestamp).as_millis() as u64;
177
178 if age <= inner.cache_lifetime_ms {
179 return cached.value.clone();
180 }
181
182 if !cached.refreshing {
184 let f = inner.f.clone();
185 let inner_arc = self.inner.clone();
186 let stale_args = args.clone();
187 let stale_id = cached.id;
188
189 tokio::spawn(async move {
190 let new_value = f(stale_args.clone()).await;
191 let mut c = inner_arc.lock().unwrap();
192 if let Some(entry) = c.cache.get(&stale_args) {
193 if entry.id == stale_id {
194 let id = c.next_id + 1;
195 c.next_id = id;
196 c.cache
197 .insert(stale_args, AsyncCacheEntry::new(new_value, id));
198 }
199 }
200 });
201 }
202
203 return cached.value.clone();
204 }
205 }
206
207 let (slot, notify) = (
209 Arc::new(Mutex::new(None)),
210 Arc::new(tokio::sync::Notify::new()),
211 );
212 {
213 let mut inner = self.inner.lock().unwrap();
214 inner.in_flight
215 .insert(args.clone(), (slot.clone(), notify.clone()));
216 }
217
218 let f = self.inner.lock().unwrap().f.clone();
219 let inner_arc = self.inner.clone();
220 let cold_args = args.clone();
221 let result = f(args).await;
222
223 {
225 let mut s = slot.lock().unwrap();
226 *s = Some(result.clone());
227 }
228 notify.notify_one();
229
230 {
232 let mut c = inner_arc.lock().unwrap();
233 c.in_flight.remove(&cold_args);
234 let id = c.next_id + 1;
235 c.next_id = id;
236 c.cache
237 .insert(cold_args, AsyncCacheEntry::new(result.clone(), id));
238 }
239
240 result
241 }
242
243 pub fn clear(&self) {
244 let mut inner = self.inner.lock().unwrap();
245 inner.cache.clear();
246 inner.in_flight.clear();
247 }
248}
249
250pub fn memoize_with_ttl_async<Args, Result>(
252 f: impl Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
253 + Send
254 + Sync
255 + 'static,
256 cache_lifetime_ms: u64,
257) -> AsyncMemoized<Args, Result>
258where
259 Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
260 Result: Clone + Send + 'static,
261{
262 AsyncMemoized::new(f, cache_lifetime_ms)
263}
264
265pub struct LruMemoized<Args, K, Result> {
273 f: Arc<dyn Fn(Args) -> Result + Send + Sync + 'static>,
274 cache: Arc<Mutex<HashMap<K, Result>>>,
275 order: Arc<Mutex<Vec<K>>>,
276 max_size: usize,
277 key_fn: Arc<dyn Fn(&Args) -> K + Send + Sync + 'static>,
278}
279
280impl<Args, K, Result> LruMemoized<Args, K, Result>
281where
282 Args: std::fmt::Debug + Hash + Eq + Clone,
283 Result: Clone,
284 K: Hash + Eq + Clone,
285{
286 pub fn new(
287 f: impl Fn(Args) -> Result + Send + Sync + 'static,
288 key_fn: impl Fn(&Args) -> K + Send + Sync + 'static,
289 max_cache_size: usize,
290 ) -> Self {
291 Self {
292 f: Arc::new(f),
293 cache: Arc::new(Mutex::new(HashMap::new())),
294 order: Arc::new(Mutex::new(Vec::new())),
295 max_size: max_cache_size,
296 key_fn: Arc::new(key_fn),
297 }
298 }
299
300 pub fn call(&self, args: Args) -> Result {
301 let key = (self.key_fn)(&args);
302 let mut cache_guard = self.cache.lock().unwrap();
303 let mut order_guard = self.order.lock().unwrap();
304
305 if let Some(value) = cache_guard.get(&key) {
306 if let Some(pos) = order_guard.iter().position(|k| k == &key) {
307 order_guard.remove(pos);
308 order_guard.push(key.clone());
309 }
310 return value.clone();
311 }
312
313 let result = (self.f)(args.clone());
314
315 if cache_guard.len() >= self.max_size && !order_guard.is_empty() {
316 if let Some(lru_key) = order_guard.first().cloned() {
317 cache_guard.remove(&lru_key);
318 order_guard.remove(0);
319 }
320 }
321
322 cache_guard.insert(key.clone(), result.clone());
323 order_guard.push(key);
324
325 result
326 }
327
328 pub fn clear(&self) {
329 let mut cache_guard = self.cache.lock().unwrap();
330 let mut order_guard = self.order.lock().unwrap();
331 cache_guard.clear();
332 order_guard.clear();
333 }
334
335 pub fn size(&self) -> usize {
336 self.cache.lock().unwrap().len()
337 }
338
339 pub fn delete(&self, key: &K) -> bool {
340 let mut cache_guard = self.cache.lock().unwrap();
341 let mut order_guard = self.order.lock().unwrap();
342 if let Some(pos) = order_guard.iter().position(|k| k == key) {
343 order_guard.remove(pos);
344 }
345 cache_guard.remove(key).is_some()
346 }
347
348 pub fn get(&self, key: &K) -> Option<Result> {
349 self.cache.lock().unwrap().get(key).cloned()
350 }
351
352 pub fn has(&self, key: &K) -> bool {
353 self.cache.lock().unwrap().contains_key(key)
354 }
355}
356
357pub fn memoize_with_lru<Args, K, Result>(
359 f: impl Fn(Args) -> Result + Send + Sync + 'static,
360 key_fn: impl Fn(&Args) -> K + Send + Sync + 'static,
361 max_cache_size: usize,
362) -> LruMemoized<Args, K, Result>
363where
364 Args: std::fmt::Debug + Hash + Eq + Clone,
365 Result: Clone,
366 K: Hash + Eq + Clone,
367{
368 LruMemoized::new(f, key_fn, max_cache_size)
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374
375 #[test]
376 fn test_memoize_with_ttl_basic() {
377 let counter = Arc::new(Mutex::new(0));
378 let f = move |_x: i32| {
379 let mut c = counter.lock().unwrap();
380 *c += 1;
381 *c
382 };
383
384 let memoized = memoize_with_ttl(f, 1000);
385
386 let result1 = memoized.call(1);
387 assert_eq!(result1, 1);
388
389 let result2 = memoized.call(1);
390 assert_eq!(result2, 1);
391 }
392
393 #[test]
394 fn test_memoize_with_lru_basic() {
395 let f = |x: i32| x * 2;
396
397 let memoized = memoize_with_lru(f, |&x: &i32| x, 2);
398
399 assert_eq!(memoized.call(1), 2);
400 assert_eq!(memoized.call(2), 4);
401 }
402
403 #[test]
404 fn test_lru_eviction() {
405 let f = |x: i32| x * 2;
406
407 let memoized = memoize_with_lru(f, |&x: &i32| x, 2);
408
409 assert_eq!(memoized.call(1), 2);
410 assert_eq!(memoized.call(2), 4);
411 assert_eq!(memoized.call(3), 6);
412
413 assert!(!memoized.has(&1));
414 }
415
416 #[tokio::test]
417 async fn test_async_memoize_basic() {
418 let counter = Arc::new(Mutex::new(0));
419 let counter2 = counter.clone();
420 let f = move |x: i32| {
421 let counter = counter2.clone();
422 let fut = Box::pin(async move {
423 let mut c = counter.lock().unwrap();
424 *c += 1;
425 x * 2
426 });
427 fut as std::pin::Pin<Box<dyn std::future::Future<Output = i32> + Send>>
428 };
429
430 let memoized = memoize_with_ttl_async(f, 1000);
431
432 let r1 = memoized.call(1).await;
433 assert_eq!(r1, 2);
434
435 let r2 = memoized.call(1).await;
436 assert_eq!(r2, 2);
437
438 assert_eq!(*counter.lock().unwrap(), 1);
440 }
441
442 #[tokio::test]
443 async fn test_async_memoize_clear() {
444 let counter = Arc::new(Mutex::new(0));
445 let counter2 = counter.clone();
446 let f = move |x: i32| {
447 let counter = counter2.clone();
448 let fut = Box::pin(async move {
449 let mut c = counter.lock().unwrap();
450 *c += 1;
451 x * 2
452 });
453 fut as std::pin::Pin<Box<dyn std::future::Future<Output = i32> + Send>>
454 };
455
456 let memoized = memoize_with_ttl_async(f, 1000);
457 assert_eq!(memoized.call(1).await, 2);
458 memoized.clear();
459 assert_eq!(memoized.call(1).await, 2);
460
461 assert_eq!(*counter.lock().unwrap(), 2);
463 }
464}