Skip to main content

oxigdal_cluster/
cache_coherency.rs

1//! Distributed cache with coherency protocol.
2//!
3//! This module implements a distributed cache system with cache coherency,
4//! distributed LRU eviction, cache warming, and compression.
5
6use crate::error::{ClusterError, Result};
7use crate::worker_pool::WorkerId;
8use dashmap::DashMap;
9use lru::LruCache;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::num::NonZeroUsize;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::{Duration, Instant};
17
18/// Distributed cache manager.
19#[derive(Clone)]
20pub struct DistributedCache {
21    inner: Arc<DistributedCacheInner>,
22}
23
24struct DistributedCacheInner {
25    /// Local cache (per node)
26    local_cache: Arc<RwLock<LruCache<CacheKey, CacheEntry>>>,
27
28    /// Cache directory (key -> locations)
29    cache_directory: DashMap<CacheKey, HashSet<WorkerId>>,
30
31    /// Invalidation queue
32    invalidations: DashMap<CacheKey, InvalidationRecord>,
33
34    /// Configuration
35    config: CacheConfig,
36
37    /// Statistics
38    stats: Arc<CacheStatistics>,
39}
40
41/// Cache configuration.
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CacheConfig {
44    /// Maximum local cache size (number of entries)
45    pub max_local_entries: usize,
46
47    /// Maximum entry size (bytes)
48    pub max_entry_size: usize,
49
50    /// Enable compression
51    pub enable_compression: bool,
52
53    /// Compression threshold (bytes)
54    pub compression_threshold: usize,
55
56    /// Cache entry TTL
57    pub entry_ttl: Duration,
58
59    /// Coherency protocol
60    pub coherency_protocol: CoherencyProtocol,
61
62    /// Enable cache warming
63    pub enable_warming: bool,
64
65    /// Warming prefetch size
66    pub warming_prefetch_size: usize,
67}
68
69impl Default for CacheConfig {
70    fn default() -> Self {
71        Self {
72            max_local_entries: 10000,
73            max_entry_size: 100 * 1024 * 1024, // 100 MB
74            enable_compression: true,
75            compression_threshold: 1024, // 1 KB
76            entry_ttl: Duration::from_secs(3600),
77            coherency_protocol: CoherencyProtocol::Invalidation,
78            enable_warming: true,
79            warming_prefetch_size: 100,
80        }
81    }
82}
83
84/// Cache coherency protocol.
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum CoherencyProtocol {
87    /// Invalidation-based (invalidate copies on write)
88    Invalidation,
89
90    /// Update-based (update all copies on write)
91    Update,
92}
93
94/// Cache key.
95#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct CacheKey {
97    /// Namespace
98    pub namespace: String,
99
100    /// Key
101    pub key: String,
102}
103
104impl CacheKey {
105    /// Create a new cache key.
106    pub fn new(namespace: String, key: String) -> Self {
107        Self { namespace, key }
108    }
109}
110
111/// Cache entry.
112#[derive(Debug, Clone)]
113pub struct CacheEntry {
114    /// Entry data
115    pub data: Vec<u8>,
116
117    /// Compressed flag
118    pub compressed: bool,
119
120    /// Version
121    pub version: u64,
122
123    /// Creation time
124    pub created_at: Instant,
125
126    /// Last accessed time
127    pub last_accessed: Instant,
128
129    /// Access count
130    pub access_count: u64,
131
132    /// Size in bytes (uncompressed)
133    pub size_bytes: usize,
134}
135
136/// Invalidation record.
137#[derive(Debug, Clone)]
138pub struct InvalidationRecord {
139    /// Cache key
140    pub key: CacheKey,
141
142    /// Invalidation version
143    pub version: u64,
144
145    /// Timestamp
146    pub timestamp: Instant,
147
148    /// Invalidated workers
149    pub workers: HashSet<WorkerId>,
150}
151
152/// Cache statistics.
153#[derive(Debug, Default)]
154struct CacheStatistics {
155    /// Cache hits
156    hits: AtomicU64,
157
158    /// Cache misses
159    misses: AtomicU64,
160
161    /// Evictions
162    evictions: AtomicU64,
163
164    /// Invalidations
165    invalidations: AtomicU64,
166
167    /// Compressions
168    compressions: AtomicU64,
169
170    /// Decompressions
171    decompressions: AtomicU64,
172
173    /// Bytes stored (compressed)
174    bytes_stored: AtomicU64,
175
176    /// Bytes saved by compression
177    bytes_saved: AtomicU64,
178}
179
180impl DistributedCache {
181    /// Default cache capacity when configured value is zero.
182    const DEFAULT_CAPACITY: usize = 1000;
183
184    /// Create a new distributed cache.
185    pub fn new(config: CacheConfig) -> Self {
186        // NonZeroUsize::new returns None for 0, so we use a default capacity in that case
187        // Using MIN (which is 1) as the fallback ensures we always have a valid NonZeroUsize
188        let capacity = NonZeroUsize::new(config.max_local_entries)
189            .unwrap_or(NonZeroUsize::new(Self::DEFAULT_CAPACITY).unwrap_or(NonZeroUsize::MIN));
190
191        Self {
192            inner: Arc::new(DistributedCacheInner {
193                local_cache: Arc::new(RwLock::new(LruCache::new(capacity))),
194                cache_directory: DashMap::new(),
195                invalidations: DashMap::new(),
196                config,
197                stats: Arc::new(CacheStatistics::default()),
198            }),
199        }
200    }
201
202    /// Create with default configuration.
203    pub fn with_defaults() -> Self {
204        Self::new(CacheConfig::default())
205    }
206
207    /// Put entry in cache.
208    pub fn put(&self, key: CacheKey, data: Vec<u8>, worker_id: WorkerId) -> Result<()> {
209        if data.len() > self.inner.config.max_entry_size {
210            return Err(ClusterError::CacheError(
211                "Entry size exceeds maximum".to_string(),
212            ));
213        }
214
215        let original_size = data.len();
216
217        // Compress if enabled and above threshold
218        let (data, compressed) = if self.inner.config.enable_compression
219            && data.len() > self.inner.config.compression_threshold
220        {
221            match self.compress_data(&data) {
222                Ok(compressed_data) => {
223                    let saved = original_size.saturating_sub(compressed_data.len());
224                    self.inner
225                        .stats
226                        .bytes_saved
227                        .fetch_add(saved as u64, Ordering::Relaxed);
228                    self.inner
229                        .stats
230                        .compressions
231                        .fetch_add(1, Ordering::Relaxed);
232                    (compressed_data, true)
233                }
234                Err(_) => (data, false),
235            }
236        } else {
237            (data, false)
238        };
239
240        let entry = CacheEntry {
241            data: data.clone(),
242            compressed,
243            version: 1,
244            created_at: Instant::now(),
245            last_accessed: Instant::now(),
246            access_count: 0,
247            size_bytes: original_size,
248        };
249
250        // Store in local cache
251        let mut cache = self.inner.local_cache.write();
252        if let Some((evicted_key, _)) = cache.push(key.clone(), entry) {
253            self.inner.stats.evictions.fetch_add(1, Ordering::Relaxed);
254
255            // Remove from directory
256            self.inner.cache_directory.remove(&evicted_key);
257        }
258        drop(cache);
259
260        // Update directory
261        self.inner
262            .cache_directory
263            .entry(key)
264            .or_default()
265            .insert(worker_id);
266
267        self.inner
268            .stats
269            .bytes_stored
270            .fetch_add(data.len() as u64, Ordering::Relaxed);
271
272        Ok(())
273    }
274
275    /// Get entry from cache.
276    pub fn get(&self, key: &CacheKey) -> Result<Option<Vec<u8>>> {
277        let mut cache = self.inner.local_cache.write();
278
279        if let Some(entry) = cache.get_mut(key) {
280            entry.last_accessed = Instant::now();
281            entry.access_count += 1;
282
283            self.inner.stats.hits.fetch_add(1, Ordering::Relaxed);
284
285            // Decompress if needed
286            let data = if entry.compressed {
287                self.inner
288                    .stats
289                    .decompressions
290                    .fetch_add(1, Ordering::Relaxed);
291                self.decompress_data(&entry.data)?
292            } else {
293                entry.data.clone()
294            };
295
296            Ok(Some(data))
297        } else {
298            self.inner.stats.misses.fetch_add(1, Ordering::Relaxed);
299            Ok(None)
300        }
301    }
302
303    /// Remove entry from cache.
304    pub fn remove(&self, key: &CacheKey, worker_id: WorkerId) -> Result<()> {
305        // Remove from local cache
306        let mut cache = self.inner.local_cache.write();
307        cache.pop(key);
308        drop(cache);
309
310        // Update directory
311        if let Some(mut locations) = self.inner.cache_directory.get_mut(key) {
312            locations.remove(&worker_id);
313            if locations.is_empty() {
314                drop(locations);
315                self.inner.cache_directory.remove(key);
316            }
317        }
318
319        Ok(())
320    }
321
322    /// Invalidate cache entry (coherency protocol).
323    pub fn invalidate(&self, key: CacheKey, version: u64) -> Result<Vec<WorkerId>> {
324        // Get all workers with this entry
325        let workers = self
326            .inner
327            .cache_directory
328            .get(&key)
329            .map(|locs| locs.iter().copied().collect::<Vec<_>>())
330            .unwrap_or_default();
331
332        // Record invalidation
333        let invalidation = InvalidationRecord {
334            key: key.clone(),
335            version,
336            timestamp: Instant::now(),
337            workers: workers.iter().copied().collect(),
338        };
339
340        self.inner.invalidations.insert(key.clone(), invalidation);
341
342        // Remove from local cache
343        let mut cache = self.inner.local_cache.write();
344        cache.pop(&key);
345        drop(cache);
346
347        // Clear directory
348        self.inner.cache_directory.remove(&key);
349
350        self.inner
351            .stats
352            .invalidations
353            .fetch_add(1, Ordering::Relaxed);
354
355        Ok(workers)
356    }
357
358    /// Check if entry exists in cache.
359    pub fn contains(&self, key: &CacheKey) -> bool {
360        self.inner.local_cache.write().contains(key)
361    }
362
363    /// Get cache entry locations.
364    pub fn get_locations(&self, key: &CacheKey) -> Vec<WorkerId> {
365        self.inner
366            .cache_directory
367            .get(key)
368            .map(|locs| locs.iter().copied().collect())
369            .unwrap_or_default()
370    }
371
372    /// Warm cache with entries.
373    pub fn warm_cache(&self, keys: Vec<CacheKey>, worker_id: WorkerId) -> Result<usize> {
374        if !self.inner.config.enable_warming {
375            return Ok(0);
376        }
377
378        let mut warmed = 0;
379
380        for key in keys
381            .into_iter()
382            .take(self.inner.config.warming_prefetch_size)
383        {
384            // Check if already in cache
385            if self.contains(&key) {
386                continue;
387            }
388
389            // Mark as available on this worker (would need to fetch in real impl)
390            self.inner
391                .cache_directory
392                .entry(key)
393                .or_default()
394                .insert(worker_id);
395
396            warmed += 1;
397        }
398
399        Ok(warmed)
400    }
401
402    /// Compress data using zstd.
403    fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
404        oxiarc_zstd::encode_all(data, 3)
405            .map_err(|e| ClusterError::CacheError(format!("Compression error: {}", e)))
406    }
407
408    /// Decompress data using zstd.
409    fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
410        oxiarc_zstd::decode_all(data)
411            .map_err(|e| ClusterError::CacheError(format!("Decompression error: {}", e)))
412    }
413
414    /// Evict expired entries.
415    pub fn evict_expired(&self) -> usize {
416        let mut cache = self.inner.local_cache.write();
417        let now = Instant::now();
418        let ttl = self.inner.config.entry_ttl;
419
420        let expired_keys: Vec<_> = cache
421            .iter()
422            .filter(|(_, entry)| now.duration_since(entry.created_at) > ttl)
423            .map(|(key, _)| key.clone())
424            .collect();
425
426        let count = expired_keys.len();
427
428        for key in expired_keys {
429            cache.pop(&key);
430            self.inner.cache_directory.remove(&key);
431        }
432
433        self.inner
434            .stats
435            .evictions
436            .fetch_add(count as u64, Ordering::Relaxed);
437
438        count
439    }
440
441    /// Get cache statistics.
442    pub fn get_statistics(&self) -> CacheStats {
443        let hits = self.inner.stats.hits.load(Ordering::Relaxed);
444        let misses = self.inner.stats.misses.load(Ordering::Relaxed);
445
446        let total_requests = hits + misses;
447        let hit_rate = if total_requests > 0 {
448            hits as f64 / total_requests as f64
449        } else {
450            0.0
451        };
452
453        let bytes_stored = self.inner.stats.bytes_stored.load(Ordering::Relaxed);
454        let bytes_saved = self.inner.stats.bytes_saved.load(Ordering::Relaxed);
455
456        let compression_ratio = if bytes_stored > 0 {
457            1.0 - (bytes_saved as f64 / bytes_stored as f64)
458        } else {
459            1.0
460        };
461
462        CacheStats {
463            hits,
464            misses,
465            hit_rate,
466            evictions: self.inner.stats.evictions.load(Ordering::Relaxed),
467            invalidations: self.inner.stats.invalidations.load(Ordering::Relaxed),
468            compressions: self.inner.stats.compressions.load(Ordering::Relaxed),
469            decompressions: self.inner.stats.decompressions.load(Ordering::Relaxed),
470            bytes_stored,
471            bytes_saved,
472            compression_ratio,
473            total_entries: self.inner.local_cache.read().len(),
474            directory_entries: self.inner.cache_directory.len(),
475        }
476    }
477
478    /// Clear cache.
479    pub fn clear(&self) {
480        self.inner.local_cache.write().clear();
481        self.inner.cache_directory.clear();
482        self.inner.invalidations.clear();
483    }
484}
485
486/// Cache statistics snapshot.
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct CacheStats {
489    /// Cache hits
490    pub hits: u64,
491
492    /// Cache misses
493    pub misses: u64,
494
495    /// Hit rate (0.0-1.0)
496    pub hit_rate: f64,
497
498    /// Evictions
499    pub evictions: u64,
500
501    /// Invalidations
502    pub invalidations: u64,
503
504    /// Compressions performed
505    pub compressions: u64,
506
507    /// Decompressions performed
508    pub decompressions: u64,
509
510    /// Bytes stored (compressed)
511    pub bytes_stored: u64,
512
513    /// Bytes saved by compression
514    pub bytes_saved: u64,
515
516    /// Compression ratio
517    pub compression_ratio: f64,
518
519    /// Total entries in local cache
520    pub total_entries: usize,
521
522    /// Total entries in directory
523    pub directory_entries: usize,
524}
525
526#[cfg(test)]
527#[allow(clippy::expect_used, clippy::unwrap_used)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_cache_creation() {
533        let cache = DistributedCache::with_defaults();
534        let stats = cache.get_statistics();
535        assert_eq!(stats.hits, 0);
536    }
537
538    #[test]
539    fn test_cache_put_get() {
540        let cache = DistributedCache::with_defaults();
541        let worker_id = WorkerId::new();
542        let key = CacheKey::new("test".to_string(), "key1".to_string());
543        let data = vec![1, 2, 3, 4, 5];
544
545        cache.put(key.clone(), data.clone(), worker_id).ok();
546
547        let result = cache.get(&key);
548        assert!(result.is_ok());
549        if let Ok(Some(retrieved)) = result {
550            assert_eq!(retrieved, data);
551        }
552    }
553
554    #[test]
555    fn test_cache_invalidation() {
556        let cache = DistributedCache::with_defaults();
557        let worker_id = WorkerId::new();
558        let key = CacheKey::new("test".to_string(), "key1".to_string());
559        let data = vec![1, 2, 3, 4, 5];
560
561        cache.put(key.clone(), data, worker_id).ok();
562        assert!(cache.contains(&key));
563
564        cache.invalidate(key.clone(), 2).ok();
565        assert!(!cache.contains(&key));
566    }
567
568    #[test]
569    fn test_cache_compression() {
570        let config = CacheConfig {
571            compression_threshold: 10,
572            ..Default::default()
573        };
574
575        let cache = DistributedCache::new(config);
576        let worker_id = WorkerId::new();
577        let key = CacheKey::new("test".to_string(), "key1".to_string());
578        let data = vec![1; 1000]; // Repeated data compresses well
579
580        cache.put(key.clone(), data.clone(), worker_id).ok();
581
582        let stats = cache.get_statistics();
583        assert!(stats.compressions > 0);
584
585        let result = cache.get(&key);
586        assert!(result.is_ok());
587        if let Ok(Some(retrieved)) = result {
588            assert_eq!(retrieved, data);
589        }
590    }
591
592    #[test]
593    fn test_cache_hit_rate() {
594        let cache = DistributedCache::with_defaults();
595        let worker_id = WorkerId::new();
596
597        let key1 = CacheKey::new("test".to_string(), "key1".to_string());
598        cache.put(key1.clone(), vec![1, 2, 3], worker_id).ok();
599
600        // Hit
601        cache.get(&key1).ok();
602
603        // Miss
604        let key2 = CacheKey::new("test".to_string(), "key2".to_string());
605        cache.get(&key2).ok();
606
607        let stats = cache.get_statistics();
608        assert_eq!(stats.hits, 1);
609        assert_eq!(stats.misses, 1);
610        assert!((stats.hit_rate - 0.5).abs() < 0.01);
611    }
612}