Skip to main content

leenfetch_core/
cache.rs

1use std::sync::Mutex;
2use std::time::{Duration, Instant};
3
4const DEFAULT_TTL_SECS: u64 = 5;
5
6pub struct Cache<T> {
7    entries: Mutex<std::collections::HashMap<String, CacheEntry<T>>>,
8    ttl: Duration,
9}
10
11struct CacheEntry<T> {
12    value: T,
13    expires_at: Instant,
14}
15
16impl<T> Cache<T> {
17    pub fn new(ttl_secs: u64) -> Self {
18        Self {
19            entries: Mutex::new(std::collections::HashMap::new()),
20            ttl: Duration::from_secs(ttl_secs),
21        }
22    }
23
24    pub fn default_ttl() -> Self {
25        Self::new(DEFAULT_TTL_SECS)
26    }
27
28    pub fn get_or_compute<F>(&self, key: &str, compute: F) -> T
29    where
30        F: FnOnce() -> T,
31        T: Clone,
32    {
33        let now = Instant::now();
34
35        // Try to get from cache first (without lock)
36        {
37            let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
38
39            if let Some(entry) = entries.get(key) {
40                if entry.expires_at > now {
41                    return entry.value.clone();
42                }
43                // Entry expired
44                entries.remove(key);
45            }
46        }
47
48        // Compute new value (outside of lock)
49        let value = compute();
50
51        // Store in cache
52        {
53            let entries = &mut *self.entries.lock().unwrap_or_else(|e| e.into_inner());
54            let expires_at = now + self.ttl;
55            entries.insert(
56                key.to_string(),
57                CacheEntry {
58                    value: value.clone(),
59                    expires_at,
60                },
61            );
62        }
63
64        value
65    }
66
67    pub fn invalidate(&self, key: &str) {
68        self.entries
69            .lock()
70            .unwrap_or_else(|e| e.into_inner())
71            .remove(key);
72    }
73
74    pub fn clear(&self) {
75        self.entries
76            .lock()
77            .unwrap_or_else(|e| e.into_inner())
78            .clear();
79    }
80}
81
82impl<T> Default for Cache<T> {
83    fn default() -> Self {
84        Self::default_ttl()
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use std::thread;
92    use std::time::Duration as StdDuration;
93
94    #[test]
95    fn test_cache_basic() {
96        let cache = Cache::new(1);
97        let call_count = std::cell::RefCell::new(0);
98
99        let get_value = || {
100            *call_count.borrow_mut() += 1;
101            42
102        };
103
104        // First call - computes
105        let v1 = cache.get_or_compute("key", get_value);
106        assert_eq!(v1, 42);
107        assert_eq!(*call_count.borrow(), 1);
108
109        // Second call - should use cache
110        let v2 = cache.get_or_compute("key", get_value);
111        assert_eq!(v2, 42);
112        assert_eq!(*call_count.borrow(), 1);
113    }
114
115    #[test]
116    fn test_cache_expiry() {
117        let cache = Cache::new(1); // 1 second TTL
118        let call_count = std::cell::RefCell::new(0);
119
120        let get_value = || {
121            *call_count.borrow_mut() += 1;
122            42
123        };
124
125        // First call
126        let v1 = cache.get_or_compute("key", get_value);
127        assert_eq!(v1, 42);
128        assert_eq!(*call_count.borrow(), 1);
129
130        // Wait for expiry
131        thread::sleep(StdDuration::from_millis(1100));
132
133        // Should compute again
134        let v2 = cache.get_or_compute("key", get_value);
135        assert_eq!(v2, 42);
136        assert_eq!(*call_count.borrow(), 2);
137    }
138}