1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicU64;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use parking_lot::RwLock;
6
7use crate::erased::{Entry, ErasedKey, ErasedKeyRef};
8use crate::guard::Guard;
9#[cfg(feature = "metrics")]
10use crate::metrics::CacheMetrics;
11use crate::shard::Shard;
12use crate::traits::CacheKey;
13
14pub struct Cache {
67 shards: Vec<RwLock<Shard>>,
69 current_size: AtomicUsize,
71 entry_count: AtomicUsize,
73 shard_count: usize,
75 #[cfg_attr(not(feature = "metrics"), allow(dead_code))]
77 max_size_bytes: usize,
78 #[cfg(feature = "metrics")]
80 hits: AtomicU64,
81 #[cfg(feature = "metrics")]
83 misses: AtomicU64,
84 #[cfg(feature = "metrics")]
86 inserts: AtomicU64,
87 #[cfg(feature = "metrics")]
89 updates: AtomicU64,
90 #[cfg(feature = "metrics")]
92 evictions: AtomicU64,
93 #[cfg(feature = "metrics")]
95 removals: AtomicU64,
96}
97
98const MIN_SHARD_SIZE: usize = 4096;
104
105const DEFAULT_SHARD_COUNT: usize = 64;
110
111fn compute_shard_count(capacity: usize, desired_shards: usize) -> usize {
116 let max_shards = (capacity / MIN_SHARD_SIZE).max(1);
118
119 desired_shards.min(max_shards).next_power_of_two().max(1)
121}
122
123impl Cache {
124 pub fn new(max_size_bytes: usize) -> Self {
135 let shard_count = compute_shard_count(max_size_bytes, DEFAULT_SHARD_COUNT);
136 Self::with_shards_internal(max_size_bytes, shard_count)
137 }
138
139 pub fn with_shards(max_size_bytes: usize, shard_count: usize) -> Self {
148 let shard_count = compute_shard_count(max_size_bytes, shard_count);
149 Self::with_shards_internal(max_size_bytes, shard_count)
150 }
151
152 fn with_shards_internal(max_size_bytes: usize, shard_count: usize) -> Self {
154 let size_per_shard = max_size_bytes / shard_count;
156
157 let shards = (0..shard_count).map(|_| RwLock::new(Shard::new(size_per_shard))).collect();
159
160 Self {
161 shards,
162 current_size: AtomicUsize::new(0),
163 entry_count: AtomicUsize::new(0),
164 shard_count,
165 max_size_bytes,
166 #[cfg(feature = "metrics")]
167 hits: AtomicU64::new(0),
168 #[cfg(feature = "metrics")]
169 misses: AtomicU64::new(0),
170 #[cfg(feature = "metrics")]
171 inserts: AtomicU64::new(0),
172 #[cfg(feature = "metrics")]
173 updates: AtomicU64::new(0),
174 #[cfg(feature = "metrics")]
175 evictions: AtomicU64::new(0),
176 #[cfg(feature = "metrics")]
177 removals: AtomicU64::new(0),
178 }
179 }
180
181 pub fn insert<K: CacheKey>(&self, key: K, value: K::Value) -> Option<K::Value> {
192 let erased_key = ErasedKey::new(&key);
193 let policy = key.policy();
194 let entry = Entry::new(value, policy);
195 let entry_size = entry.size;
196
197 let shard_lock = self.get_shard(erased_key.hash);
199
200 let mut shard = shard_lock.write();
202
203 let (old_entry, (num_evictions, evicted_size)) = shard.insert(erased_key, entry);
205
206 if let Some(ref old) = old_entry {
207 let size_diff = entry_size as isize - old.size as isize;
209 if size_diff > 0 {
210 self.current_size.fetch_add(size_diff as usize, Ordering::Relaxed);
211 } else {
212 self.current_size.fetch_sub((-size_diff) as usize, Ordering::Relaxed);
213 }
214 #[cfg(feature = "metrics")]
216 self.updates.fetch_add(1, Ordering::Relaxed);
217 } else {
218 self.current_size.fetch_add(entry_size, Ordering::Relaxed);
220 self.entry_count.fetch_add(1, Ordering::Relaxed);
221 #[cfg(feature = "metrics")]
223 self.inserts.fetch_add(1, Ordering::Relaxed);
224 }
225
226 if num_evictions > 0 {
228 self.entry_count.fetch_sub(num_evictions, Ordering::Relaxed);
229 self.current_size.fetch_sub(evicted_size, Ordering::Relaxed);
230 #[cfg(feature = "metrics")]
232 self.evictions.fetch_add(num_evictions as u64, Ordering::Relaxed);
233 }
234
235 old_entry.and_then(|e| e.into_value::<K::Value>())
236 }
237
238 pub fn get<K: CacheKey>(&self, key: &K) -> Option<Guard<'_, K::Value>> {
252 let key_ref = ErasedKeyRef::new(key); let shard_lock = self.get_shard(key_ref.hash);
254
255 let shard = shard_lock.read();
257
258 let Some(entry) = shard.get_ref(&key_ref) else {
260 #[cfg(feature = "metrics")]
262 self.misses.fetch_add(1, Ordering::Relaxed);
263 return None;
264 };
265
266 #[cfg(feature = "metrics")]
268 self.hits.fetch_add(1, Ordering::Relaxed);
269
270 let value_ref = entry.value_ref::<K::Value>()?;
272 let value_ptr = value_ref as *const K::Value;
273
274 unsafe { Some(Guard::new(shard, value_ptr)) }
277 }
278
279 pub fn get_clone<K: CacheKey>(&self, key: &K) -> Option<K::Value>
291 where
292 K::Value: Clone,
293 {
294 let key_ref = ErasedKeyRef::new(key); let shard_lock = self.get_shard(key_ref.hash);
296
297 let shard = shard_lock.read();
299
300 let Some(entry) = shard.get_ref(&key_ref) else {
301 #[cfg(feature = "metrics")]
303 self.misses.fetch_add(1, Ordering::Relaxed);
304 return None;
305 };
306
307 #[cfg(feature = "metrics")]
309 self.hits.fetch_add(1, Ordering::Relaxed);
310
311 entry.value_ref::<K::Value>().cloned()
313 }
314
315 pub fn remove<K: CacheKey>(&self, key: &K) -> Option<K::Value> {
327 let erased_key = ErasedKey::new(key);
328 let shard_lock = self.get_shard(erased_key.hash);
329
330 let mut shard = shard_lock.write();
331 let entry = shard.remove(&erased_key)?;
332
333 self.current_size.fetch_sub(entry.size, Ordering::Relaxed);
334 self.entry_count.fetch_sub(1, Ordering::Relaxed);
335
336 #[cfg(feature = "metrics")]
338 self.removals.fetch_add(1, Ordering::Relaxed);
339
340 entry.into_value::<K::Value>()
341 }
342
343 pub fn contains<K: CacheKey>(&self, key: &K) -> bool {
345 let key_ref = ErasedKeyRef::new(key); let shard_lock = self.get_shard(key_ref.hash);
347 let shard = shard_lock.read();
348
349 shard.get_ref(&key_ref).is_some()
351 }
352
353 pub fn size(&self) -> usize {
355 self.current_size.load(Ordering::Relaxed)
356 }
357
358 pub fn len(&self) -> usize {
360 self.entry_count.load(Ordering::Relaxed)
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.len() == 0
366 }
367
368 pub fn clear(&self) {
377 for shard_lock in &self.shards {
378 let mut shard = shard_lock.write();
379 shard.clear();
380 }
381 self.current_size.store(0, Ordering::Relaxed);
382 self.entry_count.store(0, Ordering::Relaxed);
383
384 #[cfg(feature = "metrics")]
386 {
387 self.hits.store(0, Ordering::Relaxed);
388 self.misses.store(0, Ordering::Relaxed);
389 self.inserts.store(0, Ordering::Relaxed);
390 self.updates.store(0, Ordering::Relaxed);
391 self.evictions.store(0, Ordering::Relaxed);
392 self.removals.store(0, Ordering::Relaxed);
393 }
394 }
395
396 #[cfg(feature = "metrics")]
415 pub fn metrics(&self) -> CacheMetrics {
416 CacheMetrics {
417 hits: self.hits.load(Ordering::Relaxed),
418 misses: self.misses.load(Ordering::Relaxed),
419 inserts: self.inserts.load(Ordering::Relaxed),
420 updates: self.updates.load(Ordering::Relaxed),
421 evictions: self.evictions.load(Ordering::Relaxed),
422 removals: self.removals.load(Ordering::Relaxed),
423 current_size_bytes: self.current_size.load(Ordering::Relaxed),
424 capacity_bytes: self.max_size_bytes,
425 entry_count: self.entry_count.load(Ordering::Relaxed),
426 }
427 }
428
429 fn get_shard(&self, hash: u64) -> &RwLock<Shard> {
431 let index = (hash as usize) & (self.shard_count - 1);
432 &self.shards[index]
433 }
434}
435
436unsafe impl Send for Cache {}
438unsafe impl Sync for Cache {}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443 use crate::DeepSizeOf;
444
445 #[derive(Hash, Eq, PartialEq, Clone, Debug)]
446 struct TestKey(u64);
447
448 impl CacheKey for TestKey {
449 type Value = TestValue;
450 }
451
452 #[derive(Clone, Debug, PartialEq, DeepSizeOf)]
453 struct TestValue {
454 data: String,
455 }
456
457 #[test]
458 fn test_compute_shard_count_scales_with_capacity() {
459 assert_eq!(compute_shard_count(1024, 64), 1);
461 assert_eq!(compute_shard_count(4095, 64), 1);
462
463 assert_eq!(compute_shard_count(4096, 64), 1);
465
466 assert_eq!(compute_shard_count(8192, 64), 2);
468
469 assert_eq!(compute_shard_count(65536, 64), 16);
471
472 assert_eq!(compute_shard_count(256 * 1024, 64), 64);
474 assert_eq!(compute_shard_count(1024 * 1024, 64), 64);
475
476 assert_eq!(compute_shard_count(8192, 128), 2); assert_eq!(compute_shard_count(1024 * 1024, 128), 128); }
480
481 #[test]
482 fn test_cache_insert_and_get() {
483 let cache = Cache::new(1024);
484
485 let key = TestKey(1);
486 let value = TestValue {
487 data: "hello".to_string(),
488 };
489
490 cache.insert(key.clone(), value.clone());
491
492 let retrieved = cache.get_clone(&key).expect("key should exist");
493 assert_eq!(retrieved, value);
494 }
495
496 #[test]
497 fn test_cache_remove() {
498 let cache = Cache::new(1024);
499
500 let key = TestKey(1);
501 let value = TestValue {
502 data: "hello".to_string(),
503 };
504
505 cache.insert(key.clone(), value.clone());
506 assert!(cache.contains(&key));
507
508 let removed = cache.remove(&key).expect("key should exist");
509 assert_eq!(removed, value);
510 assert!(!cache.contains(&key));
511 }
512
513 #[test]
514 fn test_cache_eviction() {
515 let cache = Cache::with_shards(1000, 4);
517
518 for i in 0..15 {
520 let key = TestKey(i);
521 let value = TestValue {
522 data: "x".repeat(50),
523 };
524 cache.insert(key, value);
525 }
526
527 assert!(cache.len() < 15, "Cache should have evicted some entries");
529 assert!(cache.size() <= 1000, "Cache size should be <= 1000, got {}", cache.size());
530 }
531
532 #[test]
533 fn test_cache_concurrent_access() {
534 use std::sync::Arc;
535 use std::thread;
536
537 let cache = Arc::new(Cache::new(10240));
538 let mut handles = vec![];
539
540 for t in 0..4 {
541 let cache = cache.clone();
542 handles.push(thread::spawn(move || {
543 for i in 0..100 {
544 let key = TestKey(t * 100 + i);
545 let value = TestValue {
546 data: format!("value-{}", i),
547 };
548 cache.insert(key.clone(), value.clone());
549
550 if let Some(retrieved) = cache.get_clone(&key) {
551 assert_eq!(retrieved, value);
552 }
553 }
554 }));
555 }
556
557 for handle in handles {
558 handle.join().expect("thread should not panic");
559 }
560
561 assert!(!cache.is_empty());
562 }
563
564 #[test]
565 fn test_cache_is_send_sync() {
566 fn assert_send<T: Send>() {}
567 fn assert_sync<T: Sync>() {}
568
569 assert_send::<Cache>();
570 assert_sync::<Cache>();
571 }
572}