1use cid::Cid;
7use libp2p::PeerId;
8use parking_lot::RwLock;
9use serde::Serialize;
10use std::collections::{HashMap, HashSet};
11use std::time::{Duration, Instant};
12use tracing::{debug, info};
13
14const DEFAULT_TTL: Duration = Duration::from_secs(3600);
16
17const DEFAULT_MAX_ENTRIES: usize = 10000;
19
20#[derive(Debug, Clone)]
22pub struct ProviderCacheConfig {
23 pub ttl: Duration,
25 pub max_entries: usize,
27 pub min_providers: usize,
29}
30
31impl Default for ProviderCacheConfig {
32 fn default() -> Self {
33 Self {
34 ttl: DEFAULT_TTL,
35 max_entries: DEFAULT_MAX_ENTRIES,
36 min_providers: 1,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43struct CachedProviders {
44 providers: HashSet<PeerId>,
46 cached_at: Instant,
48 last_accessed: Instant,
50 access_count: u64,
52}
53
54impl CachedProviders {
55 fn new(providers: HashSet<PeerId>) -> Self {
56 let now = Instant::now();
57 Self {
58 providers,
59 cached_at: now,
60 last_accessed: now,
61 access_count: 0,
62 }
63 }
64
65 fn is_expired(&self, ttl: Duration) -> bool {
66 self.cached_at.elapsed() > ttl
67 }
68
69 fn touch(&mut self) {
70 self.last_accessed = Instant::now();
71 self.access_count += 1;
72 }
73}
74
75pub struct ProviderCache {
77 config: ProviderCacheConfig,
79 cache: RwLock<HashMap<Cid, CachedProviders>>,
81 stats: RwLock<CacheStats>,
83}
84
85#[derive(Default)]
87struct CacheStats {
88 hits: u64,
89 misses: u64,
90 evictions: u64,
91 expirations: u64,
92}
93
94impl ProviderCache {
95 pub fn new() -> Self {
97 Self::with_config(ProviderCacheConfig::default())
98 }
99
100 pub fn with_config(config: ProviderCacheConfig) -> Self {
102 Self {
103 config,
104 cache: RwLock::new(HashMap::new()),
105 stats: RwLock::new(CacheStats::default()),
106 }
107 }
108
109 pub fn get(&self, cid: &Cid) -> Option<Vec<PeerId>> {
113 let mut cache = self.cache.write();
114 let mut stats = self.stats.write();
115
116 if let Some(entry) = cache.get_mut(cid) {
117 if entry.is_expired(self.config.ttl) {
118 cache.remove(cid);
120 stats.expirations += 1;
121 stats.misses += 1;
122 debug!("Provider cache expired for {}", cid);
123 return None;
124 }
125
126 entry.touch();
128 stats.hits += 1;
129 debug!(
130 "Provider cache hit for {} ({} providers)",
131 cid,
132 entry.providers.len()
133 );
134 return Some(entry.providers.iter().cloned().collect());
135 }
136
137 stats.misses += 1;
139 None
140 }
141
142 pub fn has_providers(&self, cid: &Cid) -> bool {
144 let cache = self.cache.read();
145 if let Some(entry) = cache.get(cid) {
146 !entry.is_expired(self.config.ttl) && !entry.providers.is_empty()
147 } else {
148 false
149 }
150 }
151
152 pub fn needs_refresh(&self, cid: &Cid) -> bool {
154 let cache = self.cache.read();
155 if let Some(entry) = cache.get(cid) {
156 entry.is_expired(self.config.ttl) || entry.providers.len() < self.config.min_providers
157 } else {
158 true
159 }
160 }
161
162 pub fn put(&self, cid: Cid, providers: Vec<PeerId>) {
164 let provider_set: HashSet<PeerId> = providers.into_iter().collect();
165
166 if provider_set.is_empty() {
167 debug!("Not caching empty provider list for {}", cid);
168 return;
169 }
170
171 let mut cache = self.cache.write();
172
173 if cache.len() >= self.config.max_entries {
175 self.evict_lru(&mut cache);
176 }
177
178 let count = provider_set.len();
179 cache.insert(cid, CachedProviders::new(provider_set));
180 info!("Cached {} providers for {}", count, cid);
181 }
182
183 pub fn add_provider(&self, cid: &Cid, provider: PeerId) {
185 let mut cache = self.cache.write();
186
187 if let Some(entry) = cache.get_mut(cid) {
188 if !entry.is_expired(self.config.ttl) {
189 entry.providers.insert(provider);
190 entry.touch();
191 debug!("Added provider {} to cache for {}", provider, cid);
192 }
193 }
194 }
195
196 pub fn remove_provider(&self, cid: &Cid, provider: &PeerId) {
198 let mut cache = self.cache.write();
199
200 if let Some(entry) = cache.get_mut(cid) {
201 entry.providers.remove(provider);
202 debug!("Removed provider {} from cache for {}", provider, cid);
203 }
204 }
205
206 pub fn invalidate(&self, cid: &Cid) {
208 let mut cache = self.cache.write();
209 cache.remove(cid);
210 debug!("Invalidated cache for {}", cid);
211 }
212
213 pub fn cleanup_expired(&self) {
215 let mut cache = self.cache.write();
216 let mut stats = self.stats.write();
217 let ttl = self.config.ttl;
218
219 let before = cache.len();
220 cache.retain(|_, entry| !entry.is_expired(ttl));
221 let removed = before - cache.len();
222
223 if removed > 0 {
224 stats.expirations += removed as u64;
225 info!("Cleaned up {} expired provider cache entries", removed);
226 }
227 }
228
229 pub fn clear(&self) {
231 let mut cache = self.cache.write();
232 cache.clear();
233 info!("Provider cache cleared");
234 }
235
236 pub fn stats(&self) -> ProviderCacheStats {
238 let cache = self.cache.read();
239 let stats = self.stats.read();
240
241 let total_providers: usize = cache.values().map(|e| e.providers.len()).sum();
242 let hit_rate = if stats.hits + stats.misses > 0 {
243 stats.hits as f64 / (stats.hits + stats.misses) as f64
244 } else {
245 0.0
246 };
247
248 ProviderCacheStats {
249 entries: cache.len(),
250 max_entries: self.config.max_entries,
251 total_providers,
252 hits: stats.hits,
253 misses: stats.misses,
254 hit_rate,
255 evictions: stats.evictions,
256 expirations: stats.expirations,
257 }
258 }
259
260 pub fn len(&self) -> usize {
262 self.cache.read().len()
263 }
264
265 pub fn is_empty(&self) -> bool {
267 self.cache.read().is_empty()
268 }
269
270 fn evict_lru(&self, cache: &mut HashMap<Cid, CachedProviders>) {
272 let to_evict = (self.config.max_entries / 10).max(1);
274
275 let mut entries: Vec<_> = cache
276 .iter()
277 .map(|(cid, entry)| (*cid, entry.last_accessed))
278 .collect();
279
280 entries.sort_by(|a, b| a.1.cmp(&b.1));
282
283 let mut stats = self.stats.write();
284 for (cid, _) in entries.into_iter().take(to_evict) {
285 cache.remove(&cid);
286 stats.evictions += 1;
287 }
288
289 debug!("Evicted {} LRU cache entries", to_evict);
290 }
291}
292
293impl Default for ProviderCache {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299#[derive(Debug, Clone, Serialize)]
301pub struct ProviderCacheStats {
302 pub entries: usize,
304 pub max_entries: usize,
306 pub total_providers: usize,
308 pub hits: u64,
310 pub misses: u64,
312 pub hit_rate: f64,
314 pub evictions: u64,
316 pub expirations: u64,
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use multihash_codetable::{Code, MultihashDigest};
324
325 fn make_cid(data: &[u8]) -> Cid {
326 let hash = Code::Sha2_256.digest(data);
327 Cid::new_v1(0x55, hash)
328 }
329
330 fn random_peer_id() -> PeerId {
331 PeerId::random()
332 }
333
334 #[test]
335 fn test_provider_cache_basic() {
336 let cache = ProviderCache::new();
337 let cid = make_cid(b"test data");
338 let peer1 = random_peer_id();
339 let peer2 = random_peer_id();
340
341 assert!(cache.get(&cid).is_none());
343 assert!(cache.needs_refresh(&cid));
344
345 cache.put(cid, vec![peer1, peer2]);
347
348 let providers = cache.get(&cid).unwrap();
350 assert_eq!(providers.len(), 2);
351 assert!(providers.contains(&peer1));
352 assert!(providers.contains(&peer2));
353 assert!(cache.has_providers(&cid));
354 assert!(!cache.needs_refresh(&cid));
355 }
356
357 #[test]
358 fn test_provider_cache_add_remove() {
359 let cache = ProviderCache::new();
360 let cid = make_cid(b"test");
361 let peer1 = random_peer_id();
362 let peer2 = random_peer_id();
363 let peer3 = random_peer_id();
364
365 cache.put(cid, vec![peer1, peer2]);
366
367 cache.add_provider(&cid, peer3);
369 let providers = cache.get(&cid).unwrap();
370 assert_eq!(providers.len(), 3);
371
372 cache.remove_provider(&cid, &peer1);
374 let providers = cache.get(&cid).unwrap();
375 assert_eq!(providers.len(), 2);
376 assert!(!providers.contains(&peer1));
377 }
378
379 #[test]
380 fn test_provider_cache_expiration() {
381 let config = ProviderCacheConfig {
382 ttl: Duration::from_millis(50),
383 ..Default::default()
384 };
385 let cache = ProviderCache::with_config(config);
386 let cid = make_cid(b"expiring");
387 let peer = random_peer_id();
388
389 cache.put(cid, vec![peer]);
390 assert!(cache.get(&cid).is_some());
391
392 std::thread::sleep(Duration::from_millis(100));
394
395 assert!(cache.get(&cid).is_none());
397 assert!(cache.needs_refresh(&cid));
398
399 let stats = cache.stats();
400 assert!(stats.expirations > 0);
401 }
402
403 #[test]
404 fn test_provider_cache_lru_eviction() {
405 let config = ProviderCacheConfig {
406 ttl: Duration::from_secs(3600),
407 max_entries: 5,
408 ..Default::default()
409 };
410 let cache = ProviderCache::with_config(config);
411 let peer = random_peer_id();
412
413 for i in 0..5 {
415 let cid = make_cid(&[i as u8]);
416 cache.put(cid, vec![peer]);
417 }
418
419 assert_eq!(cache.len(), 5);
420
421 let cid_2 = make_cid(&[2]);
423 let cid_3 = make_cid(&[3]);
424 cache.get(&cid_2);
425 cache.get(&cid_3);
426
427 let new_cid = make_cid(&[100]);
429 cache.put(new_cid, vec![peer]);
430
431 assert!(cache.len() <= 5);
433
434 let stats = cache.stats();
435 assert!(stats.evictions > 0);
436 }
437
438 #[test]
439 fn test_provider_cache_stats() {
440 let cache = ProviderCache::new();
441 let cid1 = make_cid(b"one");
442 let cid2 = make_cid(b"two");
443 let peer = random_peer_id();
444
445 cache.get(&cid1);
447
448 cache.put(cid1, vec![peer]);
450 cache.get(&cid1);
451
452 cache.get(&cid2);
454
455 let stats = cache.stats();
456 assert_eq!(stats.hits, 1);
457 assert_eq!(stats.misses, 2);
458 assert_eq!(stats.entries, 1);
459 }
460
461 #[test]
462 fn test_provider_cache_invalidate() {
463 let cache = ProviderCache::new();
464 let cid = make_cid(b"invalidate me");
465 let peer = random_peer_id();
466
467 cache.put(cid, vec![peer]);
468 assert!(cache.has_providers(&cid));
469
470 cache.invalidate(&cid);
471 assert!(!cache.has_providers(&cid));
472 }
473
474 #[test]
475 fn test_provider_cache_cleanup() {
476 let config = ProviderCacheConfig {
477 ttl: Duration::from_millis(10),
478 ..Default::default()
479 };
480 let cache = ProviderCache::with_config(config);
481 let peer = random_peer_id();
482
483 for i in 0..5 {
485 let cid = make_cid(&[i as u8]);
486 cache.put(cid, vec![peer]);
487 }
488
489 assert_eq!(cache.len(), 5);
490
491 std::thread::sleep(Duration::from_millis(50));
493
494 cache.cleanup_expired();
496 assert_eq!(cache.len(), 0);
497 }
498
499 #[test]
500 fn test_provider_cache_empty_providers_not_cached() {
501 let cache = ProviderCache::new();
502 let cid = make_cid(b"empty");
503
504 cache.put(cid, vec![]);
505 assert!(!cache.has_providers(&cid));
506 assert_eq!(cache.len(), 0);
507 }
508}