Skip to main content

ai_agent/utils/
memoize.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/utils/memoize.ts
2//! Memoization utilities with TTL and LRU support.
3
4use std::collections::HashMap;
5use std::hash::Hash;
6use std::sync::{Arc, Mutex};
7use std::time::Instant;
8
9/// Cache entry with timestamp
10#[derive(Clone)]
11struct CacheEntry<T> {
12    value: T,
13    timestamp: Instant,
14    refreshing: bool,
15}
16
17/// Creates a memoized function that returns cached values while refreshing in parallel.
18/// This implements a write-through cache pattern:
19/// - If cache is fresh, return immediately
20/// - If cache is stale, return the stale value but refresh it in the background
21/// - If no cache exists, block and compute the value
22pub struct MemoizedFunction<Args, Result> {
23    f: Arc<dyn Fn(Args) -> Result + Send + Sync>,
24    cache: Arc<Mutex<HashMap<Args, CacheEntry<Result>>>>,
25    cache_lifetime_ms: u64,
26}
27
28impl<Args, Result> MemoizedFunction<Args, Result>
29where
30    Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
31    Result: Clone + Send + 'static,
32{
33    pub fn new(
34        f: impl Fn(Args) -> Result + Send + Sync + 'static,
35        cache_lifetime_ms: u64,
36    ) -> Self {
37        Self {
38            f: Arc::new(f),
39            cache: Arc::new(Mutex::new(HashMap::new())),
40            cache_lifetime_ms,
41        }
42    }
43
44    pub fn call(&self, args: Args) -> Result {
45        let mut cache_guard = self.cache.lock().unwrap();
46        let now = Instant::now();
47
48        if let Some(cached) = cache_guard.get(&args) {
49            let age = now.duration_since(cached.timestamp).as_millis() as u64;
50
51            if age <= self.cache_lifetime_ms {
52                return cached.value.clone();
53            }
54        }
55
56        let f = Arc::clone(&self.f);
57        drop(cache_guard);
58
59        let new_value = f(args.clone());
60
61        let mut cache_guard = self.cache.lock().unwrap();
62        cache_guard.insert(
63            args,
64            CacheEntry {
65                value: new_value.clone(),
66                timestamp: now,
67                refreshing: false,
68            },
69        );
70
71        new_value
72    }
73
74    pub fn clear(&self) {
75        let mut cache_guard = self.cache.lock().unwrap();
76        cache_guard.clear();
77    }
78}
79
80/// Creates a memoized function that returns cached values while refreshing in parallel.
81pub fn memoize_with_ttl<Args, Result>(
82    f: impl Fn(Args) -> Result + Send + Sync + 'static,
83    cache_lifetime_ms: u64,
84) -> MemoizedFunction<Args, Result>
85where
86    Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
87    Result: Clone + Send + 'static,
88{
89    MemoizedFunction::new(f, cache_lifetime_ms)
90}
91
92// ============================================================================
93// Async memoization
94// ============================================================================
95
96/// Cache entry for async memoization, with unique id to detect concurrent clear/replace.
97struct AsyncCacheEntry<T> {
98    value: T,
99    timestamp: Instant,
100    refreshing: bool,
101    id: u64,
102}
103
104impl<T> AsyncCacheEntry<T> {
105    fn new(value: T, id: u64) -> Self {
106        Self {
107            value,
108            timestamp: Instant::now(),
109            refreshing: false,
110            id,
111        }
112    }
113}
114
115struct AsyncInner<Args, Result> {
116    f: Arc<
117        dyn Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
118            + Send
119            + Sync,
120    >,
121    cache: HashMap<Args, AsyncCacheEntry<Result>>,
122    /// In-flight cold-miss dedup: shared slot + notify when result arrives
123    in_flight:
124        HashMap<Args, (Arc<Mutex<Option<Result>>>, Arc<tokio::sync::Notify>)>,
125    cache_lifetime_ms: u64,
126    next_id: u64,
127}
128
129/// Async memoized function with background refresh and cold-miss dedup.
130pub struct AsyncMemoized<Args, Result> {
131    inner: Arc<Mutex<AsyncInner<Args, Result>>>,
132}
133
134impl<Args, Result> AsyncMemoized<Args, Result>
135where
136    Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
137    Result: Clone + Send + 'static,
138{
139    pub fn new(
140        f: impl Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
141            + Send
142            + Sync
143            + 'static,
144        cache_lifetime_ms: u64,
145    ) -> Self {
146        Self {
147            inner: Arc::new(Mutex::new(AsyncInner {
148                f: Arc::new(f),
149                cache: HashMap::new(),
150                in_flight: HashMap::new(),
151                cache_lifetime_ms,
152                next_id: 1,
153            })),
154        }
155    }
156
157    pub async fn call(&self, args: Args) -> Result {
158        let now = Instant::now();
159
160        // 1. Check for in-flight dedup - another caller already computing
161        let maybe_slot_notify = {
162            let inner = self.inner.lock().unwrap();
163            inner.in_flight.get(&args).map(|(s, n)| (s.clone(), n.clone()))
164        };
165        if let Some((slot, notify)) = maybe_slot_notify {
166            notify.notified().await;
167            if let Some(ref result) = *slot.lock().unwrap() {
168                return result.clone();
169            }
170        }
171
172        // 2. Check cache
173        {
174            let mut inner = self.inner.lock().unwrap();
175            if let Some(cached) = inner.cache.get(&args) {
176                let age = now.duration_since(cached.timestamp).as_millis() as u64;
177
178                if age <= inner.cache_lifetime_ms {
179                    return cached.value.clone();
180                }
181
182                // Stale - return stale value and refresh in background
183                if !cached.refreshing {
184                    let f = inner.f.clone();
185                    let inner_arc = self.inner.clone();
186                    let stale_args = args.clone();
187                    let stale_id = cached.id;
188
189                    tokio::spawn(async move {
190                        let new_value = f(stale_args.clone()).await;
191                        let mut c = inner_arc.lock().unwrap();
192                        if let Some(entry) = c.cache.get(&stale_args) {
193                            if entry.id == stale_id {
194                                let id = c.next_id + 1;
195                                c.next_id = id;
196                                c.cache
197                                    .insert(stale_args, AsyncCacheEntry::new(new_value, id));
198                            }
199                        }
200                    });
201                }
202
203                return cached.value.clone();
204            }
205        }
206
207        // 3. Cold miss - spawn task and wait
208        let (slot, notify) = (
209            Arc::new(Mutex::new(None)),
210            Arc::new(tokio::sync::Notify::new()),
211        );
212        {
213            let mut inner = self.inner.lock().unwrap();
214            inner.in_flight
215                .insert(args.clone(), (slot.clone(), notify.clone()));
216        }
217
218        let f = self.inner.lock().unwrap().f.clone();
219        let inner_arc = self.inner.clone();
220        let cold_args = args.clone();
221        let result = f(args).await;
222
223        // Store in shared slot for dedup
224        {
225            let mut s = slot.lock().unwrap();
226            *s = Some(result.clone());
227        }
228        notify.notify_one();
229
230        // Remove in-flight and store in cache
231        {
232            let mut c = inner_arc.lock().unwrap();
233            c.in_flight.remove(&cold_args);
234            let id = c.next_id + 1;
235            c.next_id = id;
236            c.cache
237                .insert(cold_args, AsyncCacheEntry::new(result.clone(), id));
238        }
239
240        result
241    }
242
243    pub fn clear(&self) {
244        let mut inner = self.inner.lock().unwrap();
245        inner.cache.clear();
246        inner.in_flight.clear();
247    }
248}
249
250/// Creates a memoized async function that returns cached values while refreshing in parallel.
251pub fn memoize_with_ttl_async<Args, Result>(
252    f: impl Fn(Args) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result> + Send>>
253        + Send
254        + Sync
255        + 'static,
256    cache_lifetime_ms: u64,
257) -> AsyncMemoized<Args, Result>
258where
259    Args: Clone + std::fmt::Debug + Hash + Eq + Send + 'static,
260    Result: Clone + Send + 'static,
261{
262    AsyncMemoized::new(f, cache_lifetime_ms)
263}
264
265// ============================================================================
266// LRU memoization
267// ============================================================================
268
269/// Creates a memoized function with LRU (Least Recently Used) eviction policy.
270/// This prevents unbounded memory growth by evicting the least recently used entries
271/// when the cache reaches its maximum size.
272pub struct LruMemoized<Args, K, Result> {
273    f: Arc<dyn Fn(Args) -> Result + Send + Sync + 'static>,
274    cache: Arc<Mutex<HashMap<K, Result>>>,
275    order: Arc<Mutex<Vec<K>>>,
276    max_size: usize,
277    key_fn: Arc<dyn Fn(&Args) -> K + Send + Sync + 'static>,
278}
279
280impl<Args, K, Result> LruMemoized<Args, K, Result>
281where
282    Args: std::fmt::Debug + Hash + Eq + Clone,
283    Result: Clone,
284    K: Hash + Eq + Clone,
285{
286    pub fn new(
287        f: impl Fn(Args) -> Result + Send + Sync + 'static,
288        key_fn: impl Fn(&Args) -> K + Send + Sync + 'static,
289        max_cache_size: usize,
290    ) -> Self {
291        Self {
292            f: Arc::new(f),
293            cache: Arc::new(Mutex::new(HashMap::new())),
294            order: Arc::new(Mutex::new(Vec::new())),
295            max_size: max_cache_size,
296            key_fn: Arc::new(key_fn),
297        }
298    }
299
300    pub fn call(&self, args: Args) -> Result {
301        let key = (self.key_fn)(&args);
302        let mut cache_guard = self.cache.lock().unwrap();
303        let mut order_guard = self.order.lock().unwrap();
304
305        if let Some(value) = cache_guard.get(&key) {
306            if let Some(pos) = order_guard.iter().position(|k| k == &key) {
307                order_guard.remove(pos);
308                order_guard.push(key.clone());
309            }
310            return value.clone();
311        }
312
313        let result = (self.f)(args.clone());
314
315        if cache_guard.len() >= self.max_size && !order_guard.is_empty() {
316            if let Some(lru_key) = order_guard.first().cloned() {
317                cache_guard.remove(&lru_key);
318                order_guard.remove(0);
319            }
320        }
321
322        cache_guard.insert(key.clone(), result.clone());
323        order_guard.push(key);
324
325        result
326    }
327
328    pub fn clear(&self) {
329        let mut cache_guard = self.cache.lock().unwrap();
330        let mut order_guard = self.order.lock().unwrap();
331        cache_guard.clear();
332        order_guard.clear();
333    }
334
335    pub fn size(&self) -> usize {
336        self.cache.lock().unwrap().len()
337    }
338
339    pub fn delete(&self, key: &K) -> bool {
340        let mut cache_guard = self.cache.lock().unwrap();
341        let mut order_guard = self.order.lock().unwrap();
342        if let Some(pos) = order_guard.iter().position(|k| k == key) {
343            order_guard.remove(pos);
344        }
345        cache_guard.remove(key).is_some()
346    }
347
348    pub fn get(&self, key: &K) -> Option<Result> {
349        self.cache.lock().unwrap().get(key).cloned()
350    }
351
352    pub fn has(&self, key: &K) -> bool {
353        self.cache.lock().unwrap().contains_key(key)
354    }
355}
356
357/// Creates a memoized function with LRU (Least Recently Used) eviction policy.
358pub fn memoize_with_lru<Args, K, Result>(
359    f: impl Fn(Args) -> Result + Send + Sync + 'static,
360    key_fn: impl Fn(&Args) -> K + Send + Sync + 'static,
361    max_cache_size: usize,
362) -> LruMemoized<Args, K, Result>
363where
364    Args: std::fmt::Debug + Hash + Eq + Clone,
365    Result: Clone,
366    K: Hash + Eq + Clone,
367{
368    LruMemoized::new(f, key_fn, max_cache_size)
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374
375    #[test]
376    fn test_memoize_with_ttl_basic() {
377        let counter = Arc::new(Mutex::new(0));
378        let f = move |_x: i32| {
379            let mut c = counter.lock().unwrap();
380            *c += 1;
381            *c
382        };
383
384        let memoized = memoize_with_ttl(f, 1000);
385
386        let result1 = memoized.call(1);
387        assert_eq!(result1, 1);
388
389        let result2 = memoized.call(1);
390        assert_eq!(result2, 1);
391    }
392
393    #[test]
394    fn test_memoize_with_lru_basic() {
395        let f = |x: i32| x * 2;
396
397        let memoized = memoize_with_lru(f, |&x: &i32| x, 2);
398
399        assert_eq!(memoized.call(1), 2);
400        assert_eq!(memoized.call(2), 4);
401    }
402
403    #[test]
404    fn test_lru_eviction() {
405        let f = |x: i32| x * 2;
406
407        let memoized = memoize_with_lru(f, |&x: &i32| x, 2);
408
409        assert_eq!(memoized.call(1), 2);
410        assert_eq!(memoized.call(2), 4);
411        assert_eq!(memoized.call(3), 6);
412
413        assert!(!memoized.has(&1));
414    }
415
416    #[tokio::test]
417    async fn test_async_memoize_basic() {
418        let counter = Arc::new(Mutex::new(0));
419        let counter2 = counter.clone();
420        let f = move |x: i32| {
421            let counter = counter2.clone();
422            let fut = Box::pin(async move {
423                let mut c = counter.lock().unwrap();
424                *c += 1;
425                x * 2
426            });
427            fut as std::pin::Pin<Box<dyn std::future::Future<Output = i32> + Send>>
428        };
429
430        let memoized = memoize_with_ttl_async(f, 1000);
431
432        let r1 = memoized.call(1).await;
433        assert_eq!(r1, 2);
434
435        let r2 = memoized.call(1).await;
436        assert_eq!(r2, 2);
437
438        // Should still be 1 (cached)
439        assert_eq!(*counter.lock().unwrap(), 1);
440    }
441
442    #[tokio::test]
443    async fn test_async_memoize_clear() {
444        let counter = Arc::new(Mutex::new(0));
445        let counter2 = counter.clone();
446        let f = move |x: i32| {
447            let counter = counter2.clone();
448            let fut = Box::pin(async move {
449                let mut c = counter.lock().unwrap();
450                *c += 1;
451                x * 2
452            });
453            fut as std::pin::Pin<Box<dyn std::future::Future<Output = i32> + Send>>
454        };
455
456        let memoized = memoize_with_ttl_async(f, 1000);
457        assert_eq!(memoized.call(1).await, 2);
458        memoized.clear();
459        assert_eq!(memoized.call(1).await, 2);
460
461        // Should be 2 after clear
462        assert_eq!(*counter.lock().unwrap(), 2);
463    }
464}