Skip to main content

xds_cache/
cache.rs

1//! Cache trait and ShardedCache implementation.
2//!
3//! The cache stores snapshots keyed by node hash. The [`ShardedCache`]
4//! implementation uses `DashMap` for lock-free concurrent access.
5
6use std::sync::Arc;
7
8use dashmap::DashMap;
9use tracing::{debug, trace};
10use xds_core::NodeHash;
11
12use crate::snapshot::Snapshot;
13use crate::stats::CacheStats;
14use crate::watch::WatchManager;
15
16/// Trait for xDS snapshot caches.
17///
18/// Provides the interface for storing and retrieving snapshots.
19pub trait Cache: Send + Sync {
20    /// Get a snapshot for a node.
21    fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>>;
22
23    /// Set a snapshot for a node.
24    ///
25    /// This will notify any watches for this node.
26    fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot);
27
28    /// Clear the snapshot for a node.
29    fn clear_snapshot(&self, node: NodeHash);
30
31    /// Get the number of cached snapshots.
32    fn snapshot_count(&self) -> usize;
33}
34
35/// A high-performance sharded cache using DashMap.
36///
37/// This cache implementation:
38/// - Uses `DashMap` for lock-free concurrent reads
39/// - Automatically notifies watches on snapshot updates
40/// - Tracks statistics for monitoring
41///
42/// ## Thread Safety
43///
44/// All operations are thread-safe. The cache uses `DashMap` internally,
45/// which provides fine-grained locking at the bucket level rather than
46/// a global lock.
47///
48/// ## Important
49///
50/// All `DashMap` references are dropped before any async operations
51/// to prevent holding locks across await points.
52#[derive(Debug)]
53pub struct ShardedCache {
54    /// Snapshots keyed by node hash.
55    snapshots: DashMap<NodeHash, Arc<Snapshot>>,
56    /// Watch manager for notifications.
57    watches: WatchManager,
58    /// Statistics.
59    stats: CacheStats,
60}
61
62impl Default for ShardedCache {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl ShardedCache {
69    /// Create a new sharded cache with default settings.
70    pub fn new() -> Self {
71        Self::with_capacity(64)
72    }
73
74    /// Create a new sharded cache with a specific initial capacity.
75    pub fn with_capacity(capacity: usize) -> Self {
76        Self {
77            snapshots: DashMap::with_capacity(capacity),
78            watches: WatchManager::new(),
79            stats: CacheStats::new(),
80        }
81    }
82
83    /// Get the watch manager for creating watches.
84    #[inline]
85    pub fn watches(&self) -> &WatchManager {
86        &self.watches
87    }
88
89    /// Get cache statistics.
90    #[inline]
91    pub fn stats(&self) -> &CacheStats {
92        &self.stats
93    }
94
95    /// Create a watch for a node.
96    ///
97    /// The watch will receive updates when the snapshot for this node changes.
98    /// If a snapshot already exists, the caller should check with `get_snapshot`
99    /// first.
100    #[inline]
101    pub fn create_watch(&self, node: NodeHash) -> crate::watch::Watch {
102        self.watches.create_watch(node)
103    }
104
105    /// Cancel a watch.
106    #[inline]
107    pub fn cancel_watch(&self, watch_id: crate::watch::WatchId) {
108        self.watches.cancel_watch(watch_id)
109    }
110
111    /// Get all node hashes in the cache.
112    pub fn nodes(&self) -> Vec<NodeHash> {
113        self.snapshots.iter().map(|r| *r.key()).collect()
114    }
115
116    /// Check if a snapshot exists for a node.
117    pub fn has_snapshot(&self, node: NodeHash) -> bool {
118        self.snapshots.contains_key(&node)
119    }
120
121    /// Iterate over all snapshots.
122    ///
123    /// Note: This acquires read locks on all shards.
124    pub fn iter(&self) -> impl Iterator<Item = (NodeHash, Arc<Snapshot>)> + '_ {
125        self.snapshots
126            .iter()
127            .map(|r| (*r.key(), Arc::clone(r.value())))
128    }
129}
130
131impl Cache for ShardedCache {
132    fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>> {
133        // DashMap::get returns a Ref that holds a read lock.
134        // We clone the Arc and drop the Ref immediately.
135        let result = self.snapshots.get(&node).map(|r| Arc::clone(&*r));
136
137        if result.is_some() {
138            self.stats.record_hit();
139            trace!(node = %node, "cache hit");
140        } else {
141            self.stats.record_miss();
142            trace!(node = %node, "cache miss");
143        }
144
145        result
146    }
147
148    fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot) {
149        let snapshot = Arc::new(snapshot);
150
151        // Insert snapshot (DashMap insert is lock-free for the caller)
152        self.snapshots.insert(node, Arc::clone(&snapshot));
153        self.stats.record_set();
154
155        debug!(
156            node = %node,
157            version = %snapshot.version(),
158            resources = snapshot.total_resources(),
159            "set snapshot"
160        );
161
162        // Notify watches (no DashMap lock held)
163        self.watches.notify(node, snapshot);
164    }
165
166    fn clear_snapshot(&self, node: NodeHash) {
167        if self.snapshots.remove(&node).is_some() {
168            self.stats.record_clear();
169            debug!(node = %node, "cleared snapshot");
170        }
171    }
172
173    fn snapshot_count(&self) -> usize {
174        self.snapshots.len()
175    }
176}
177
178/// Builder for creating a configured cache.
179#[derive(Debug, Default)]
180#[allow(dead_code)] // Public API surface
181pub struct CacheBuilder {
182    capacity: Option<usize>,
183    watch_buffer_size: Option<usize>,
184}
185
186#[allow(dead_code)] // Public API surface
187impl CacheBuilder {
188    /// Create a new cache builder.
189    pub fn new() -> Self {
190        Self::default()
191    }
192
193    /// Set the initial capacity.
194    pub fn capacity(mut self, capacity: usize) -> Self {
195        self.capacity = Some(capacity);
196        self
197    }
198
199    /// Set the watch channel buffer size.
200    pub fn watch_buffer_size(mut self, size: usize) -> Self {
201        self.watch_buffer_size = Some(size);
202        self
203    }
204
205    /// Build the cache.
206    pub fn build(self) -> ShardedCache {
207        let capacity = self.capacity.unwrap_or(64);
208        let watch_buffer = self.watch_buffer_size.unwrap_or(16);
209
210        ShardedCache {
211            snapshots: DashMap::with_capacity(capacity),
212            watches: WatchManager::with_buffer_size(watch_buffer),
213            stats: CacheStats::new(),
214        }
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use std::sync::atomic::{AtomicUsize, Ordering};
222    use std::thread;
223    use std::time::Duration;
224
225    #[test]
226    fn cache_basic_operations() {
227        let cache = ShardedCache::new();
228        let node = NodeHash::from_id("test-node");
229
230        // Initially empty
231        assert!(cache.get_snapshot(node).is_none());
232        assert_eq!(cache.snapshot_count(), 0);
233
234        // Set a snapshot
235        let snapshot = Snapshot::builder().version("v1").build();
236        cache.set_snapshot(node, snapshot);
237
238        // Now exists
239        assert!(cache.has_snapshot(node));
240        assert_eq!(cache.snapshot_count(), 1);
241
242        let retrieved = cache.get_snapshot(node).unwrap();
243        assert_eq!(retrieved.version(), "v1");
244
245        // Clear
246        cache.clear_snapshot(node);
247        assert!(!cache.has_snapshot(node));
248        assert_eq!(cache.snapshot_count(), 0);
249    }
250
251    #[test]
252    fn cache_stats_tracking() {
253        let cache = ShardedCache::new();
254        let node = NodeHash::from_id("test-node");
255
256        // Miss
257        cache.get_snapshot(node);
258        assert_eq!(cache.stats().snapshot_misses(), 1);
259
260        // Set
261        cache.set_snapshot(node, Snapshot::builder().version("v1").build());
262        assert_eq!(cache.stats().snapshots_set(), 1);
263
264        // Hit
265        cache.get_snapshot(node);
266        assert_eq!(cache.stats().snapshot_hits(), 1);
267    }
268
269    #[tokio::test]
270    async fn cache_watch_notification() {
271        let cache = ShardedCache::new();
272        let node = NodeHash::from_id("test-node");
273
274        let mut watch = cache.create_watch(node);
275
276        // Set snapshot
277        cache.set_snapshot(node, Snapshot::builder().version("v1").build());
278
279        // Watch should receive it
280        let snapshot = watch.recv().await.unwrap();
281        assert_eq!(snapshot.version(), "v1");
282    }
283
284    #[test]
285    fn cache_builder() {
286        let cache = CacheBuilder::new()
287            .capacity(128)
288            .watch_buffer_size(32)
289            .build();
290
291        assert_eq!(cache.snapshot_count(), 0);
292    }
293
294    // === Concurrent Access Tests ===
295
296    #[test]
297    fn cache_concurrent_reads() {
298        let cache = Arc::new(ShardedCache::new());
299        let node = NodeHash::from_id("test-node");
300
301        // Pre-populate cache
302        cache.set_snapshot(node, Snapshot::builder().version("v1").build());
303
304        let read_count = Arc::new(AtomicUsize::new(0));
305        let mut handles = vec![];
306
307        // Spawn 10 reader threads
308        for _ in 0..10 {
309            let cache = Arc::clone(&cache);
310            let count = Arc::clone(&read_count);
311            handles.push(thread::spawn(move || {
312                for _ in 0..100 {
313                    if cache.get_snapshot(node).is_some() {
314                        count.fetch_add(1, Ordering::Relaxed);
315                    }
316                }
317            }));
318        }
319
320        for handle in handles {
321            handle.join().expect("Thread panicked");
322        }
323
324        // All reads should succeed
325        assert_eq!(read_count.load(Ordering::Relaxed), 1000);
326    }
327
328    #[test]
329    fn cache_concurrent_writes() {
330        let cache = Arc::new(ShardedCache::new());
331        let mut handles = vec![];
332
333        // Spawn 10 writer threads, each writing to different nodes
334        for i in 0..10 {
335            let cache = Arc::clone(&cache);
336            handles.push(thread::spawn(move || {
337                for j in 0..100 {
338                    let node = NodeHash::from_id(&format!("node-{}-{}", i, j));
339                    cache
340                        .set_snapshot(node, Snapshot::builder().version(format!("v{}", j)).build());
341                }
342            }));
343        }
344
345        for handle in handles {
346            handle.join().expect("Thread panicked");
347        }
348
349        // All 1000 nodes should be in cache
350        assert_eq!(cache.snapshot_count(), 1000);
351    }
352
353    #[test]
354    fn cache_concurrent_read_write() {
355        let cache = Arc::new(ShardedCache::new());
356        let node = NodeHash::from_id("contended-node");
357
358        // Pre-populate
359        cache.set_snapshot(node, Snapshot::builder().version("v0").build());
360
361        let reads = Arc::new(AtomicUsize::new(0));
362        let writes = Arc::new(AtomicUsize::new(0));
363        let mut handles = vec![];
364
365        // Writer thread
366        {
367            let cache = Arc::clone(&cache);
368            let writes = Arc::clone(&writes);
369            handles.push(thread::spawn(move || {
370                for i in 1..=50 {
371                    cache
372                        .set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
373                    writes.fetch_add(1, Ordering::Relaxed);
374                    thread::sleep(Duration::from_micros(100));
375                }
376            }));
377        }
378
379        // Reader threads
380        for _ in 0..5 {
381            let cache = Arc::clone(&cache);
382            let reads = Arc::clone(&reads);
383            handles.push(thread::spawn(move || {
384                for _ in 0..100 {
385                    if cache.get_snapshot(node).is_some() {
386                        reads.fetch_add(1, Ordering::Relaxed);
387                    }
388                    thread::sleep(Duration::from_micros(50));
389                }
390            }));
391        }
392
393        for handle in handles {
394            handle.join().expect("Thread panicked");
395        }
396
397        assert_eq!(writes.load(Ordering::Relaxed), 50);
398        // All reads should succeed (snapshot always exists)
399        assert_eq!(reads.load(Ordering::Relaxed), 500);
400    }
401
402    // === Large Snapshot Tests ===
403
404    #[test]
405    fn cache_many_nodes() {
406        let cache = ShardedCache::with_capacity(10000);
407
408        // Add 10,000 nodes
409        for i in 0..10000 {
410            let node = NodeHash::from_id(&format!("node-{}", i));
411            cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
412        }
413
414        assert_eq!(cache.snapshot_count(), 10000);
415
416        // Verify random access
417        for i in [0, 999, 5000, 9999] {
418            let node = NodeHash::from_id(&format!("node-{}", i));
419            let snap = cache.get_snapshot(node).unwrap();
420            assert_eq!(snap.version(), format!("v{}", i));
421        }
422    }
423
424    #[test]
425    fn cache_snapshot_update() {
426        let cache = ShardedCache::new();
427        let node = NodeHash::from_id("test-node");
428
429        // Initial version
430        cache.set_snapshot(node, Snapshot::builder().version("v1").build());
431        assert_eq!(cache.get_snapshot(node).unwrap().version(), "v1");
432
433        // Update version
434        cache.set_snapshot(node, Snapshot::builder().version("v2").build());
435        assert_eq!(cache.get_snapshot(node).unwrap().version(), "v2");
436
437        // Stats should show 2 sets
438        assert_eq!(cache.stats().snapshots_set(), 2);
439    }
440
441    // === Watch Tests ===
442
443    #[tokio::test]
444    async fn cache_multiple_watches_same_node() {
445        let cache = ShardedCache::new();
446        let node = NodeHash::from_id("test-node");
447
448        let mut watch1 = cache.create_watch(node);
449        let mut watch2 = cache.create_watch(node);
450
451        // Set snapshot
452        cache.set_snapshot(node, Snapshot::builder().version("v1").build());
453
454        // Both watches should receive it
455        let snap1 = watch1.recv().await.unwrap();
456        let snap2 = watch2.recv().await.unwrap();
457        assert_eq!(snap1.version(), "v1");
458        assert_eq!(snap2.version(), "v1");
459    }
460
461    #[tokio::test]
462    async fn cache_watch_receives_updates() {
463        let cache = ShardedCache::new();
464        let node = NodeHash::from_id("test-node");
465
466        let mut watch = cache.create_watch(node);
467
468        // Send multiple updates
469        for i in 1..=3 {
470            cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
471        }
472
473        // Watch should receive all updates (buffered)
474        let snap1 = watch.recv().await.unwrap();
475        assert_eq!(snap1.version(), "v1");
476
477        let snap2 = watch.recv().await.unwrap();
478        assert_eq!(snap2.version(), "v2");
479
480        let snap3 = watch.recv().await.unwrap();
481        assert_eq!(snap3.version(), "v3");
482    }
483
484    // === Edge Cases ===
485
486    #[test]
487    fn cache_clear_nonexistent_node() {
488        let cache = ShardedCache::new();
489        let node = NodeHash::from_id("nonexistent");
490
491        // Should not panic
492        cache.clear_snapshot(node);
493        assert_eq!(cache.snapshot_count(), 0);
494    }
495
496    #[test]
497    fn cache_wildcard_node() {
498        let cache = ShardedCache::new();
499        let wildcard = NodeHash::wildcard();
500
501        cache.set_snapshot(wildcard, Snapshot::builder().version("v1").build());
502        assert!(cache.has_snapshot(wildcard));
503
504        let snap = cache.get_snapshot(wildcard).unwrap();
505        assert_eq!(snap.version(), "v1");
506    }
507
508    #[test]
509    fn cache_node_hash_collision_unlikely() {
510        // FNV-1a should give different hashes for similar strings
511        let node1 = NodeHash::from_id("node-1");
512        let node2 = NodeHash::from_id("node-2");
513        let node3 = NodeHash::from_id("1-node");
514
515        // All should be different (this is a sanity check, not guaranteed)
516        assert_ne!(node1, node2);
517        assert_ne!(node2, node3);
518        assert_ne!(node1, node3);
519    }
520}