1use std::sync::Arc;
7
8use dashmap::DashMap;
9use tracing::{debug, trace};
10use xds_core::NodeHash;
11
12use crate::snapshot::Snapshot;
13use crate::stats::CacheStats;
14use crate::watch::WatchManager;
15
16pub trait Cache: Send + Sync {
20 fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>>;
22
23 fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot);
27
28 fn clear_snapshot(&self, node: NodeHash);
30
31 fn snapshot_count(&self) -> usize;
33}
34
35#[derive(Debug)]
53pub struct ShardedCache {
54 snapshots: DashMap<NodeHash, Arc<Snapshot>>,
56 watches: WatchManager,
58 stats: CacheStats,
60}
61
62impl Default for ShardedCache {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68impl ShardedCache {
69 pub fn new() -> Self {
71 Self::with_capacity(64)
72 }
73
74 pub fn with_capacity(capacity: usize) -> Self {
76 Self {
77 snapshots: DashMap::with_capacity(capacity),
78 watches: WatchManager::new(),
79 stats: CacheStats::new(),
80 }
81 }
82
83 #[inline]
85 pub fn watches(&self) -> &WatchManager {
86 &self.watches
87 }
88
89 #[inline]
91 pub fn stats(&self) -> &CacheStats {
92 &self.stats
93 }
94
95 #[inline]
101 pub fn create_watch(&self, node: NodeHash) -> crate::watch::Watch {
102 self.watches.create_watch(node)
103 }
104
105 #[inline]
107 pub fn cancel_watch(&self, watch_id: crate::watch::WatchId) {
108 self.watches.cancel_watch(watch_id)
109 }
110
111 pub fn nodes(&self) -> Vec<NodeHash> {
113 self.snapshots.iter().map(|r| *r.key()).collect()
114 }
115
116 pub fn has_snapshot(&self, node: NodeHash) -> bool {
118 self.snapshots.contains_key(&node)
119 }
120
121 pub fn iter(&self) -> impl Iterator<Item = (NodeHash, Arc<Snapshot>)> + '_ {
125 self.snapshots
126 .iter()
127 .map(|r| (*r.key(), Arc::clone(r.value())))
128 }
129}
130
131impl Cache for ShardedCache {
132 fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>> {
133 let result = self.snapshots.get(&node).map(|r| Arc::clone(&*r));
136
137 if result.is_some() {
138 self.stats.record_hit();
139 trace!(node = %node, "cache hit");
140 } else {
141 self.stats.record_miss();
142 trace!(node = %node, "cache miss");
143 }
144
145 result
146 }
147
148 fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot) {
149 let snapshot = Arc::new(snapshot);
150
151 self.snapshots.insert(node, Arc::clone(&snapshot));
153 self.stats.record_set();
154
155 debug!(
156 node = %node,
157 version = %snapshot.version(),
158 resources = snapshot.total_resources(),
159 "set snapshot"
160 );
161
162 self.watches.notify(node, snapshot);
164 }
165
166 fn clear_snapshot(&self, node: NodeHash) {
167 if self.snapshots.remove(&node).is_some() {
168 self.stats.record_clear();
169 debug!(node = %node, "cleared snapshot");
170 }
171 }
172
173 fn snapshot_count(&self) -> usize {
174 self.snapshots.len()
175 }
176}
177
178#[derive(Debug, Default)]
180#[allow(dead_code)] pub struct CacheBuilder {
182 capacity: Option<usize>,
183 watch_buffer_size: Option<usize>,
184}
185
186#[allow(dead_code)] impl CacheBuilder {
188 pub fn new() -> Self {
190 Self::default()
191 }
192
193 pub fn capacity(mut self, capacity: usize) -> Self {
195 self.capacity = Some(capacity);
196 self
197 }
198
199 pub fn watch_buffer_size(mut self, size: usize) -> Self {
201 self.watch_buffer_size = Some(size);
202 self
203 }
204
205 pub fn build(self) -> ShardedCache {
207 let capacity = self.capacity.unwrap_or(64);
208 let watch_buffer = self.watch_buffer_size.unwrap_or(16);
209
210 ShardedCache {
211 snapshots: DashMap::with_capacity(capacity),
212 watches: WatchManager::with_buffer_size(watch_buffer),
213 stats: CacheStats::new(),
214 }
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use std::sync::atomic::{AtomicUsize, Ordering};
222 use std::thread;
223 use std::time::Duration;
224
225 #[test]
226 fn cache_basic_operations() {
227 let cache = ShardedCache::new();
228 let node = NodeHash::from_id("test-node");
229
230 assert!(cache.get_snapshot(node).is_none());
232 assert_eq!(cache.snapshot_count(), 0);
233
234 let snapshot = Snapshot::builder().version("v1").build();
236 cache.set_snapshot(node, snapshot);
237
238 assert!(cache.has_snapshot(node));
240 assert_eq!(cache.snapshot_count(), 1);
241
242 let retrieved = cache.get_snapshot(node).unwrap();
243 assert_eq!(retrieved.version(), "v1");
244
245 cache.clear_snapshot(node);
247 assert!(!cache.has_snapshot(node));
248 assert_eq!(cache.snapshot_count(), 0);
249 }
250
251 #[test]
252 fn cache_stats_tracking() {
253 let cache = ShardedCache::new();
254 let node = NodeHash::from_id("test-node");
255
256 cache.get_snapshot(node);
258 assert_eq!(cache.stats().snapshot_misses(), 1);
259
260 cache.set_snapshot(node, Snapshot::builder().version("v1").build());
262 assert_eq!(cache.stats().snapshots_set(), 1);
263
264 cache.get_snapshot(node);
266 assert_eq!(cache.stats().snapshot_hits(), 1);
267 }
268
269 #[tokio::test]
270 async fn cache_watch_notification() {
271 let cache = ShardedCache::new();
272 let node = NodeHash::from_id("test-node");
273
274 let mut watch = cache.create_watch(node);
275
276 cache.set_snapshot(node, Snapshot::builder().version("v1").build());
278
279 let snapshot = watch.recv().await.unwrap();
281 assert_eq!(snapshot.version(), "v1");
282 }
283
284 #[test]
285 fn cache_builder() {
286 let cache = CacheBuilder::new()
287 .capacity(128)
288 .watch_buffer_size(32)
289 .build();
290
291 assert_eq!(cache.snapshot_count(), 0);
292 }
293
294 #[test]
297 fn cache_concurrent_reads() {
298 let cache = Arc::new(ShardedCache::new());
299 let node = NodeHash::from_id("test-node");
300
301 cache.set_snapshot(node, Snapshot::builder().version("v1").build());
303
304 let read_count = Arc::new(AtomicUsize::new(0));
305 let mut handles = vec![];
306
307 for _ in 0..10 {
309 let cache = Arc::clone(&cache);
310 let count = Arc::clone(&read_count);
311 handles.push(thread::spawn(move || {
312 for _ in 0..100 {
313 if cache.get_snapshot(node).is_some() {
314 count.fetch_add(1, Ordering::Relaxed);
315 }
316 }
317 }));
318 }
319
320 for handle in handles {
321 handle.join().expect("Thread panicked");
322 }
323
324 assert_eq!(read_count.load(Ordering::Relaxed), 1000);
326 }
327
328 #[test]
329 fn cache_concurrent_writes() {
330 let cache = Arc::new(ShardedCache::new());
331 let mut handles = vec![];
332
333 for i in 0..10 {
335 let cache = Arc::clone(&cache);
336 handles.push(thread::spawn(move || {
337 for j in 0..100 {
338 let node = NodeHash::from_id(&format!("node-{}-{}", i, j));
339 cache
340 .set_snapshot(node, Snapshot::builder().version(format!("v{}", j)).build());
341 }
342 }));
343 }
344
345 for handle in handles {
346 handle.join().expect("Thread panicked");
347 }
348
349 assert_eq!(cache.snapshot_count(), 1000);
351 }
352
353 #[test]
354 fn cache_concurrent_read_write() {
355 let cache = Arc::new(ShardedCache::new());
356 let node = NodeHash::from_id("contended-node");
357
358 cache.set_snapshot(node, Snapshot::builder().version("v0").build());
360
361 let reads = Arc::new(AtomicUsize::new(0));
362 let writes = Arc::new(AtomicUsize::new(0));
363 let mut handles = vec![];
364
365 {
367 let cache = Arc::clone(&cache);
368 let writes = Arc::clone(&writes);
369 handles.push(thread::spawn(move || {
370 for i in 1..=50 {
371 cache
372 .set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
373 writes.fetch_add(1, Ordering::Relaxed);
374 thread::sleep(Duration::from_micros(100));
375 }
376 }));
377 }
378
379 for _ in 0..5 {
381 let cache = Arc::clone(&cache);
382 let reads = Arc::clone(&reads);
383 handles.push(thread::spawn(move || {
384 for _ in 0..100 {
385 if cache.get_snapshot(node).is_some() {
386 reads.fetch_add(1, Ordering::Relaxed);
387 }
388 thread::sleep(Duration::from_micros(50));
389 }
390 }));
391 }
392
393 for handle in handles {
394 handle.join().expect("Thread panicked");
395 }
396
397 assert_eq!(writes.load(Ordering::Relaxed), 50);
398 assert_eq!(reads.load(Ordering::Relaxed), 500);
400 }
401
402 #[test]
405 fn cache_many_nodes() {
406 let cache = ShardedCache::with_capacity(10000);
407
408 for i in 0..10000 {
410 let node = NodeHash::from_id(&format!("node-{}", i));
411 cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
412 }
413
414 assert_eq!(cache.snapshot_count(), 10000);
415
416 for i in [0, 999, 5000, 9999] {
418 let node = NodeHash::from_id(&format!("node-{}", i));
419 let snap = cache.get_snapshot(node).unwrap();
420 assert_eq!(snap.version(), format!("v{}", i));
421 }
422 }
423
424 #[test]
425 fn cache_snapshot_update() {
426 let cache = ShardedCache::new();
427 let node = NodeHash::from_id("test-node");
428
429 cache.set_snapshot(node, Snapshot::builder().version("v1").build());
431 assert_eq!(cache.get_snapshot(node).unwrap().version(), "v1");
432
433 cache.set_snapshot(node, Snapshot::builder().version("v2").build());
435 assert_eq!(cache.get_snapshot(node).unwrap().version(), "v2");
436
437 assert_eq!(cache.stats().snapshots_set(), 2);
439 }
440
441 #[tokio::test]
444 async fn cache_multiple_watches_same_node() {
445 let cache = ShardedCache::new();
446 let node = NodeHash::from_id("test-node");
447
448 let mut watch1 = cache.create_watch(node);
449 let mut watch2 = cache.create_watch(node);
450
451 cache.set_snapshot(node, Snapshot::builder().version("v1").build());
453
454 let snap1 = watch1.recv().await.unwrap();
456 let snap2 = watch2.recv().await.unwrap();
457 assert_eq!(snap1.version(), "v1");
458 assert_eq!(snap2.version(), "v1");
459 }
460
461 #[tokio::test]
462 async fn cache_watch_receives_updates() {
463 let cache = ShardedCache::new();
464 let node = NodeHash::from_id("test-node");
465
466 let mut watch = cache.create_watch(node);
467
468 for i in 1..=3 {
470 cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
471 }
472
473 let snap1 = watch.recv().await.unwrap();
475 assert_eq!(snap1.version(), "v1");
476
477 let snap2 = watch.recv().await.unwrap();
478 assert_eq!(snap2.version(), "v2");
479
480 let snap3 = watch.recv().await.unwrap();
481 assert_eq!(snap3.version(), "v3");
482 }
483
484 #[test]
487 fn cache_clear_nonexistent_node() {
488 let cache = ShardedCache::new();
489 let node = NodeHash::from_id("nonexistent");
490
491 cache.clear_snapshot(node);
493 assert_eq!(cache.snapshot_count(), 0);
494 }
495
496 #[test]
497 fn cache_wildcard_node() {
498 let cache = ShardedCache::new();
499 let wildcard = NodeHash::wildcard();
500
501 cache.set_snapshot(wildcard, Snapshot::builder().version("v1").build());
502 assert!(cache.has_snapshot(wildcard));
503
504 let snap = cache.get_snapshot(wildcard).unwrap();
505 assert_eq!(snap.version(), "v1");
506 }
507
508 #[test]
509 fn cache_node_hash_collision_unlikely() {
510 let node1 = NodeHash::from_id("node-1");
512 let node2 = NodeHash::from_id("node-2");
513 let node3 = NodeHash::from_id("1-node");
514
515 assert_ne!(node1, node2);
517 assert_ne!(node2, node3);
518 assert_ne!(node1, node3);
519 }
520}