use std::sync::Arc;
use dashmap::DashMap;
use tracing::{debug, trace};
use xds_core::NodeHash;
use crate::snapshot::Snapshot;
use crate::stats::CacheStats;
use crate::watch::WatchManager;
pub trait Cache: Send + Sync {
fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>>;
fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot);
fn clear_snapshot(&self, node: NodeHash);
fn snapshot_count(&self) -> usize;
}
#[derive(Debug)]
pub struct ShardedCache {
snapshots: DashMap<NodeHash, Arc<Snapshot>>,
watches: WatchManager,
stats: CacheStats,
}
impl Default for ShardedCache {
fn default() -> Self {
Self::new()
}
}
impl ShardedCache {
pub fn new() -> Self {
Self::with_capacity(64)
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
snapshots: DashMap::with_capacity(capacity),
watches: WatchManager::new(),
stats: CacheStats::new(),
}
}
#[inline]
pub fn watches(&self) -> &WatchManager {
&self.watches
}
#[inline]
pub fn stats(&self) -> &CacheStats {
&self.stats
}
#[inline]
pub fn create_watch(&self, node: NodeHash) -> crate::watch::Watch {
self.watches.create_watch(node)
}
#[inline]
pub fn cancel_watch(&self, watch_id: crate::watch::WatchId) {
self.watches.cancel_watch(watch_id)
}
pub fn nodes(&self) -> Vec<NodeHash> {
self.snapshots.iter().map(|r| *r.key()).collect()
}
pub fn has_snapshot(&self, node: NodeHash) -> bool {
self.snapshots.contains_key(&node)
}
pub fn iter(&self) -> impl Iterator<Item = (NodeHash, Arc<Snapshot>)> + '_ {
self.snapshots
.iter()
.map(|r| (*r.key(), Arc::clone(r.value())))
}
}
impl Cache for ShardedCache {
fn get_snapshot(&self, node: NodeHash) -> Option<Arc<Snapshot>> {
let result = self.snapshots.get(&node).map(|r| Arc::clone(&*r));
if result.is_some() {
self.stats.record_hit();
trace!(node = %node, "cache hit");
} else {
self.stats.record_miss();
trace!(node = %node, "cache miss");
}
result
}
fn set_snapshot(&self, node: NodeHash, snapshot: Snapshot) {
let snapshot = Arc::new(snapshot);
self.snapshots.insert(node, Arc::clone(&snapshot));
self.stats.record_set();
debug!(
node = %node,
version = %snapshot.version(),
resources = snapshot.total_resources(),
"set snapshot"
);
self.watches.notify(node, snapshot);
}
fn clear_snapshot(&self, node: NodeHash) {
if self.snapshots.remove(&node).is_some() {
self.stats.record_clear();
debug!(node = %node, "cleared snapshot");
}
}
fn snapshot_count(&self) -> usize {
self.snapshots.len()
}
}
#[derive(Debug, Default)]
#[allow(dead_code)] pub struct CacheBuilder {
capacity: Option<usize>,
watch_buffer_size: Option<usize>,
}
#[allow(dead_code)] impl CacheBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn capacity(mut self, capacity: usize) -> Self {
self.capacity = Some(capacity);
self
}
pub fn watch_buffer_size(mut self, size: usize) -> Self {
self.watch_buffer_size = Some(size);
self
}
pub fn build(self) -> ShardedCache {
let capacity = self.capacity.unwrap_or(64);
let watch_buffer = self.watch_buffer_size.unwrap_or(16);
ShardedCache {
snapshots: DashMap::with_capacity(capacity),
watches: WatchManager::with_buffer_size(watch_buffer),
stats: CacheStats::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
use std::time::Duration;
#[test]
fn cache_basic_operations() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
assert!(cache.get_snapshot(node).is_none());
assert_eq!(cache.snapshot_count(), 0);
let snapshot = Snapshot::builder().version("v1").build();
cache.set_snapshot(node, snapshot);
assert!(cache.has_snapshot(node));
assert_eq!(cache.snapshot_count(), 1);
let retrieved = cache.get_snapshot(node).unwrap();
assert_eq!(retrieved.version(), "v1");
cache.clear_snapshot(node);
assert!(!cache.has_snapshot(node));
assert_eq!(cache.snapshot_count(), 0);
}
#[test]
fn cache_stats_tracking() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
cache.get_snapshot(node);
assert_eq!(cache.stats().snapshot_misses(), 1);
cache.set_snapshot(node, Snapshot::builder().version("v1").build());
assert_eq!(cache.stats().snapshots_set(), 1);
cache.get_snapshot(node);
assert_eq!(cache.stats().snapshot_hits(), 1);
}
#[tokio::test]
async fn cache_watch_notification() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
let mut watch = cache.create_watch(node);
cache.set_snapshot(node, Snapshot::builder().version("v1").build());
let snapshot = watch.recv().await.unwrap();
assert_eq!(snapshot.version(), "v1");
}
#[test]
fn cache_builder() {
let cache = CacheBuilder::new()
.capacity(128)
.watch_buffer_size(32)
.build();
assert_eq!(cache.snapshot_count(), 0);
}
#[test]
fn cache_concurrent_reads() {
let cache = Arc::new(ShardedCache::new());
let node = NodeHash::from_id("test-node");
cache.set_snapshot(node, Snapshot::builder().version("v1").build());
let read_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let cache = Arc::clone(&cache);
let count = Arc::clone(&read_count);
handles.push(thread::spawn(move || {
for _ in 0..100 {
if cache.get_snapshot(node).is_some() {
count.fetch_add(1, Ordering::Relaxed);
}
}
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(read_count.load(Ordering::Relaxed), 1000);
}
#[test]
fn cache_concurrent_writes() {
let cache = Arc::new(ShardedCache::new());
let mut handles = vec![];
for i in 0..10 {
let cache = Arc::clone(&cache);
handles.push(thread::spawn(move || {
for j in 0..100 {
let node = NodeHash::from_id(&format!("node-{}-{}", i, j));
cache
.set_snapshot(node, Snapshot::builder().version(format!("v{}", j)).build());
}
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(cache.snapshot_count(), 1000);
}
#[test]
fn cache_concurrent_read_write() {
let cache = Arc::new(ShardedCache::new());
let node = NodeHash::from_id("contended-node");
cache.set_snapshot(node, Snapshot::builder().version("v0").build());
let reads = Arc::new(AtomicUsize::new(0));
let writes = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
{
let cache = Arc::clone(&cache);
let writes = Arc::clone(&writes);
handles.push(thread::spawn(move || {
for i in 1..=50 {
cache
.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
writes.fetch_add(1, Ordering::Relaxed);
thread::sleep(Duration::from_micros(100));
}
}));
}
for _ in 0..5 {
let cache = Arc::clone(&cache);
let reads = Arc::clone(&reads);
handles.push(thread::spawn(move || {
for _ in 0..100 {
if cache.get_snapshot(node).is_some() {
reads.fetch_add(1, Ordering::Relaxed);
}
thread::sleep(Duration::from_micros(50));
}
}));
}
for handle in handles {
handle.join().expect("Thread panicked");
}
assert_eq!(writes.load(Ordering::Relaxed), 50);
assert_eq!(reads.load(Ordering::Relaxed), 500);
}
#[test]
fn cache_many_nodes() {
let cache = ShardedCache::with_capacity(10000);
for i in 0..10000 {
let node = NodeHash::from_id(&format!("node-{}", i));
cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
}
assert_eq!(cache.snapshot_count(), 10000);
for i in [0, 999, 5000, 9999] {
let node = NodeHash::from_id(&format!("node-{}", i));
let snap = cache.get_snapshot(node).unwrap();
assert_eq!(snap.version(), format!("v{}", i));
}
}
#[test]
fn cache_snapshot_update() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
cache.set_snapshot(node, Snapshot::builder().version("v1").build());
assert_eq!(cache.get_snapshot(node).unwrap().version(), "v1");
cache.set_snapshot(node, Snapshot::builder().version("v2").build());
assert_eq!(cache.get_snapshot(node).unwrap().version(), "v2");
assert_eq!(cache.stats().snapshots_set(), 2);
}
#[tokio::test]
async fn cache_multiple_watches_same_node() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
let mut watch1 = cache.create_watch(node);
let mut watch2 = cache.create_watch(node);
cache.set_snapshot(node, Snapshot::builder().version("v1").build());
let snap1 = watch1.recv().await.unwrap();
let snap2 = watch2.recv().await.unwrap();
assert_eq!(snap1.version(), "v1");
assert_eq!(snap2.version(), "v1");
}
#[tokio::test]
async fn cache_watch_receives_updates() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("test-node");
let mut watch = cache.create_watch(node);
for i in 1..=3 {
cache.set_snapshot(node, Snapshot::builder().version(format!("v{}", i)).build());
}
let snap1 = watch.recv().await.unwrap();
assert_eq!(snap1.version(), "v1");
let snap2 = watch.recv().await.unwrap();
assert_eq!(snap2.version(), "v2");
let snap3 = watch.recv().await.unwrap();
assert_eq!(snap3.version(), "v3");
}
#[test]
fn cache_clear_nonexistent_node() {
let cache = ShardedCache::new();
let node = NodeHash::from_id("nonexistent");
cache.clear_snapshot(node);
assert_eq!(cache.snapshot_count(), 0);
}
#[test]
fn cache_wildcard_node() {
let cache = ShardedCache::new();
let wildcard = NodeHash::wildcard();
cache.set_snapshot(wildcard, Snapshot::builder().version("v1").build());
assert!(cache.has_snapshot(wildcard));
let snap = cache.get_snapshot(wildcard).unwrap();
assert_eq!(snap.version(), "v1");
}
#[test]
fn cache_node_hash_collision_unlikely() {
let node1 = NodeHash::from_id("node-1");
let node2 = NodeHash::from_id("node-2");
let node3 = NodeHash::from_id("1-node");
assert_ne!(node1, node2);
assert_ne!(node2, node3);
assert_ne!(node1, node3);
}
}