1use crate::error::{ClusterError, Result};
7use crate::worker_pool::WorkerId;
8use dashmap::DashMap;
9use lru::LruCache;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::HashSet;
13use std::num::NonZeroUsize;
14use std::sync::Arc;
15use std::sync::atomic::{AtomicU64, Ordering};
16use std::time::{Duration, Instant};
17
18#[derive(Clone)]
20pub struct DistributedCache {
21 inner: Arc<DistributedCacheInner>,
22}
23
24struct DistributedCacheInner {
25 local_cache: Arc<RwLock<LruCache<CacheKey, CacheEntry>>>,
27
28 cache_directory: DashMap<CacheKey, HashSet<WorkerId>>,
30
31 invalidations: DashMap<CacheKey, InvalidationRecord>,
33
34 config: CacheConfig,
36
37 stats: Arc<CacheStatistics>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct CacheConfig {
44 pub max_local_entries: usize,
46
47 pub max_entry_size: usize,
49
50 pub enable_compression: bool,
52
53 pub compression_threshold: usize,
55
56 pub entry_ttl: Duration,
58
59 pub coherency_protocol: CoherencyProtocol,
61
62 pub enable_warming: bool,
64
65 pub warming_prefetch_size: usize,
67}
68
69impl Default for CacheConfig {
70 fn default() -> Self {
71 Self {
72 max_local_entries: 10000,
73 max_entry_size: 100 * 1024 * 1024, enable_compression: true,
75 compression_threshold: 1024, entry_ttl: Duration::from_secs(3600),
77 coherency_protocol: CoherencyProtocol::Invalidation,
78 enable_warming: true,
79 warming_prefetch_size: 100,
80 }
81 }
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
86pub enum CoherencyProtocol {
87 Invalidation,
89
90 Update,
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
96pub struct CacheKey {
97 pub namespace: String,
99
100 pub key: String,
102}
103
104impl CacheKey {
105 pub fn new(namespace: String, key: String) -> Self {
107 Self { namespace, key }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct CacheEntry {
114 pub data: Vec<u8>,
116
117 pub compressed: bool,
119
120 pub version: u64,
122
123 pub created_at: Instant,
125
126 pub last_accessed: Instant,
128
129 pub access_count: u64,
131
132 pub size_bytes: usize,
134}
135
136#[derive(Debug, Clone)]
138pub struct InvalidationRecord {
139 pub key: CacheKey,
141
142 pub version: u64,
144
145 pub timestamp: Instant,
147
148 pub workers: HashSet<WorkerId>,
150}
151
152#[derive(Debug, Default)]
154struct CacheStatistics {
155 hits: AtomicU64,
157
158 misses: AtomicU64,
160
161 evictions: AtomicU64,
163
164 invalidations: AtomicU64,
166
167 compressions: AtomicU64,
169
170 decompressions: AtomicU64,
172
173 bytes_stored: AtomicU64,
175
176 bytes_saved: AtomicU64,
178}
179
180impl DistributedCache {
181 const DEFAULT_CAPACITY: usize = 1000;
183
184 pub fn new(config: CacheConfig) -> Self {
186 let capacity = NonZeroUsize::new(config.max_local_entries)
189 .unwrap_or(NonZeroUsize::new(Self::DEFAULT_CAPACITY).unwrap_or(NonZeroUsize::MIN));
190
191 Self {
192 inner: Arc::new(DistributedCacheInner {
193 local_cache: Arc::new(RwLock::new(LruCache::new(capacity))),
194 cache_directory: DashMap::new(),
195 invalidations: DashMap::new(),
196 config,
197 stats: Arc::new(CacheStatistics::default()),
198 }),
199 }
200 }
201
202 pub fn with_defaults() -> Self {
204 Self::new(CacheConfig::default())
205 }
206
207 pub fn put(&self, key: CacheKey, data: Vec<u8>, worker_id: WorkerId) -> Result<()> {
209 if data.len() > self.inner.config.max_entry_size {
210 return Err(ClusterError::CacheError(
211 "Entry size exceeds maximum".to_string(),
212 ));
213 }
214
215 let original_size = data.len();
216
217 let (data, compressed) = if self.inner.config.enable_compression
219 && data.len() > self.inner.config.compression_threshold
220 {
221 match self.compress_data(&data) {
222 Ok(compressed_data) => {
223 let saved = original_size.saturating_sub(compressed_data.len());
224 self.inner
225 .stats
226 .bytes_saved
227 .fetch_add(saved as u64, Ordering::Relaxed);
228 self.inner
229 .stats
230 .compressions
231 .fetch_add(1, Ordering::Relaxed);
232 (compressed_data, true)
233 }
234 Err(_) => (data, false),
235 }
236 } else {
237 (data, false)
238 };
239
240 let entry = CacheEntry {
241 data: data.clone(),
242 compressed,
243 version: 1,
244 created_at: Instant::now(),
245 last_accessed: Instant::now(),
246 access_count: 0,
247 size_bytes: original_size,
248 };
249
250 let mut cache = self.inner.local_cache.write();
252 if let Some((evicted_key, _)) = cache.push(key.clone(), entry) {
253 self.inner.stats.evictions.fetch_add(1, Ordering::Relaxed);
254
255 self.inner.cache_directory.remove(&evicted_key);
257 }
258 drop(cache);
259
260 self.inner
262 .cache_directory
263 .entry(key)
264 .or_default()
265 .insert(worker_id);
266
267 self.inner
268 .stats
269 .bytes_stored
270 .fetch_add(data.len() as u64, Ordering::Relaxed);
271
272 Ok(())
273 }
274
275 pub fn get(&self, key: &CacheKey) -> Result<Option<Vec<u8>>> {
277 let mut cache = self.inner.local_cache.write();
278
279 if let Some(entry) = cache.get_mut(key) {
280 entry.last_accessed = Instant::now();
281 entry.access_count += 1;
282
283 self.inner.stats.hits.fetch_add(1, Ordering::Relaxed);
284
285 let data = if entry.compressed {
287 self.inner
288 .stats
289 .decompressions
290 .fetch_add(1, Ordering::Relaxed);
291 self.decompress_data(&entry.data)?
292 } else {
293 entry.data.clone()
294 };
295
296 Ok(Some(data))
297 } else {
298 self.inner.stats.misses.fetch_add(1, Ordering::Relaxed);
299 Ok(None)
300 }
301 }
302
303 pub fn remove(&self, key: &CacheKey, worker_id: WorkerId) -> Result<()> {
305 let mut cache = self.inner.local_cache.write();
307 cache.pop(key);
308 drop(cache);
309
310 if let Some(mut locations) = self.inner.cache_directory.get_mut(key) {
312 locations.remove(&worker_id);
313 if locations.is_empty() {
314 drop(locations);
315 self.inner.cache_directory.remove(key);
316 }
317 }
318
319 Ok(())
320 }
321
322 pub fn invalidate(&self, key: CacheKey, version: u64) -> Result<Vec<WorkerId>> {
324 let workers = self
326 .inner
327 .cache_directory
328 .get(&key)
329 .map(|locs| locs.iter().copied().collect::<Vec<_>>())
330 .unwrap_or_default();
331
332 let invalidation = InvalidationRecord {
334 key: key.clone(),
335 version,
336 timestamp: Instant::now(),
337 workers: workers.iter().copied().collect(),
338 };
339
340 self.inner.invalidations.insert(key.clone(), invalidation);
341
342 let mut cache = self.inner.local_cache.write();
344 cache.pop(&key);
345 drop(cache);
346
347 self.inner.cache_directory.remove(&key);
349
350 self.inner
351 .stats
352 .invalidations
353 .fetch_add(1, Ordering::Relaxed);
354
355 Ok(workers)
356 }
357
358 pub fn contains(&self, key: &CacheKey) -> bool {
360 self.inner.local_cache.write().contains(key)
361 }
362
363 pub fn get_locations(&self, key: &CacheKey) -> Vec<WorkerId> {
365 self.inner
366 .cache_directory
367 .get(key)
368 .map(|locs| locs.iter().copied().collect())
369 .unwrap_or_default()
370 }
371
372 pub fn warm_cache(&self, keys: Vec<CacheKey>, worker_id: WorkerId) -> Result<usize> {
374 if !self.inner.config.enable_warming {
375 return Ok(0);
376 }
377
378 let mut warmed = 0;
379
380 for key in keys
381 .into_iter()
382 .take(self.inner.config.warming_prefetch_size)
383 {
384 if self.contains(&key) {
386 continue;
387 }
388
389 self.inner
391 .cache_directory
392 .entry(key)
393 .or_default()
394 .insert(worker_id);
395
396 warmed += 1;
397 }
398
399 Ok(warmed)
400 }
401
402 fn compress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
404 oxiarc_zstd::encode_all(data, 3)
405 .map_err(|e| ClusterError::CacheError(format!("Compression error: {}", e)))
406 }
407
408 fn decompress_data(&self, data: &[u8]) -> Result<Vec<u8>> {
410 oxiarc_zstd::decode_all(data)
411 .map_err(|e| ClusterError::CacheError(format!("Decompression error: {}", e)))
412 }
413
414 pub fn evict_expired(&self) -> usize {
416 let mut cache = self.inner.local_cache.write();
417 let now = Instant::now();
418 let ttl = self.inner.config.entry_ttl;
419
420 let expired_keys: Vec<_> = cache
421 .iter()
422 .filter(|(_, entry)| now.duration_since(entry.created_at) > ttl)
423 .map(|(key, _)| key.clone())
424 .collect();
425
426 let count = expired_keys.len();
427
428 for key in expired_keys {
429 cache.pop(&key);
430 self.inner.cache_directory.remove(&key);
431 }
432
433 self.inner
434 .stats
435 .evictions
436 .fetch_add(count as u64, Ordering::Relaxed);
437
438 count
439 }
440
441 pub fn get_statistics(&self) -> CacheStats {
443 let hits = self.inner.stats.hits.load(Ordering::Relaxed);
444 let misses = self.inner.stats.misses.load(Ordering::Relaxed);
445
446 let total_requests = hits + misses;
447 let hit_rate = if total_requests > 0 {
448 hits as f64 / total_requests as f64
449 } else {
450 0.0
451 };
452
453 let bytes_stored = self.inner.stats.bytes_stored.load(Ordering::Relaxed);
454 let bytes_saved = self.inner.stats.bytes_saved.load(Ordering::Relaxed);
455
456 let compression_ratio = if bytes_stored > 0 {
457 1.0 - (bytes_saved as f64 / bytes_stored as f64)
458 } else {
459 1.0
460 };
461
462 CacheStats {
463 hits,
464 misses,
465 hit_rate,
466 evictions: self.inner.stats.evictions.load(Ordering::Relaxed),
467 invalidations: self.inner.stats.invalidations.load(Ordering::Relaxed),
468 compressions: self.inner.stats.compressions.load(Ordering::Relaxed),
469 decompressions: self.inner.stats.decompressions.load(Ordering::Relaxed),
470 bytes_stored,
471 bytes_saved,
472 compression_ratio,
473 total_entries: self.inner.local_cache.read().len(),
474 directory_entries: self.inner.cache_directory.len(),
475 }
476 }
477
478 pub fn clear(&self) {
480 self.inner.local_cache.write().clear();
481 self.inner.cache_directory.clear();
482 self.inner.invalidations.clear();
483 }
484}
485
486#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct CacheStats {
489 pub hits: u64,
491
492 pub misses: u64,
494
495 pub hit_rate: f64,
497
498 pub evictions: u64,
500
501 pub invalidations: u64,
503
504 pub compressions: u64,
506
507 pub decompressions: u64,
509
510 pub bytes_stored: u64,
512
513 pub bytes_saved: u64,
515
516 pub compression_ratio: f64,
518
519 pub total_entries: usize,
521
522 pub directory_entries: usize,
524}
525
526#[cfg(test)]
527#[allow(clippy::expect_used, clippy::unwrap_used)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_cache_creation() {
533 let cache = DistributedCache::with_defaults();
534 let stats = cache.get_statistics();
535 assert_eq!(stats.hits, 0);
536 }
537
538 #[test]
539 fn test_cache_put_get() {
540 let cache = DistributedCache::with_defaults();
541 let worker_id = WorkerId::new();
542 let key = CacheKey::new("test".to_string(), "key1".to_string());
543 let data = vec![1, 2, 3, 4, 5];
544
545 cache.put(key.clone(), data.clone(), worker_id).ok();
546
547 let result = cache.get(&key);
548 assert!(result.is_ok());
549 if let Ok(Some(retrieved)) = result {
550 assert_eq!(retrieved, data);
551 }
552 }
553
554 #[test]
555 fn test_cache_invalidation() {
556 let cache = DistributedCache::with_defaults();
557 let worker_id = WorkerId::new();
558 let key = CacheKey::new("test".to_string(), "key1".to_string());
559 let data = vec![1, 2, 3, 4, 5];
560
561 cache.put(key.clone(), data, worker_id).ok();
562 assert!(cache.contains(&key));
563
564 cache.invalidate(key.clone(), 2).ok();
565 assert!(!cache.contains(&key));
566 }
567
568 #[test]
569 fn test_cache_compression() {
570 let config = CacheConfig {
571 compression_threshold: 10,
572 ..Default::default()
573 };
574
575 let cache = DistributedCache::new(config);
576 let worker_id = WorkerId::new();
577 let key = CacheKey::new("test".to_string(), "key1".to_string());
578 let data = vec![1; 1000]; cache.put(key.clone(), data.clone(), worker_id).ok();
581
582 let stats = cache.get_statistics();
583 assert!(stats.compressions > 0);
584
585 let result = cache.get(&key);
586 assert!(result.is_ok());
587 if let Ok(Some(retrieved)) = result {
588 assert_eq!(retrieved, data);
589 }
590 }
591
592 #[test]
593 fn test_cache_hit_rate() {
594 let cache = DistributedCache::with_defaults();
595 let worker_id = WorkerId::new();
596
597 let key1 = CacheKey::new("test".to_string(), "key1".to_string());
598 cache.put(key1.clone(), vec![1, 2, 3], worker_id).ok();
599
600 cache.get(&key1).ok();
602
603 let key2 = CacheKey::new("test".to_string(), "key2".to_string());
605 cache.get(&key2).ok();
606
607 let stats = cache.get_statistics();
608 assert_eq!(stats.hits, 1);
609 assert_eq!(stats.misses, 1);
610 assert!((stats.hit_rate - 0.5).abs() < 0.01);
611 }
612}