Skip to main content

lean_ctx/core/providers/
cache.rs

1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant, SystemTime};
4
5static PROVIDER_CACHE: std::sync::LazyLock<Mutex<ProviderCache>> =
6    std::sync::LazyLock::new(|| Mutex::new(ProviderCache::new()));
7
8struct CacheEntry {
9    data: String,
10    expires_at: Instant,
11    #[allow(dead_code)]
12    created_at: SystemTime,
13    provider_id: String,
14}
15
16/// Per-provider cache statistics.
17#[derive(Debug, Clone, Default)]
18pub struct ProviderCacheStats {
19    pub provider_id: String,
20    pub hits: u64,
21    pub misses: u64,
22    pub entry_count: usize,
23    pub last_fetch: Option<SystemTime>,
24}
25
26impl ProviderCacheStats {
27    pub fn hit_rate(&self) -> f64 {
28        let total = self.hits + self.misses;
29        if total == 0 {
30            return 0.0;
31        }
32        self.hits as f64 / total as f64
33    }
34}
35
36/// Global cache statistics across all providers.
37#[derive(Debug, Clone, Default)]
38pub struct CacheMetrics {
39    pub total_hits: u64,
40    pub total_misses: u64,
41    pub total_entries: usize,
42    pub provider_stats: Vec<ProviderCacheStats>,
43}
44
45impl CacheMetrics {
46    pub fn total_hit_rate(&self) -> f64 {
47        let total = self.total_hits + self.total_misses;
48        if total == 0 {
49            return 0.0;
50        }
51        self.total_hits as f64 / total as f64
52    }
53}
54
55struct ProviderCache {
56    entries: HashMap<String, CacheEntry>,
57    hits: HashMap<String, u64>,
58    misses: HashMap<String, u64>,
59    last_fetch: HashMap<String, SystemTime>,
60}
61
62impl ProviderCache {
63    fn new() -> Self {
64        Self {
65            entries: HashMap::new(),
66            hits: HashMap::new(),
67            misses: HashMap::new(),
68            last_fetch: HashMap::new(),
69        }
70    }
71
72    fn get(&mut self, key: &str) -> Option<&str> {
73        self.entries.retain(|_, v| v.expires_at > Instant::now());
74        if let Some(entry) = self.entries.get(key) {
75            *self.hits.entry(entry.provider_id.clone()).or_insert(0) += 1;
76            Some(entry.data.as_str())
77        } else {
78            let provider = key.split(':').next().unwrap_or("unknown");
79            *self.misses.entry(provider.to_string()).or_insert(0) += 1;
80            None
81        }
82    }
83
84    fn set(&mut self, key: String, data: String, ttl: Duration, provider_id: &str) {
85        let now = SystemTime::now();
86        self.last_fetch.insert(provider_id.to_string(), now);
87        self.entries.insert(
88            key,
89            CacheEntry {
90                data,
91                expires_at: Instant::now() + ttl,
92                created_at: now,
93                provider_id: provider_id.to_string(),
94            },
95        );
96    }
97
98    fn invalidate_provider(&mut self, provider_id: &str) -> usize {
99        let before = self.entries.len();
100        self.entries.retain(|_, v| v.provider_id != provider_id);
101        before - self.entries.len()
102    }
103
104    fn invalidate_all(&mut self) -> usize {
105        let count = self.entries.len();
106        self.entries.clear();
107        count
108    }
109
110    fn metrics(&mut self) -> CacheMetrics {
111        self.entries.retain(|_, v| v.expires_at > Instant::now());
112
113        let mut by_provider: HashMap<String, ProviderCacheStats> = HashMap::new();
114
115        for entry in self.entries.values() {
116            let stats = by_provider.entry(entry.provider_id.clone()).or_default();
117            stats.provider_id.clone_from(&entry.provider_id);
118            stats.entry_count += 1;
119        }
120
121        for (pid, &count) in &self.hits {
122            let stats = by_provider.entry(pid.clone()).or_default();
123            stats.provider_id.clone_from(pid);
124            stats.hits = count;
125        }
126        for (pid, &count) in &self.misses {
127            let stats = by_provider.entry(pid.clone()).or_default();
128            stats.provider_id.clone_from(pid);
129            stats.misses = count;
130        }
131        for (pid, &ts) in &self.last_fetch {
132            let stats = by_provider.entry(pid.clone()).or_default();
133            stats.provider_id.clone_from(pid);
134            stats.last_fetch = Some(ts);
135        }
136
137        let mut provider_stats: Vec<_> = by_provider.into_values().collect();
138        provider_stats.sort_by(|a, b| a.provider_id.cmp(&b.provider_id));
139
140        CacheMetrics {
141            total_hits: self.hits.values().sum(),
142            total_misses: self.misses.values().sum(),
143            total_entries: self.entries.len(),
144            provider_stats,
145        }
146    }
147}
148
149pub fn get_cached(key: &str) -> Option<String> {
150    PROVIDER_CACHE
151        .lock()
152        .ok()
153        .and_then(|mut c| c.get(key).map(std::string::ToString::to_string))
154}
155
156pub fn set_cached(key: &str, data: &str, ttl_secs: u64) {
157    set_cached_with_provider(
158        key,
159        data,
160        ttl_secs,
161        key.split(':').next().unwrap_or("unknown"),
162    );
163}
164
165pub fn set_cached_with_provider(key: &str, data: &str, ttl_secs: u64, provider_id: &str) {
166    if let Ok(mut cache) = PROVIDER_CACHE.lock() {
167        cache.set(
168            key.to_string(),
169            data.to_string(),
170            Duration::from_secs(ttl_secs),
171            provider_id,
172        );
173    }
174}
175
176pub fn invalidate_provider(provider_id: &str) -> usize {
177    PROVIDER_CACHE
178        .lock()
179        .ok()
180        .map_or(0, |mut c| c.invalidate_provider(provider_id))
181}
182
183pub fn invalidate_all() -> usize {
184    PROVIDER_CACHE
185        .lock()
186        .ok()
187        .map_or(0, |mut c| c.invalidate_all())
188}
189
190pub fn cache_metrics() -> CacheMetrics {
191    PROVIDER_CACHE
192        .lock()
193        .ok()
194        .map(|mut c| c.metrics())
195        .unwrap_or_default()
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn cache_set_and_get() {
204        let mut cache = ProviderCache::new();
205        cache.set(
206            "test:key".into(),
207            "value".into(),
208            Duration::from_mins(1),
209            "test",
210        );
211        assert_eq!(cache.get("test:key"), Some("value"));
212    }
213
214    #[test]
215    fn cache_expired_entry_returns_none() {
216        let mut cache = ProviderCache::new();
217        cache.set(
218            "test:key".into(),
219            "value".into(),
220            Duration::from_secs(0),
221            "test",
222        );
223        std::thread::sleep(Duration::from_millis(10));
224        assert!(cache.get("test:key").is_none());
225    }
226
227    #[test]
228    fn cache_tracks_hits_and_misses() {
229        let mut cache = ProviderCache::new();
230        cache.set(
231            "github:key".into(),
232            "data".into(),
233            Duration::from_mins(1),
234            "github",
235        );
236        cache.get("github:key"); // hit
237        cache.get("github:key"); // hit
238        cache.get("github:missing"); // miss
239
240        let metrics = cache.metrics();
241        assert_eq!(metrics.total_hits, 2);
242        assert_eq!(metrics.total_misses, 1);
243        assert!((metrics.total_hit_rate() - 0.666).abs() < 0.01);
244    }
245
246    #[test]
247    fn cache_invalidate_provider() {
248        let mut cache = ProviderCache::new();
249        cache.set(
250            "github:a".into(),
251            "1".into(),
252            Duration::from_mins(1),
253            "github",
254        );
255        cache.set(
256            "github:b".into(),
257            "2".into(),
258            Duration::from_mins(1),
259            "github",
260        );
261        cache.set(
262            "gitlab:c".into(),
263            "3".into(),
264            Duration::from_mins(1),
265            "gitlab",
266        );
267
268        let removed = cache.invalidate_provider("github");
269        assert_eq!(removed, 2);
270        assert!(cache.get("github:a").is_none());
271        assert_eq!(cache.get("gitlab:c"), Some("3"));
272    }
273
274    #[test]
275    fn cache_invalidate_all() {
276        let mut cache = ProviderCache::new();
277        cache.set("a".into(), "1".into(), Duration::from_mins(1), "x");
278        cache.set("b".into(), "2".into(), Duration::from_mins(1), "y");
279
280        let removed = cache.invalidate_all();
281        assert_eq!(removed, 2);
282        assert!(cache.get("a").is_none());
283    }
284
285    #[test]
286    fn cache_metrics_per_provider() {
287        let mut cache = ProviderCache::new();
288        cache.set(
289            "github:x".into(),
290            "a".into(),
291            Duration::from_mins(1),
292            "github",
293        );
294        cache.set(
295            "gitlab:y".into(),
296            "b".into(),
297            Duration::from_mins(1),
298            "gitlab",
299        );
300        cache.get("github:x");
301        cache.get("gitlab:miss");
302
303        let metrics = cache.metrics();
304        assert_eq!(metrics.provider_stats.len(), 2);
305
306        let gh = metrics
307            .provider_stats
308            .iter()
309            .find(|s| s.provider_id == "github")
310            .unwrap();
311        assert_eq!(gh.entry_count, 1);
312        assert_eq!(gh.hits, 1);
313
314        let gl = metrics
315            .provider_stats
316            .iter()
317            .find(|s| s.provider_id == "gitlab")
318            .unwrap();
319        assert_eq!(gl.entry_count, 1);
320        assert!(gl.last_fetch.is_some());
321    }
322
323    #[test]
324    fn provider_cache_stats_hit_rate() {
325        let stats = ProviderCacheStats {
326            provider_id: "test".into(),
327            hits: 3,
328            misses: 1,
329            entry_count: 2,
330            last_fetch: None,
331        };
332        assert!((stats.hit_rate() - 0.75).abs() < f64::EPSILON);
333    }
334
335    #[test]
336    fn provider_cache_stats_hit_rate_zero() {
337        let stats = ProviderCacheStats::default();
338        assert!((stats.hit_rate() - 0.0).abs() < f64::EPSILON);
339    }
340}