1#[cfg(feature = "metrics")]
2use std::sync::atomic::AtomicU64;
3use std::sync::atomic::{AtomicUsize, Ordering};
4
5use parking_lot::RwLock;
6
7use crate::erased::{Entry, ErasedKey, ErasedKeyLookup, ErasedKeyRef};
8use crate::guard::Guard;
9#[cfg(feature = "metrics")]
10use crate::metrics::CacheMetrics;
11use crate::shard::Shard;
12use crate::traits::{CacheKey, CacheKeyLookup};
13
14pub struct Cache {
67 shards: Vec<RwLock<Shard>>,
69 current_size: AtomicUsize,
71 entry_count: AtomicUsize,
73 shard_count: usize,
75 #[cfg_attr(not(feature = "metrics"), allow(dead_code))]
77 max_size_bytes: usize,
78 #[cfg(feature = "metrics")]
80 hits: AtomicU64,
81 #[cfg(feature = "metrics")]
83 misses: AtomicU64,
84 #[cfg(feature = "metrics")]
86 inserts: AtomicU64,
87 #[cfg(feature = "metrics")]
89 updates: AtomicU64,
90 #[cfg(feature = "metrics")]
92 evictions: AtomicU64,
93 #[cfg(feature = "metrics")]
95 removals: AtomicU64,
96}
97
98const MIN_SHARD_SIZE: usize = 4096;
104
105const DEFAULT_SHARD_COUNT: usize = 64;
110
111fn compute_shard_count(capacity: usize, desired_shards: usize) -> usize {
116 let max_shards = (capacity / MIN_SHARD_SIZE).max(1);
118
119 desired_shards.min(max_shards).next_power_of_two().max(1)
121}
122
123impl Cache {
124 pub fn new(max_size_bytes: usize) -> Self {
135 let shard_count = compute_shard_count(max_size_bytes, DEFAULT_SHARD_COUNT);
136 Self::with_shards_internal(max_size_bytes, shard_count)
137 }
138
139 pub fn with_shards(max_size_bytes: usize, shard_count: usize) -> Self {
148 let shard_count = compute_shard_count(max_size_bytes, shard_count);
149 Self::with_shards_internal(max_size_bytes, shard_count)
150 }
151
152 fn with_shards_internal(max_size_bytes: usize, shard_count: usize) -> Self {
154 let size_per_shard = max_size_bytes / shard_count;
156
157 let shards = (0..shard_count).map(|_| RwLock::new(Shard::new(size_per_shard))).collect();
159
160 Self {
161 shards,
162 current_size: AtomicUsize::new(0),
163 entry_count: AtomicUsize::new(0),
164 shard_count,
165 max_size_bytes,
166 #[cfg(feature = "metrics")]
167 hits: AtomicU64::new(0),
168 #[cfg(feature = "metrics")]
169 misses: AtomicU64::new(0),
170 #[cfg(feature = "metrics")]
171 inserts: AtomicU64::new(0),
172 #[cfg(feature = "metrics")]
173 updates: AtomicU64::new(0),
174 #[cfg(feature = "metrics")]
175 evictions: AtomicU64::new(0),
176 #[cfg(feature = "metrics")]
177 removals: AtomicU64::new(0),
178 }
179 }
180
181 pub fn insert<K: CacheKey>(&self, key: K, value: K::Value) -> Option<K::Value> {
192 let erased_key = ErasedKey::new(&key);
193 let policy = key.policy();
194 let entry = Entry::new(value, policy);
195 let entry_size = entry.size;
196
197 let shard_lock = self.get_shard(erased_key.hash);
199
200 let mut shard = shard_lock.write();
202
203 let (old_entry, (num_evictions, evicted_size)) = shard.insert(erased_key, entry);
205
206 if let Some(ref old) = old_entry {
207 let size_diff = entry_size as isize - old.size as isize;
209 if size_diff > 0 {
210 self.current_size.fetch_add(size_diff as usize, Ordering::Relaxed);
211 } else {
212 self.current_size.fetch_sub((-size_diff) as usize, Ordering::Relaxed);
213 }
214 #[cfg(feature = "metrics")]
216 self.updates.fetch_add(1, Ordering::Relaxed);
217 } else {
218 self.current_size.fetch_add(entry_size, Ordering::Relaxed);
220 self.entry_count.fetch_add(1, Ordering::Relaxed);
221 #[cfg(feature = "metrics")]
223 self.inserts.fetch_add(1, Ordering::Relaxed);
224 }
225
226 if num_evictions > 0 {
228 self.entry_count.fetch_sub(num_evictions, Ordering::Relaxed);
229 self.current_size.fetch_sub(evicted_size, Ordering::Relaxed);
230 #[cfg(feature = "metrics")]
232 self.evictions.fetch_add(num_evictions as u64, Ordering::Relaxed);
233 }
234
235 old_entry.and_then(|e| e.into_value::<K::Value>())
236 }
237
238 pub fn get<K: CacheKey>(&self, key: &K) -> Option<Guard<'_, K::Value>> {
252 let key_ref = ErasedKeyRef::new(key);
253 let shard_lock = self.get_shard(key_ref.hash);
254
255 let shard = shard_lock.read();
257
258 let Some(entry) = shard.get_ref(&key_ref) else {
260 #[cfg(feature = "metrics")]
262 self.misses.fetch_add(1, Ordering::Relaxed);
263 return None;
264 };
265
266 #[cfg(feature = "metrics")]
268 self.hits.fetch_add(1, Ordering::Relaxed);
269
270 let value_ref = entry.value_ref::<K::Value>()?;
272 let value_ptr = value_ref as *const K::Value;
273
274 unsafe { Some(Guard::new(shard, value_ptr)) }
277 }
278
279 pub fn get_clone<K: CacheKey>(&self, key: &K) -> Option<K::Value>
291 where
292 K::Value: Clone,
293 {
294 let key_ref = ErasedKeyRef::new(key);
295 let shard_lock = self.get_shard(key_ref.hash);
296
297 let shard = shard_lock.read();
299
300 let Some(entry) = shard.get_ref(&key_ref) else {
301 #[cfg(feature = "metrics")]
303 self.misses.fetch_add(1, Ordering::Relaxed);
304 return None;
305 };
306
307 #[cfg(feature = "metrics")]
309 self.hits.fetch_add(1, Ordering::Relaxed);
310
311 entry.value_ref::<K::Value>().cloned()
313 }
314
315 pub fn get_by<K, Q>(&self, key: &Q) -> Option<Guard<'_, K::Value>>
349 where
350 K: CacheKey,
351 Q: CacheKeyLookup<K> + ?Sized,
352 {
353 let key_ref = ErasedKeyLookup::new(key);
354 let shard_lock = self.get_shard(key_ref.hash);
355
356 let shard = shard_lock.read();
358
359 let Some(entry) = shard.get_ref_by(&key_ref) else {
361 #[cfg(feature = "metrics")]
363 self.misses.fetch_add(1, Ordering::Relaxed);
364 return None;
365 };
366
367 #[cfg(feature = "metrics")]
369 self.hits.fetch_add(1, Ordering::Relaxed);
370
371 let value_ref = entry.value_ref::<K::Value>()?;
373 let value_ptr = value_ref as *const K::Value;
374
375 unsafe { Some(Guard::new(shard, value_ptr)) }
378 }
379
380 pub fn get_clone_by<K, Q>(&self, key: &Q) -> Option<K::Value>
395 where
396 K: CacheKey,
397 K::Value: Clone,
398 Q: CacheKeyLookup<K> + ?Sized,
399 {
400 let key_ref = ErasedKeyLookup::new(key);
401 let shard_lock = self.get_shard(key_ref.hash);
402
403 let shard = shard_lock.read();
405
406 let Some(entry) = shard.get_ref_by(&key_ref) else {
407 #[cfg(feature = "metrics")]
409 self.misses.fetch_add(1, Ordering::Relaxed);
410 return None;
411 };
412
413 #[cfg(feature = "metrics")]
415 self.hits.fetch_add(1, Ordering::Relaxed);
416
417 entry.value_ref::<K::Value>().cloned()
419 }
420
421 pub fn remove<K: CacheKey>(&self, key: &K) -> Option<K::Value> {
433 let erased_key = ErasedKey::new(key);
434 let shard_lock = self.get_shard(erased_key.hash);
435
436 let mut shard = shard_lock.write();
437 let entry = shard.remove(&erased_key)?;
438
439 self.current_size.fetch_sub(entry.size, Ordering::Relaxed);
440 self.entry_count.fetch_sub(1, Ordering::Relaxed);
441
442 #[cfg(feature = "metrics")]
444 self.removals.fetch_add(1, Ordering::Relaxed);
445
446 entry.into_value::<K::Value>()
447 }
448
449 pub fn contains<K: CacheKey>(&self, key: &K) -> bool {
451 let key_ref = ErasedKeyRef::new(key);
452 let shard_lock = self.get_shard(key_ref.hash);
453 let shard = shard_lock.read();
454
455 shard.get_ref(&key_ref).is_some()
457 }
458
459 pub fn contains_by<K, Q>(&self, key: &Q) -> bool
471 where
472 K: CacheKey,
473 Q: CacheKeyLookup<K> + ?Sized,
474 {
475 let key_ref = ErasedKeyLookup::new(key);
476 let shard_lock = self.get_shard(key_ref.hash);
477 let shard = shard_lock.read();
478
479 shard.get_ref_by(&key_ref).is_some()
481 }
482
483 pub fn size(&self) -> usize {
485 self.current_size.load(Ordering::Relaxed)
486 }
487
488 pub fn len(&self) -> usize {
490 self.entry_count.load(Ordering::Relaxed)
491 }
492
493 pub fn is_empty(&self) -> bool {
495 self.len() == 0
496 }
497
498 pub fn clear(&self) {
507 for shard_lock in &self.shards {
508 let mut shard = shard_lock.write();
509 shard.clear();
510 }
511 self.current_size.store(0, Ordering::Relaxed);
512 self.entry_count.store(0, Ordering::Relaxed);
513
514 #[cfg(feature = "metrics")]
516 {
517 self.hits.store(0, Ordering::Relaxed);
518 self.misses.store(0, Ordering::Relaxed);
519 self.inserts.store(0, Ordering::Relaxed);
520 self.updates.store(0, Ordering::Relaxed);
521 self.evictions.store(0, Ordering::Relaxed);
522 self.removals.store(0, Ordering::Relaxed);
523 }
524 }
525
526 #[cfg(feature = "metrics")]
545 pub fn metrics(&self) -> CacheMetrics {
546 CacheMetrics {
547 hits: self.hits.load(Ordering::Relaxed),
548 misses: self.misses.load(Ordering::Relaxed),
549 inserts: self.inserts.load(Ordering::Relaxed),
550 updates: self.updates.load(Ordering::Relaxed),
551 evictions: self.evictions.load(Ordering::Relaxed),
552 removals: self.removals.load(Ordering::Relaxed),
553 current_size_bytes: self.current_size.load(Ordering::Relaxed),
554 capacity_bytes: self.max_size_bytes,
555 entry_count: self.entry_count.load(Ordering::Relaxed),
556 }
557 }
558
559 fn get_shard(&self, hash: u64) -> &RwLock<Shard> {
561 let index = (hash as usize) & (self.shard_count - 1);
562 &self.shards[index]
563 }
564}
565
566unsafe impl Send for Cache {}
568unsafe impl Sync for Cache {}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use crate::DeepSizeOf;
574
575 #[derive(Hash, Eq, PartialEq, Clone, Debug)]
576 struct TestKey(u64);
577
578 impl CacheKey for TestKey {
579 type Value = TestValue;
580 }
581
582 #[derive(Clone, Debug, PartialEq, DeepSizeOf)]
583 struct TestValue {
584 data: String,
585 }
586
587 #[test]
588 fn test_compute_shard_count_scales_with_capacity() {
589 assert_eq!(compute_shard_count(1024, 64), 1);
591 assert_eq!(compute_shard_count(4095, 64), 1);
592
593 assert_eq!(compute_shard_count(4096, 64), 1);
595
596 assert_eq!(compute_shard_count(8192, 64), 2);
598
599 assert_eq!(compute_shard_count(65536, 64), 16);
601
602 assert_eq!(compute_shard_count(256 * 1024, 64), 64);
604 assert_eq!(compute_shard_count(1024 * 1024, 64), 64);
605
606 assert_eq!(compute_shard_count(8192, 128), 2); assert_eq!(compute_shard_count(1024 * 1024, 128), 128); }
610
611 #[test]
612 fn test_cache_insert_and_get() {
613 let cache = Cache::new(1024);
614
615 let key = TestKey(1);
616 let value = TestValue {
617 data: "hello".to_string(),
618 };
619
620 cache.insert(key.clone(), value.clone());
621
622 let retrieved = cache.get_clone(&key).expect("key should exist");
623 assert_eq!(retrieved, value);
624 }
625
626 #[test]
627 fn test_cache_remove() {
628 let cache = Cache::new(1024);
629
630 let key = TestKey(1);
631 let value = TestValue {
632 data: "hello".to_string(),
633 };
634
635 cache.insert(key.clone(), value.clone());
636 assert!(cache.contains(&key));
637
638 let removed = cache.remove(&key).expect("key should exist");
639 assert_eq!(removed, value);
640 assert!(!cache.contains(&key));
641 }
642
643 #[test]
644 fn test_cache_eviction() {
645 let cache = Cache::with_shards(1000, 4);
647
648 for i in 0..15 {
650 let key = TestKey(i);
651 let value = TestValue {
652 data: "x".repeat(50),
653 };
654 cache.insert(key, value);
655 }
656
657 assert!(cache.len() < 15, "Cache should have evicted some entries");
659 assert!(cache.size() <= 1000, "Cache size should be <= 1000, got {}", cache.size());
660 }
661
662 #[test]
663 fn test_cache_concurrent_access() {
664 use std::sync::Arc;
665 use std::thread;
666
667 let cache = Arc::new(Cache::new(10240));
668 let mut handles = vec![];
669
670 for t in 0..4 {
671 let cache = cache.clone();
672 handles.push(thread::spawn(move || {
673 for i in 0..100 {
674 let key = TestKey(t * 100 + i);
675 let value = TestValue {
676 data: format!("value-{}", i),
677 };
678 cache.insert(key.clone(), value.clone());
679
680 if let Some(retrieved) = cache.get_clone(&key) {
681 assert_eq!(retrieved, value);
682 }
683 }
684 }));
685 }
686
687 for handle in handles {
688 handle.join().expect("thread should not panic");
689 }
690
691 assert!(!cache.is_empty());
692 }
693
694 #[test]
695 fn test_cache_is_send_sync() {
696 fn assert_send<T: Send>() {}
697 fn assert_sync<T: Sync>() {}
698
699 assert_send::<Cache>();
700 assert_sync::<Cache>();
701 }
702
703 #[derive(Hash, Eq, PartialEq, Clone, Debug)]
706 struct DbCacheKey(String, String);
707
708 impl CacheKey for DbCacheKey {
709 type Value = TestValue;
710 }
711
712 struct DbCacheKeyRef<'a>(&'a str, &'a str);
713
714 impl std::hash::Hash for DbCacheKeyRef<'_> {
715 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
716 self.0.hash(state);
718 self.1.hash(state);
719 }
720 }
721
722 impl CacheKeyLookup<DbCacheKey> for DbCacheKeyRef<'_> {
723 fn eq_key(&self, key: &DbCacheKey) -> bool {
724 self.0 == key.0 && self.1 == key.1
725 }
726
727 fn to_owned_key(self) -> DbCacheKey {
728 DbCacheKey(self.0.to_owned(), self.1.to_owned())
729 }
730 }
731
732 #[test]
733 fn test_borrowed_key_lookup_get_by() {
734 let cache = Cache::new(1024);
735
736 let key = DbCacheKey("namespace".to_string(), "database".to_string());
737 let value = TestValue {
738 data: "test_data".to_string(),
739 };
740
741 cache.insert(key.clone(), value.clone());
742
743 let borrowed_key = DbCacheKeyRef("namespace", "database");
745 let retrieved = cache.get_by::<DbCacheKey, _>(&borrowed_key);
746 assert!(retrieved.is_some());
747 assert_eq!(*retrieved.unwrap(), value);
748
749 let borrowed_key_missing = DbCacheKeyRef("namespace", "missing");
751 let retrieved = cache.get_by::<DbCacheKey, _>(&borrowed_key_missing);
752 assert!(retrieved.is_none());
753 }
754
755 #[test]
756 fn test_borrowed_key_lookup_get_clone_by() {
757 let cache = Cache::new(1024);
758
759 let key = DbCacheKey("ns".to_string(), "db".to_string());
760 let value = TestValue {
761 data: "cloned_data".to_string(),
762 };
763
764 cache.insert(key.clone(), value.clone());
765
766 let borrowed_key = DbCacheKeyRef("ns", "db");
768 let retrieved = cache.get_clone_by::<DbCacheKey, _>(&borrowed_key);
769 assert_eq!(retrieved, Some(value));
770
771 let borrowed_key_missing = DbCacheKeyRef("ns", "missing");
773 let retrieved = cache.get_clone_by::<DbCacheKey, _>(&borrowed_key_missing);
774 assert_eq!(retrieved, None);
775 }
776
777 #[test]
778 fn test_borrowed_key_lookup_contains_by() {
779 let cache = Cache::new(1024);
780
781 let key = DbCacheKey("catalog".to_string(), "schema".to_string());
782 let value = TestValue {
783 data: "contains_test".to_string(),
784 };
785
786 cache.insert(key.clone(), value);
787
788 let borrowed_key = DbCacheKeyRef("catalog", "schema");
790 assert!(cache.contains_by::<DbCacheKey, _>(&borrowed_key));
791
792 let borrowed_key_missing = DbCacheKeyRef("catalog", "missing");
794 assert!(!cache.contains_by::<DbCacheKey, _>(&borrowed_key_missing));
795 }
796
797 #[test]
798 fn test_borrowed_key_lookup_multiple_entries() {
799 let cache = Cache::new(4096);
800
801 for i in 0..10 {
803 let key = DbCacheKey(format!("ns{}", i), format!("db{}", i));
804 let value = TestValue {
805 data: format!("data{}", i),
806 };
807 cache.insert(key, value);
808 }
809
810 for i in 0..10 {
812 let ns = format!("ns{}", i);
813 let db = format!("db{}", i);
814 let borrowed_key = DbCacheKeyRef(&ns, &db);
815
816 let retrieved = cache.get_clone_by::<DbCacheKey, _>(&borrowed_key);
817 assert!(retrieved.is_some());
818 assert_eq!(retrieved.unwrap().data, format!("data{}", i));
819 }
820 }
821
822 #[test]
823 fn test_borrowed_key_existing_api_still_works() {
824 let cache = Cache::new(1024);
825
826 let key = DbCacheKey("test".to_string(), "key".to_string());
827 let value = TestValue {
828 data: "existing_api".to_string(),
829 };
830
831 cache.insert(key.clone(), value.clone());
832
833 let retrieved = cache.get(&key);
835 assert!(retrieved.is_some());
836 assert_eq!(*retrieved.unwrap(), value);
837
838 assert!(cache.contains(&key));
839
840 let cloned = cache.get_clone(&key);
841 assert_eq!(cloned, Some(value));
842 }
843}