1use std::collections::HashMap;
2use std::sync::Mutex;
3use std::time::{Duration, Instant, SystemTime};
4
5static PROVIDER_CACHE: std::sync::LazyLock<Mutex<ProviderCache>> =
6 std::sync::LazyLock::new(|| Mutex::new(ProviderCache::new()));
7
8struct CacheEntry {
9 data: String,
10 expires_at: Instant,
11 #[allow(dead_code)]
12 created_at: SystemTime,
13 provider_id: String,
14}
15
16#[derive(Debug, Clone, Default)]
18pub struct ProviderCacheStats {
19 pub provider_id: String,
20 pub hits: u64,
21 pub misses: u64,
22 pub entry_count: usize,
23 pub last_fetch: Option<SystemTime>,
24}
25
26impl ProviderCacheStats {
27 pub fn hit_rate(&self) -> f64 {
28 let total = self.hits + self.misses;
29 if total == 0 {
30 return 0.0;
31 }
32 self.hits as f64 / total as f64
33 }
34}
35
36#[derive(Debug, Clone, Default)]
38pub struct CacheMetrics {
39 pub total_hits: u64,
40 pub total_misses: u64,
41 pub total_entries: usize,
42 pub provider_stats: Vec<ProviderCacheStats>,
43}
44
45impl CacheMetrics {
46 pub fn total_hit_rate(&self) -> f64 {
47 let total = self.total_hits + self.total_misses;
48 if total == 0 {
49 return 0.0;
50 }
51 self.total_hits as f64 / total as f64
52 }
53}
54
55struct ProviderCache {
56 entries: HashMap<String, CacheEntry>,
57 hits: HashMap<String, u64>,
58 misses: HashMap<String, u64>,
59 last_fetch: HashMap<String, SystemTime>,
60}
61
62impl ProviderCache {
63 fn new() -> Self {
64 Self {
65 entries: HashMap::new(),
66 hits: HashMap::new(),
67 misses: HashMap::new(),
68 last_fetch: HashMap::new(),
69 }
70 }
71
72 fn get(&mut self, key: &str) -> Option<&str> {
73 self.entries.retain(|_, v| v.expires_at > Instant::now());
74 if let Some(entry) = self.entries.get(key) {
75 *self.hits.entry(entry.provider_id.clone()).or_insert(0) += 1;
76 Some(entry.data.as_str())
77 } else {
78 let provider = key.split(':').next().unwrap_or("unknown");
79 *self.misses.entry(provider.to_string()).or_insert(0) += 1;
80 None
81 }
82 }
83
84 fn set(&mut self, key: String, data: String, ttl: Duration, provider_id: &str) {
85 let now = SystemTime::now();
86 self.last_fetch.insert(provider_id.to_string(), now);
87 self.entries.insert(
88 key,
89 CacheEntry {
90 data,
91 expires_at: Instant::now() + ttl,
92 created_at: now,
93 provider_id: provider_id.to_string(),
94 },
95 );
96 }
97
98 fn invalidate_provider(&mut self, provider_id: &str) -> usize {
99 let before = self.entries.len();
100 self.entries.retain(|_, v| v.provider_id != provider_id);
101 before - self.entries.len()
102 }
103
104 fn invalidate_all(&mut self) -> usize {
105 let count = self.entries.len();
106 self.entries.clear();
107 count
108 }
109
110 fn metrics(&mut self) -> CacheMetrics {
111 self.entries.retain(|_, v| v.expires_at > Instant::now());
112
113 let mut by_provider: HashMap<String, ProviderCacheStats> = HashMap::new();
114
115 for entry in self.entries.values() {
116 let stats = by_provider.entry(entry.provider_id.clone()).or_default();
117 stats.provider_id.clone_from(&entry.provider_id);
118 stats.entry_count += 1;
119 }
120
121 for (pid, &count) in &self.hits {
122 let stats = by_provider.entry(pid.clone()).or_default();
123 stats.provider_id.clone_from(pid);
124 stats.hits = count;
125 }
126 for (pid, &count) in &self.misses {
127 let stats = by_provider.entry(pid.clone()).or_default();
128 stats.provider_id.clone_from(pid);
129 stats.misses = count;
130 }
131 for (pid, &ts) in &self.last_fetch {
132 let stats = by_provider.entry(pid.clone()).or_default();
133 stats.provider_id.clone_from(pid);
134 stats.last_fetch = Some(ts);
135 }
136
137 let mut provider_stats: Vec<_> = by_provider.into_values().collect();
138 provider_stats.sort_by(|a, b| a.provider_id.cmp(&b.provider_id));
139
140 CacheMetrics {
141 total_hits: self.hits.values().sum(),
142 total_misses: self.misses.values().sum(),
143 total_entries: self.entries.len(),
144 provider_stats,
145 }
146 }
147}
148
149pub fn get_cached(key: &str) -> Option<String> {
150 PROVIDER_CACHE
151 .lock()
152 .ok()
153 .and_then(|mut c| c.get(key).map(std::string::ToString::to_string))
154}
155
156pub fn set_cached(key: &str, data: &str, ttl_secs: u64) {
157 set_cached_with_provider(
158 key,
159 data,
160 ttl_secs,
161 key.split(':').next().unwrap_or("unknown"),
162 );
163}
164
165pub fn set_cached_with_provider(key: &str, data: &str, ttl_secs: u64, provider_id: &str) {
166 if let Ok(mut cache) = PROVIDER_CACHE.lock() {
167 cache.set(
168 key.to_string(),
169 data.to_string(),
170 Duration::from_secs(ttl_secs),
171 provider_id,
172 );
173 }
174}
175
176pub fn invalidate_provider(provider_id: &str) -> usize {
177 PROVIDER_CACHE
178 .lock()
179 .ok()
180 .map_or(0, |mut c| c.invalidate_provider(provider_id))
181}
182
183pub fn invalidate_all() -> usize {
184 PROVIDER_CACHE
185 .lock()
186 .ok()
187 .map_or(0, |mut c| c.invalidate_all())
188}
189
190pub fn cache_metrics() -> CacheMetrics {
191 PROVIDER_CACHE
192 .lock()
193 .ok()
194 .map(|mut c| c.metrics())
195 .unwrap_or_default()
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn cache_set_and_get() {
204 let mut cache = ProviderCache::new();
205 cache.set(
206 "test:key".into(),
207 "value".into(),
208 Duration::from_mins(1),
209 "test",
210 );
211 assert_eq!(cache.get("test:key"), Some("value"));
212 }
213
214 #[test]
215 fn cache_expired_entry_returns_none() {
216 let mut cache = ProviderCache::new();
217 cache.set(
218 "test:key".into(),
219 "value".into(),
220 Duration::from_secs(0),
221 "test",
222 );
223 std::thread::sleep(Duration::from_millis(10));
224 assert!(cache.get("test:key").is_none());
225 }
226
227 #[test]
228 fn cache_tracks_hits_and_misses() {
229 let mut cache = ProviderCache::new();
230 cache.set(
231 "github:key".into(),
232 "data".into(),
233 Duration::from_mins(1),
234 "github",
235 );
236 cache.get("github:key"); cache.get("github:key"); cache.get("github:missing"); let metrics = cache.metrics();
241 assert_eq!(metrics.total_hits, 2);
242 assert_eq!(metrics.total_misses, 1);
243 assert!((metrics.total_hit_rate() - 0.666).abs() < 0.01);
244 }
245
246 #[test]
247 fn cache_invalidate_provider() {
248 let mut cache = ProviderCache::new();
249 cache.set(
250 "github:a".into(),
251 "1".into(),
252 Duration::from_mins(1),
253 "github",
254 );
255 cache.set(
256 "github:b".into(),
257 "2".into(),
258 Duration::from_mins(1),
259 "github",
260 );
261 cache.set(
262 "gitlab:c".into(),
263 "3".into(),
264 Duration::from_mins(1),
265 "gitlab",
266 );
267
268 let removed = cache.invalidate_provider("github");
269 assert_eq!(removed, 2);
270 assert!(cache.get("github:a").is_none());
271 assert_eq!(cache.get("gitlab:c"), Some("3"));
272 }
273
274 #[test]
275 fn cache_invalidate_all() {
276 let mut cache = ProviderCache::new();
277 cache.set("a".into(), "1".into(), Duration::from_mins(1), "x");
278 cache.set("b".into(), "2".into(), Duration::from_mins(1), "y");
279
280 let removed = cache.invalidate_all();
281 assert_eq!(removed, 2);
282 assert!(cache.get("a").is_none());
283 }
284
285 #[test]
286 fn cache_metrics_per_provider() {
287 let mut cache = ProviderCache::new();
288 cache.set(
289 "github:x".into(),
290 "a".into(),
291 Duration::from_mins(1),
292 "github",
293 );
294 cache.set(
295 "gitlab:y".into(),
296 "b".into(),
297 Duration::from_mins(1),
298 "gitlab",
299 );
300 cache.get("github:x");
301 cache.get("gitlab:miss");
302
303 let metrics = cache.metrics();
304 assert_eq!(metrics.provider_stats.len(), 2);
305
306 let gh = metrics
307 .provider_stats
308 .iter()
309 .find(|s| s.provider_id == "github")
310 .unwrap();
311 assert_eq!(gh.entry_count, 1);
312 assert_eq!(gh.hits, 1);
313
314 let gl = metrics
315 .provider_stats
316 .iter()
317 .find(|s| s.provider_id == "gitlab")
318 .unwrap();
319 assert_eq!(gl.entry_count, 1);
320 assert!(gl.last_fetch.is_some());
321 }
322
323 #[test]
324 fn provider_cache_stats_hit_rate() {
325 let stats = ProviderCacheStats {
326 provider_id: "test".into(),
327 hits: 3,
328 misses: 1,
329 entry_count: 2,
330 last_fetch: None,
331 };
332 assert!((stats.hit_rate() - 0.75).abs() < f64::EPSILON);
333 }
334
335 #[test]
336 fn provider_cache_stats_hit_rate_zero() {
337 let stats = ProviderCacheStats::default();
338 assert!((stats.hit_rate() - 0.0).abs() < f64::EPSILON);
339 }
340}