Skip to main content

leenfetch_core/
cache.rs

1use std::sync::Mutex;
2use std::time::{Duration, Instant};
3
4const DEFAULT_TTL_SECS: u64 = 5;
5
6pub struct Cache<T> {
7    entries: Mutex<std::collections::HashMap<String, CacheEntry<T>>>,
8    ttl: Duration,
9}
10
11struct CacheEntry<T> {
12    value: T,
13    expires_at: Instant,
14}
15
16impl<T> Cache<T> {
17    pub fn new(ttl_secs: u64) -> Self {
18        Self {
19            entries: Mutex::new(std::collections::HashMap::new()),
20            ttl: Duration::from_secs(ttl_secs),
21        }
22    }
23
24    pub fn default_ttl() -> Self {
25        Self::new(DEFAULT_TTL_SECS)
26    }
27
28    pub fn get_or_compute<F>(&self, key: &str, compute: F) -> T
29    where
30        F: FnOnce() -> T,
31        T: Clone,
32    {
33        let now = Instant::now();
34
35        // Try to get from cache first (without lock)
36        {
37            let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
38
39            if let Some(entry) = entries.get(key) {
40                if entry.expires_at > now {
41                    return entry.value.clone();
42                }
43                // Entry expired
44                entries.remove(key);
45            }
46        }
47
48        // Compute new value (outside of lock)
49        let value = compute();
50
51        // Store in cache
52        {
53            let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
54            let expires_at = now + self.ttl;
55            entries.insert(
56                key.to_string(),
57                CacheEntry {
58                    value: value.clone(),
59                    expires_at,
60                },
61            );
62        }
63
64        value
65    }
66
67    pub fn invalidate(&self, key: &str) {
68        self.entries.lock().unwrap_or_else(|e| e.into_inner()).remove(key);
69    }
70
71    pub fn clear(&self) {
72        self.entries.lock().unwrap_or_else(|e| e.into_inner()).clear();
73    }
74}
75
76impl<T> Default for Cache<T> {
77    fn default() -> Self {
78        Self::default_ttl()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use std::thread;
86    use std::time::Duration as StdDuration;
87
88    #[test]
89    fn test_cache_basic() {
90        let cache = Cache::new(1);
91        let call_count = std::cell::RefCell::new(0);
92
93        let get_value = || {
94            *call_count.borrow_mut() += 1;
95            42
96        };
97
98        // First call - computes
99        let v1 = cache.get_or_compute("key", get_value);
100        assert_eq!(v1, 42);
101        assert_eq!(*call_count.borrow(), 1);
102
103        // Second call - should use cache
104        let v2 = cache.get_or_compute("key", get_value);
105        assert_eq!(v2, 42);
106        assert_eq!(*call_count.borrow(), 1);
107    }
108
109    #[test]
110    fn test_cache_expiry() {
111        let cache = Cache::new(1); // 1 second TTL
112        let call_count = std::cell::RefCell::new(0);
113
114        let get_value = || {
115            *call_count.borrow_mut() += 1;
116            42
117        };
118
119        // First call
120        let v1 = cache.get_or_compute("key", get_value);
121        assert_eq!(v1, 42);
122        assert_eq!(*call_count.borrow(), 1);
123
124        // Wait for expiry
125        thread::sleep(StdDuration::from_millis(1100));
126
127        // Should compute again
128        let v2 = cache.get_or_compute("key", get_value);
129        assert_eq!(v2, 42);
130        assert_eq!(*call_count.borrow(), 2);
131    }
132}