Skip to main content

fixed_cache/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![cfg_attr(not(feature = "std"), no_std)]
4#![allow(clippy::new_without_default)]
5
6#[cfg(feature = "alloc")]
7extern crate alloc;
8
9use core::{
10    cell::UnsafeCell,
11    convert::Infallible,
12    fmt,
13    hash::{BuildHasher, Hash},
14    marker::PhantomData,
15    mem::{self, MaybeUninit},
16    ptr,
17    sync::atomic::{AtomicUsize, Ordering},
18};
19use equivalent::Equivalent;
20
21#[cfg(feature = "stats")]
22mod stats;
23#[cfg(feature = "stats")]
24#[cfg_attr(docsrs, doc(cfg(feature = "stats")))]
25pub use stats::{AnyRef, CountingStatsHandler, Stats, StatsHandler};
26
27const LOCKED_BIT: usize = 1 << 0;
28const ALIVE_BIT: usize = 1 << 1;
29const NEEDED_BITS: usize = 2;
30
31const EPOCH_BITS: usize = 10;
32const EPOCH_SHIFT: usize = NEEDED_BITS;
33const EPOCH_MASK: usize = ((1 << EPOCH_BITS) - 1) << EPOCH_SHIFT;
34const EPOCH_NEEDED_BITS: usize = NEEDED_BITS + EPOCH_BITS;
35
36#[cfg(feature = "rapidhash")]
37type DefaultBuildHasher = core::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
38#[cfg(all(not(feature = "rapidhash"), feature = "std"))]
39type DefaultBuildHasher = std::hash::RandomState;
40#[cfg(all(not(feature = "rapidhash"), not(feature = "std")))]
41type DefaultBuildHasher = core::hash::BuildHasherDefault<NoDefaultHasher>;
42
43#[cfg(all(not(feature = "rapidhash"), not(feature = "std")))]
44#[doc(hidden)]
45pub enum NoDefaultHasher {}
46
47/// Configuration trait for [`Cache`].
48///
49/// This trait allows customizing cache behavior through compile-time configuration.
50pub trait CacheConfig {
51    /// Whether to track statistics for cache performance.
52    ///
53    /// When enabled, the cache tracks hit and miss counts for each key, allowing
54    /// monitoring of cache performance.
55    ///
56    /// Only enabled when the `stats` feature is also enabled.
57    ///
58    /// Defaults to `true`.
59    const STATS: bool = true;
60
61    /// Whether to track epochs for cheap invalidation via [`Cache::clear`].
62    ///
63    /// When enabled, the cache tracks an epoch counter that is incremented on each call to
64    /// [`Cache::clear`]. Entries are considered invalid if their epoch doesn't match the current
65    /// epoch, allowing O(1) invalidation of all entries without touching each bucket.
66    ///
67    /// Defaults to `false`.
68    const EPOCHS: bool = false;
69}
70
71/// Default cache configuration.
72pub struct DefaultCacheConfig(());
73impl CacheConfig for DefaultCacheConfig {}
74
75/// A concurrent, fixed-size, set-associative cache.
76///
77/// This cache maps keys to values using a fixed number of buckets. Each key hashes to exactly
78/// one bucket, and collisions are resolved by eviction (the new value replaces the old one).
79///
80/// # Thread Safety
81///
82/// The cache is safe to share across threads (`Send + Sync`). All operations use atomic
83/// instructions and never block, making it suitable for high-contention scenarios.
84///
85/// # Limitations
86///
87/// - **Eviction on collision**: When two keys hash to the same bucket, the older entry is evicted.
88/// - **No iteration**: Individual entries cannot be enumerated.
89///
90/// # Type Parameters
91///
92/// - `K`: The key type. Must implement [`Hash`] + [`Eq`].
93/// - `V`: The value type. Must implement [`Clone`].
94/// - `S`: The hash builder type. Must implement [`BuildHasher`]. Defaults to [`RandomState`] or
95///   [`rapidhash`] if the `rapidhash` feature is enabled.
96///
97/// # Example
98///
99/// ```
100/// use fixed_cache::Cache;
101///
102/// let cache: Cache<u64, u64> = Cache::new(256, Default::default());
103///
104/// // Insert a value
105/// cache.insert(42, 100);
106/// assert_eq!(cache.get(&42), Some(100));
107///
108/// // Get or compute a value
109/// let value = cache.get_or_insert_with(123, |&k| k * 2);
110/// assert_eq!(value, 246);
111/// ```
112///
113/// [`Hash`]: std::hash::Hash
114/// [`Eq`]: std::cmp::Eq
115/// [`Clone`]: std::clone::Clone
116/// [`Drop`]: std::ops::Drop
117/// [`BuildHasher`]: std::hash::BuildHasher
118/// [`RandomState`]: std::hash::RandomState
119/// [`rapidhash`]: https://crates.io/crates/rapidhash
120pub struct Cache<K, V, S = DefaultBuildHasher, C: CacheConfig = DefaultCacheConfig> {
121    entries: *const [Bucket<(K, V)>],
122    build_hasher: S,
123    #[cfg(feature = "stats")]
124    stats: Option<Stats<K, V>>,
125    #[cfg(feature = "alloc")]
126    drop: bool,
127    epoch: AtomicUsize,
128    _config: PhantomData<C>,
129}
130
131impl<K, V, S, C: CacheConfig> fmt::Debug for Cache<K, V, S, C> {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        f.debug_struct("Cache").finish_non_exhaustive()
134    }
135}
136
137// SAFETY: `Cache` is safe to share across threads because `Bucket` uses atomic operations.
138unsafe impl<K: Send, V: Send, S: Send, C: CacheConfig + Send> Send for Cache<K, V, S, C> {}
139unsafe impl<K: Send, V: Send, S: Sync, C: CacheConfig + Sync> Sync for Cache<K, V, S, C> {}
140
141impl<K, V, S, C> Cache<K, V, S, C>
142where
143    K: Hash + Eq,
144    S: BuildHasher,
145    C: CacheConfig,
146{
147    const NEEDS_DROP: bool = Bucket::<(K, V)>::NEEDS_DROP;
148
149    /// Create a new cache with the specified number of entries and hasher.
150    ///
151    /// Dynamically allocates memory for the cache entries.
152    ///
153    /// # Panics
154    ///
155    /// Panics if `num`:
156    /// - is not a power of two.
157    /// - isn't at least 4.
158    // See len_assertion for why.
159    #[cfg(feature = "alloc")]
160    #[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
161    pub fn new(num: usize, build_hasher: S) -> Self {
162        Self::len_assertion(num);
163        let layout = alloc::alloc::Layout::array::<Bucket<(K, V)>>(num).unwrap();
164        let ptr = unsafe { alloc::alloc::alloc_zeroed(layout) };
165        if ptr.is_null() {
166            alloc::alloc::handle_alloc_error(layout);
167        }
168        let entries = ptr::slice_from_raw_parts(ptr.cast::<Bucket<(K, V)>>(), num);
169        Self::new_inner(entries, build_hasher, true)
170    }
171
172    /// Creates a new cache with the specified entries and hasher.
173    ///
174    /// # Panics
175    ///
176    /// See [`new`](Self::new).
177    #[inline]
178    pub const fn new_static(entries: &'static [Bucket<(K, V)>], build_hasher: S) -> Self {
179        Self::len_assertion(entries.len());
180        Self::new_inner(entries, build_hasher, false)
181    }
182
183    /// Sets the cache's statistics.
184    ///
185    /// Can only be called when [`CacheConfig::STATS`] is `true`.
186    #[cfg(feature = "stats")]
187    #[inline]
188    pub fn with_stats(mut self, stats: Option<Stats<K, V>>) -> Self {
189        self.set_stats(stats);
190        self
191    }
192
193    /// Sets the cache's statistics.
194    ///
195    /// Can only be called when [`CacheConfig::STATS`] is `true`.
196    #[cfg(feature = "stats")]
197    #[inline]
198    pub fn set_stats(&mut self, stats: Option<Stats<K, V>>) {
199        const { assert!(C::STATS, "can only set stats when C::STATS is true") }
200        self.stats = stats;
201    }
202
203    #[inline]
204    const fn new_inner(entries: *const [Bucket<(K, V)>], build_hasher: S, drop: bool) -> Self {
205        let _ = drop;
206        Self {
207            entries,
208            build_hasher,
209            #[cfg(feature = "stats")]
210            stats: None,
211            #[cfg(feature = "alloc")]
212            drop,
213            epoch: AtomicUsize::new(0),
214            _config: PhantomData,
215        }
216    }
217
218    #[inline]
219    const fn len_assertion(len: usize) {
220        // We need `RESERVED_BITS` bits to store metadata for each entry.
221        // Since we calculate the tag mask based on the index mask, and the index mask is (len - 1),
222        // we assert that the length's bottom `RESERVED_BITS` bits are zero.
223        let reserved = if C::EPOCHS { EPOCH_NEEDED_BITS } else { NEEDED_BITS };
224        assert!(len.is_power_of_two(), "length must be a power of two");
225        assert!((len & ((1 << reserved) - 1)) == 0, "len must have its bottom N bits set to zero");
226    }
227
228    #[inline]
229    const fn index_mask(&self) -> usize {
230        let n = self.capacity();
231        unsafe { core::hint::assert_unchecked(n.is_power_of_two()) };
232        n - 1
233    }
234
235    #[inline]
236    const fn tag_mask(&self) -> usize {
237        !self.index_mask()
238    }
239
240    /// Returns the current epoch of the cache.
241    ///
242    /// Only meaningful when [`CacheConfig::EPOCHS`] is `true`.
243    #[inline]
244    fn epoch(&self) -> usize {
245        self.epoch.load(Ordering::Relaxed)
246    }
247
248    /// Clears the cache by invalidating all entries.
249    ///
250    /// This is O(1): it simply increments an epoch counter. On epoch wraparound, falls back to
251    /// [`clear_slow`](Self::clear_slow).
252    ///
253    /// Can only be called when [`CacheConfig::EPOCHS`] is `true`.
254    #[inline]
255    pub fn clear(&self) {
256        const EPOCH_MAX: usize = (1 << EPOCH_BITS) - 1;
257        const { assert!(C::EPOCHS, "can only .clear() when C::EPOCHS is true") }
258        let prev = self.epoch.fetch_add(1, Ordering::Release);
259        // Tags store only the low EPOCH_BITS of the epoch, clear the cache on every low-bit wrap.
260        if (prev & EPOCH_MAX) == EPOCH_MAX {
261            self.clear_slow();
262        }
263    }
264
265    /// Clears the cache by invalidating all buckets.
266    ///
267    /// This is O(N) where N is the number of buckets. Prefer [`clear`](Self::clear) when
268    /// [`CacheConfig::EPOCHS`] is `true`.
269    ///
270    /// # Safety
271    ///
272    /// This method is safe but may race with concurrent operations. Callers should ensure
273    /// no other threads are accessing the cache during this operation.
274    #[inline(never)]
275    pub fn clear_slow(&self) {
276        // SAFETY: Callers ensure no other threads are accessing the cache.
277        unsafe {
278            for entry in &*self.entries {
279                let is_alive = Self::NEEDS_DROP && entry.is_alive();
280                // Store before dropping to avoid double-dropping if dropping panics.
281                entry.tag.store(0, Ordering::Relaxed);
282                if is_alive {
283                    (*entry.data.get()).assume_init_drop();
284                }
285            }
286        };
287    }
288
289    /// Returns the hash builder used by this cache.
290    #[inline]
291    pub const fn hasher(&self) -> &S {
292        &self.build_hasher
293    }
294
295    /// Returns the number of entries in this cache.
296    #[inline]
297    pub const fn capacity(&self) -> usize {
298        self.entries.len()
299    }
300
301    /// Returns the statistics of this cache.
302    #[cfg(feature = "stats")]
303    pub const fn stats(&self) -> Option<&Stats<K, V>> {
304        self.stats.as_ref()
305    }
306}
307
308impl<K, V, S, C> Cache<K, V, S, C>
309where
310    K: Hash + Eq,
311    V: Clone,
312    S: BuildHasher,
313    C: CacheConfig,
314{
315    /// Get an entry from the cache.
316    pub fn get<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> Option<V> {
317        let (bucket, tag) = self.calc(key);
318        self.get_inner(key, bucket, tag)
319    }
320
321    #[inline]
322    fn get_inner<Q: ?Sized + Hash + Equivalent<K>>(
323        &self,
324        key: &Q,
325        bucket: &Bucket<(K, V)>,
326        tag: usize,
327    ) -> Option<V> {
328        if bucket.try_lock(Some(tag)) {
329            // SAFETY: We hold the lock and bucket is alive, so we have exclusive access.
330            let (ck, v) = unsafe { (*bucket.data.get()).assume_init_ref() };
331            if key.equivalent(ck) {
332                #[cfg(feature = "stats")]
333                if C::STATS
334                    && let Some(stats) = &self.stats
335                {
336                    stats.record_hit(ck, v);
337                }
338                let v = v.clone();
339                bucket.unlock(tag);
340                return Some(v);
341            }
342            bucket.unlock(tag);
343        }
344        #[cfg(feature = "stats")]
345        if C::STATS
346            && let Some(stats) = &self.stats
347        {
348            stats.record_miss(AnyRef::new(&key));
349        }
350        None
351    }
352
353    /// Insert an entry into the cache.
354    pub fn insert(&self, key: K, value: V) {
355        let (bucket, tag) = self.calc(&key);
356        self.insert_inner(bucket, tag, || (key, value));
357    }
358
359    /// Remove an entry from the cache.
360    ///
361    /// Returns the value if the key was present in the cache.
362    pub fn remove<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> Option<V> {
363        let (bucket, tag) = self.calc(key);
364        if bucket.try_lock(Some(tag)) {
365            // SAFETY: We hold the lock and bucket is alive, so we have exclusive access.
366            let data = unsafe { &mut *bucket.data.get() };
367            let (ck, v) = unsafe { data.assume_init_ref() };
368            if key.equivalent(ck) {
369                let v = v.clone();
370                #[cfg(feature = "stats")]
371                if C::STATS
372                    && let Some(stats) = &self.stats
373                {
374                    stats.record_remove(ck, &v);
375                }
376                if Self::NEEDS_DROP {
377                    // SAFETY: We hold the lock, so we have exclusive access.
378                    unsafe { data.assume_init_drop() };
379                }
380                bucket.unlock(0);
381                return Some(v);
382            }
383            bucket.unlock(tag);
384        }
385        None
386    }
387
388    #[inline]
389    fn insert_inner(
390        &self,
391        bucket: &Bucket<(K, V)>,
392        tag: usize,
393        make_entry: impl FnOnce() -> (K, V),
394    ) {
395        #[inline(always)]
396        unsafe fn do_write<T>(ptr: *mut T, f: impl FnOnce() -> T) {
397            // This function is translated as:
398            // - allocate space for a T on the stack
399            // - call f() with the return value being put onto this stack space
400            // - memcpy from the stack to the heap
401            //
402            // Ideally we want LLVM to always realize that doing a stack
403            // allocation is unnecessary and optimize the code so it writes
404            // directly into the heap instead. It seems we get it to realize
405            // this most consistently if we put this critical line into it's
406            // own function instead of inlining it into the surrounding code.
407            unsafe { ptr::write(ptr, f()) };
408        }
409
410        let Some(prev_tag) = bucket.try_lock_ret(None) else {
411            return;
412        };
413
414        // SAFETY: We hold the lock, so we have exclusive access.
415        unsafe {
416            let is_alive = (prev_tag & !LOCKED_BIT) != 0;
417            let data = bucket.data.get().cast::<(K, V)>();
418
419            if C::STATS && cfg!(feature = "stats") {
420                #[cfg(feature = "stats")]
421                if is_alive {
422                    let (old_key, old_value) = ptr::replace(data, make_entry());
423                    if let Some(stats) = &self.stats {
424                        stats.record_insert(&(*data).0, &(*data).1, Some((&old_key, &old_value)));
425                    }
426                } else {
427                    do_write(data, make_entry);
428                    if let Some(stats) = &self.stats {
429                        stats.record_insert(&(*data).0, &(*data).1, None);
430                    }
431                }
432            } else {
433                if Self::NEEDS_DROP && is_alive {
434                    ptr::drop_in_place(data);
435                }
436                do_write(data, make_entry);
437            }
438        }
439        bucket.unlock(tag);
440    }
441
442    /// Gets a value from the cache, or inserts one computed by `f` if not present.
443    ///
444    /// If the key is found in the cache, returns a clone of the cached value.
445    /// Otherwise, calls `f` to compute the value, attempts to insert it, and returns it.
446    #[inline]
447    pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
448    where
449        F: FnOnce(&K) -> V,
450    {
451        let Ok(v) = self.get_or_try_insert_with(key, |key| Ok::<_, Infallible>(f(key)));
452        v
453    }
454
455    /// Gets a value from the cache, or inserts one computed by `f` if not present.
456    ///
457    /// If the key is found in the cache, returns a clone of the cached value.
458    /// Otherwise, calls `f` to compute the value, attempts to insert it, and returns it.
459    ///
460    /// This is the same as [`get_or_insert_with`], but takes a reference to the key, and a function
461    /// to get the key reference to an owned key.
462    ///
463    /// [`get_or_insert_with`]: Self::get_or_insert_with
464    #[inline]
465    pub fn get_or_insert_with_ref<'a, Q, F, Cvt>(&self, key: &'a Q, f: F, cvt: Cvt) -> V
466    where
467        Q: ?Sized + Hash + Equivalent<K>,
468        F: FnOnce(&'a Q) -> V,
469        Cvt: FnOnce(&'a Q) -> K,
470    {
471        let Ok(v) = self.get_or_try_insert_with_ref(key, |key| Ok::<_, Infallible>(f(key)), cvt);
472        v
473    }
474
475    /// Gets a value from the cache, or attempts to insert one computed by `f` if not present.
476    ///
477    /// If the key is found in the cache, returns `Ok` with a clone of the cached value.
478    /// Otherwise, calls `f` to compute the value. If `f` returns `Ok`, attempts to insert
479    /// the value and returns it. If `f` returns `Err`, the error is propagated.
480    #[inline]
481    pub fn get_or_try_insert_with<F, E>(&self, key: K, f: F) -> Result<V, E>
482    where
483        F: FnOnce(&K) -> Result<V, E>,
484    {
485        let mut key = mem::ManuallyDrop::new(key);
486        let mut read = false;
487        let r = self.get_or_try_insert_with_ref(&*key, f, |k| {
488            read = true;
489            unsafe { ptr::read(k) }
490        });
491        if !read {
492            unsafe { mem::ManuallyDrop::drop(&mut key) }
493        }
494        r
495    }
496
497    /// Gets a value from the cache, or attempts to insert one computed by `f` if not present.
498    ///
499    /// If the key is found in the cache, returns `Ok` with a clone of the cached value.
500    /// Otherwise, calls `f` to compute the value. If `f` returns `Ok`, attempts to insert
501    /// the value and returns it. If `f` returns `Err`, the error is propagated.
502    ///
503    /// This is the same as [`Self::get_or_try_insert_with`], but takes a reference to the key, and
504    /// a function to get the key reference to an owned key.
505    #[inline]
506    pub fn get_or_try_insert_with_ref<'a, Q, F, Cvt, E>(
507        &self,
508        key: &'a Q,
509        f: F,
510        cvt: Cvt,
511    ) -> Result<V, E>
512    where
513        Q: ?Sized + Hash + Equivalent<K>,
514        F: FnOnce(&'a Q) -> Result<V, E>,
515        Cvt: FnOnce(&'a Q) -> K,
516    {
517        let (bucket, tag) = self.calc(key);
518        if let Some(v) = self.get_inner(key, bucket, tag) {
519            return Ok(v);
520        }
521        let value = f(key)?;
522        self.insert_inner(bucket, tag, || (cvt(key), value.clone()));
523        Ok(value)
524    }
525
526    #[inline]
527    fn calc<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> (&Bucket<(K, V)>, usize) {
528        let hash = self.hash_key(key);
529        // SAFETY: index is masked to be within bounds.
530        let bucket = unsafe { (&*self.entries).get_unchecked(hash & self.index_mask()) };
531        let mut tag = hash & self.tag_mask();
532        if Self::NEEDS_DROP {
533            tag |= ALIVE_BIT;
534        }
535        if C::EPOCHS {
536            tag = (tag & !EPOCH_MASK) | ((self.epoch() << EPOCH_SHIFT) & EPOCH_MASK);
537        }
538        (bucket, tag)
539    }
540
541    #[inline]
542    fn hash_key<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> usize {
543        let hash = self.build_hasher.hash_one(key);
544
545        if cfg!(target_pointer_width = "32") {
546            ((hash >> 32) as usize) ^ (hash as usize)
547        } else {
548            hash as usize
549        }
550    }
551}
552
553impl<K, V, S, C: CacheConfig> Drop for Cache<K, V, S, C> {
554    fn drop(&mut self) {
555        #[cfg(feature = "alloc")]
556        if self.drop {
557            // SAFETY: `Drop` has exclusive access.
558            drop(unsafe { alloc::boxed::Box::from_raw(self.entries.cast_mut()) });
559        }
560    }
561}
562
563/// A single cache bucket that holds one key-value pair.
564///
565/// Buckets are aligned to 128 bytes to avoid false sharing between cache lines.
566/// Each bucket contains an atomic tag for lock-free synchronization and uninitialized
567/// storage for the data.
568///
569/// This type is public to allow use with the [`static_cache!`] macro for compile-time
570/// cache initialization. You typically don't need to interact with it directly.
571#[repr(C, align(128))]
572#[doc(hidden)]
573pub struct Bucket<T> {
574    tag: AtomicUsize,
575    data: UnsafeCell<MaybeUninit<T>>,
576}
577
578impl<T> Bucket<T> {
579    const NEEDS_DROP: bool = mem::needs_drop::<T>();
580
581    /// Creates a new zeroed bucket.
582    #[inline]
583    pub const fn new() -> Self {
584        Self { tag: AtomicUsize::new(0), data: UnsafeCell::new(MaybeUninit::zeroed()) }
585    }
586
587    #[inline]
588    fn try_lock(&self, expected: Option<usize>) -> bool {
589        self.try_lock_ret(expected).is_some()
590    }
591
592    #[inline]
593    fn try_lock_ret(&self, expected: Option<usize>) -> Option<usize> {
594        let state = self.tag.load(Ordering::Relaxed);
595        if let Some(expected) = expected {
596            if state != expected {
597                return None;
598            }
599        } else if state & LOCKED_BIT != 0 {
600            return None;
601        }
602        self.tag
603            .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
604            .ok()
605    }
606
607    #[inline]
608    fn is_alive(&self) -> bool {
609        self.tag.load(Ordering::Relaxed) & ALIVE_BIT != 0
610    }
611
612    #[inline]
613    fn unlock(&self, tag: usize) {
614        self.tag.store(tag, Ordering::Release);
615    }
616}
617
618// SAFETY: `Bucket` is a specialized `Mutex<T>` that never blocks.
619unsafe impl<T: Send> Send for Bucket<T> {}
620unsafe impl<T: Send> Sync for Bucket<T> {}
621
622impl<T> Drop for Bucket<T> {
623    fn drop(&mut self) {
624        if Self::NEEDS_DROP && self.is_alive() {
625            // SAFETY: `Drop` has exclusive access.
626            unsafe { self.data.get_mut().assume_init_drop() };
627        }
628    }
629}
630
631/// Declares a static cache with the given name, key type, value type, and size.
632///
633/// The size must be a power of two.
634///
635/// # Example
636///
637/// ```
638/// # #[cfg(feature = "rapidhash")] {
639/// use fixed_cache::{Cache, static_cache};
640///
641/// type BuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
642///
643/// static MY_CACHE: Cache<u64, &'static str, BuildHasher> =
644///     static_cache!(u64, &'static str, 1024, BuildHasher::new());
645///
646/// let value = MY_CACHE.get_or_insert_with(42, |_k| "hi");
647/// assert_eq!(value, "hi");
648///
649/// let new_value = MY_CACHE.get_or_insert_with(42, |_k| "not hi");
650/// assert_eq!(new_value, "hi");
651/// # }
652/// ```
653#[macro_export]
654macro_rules! static_cache {
655    ($K:ty, $V:ty, $size:expr) => {
656        $crate::static_cache!($K, $V, $size, Default::default())
657    };
658    ($K:ty, $V:ty, $size:expr, $hasher:expr) => {{
659        static ENTRIES: [$crate::Bucket<($K, $V)>; $size] =
660            [const { $crate::Bucket::new() }; $size];
661        $crate::Cache::new_static(&ENTRIES, $hasher)
662    }};
663}
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668    use std::{cell::Cell, rc::Rc, thread};
669
670    const fn iters(n: usize) -> usize {
671        if cfg!(miri) { n / 10 } else { n }
672    }
673
674    type BH = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
675    type Cache<K, V> = super::Cache<K, V, BH>;
676
677    struct EpochConfig;
678    impl CacheConfig for EpochConfig {
679        const EPOCHS: bool = true;
680    }
681    type EpochCache<K, V> = super::Cache<K, V, BH, EpochConfig>;
682
683    struct NoStatsConfig;
684    impl CacheConfig for NoStatsConfig {
685        const STATS: bool = false;
686    }
687    type NoStatsCache<K, V> = super::Cache<K, V, BH, NoStatsConfig>;
688
689    fn new_cache<K: Hash + Eq, V: Clone>(size: usize) -> Cache<K, V> {
690        Cache::new(size, Default::default())
691    }
692
693    type Drops = Rc<Cell<usize>>;
694
695    fn drops() -> Drops {
696        Rc::new(Cell::new(0))
697    }
698
699    #[derive(Clone)]
700    struct DropKey {
701        id: u64,
702        drops: Drops,
703    }
704
705    impl DropKey {
706        fn new(id: u64, drops: &Drops) -> Self {
707            Self { id, drops: drops.clone() }
708        }
709    }
710
711    impl Hash for DropKey {
712        fn hash<H: core::hash::Hasher>(&self, state: &mut H) {
713            self.id.hash(state);
714        }
715    }
716
717    impl PartialEq for DropKey {
718        fn eq(&self, other: &Self) -> bool {
719            self.id == other.id
720        }
721    }
722
723    impl Eq for DropKey {}
724
725    impl Drop for DropKey {
726        fn drop(&mut self) {
727            self.drops.set(self.drops.get() + 1);
728        }
729    }
730
731    #[derive(Clone)]
732    struct DropValue {
733        #[allow(dead_code)]
734        value: u64,
735        drops: Drops,
736    }
737
738    impl DropValue {
739        fn new(value: u64, drops: &Drops) -> Self {
740            Self { value, drops: drops.clone() }
741        }
742    }
743
744    impl Drop for DropValue {
745        fn drop(&mut self) {
746            self.drops.set(self.drops.get() + 1);
747        }
748    }
749
750    #[test]
751    fn basic_get_or_insert() {
752        let cache = new_cache(1024);
753
754        let mut computed = false;
755        let value = cache.get_or_insert_with(42, |&k| {
756            computed = true;
757            k * 2
758        });
759        assert!(computed);
760        assert_eq!(value, 84);
761
762        computed = false;
763        let value = cache.get_or_insert_with(42, |&k| {
764            computed = true;
765            k * 2
766        });
767        assert!(!computed);
768        assert_eq!(value, 84);
769    }
770
771    #[test]
772    fn different_keys() {
773        let cache: Cache<String, usize> = new_cache(1024);
774
775        let v1 = cache.get_or_insert_with("hello".to_string(), |s| s.len());
776        let v2 = cache.get_or_insert_with("world!".to_string(), |s| s.len());
777
778        assert_eq!(v1, 5);
779        assert_eq!(v2, 6);
780    }
781
782    #[test]
783    fn new_dynamic_allocation() {
784        let cache: Cache<u32, u32> = new_cache(64);
785        assert_eq!(cache.capacity(), 64);
786
787        cache.insert(1, 100);
788        assert_eq!(cache.get(&1), Some(100));
789    }
790
791    #[test]
792    fn get_miss() {
793        let cache = new_cache::<u64, u64>(64);
794        assert_eq!(cache.get(&999), None);
795    }
796
797    #[test]
798    fn insert_and_get() {
799        let cache: Cache<u64, String> = new_cache(64);
800
801        cache.insert(1, "one".to_string());
802        cache.insert(2, "two".to_string());
803        cache.insert(3, "three".to_string());
804
805        assert_eq!(cache.get(&1), Some("one".to_string()));
806        assert_eq!(cache.get(&2), Some("two".to_string()));
807        assert_eq!(cache.get(&3), Some("three".to_string()));
808        assert_eq!(cache.get(&4), None);
809    }
810
811    #[test]
812    fn insert_twice() {
813        let cache = new_cache(64);
814
815        cache.insert(42, 1);
816        assert_eq!(cache.get(&42), Some(1));
817
818        cache.insert(42, 2);
819        let v = cache.get(&42);
820        assert!(v == Some(1) || v == Some(2));
821    }
822
823    #[test]
824    fn remove_existing() {
825        let cache: Cache<u64, String> = new_cache(64);
826
827        cache.insert(1, "one".to_string());
828        assert_eq!(cache.get(&1), Some("one".to_string()));
829
830        let removed = cache.remove(&1);
831        assert_eq!(removed, Some("one".to_string()));
832        assert_eq!(cache.get(&1), None);
833    }
834
835    #[test]
836    fn remove_nonexistent() {
837        let cache = new_cache::<u64, u64>(64);
838        assert_eq!(cache.remove(&999), None);
839    }
840
841    #[test]
842    fn get_or_insert_with_ref() {
843        let cache: Cache<String, usize> = new_cache(64);
844
845        let key = "hello";
846        let value = cache.get_or_insert_with_ref(key, |s| s.len(), |s| s.to_string());
847        assert_eq!(value, 5);
848
849        let value2 = cache.get_or_insert_with_ref(key, |_| 999, |s| s.to_string());
850        assert_eq!(value2, 5);
851    }
852
853    #[test]
854    fn get_or_insert_with_ref_different_keys() {
855        let cache: Cache<String, usize> = new_cache(1024);
856
857        let v1 = cache.get_or_insert_with_ref("foo", |s| s.len(), |s| s.to_string());
858        let v2 = cache.get_or_insert_with_ref("barbaz", |s| s.len(), |s| s.to_string());
859
860        assert_eq!(v1, 3);
861        assert_eq!(v2, 6);
862    }
863
864    #[test]
865    fn capacity() {
866        let cache = new_cache::<u64, u64>(256);
867        assert_eq!(cache.capacity(), 256);
868
869        let cache2 = new_cache::<u64, u64>(128);
870        assert_eq!(cache2.capacity(), 128);
871    }
872
873    #[test]
874    fn hasher() {
875        let cache = new_cache::<u64, u64>(64);
876        let _ = cache.hasher();
877    }
878
879    #[test]
880    fn debug_impl() {
881        let cache = new_cache::<u64, u64>(64);
882        let debug_str = format!("{:?}", cache);
883        assert!(debug_str.contains("Cache"));
884    }
885
886    #[test]
887    fn bucket_new() {
888        let bucket: Bucket<(u64, u64)> = Bucket::new();
889        assert_eq!(bucket.tag.load(Ordering::Relaxed), 0);
890    }
891
892    #[test]
893    fn many_entries() {
894        let cache: Cache<u64, u64> = new_cache(1024);
895        let n = iters(500);
896
897        for i in 0..n as u64 {
898            cache.insert(i, i * 2);
899        }
900
901        let mut hits = 0;
902        for i in 0..n as u64 {
903            if cache.get(&i) == Some(i * 2) {
904                hits += 1;
905            }
906        }
907        assert!(hits > 0);
908    }
909
910    #[test]
911    fn string_keys() {
912        let cache: Cache<String, i32> = new_cache(1024);
913
914        cache.insert("alpha".to_string(), 1);
915        cache.insert("beta".to_string(), 2);
916        cache.insert("gamma".to_string(), 3);
917
918        assert_eq!(cache.get("alpha"), Some(1));
919        assert_eq!(cache.get("beta"), Some(2));
920        assert_eq!(cache.get("gamma"), Some(3));
921    }
922
923    #[test]
924    fn zero_values() {
925        let cache: Cache<u64, u64> = new_cache(64);
926
927        cache.insert(0, 0);
928        assert_eq!(cache.get(&0), Some(0));
929
930        cache.insert(1, 0);
931        assert_eq!(cache.get(&1), Some(0));
932    }
933
934    #[test]
935    fn clone_value() {
936        #[derive(Clone, PartialEq, Debug)]
937        struct MyValue(u64);
938
939        let cache: Cache<u64, MyValue> = new_cache(64);
940
941        cache.insert(1, MyValue(123));
942        let v = cache.get(&1);
943        assert_eq!(v, Some(MyValue(123)));
944    }
945
946    fn run_concurrent<F>(num_threads: usize, f: F)
947    where
948        F: Fn(usize) + Send + Sync,
949    {
950        thread::scope(|s| {
951            for t in 0..num_threads {
952                let f = &f;
953                s.spawn(move || f(t));
954            }
955        });
956    }
957
958    #[test]
959    fn concurrent_reads() {
960        let cache: Cache<u64, u64> = new_cache(1024);
961        let n = iters(100);
962
963        for i in 0..n as u64 {
964            cache.insert(i, i * 10);
965        }
966
967        run_concurrent(4, |_| {
968            for i in 0..n as u64 {
969                let _ = cache.get(&i);
970            }
971        });
972    }
973
974    #[test]
975    fn concurrent_writes() {
976        let cache: Cache<u64, u64> = new_cache(1024);
977        let n = iters(100);
978
979        run_concurrent(4, |t| {
980            for i in 0..n {
981                cache.insert((t * 1000 + i) as u64, i as u64);
982            }
983        });
984    }
985
986    #[test]
987    fn concurrent_read_write() {
988        let cache: Cache<u64, u64> = new_cache(256);
989        let n = iters(1000);
990
991        run_concurrent(2, |t| {
992            for i in 0..n as u64 {
993                if t == 0 {
994                    cache.insert(i % 100, i);
995                } else {
996                    let _ = cache.get(&(i % 100));
997                }
998            }
999        });
1000    }
1001
1002    #[test]
1003    fn concurrent_get_or_insert() {
1004        let cache: Cache<u64, u64> = new_cache(1024);
1005        let n = iters(100);
1006
1007        run_concurrent(8, |_| {
1008            for i in 0..n as u64 {
1009                let _ = cache.get_or_insert_with(i, |&k| k * 2);
1010            }
1011        });
1012
1013        for i in 0..n as u64 {
1014            if let Some(v) = cache.get(&i) {
1015                assert_eq!(v, i * 2);
1016            }
1017        }
1018    }
1019
1020    #[test]
1021    #[should_panic = "power of two"]
1022    fn non_power_of_two() {
1023        let _ = new_cache::<u64, u64>(100);
1024    }
1025
1026    #[test]
1027    #[should_panic = "len must have its bottom N bits set to zero"]
1028    fn small_cache() {
1029        let _ = new_cache::<u64, u64>(2);
1030    }
1031
1032    #[test]
1033    fn power_of_two_sizes() {
1034        for shift in 2..10 {
1035            let size = 1 << shift;
1036            let cache = new_cache::<u64, u64>(size);
1037            assert_eq!(cache.capacity(), size);
1038        }
1039    }
1040
1041    #[test]
1042    fn equivalent_key_lookup() {
1043        let cache: Cache<String, i32> = new_cache(64);
1044
1045        cache.insert("hello".to_string(), 42);
1046
1047        assert_eq!(cache.get("hello"), Some(42));
1048    }
1049
1050    #[test]
1051    fn large_values() {
1052        let cache: Cache<u64, [u8; 1000]> = new_cache(64);
1053
1054        let large_value = [42u8; 1000];
1055        cache.insert(1, large_value);
1056
1057        assert_eq!(cache.get(&1), Some(large_value));
1058    }
1059
1060    #[test]
1061    fn send_sync() {
1062        fn assert_send<T: Send>() {}
1063        fn assert_sync<T: Sync>() {}
1064
1065        assert_send::<Cache<u64, u64>>();
1066        assert_sync::<Cache<u64, u64>>();
1067        assert_send::<Bucket<(u64, u64)>>();
1068        assert_sync::<Bucket<(u64, u64)>>();
1069    }
1070
1071    #[test]
1072    fn get_or_try_insert_with_ok() {
1073        let cache = new_cache(1024);
1074
1075        let mut computed = false;
1076        let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |&k| {
1077            computed = true;
1078            Ok(k * 2)
1079        });
1080        assert!(computed);
1081        assert_eq!(result, Ok(84));
1082
1083        computed = false;
1084        let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |&k| {
1085            computed = true;
1086            Ok(k * 2)
1087        });
1088        assert!(!computed);
1089        assert_eq!(result, Ok(84));
1090    }
1091
1092    #[test]
1093    fn get_or_try_insert_with_err() {
1094        let cache: Cache<u64, u64> = new_cache(1024);
1095
1096        let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |_| Err("failed"));
1097        assert_eq!(result, Err("failed"));
1098
1099        assert_eq!(cache.get(&42), None);
1100    }
1101
1102    #[test]
1103    fn get_or_try_insert_with_ref_ok() {
1104        let cache: Cache<String, usize> = new_cache(64);
1105
1106        let key = "hello";
1107        let result: Result<usize, &str> =
1108            cache.get_or_try_insert_with_ref(key, |s| Ok(s.len()), |s| s.to_string());
1109        assert_eq!(result, Ok(5));
1110
1111        let result2: Result<usize, &str> =
1112            cache.get_or_try_insert_with_ref(key, |_| Ok(999), |s| s.to_string());
1113        assert_eq!(result2, Ok(5));
1114    }
1115
1116    #[test]
1117    fn get_or_try_insert_with_ref_err() {
1118        let cache: Cache<String, usize> = new_cache(64);
1119
1120        let key = "hello";
1121        let result: Result<usize, &str> =
1122            cache.get_or_try_insert_with_ref(key, |_| Err("failed"), |s| s.to_string());
1123        assert_eq!(result, Err("failed"));
1124
1125        assert_eq!(cache.get(key), None);
1126    }
1127
1128    #[test]
1129    fn drop_on_cache_drop() {
1130        let drops = drops();
1131        {
1132            let cache: super::Cache<DropKey, DropValue, BH> =
1133                super::Cache::new(64, Default::default());
1134            cache.insert(DropKey::new(1, &drops), DropValue::new(100, &drops));
1135            cache.insert(DropKey::new(2, &drops), DropValue::new(200, &drops));
1136            cache.insert(DropKey::new(3, &drops), DropValue::new(300, &drops));
1137            assert_eq!(drops.get(), 0);
1138        }
1139        // 3 keys + 3 values = 6 drops
1140        assert_eq!(drops.get(), 6);
1141    }
1142
1143    #[test]
1144    fn drop_on_eviction() {
1145        let drops = drops();
1146        {
1147            let cache: super::Cache<DropKey, DropValue, BH> =
1148                super::Cache::new(64, Default::default());
1149            cache.insert(DropKey::new(1, &drops), DropValue::new(100, &drops));
1150            assert_eq!(drops.get(), 0);
1151            // Insert same key again - should evict old entry
1152            cache.insert(DropKey::new(1, &drops), DropValue::new(200, &drops));
1153            // Old key + old value dropped = 2
1154            assert_eq!(drops.get(), 2);
1155        }
1156        // Cache dropped: new key + new value = 2 more
1157        assert_eq!(drops.get(), 4);
1158    }
1159
1160    #[test]
1161    fn epoch_clear() {
1162        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1163
1164        assert_eq!(cache.epoch(), 0);
1165
1166        cache.insert(1, 100);
1167        cache.insert(2, 200);
1168        assert_eq!(cache.get(&1), Some(100));
1169        assert_eq!(cache.get(&2), Some(200));
1170
1171        cache.clear();
1172        assert_eq!(cache.epoch(), 1);
1173
1174        assert_eq!(cache.get(&1), None);
1175        assert_eq!(cache.get(&2), None);
1176
1177        cache.insert(1, 101);
1178        assert_eq!(cache.get(&1), Some(101));
1179
1180        cache.clear();
1181        assert_eq!(cache.epoch(), 2);
1182        assert_eq!(cache.get(&1), None);
1183    }
1184
1185    #[test]
1186    fn epoch_wrap_around() {
1187        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1188
1189        for _ in 0..300 {
1190            cache.insert(42, 123);
1191            assert_eq!(cache.get(&42), Some(123));
1192            cache.clear();
1193            assert_eq!(cache.get(&42), None);
1194        }
1195    }
1196
1197    #[test]
1198    fn no_stats_config() {
1199        let cache: NoStatsCache<u64, u64> = NoStatsCache::new(64, Default::default());
1200
1201        cache.insert(1, 100);
1202        assert_eq!(cache.get(&1), Some(100));
1203        assert_eq!(cache.get(&999), None);
1204
1205        cache.insert(1, 200);
1206        assert_eq!(cache.get(&1), Some(200));
1207
1208        cache.remove(&1);
1209        assert_eq!(cache.get(&1), None);
1210
1211        let v = cache.get_or_insert_with(42, |&k| k * 2);
1212        assert_eq!(v, 84);
1213    }
1214
1215    #[test]
1216    fn epoch_wraparound_stays_cleared() {
1217        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1218
1219        cache.insert(42, 123);
1220        assert_eq!(cache.get(&42), Some(123));
1221
1222        for i in 0..2048 {
1223            cache.clear();
1224            assert_eq!(cache.get(&42), None, "failed at clear #{i}");
1225        }
1226    }
1227
1228    #[test]
1229    fn epoch_repeated_wraparound_stays_cleared() {
1230        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1231
1232        for _ in 0..1024 {
1233            cache.clear();
1234        }
1235
1236        cache.insert(42, 123);
1237        assert_eq!(cache.get(&42), Some(123));
1238
1239        for i in 0..1024 {
1240            cache.clear();
1241            assert_eq!(cache.get(&42), None, "failed at clear #{i}");
1242        }
1243    }
1244
1245    #[test]
1246    fn remove_copy_type() {
1247        let cache = new_cache::<u64, u64>(64);
1248
1249        cache.insert(1, 100);
1250        assert_eq!(cache.get(&1), Some(100));
1251
1252        let removed = cache.remove(&1);
1253        assert_eq!(removed, Some(100));
1254        assert_eq!(cache.get(&1), None);
1255
1256        cache.insert(1, 200);
1257        assert_eq!(cache.get(&1), Some(200));
1258    }
1259
1260    #[test]
1261    fn remove_then_reinsert_copy() {
1262        let cache = new_cache::<u64, u64>(64);
1263
1264        for i in 0..100u64 {
1265            cache.insert(1, i);
1266            assert_eq!(cache.get(&1), Some(i));
1267            assert_eq!(cache.remove(&1), Some(i));
1268            assert_eq!(cache.get(&1), None);
1269        }
1270    }
1271
1272    #[test]
1273    fn epoch_with_needs_drop() {
1274        let cache: EpochCache<String, String> = EpochCache::new(4096, Default::default());
1275
1276        cache.insert("key".to_string(), "value".to_string());
1277        assert_eq!(cache.get("key"), Some("value".to_string()));
1278
1279        cache.clear();
1280        assert_eq!(cache.get("key"), None);
1281
1282        cache.insert("key".to_string(), "value2".to_string());
1283        assert_eq!(cache.get("key"), Some("value2".to_string()));
1284    }
1285
1286    #[test]
1287    fn clear_slow_drops_entries() {
1288        let drops = drops();
1289        {
1290            let cache: Cache<DropKey, DropValue> = new_cache(64);
1291            cache.insert(DropKey::new(1, &drops), DropValue::new(100, &drops));
1292            assert_eq!(drops.get(), 0);
1293
1294            cache.clear_slow();
1295            assert_eq!(drops.get(), 2);
1296
1297            cache.insert(DropKey::new(1, &drops), DropValue::new(200, &drops));
1298            assert_eq!(drops.get(), 2);
1299        }
1300        assert_eq!(drops.get(), 4);
1301    }
1302
1303    #[test]
1304    fn clear_slow_panic_does_not_double_drop_entry() {
1305        use std::panic::{AssertUnwindSafe, catch_unwind};
1306
1307        #[derive(Clone)]
1308        struct PanicOnFirstDrop(Drops);
1309        impl Drop for PanicOnFirstDrop {
1310            fn drop(&mut self) {
1311                let prev = self.0.get();
1312                self.0.set(prev + 1);
1313                if prev == 0 {
1314                    panic!("intentional panic from drop");
1315                }
1316            }
1317        }
1318
1319        let drops = drops();
1320        {
1321            let cache: Cache<u64, PanicOnFirstDrop> = new_cache(64);
1322            cache.insert(1, PanicOnFirstDrop(drops.clone()));
1323
1324            let result = catch_unwind(AssertUnwindSafe(|| cache.clear_slow()));
1325            assert!(result.is_err());
1326            assert_eq!(drops.get(), 1);
1327        }
1328        assert_eq!(drops.get(), 1);
1329    }
1330
1331    #[test]
1332    fn epoch_remove() {
1333        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1334
1335        cache.insert(1, 100);
1336        assert_eq!(cache.remove(&1), Some(100));
1337        assert_eq!(cache.get(&1), None);
1338
1339        cache.insert(1, 200);
1340        assert_eq!(cache.get(&1), Some(200));
1341
1342        cache.clear();
1343        assert_eq!(cache.get(&1), None);
1344        assert_eq!(cache.remove(&1), None);
1345    }
1346
1347    #[test]
1348    fn no_stats_needs_drop() {
1349        let cache: NoStatsCache<String, String> = NoStatsCache::new(64, Default::default());
1350
1351        cache.insert("a".to_string(), "b".to_string());
1352        assert_eq!(cache.get("a"), Some("b".to_string()));
1353
1354        cache.insert("a".to_string(), "c".to_string());
1355        assert_eq!(cache.get("a"), Some("c".to_string()));
1356
1357        cache.remove(&"a".to_string());
1358        assert_eq!(cache.get("a"), None);
1359    }
1360
1361    #[test]
1362    fn no_stats_get_or_insert() {
1363        let cache: NoStatsCache<String, usize> = NoStatsCache::new(64, Default::default());
1364
1365        let v = cache.get_or_insert_with_ref("hello", |s| s.len(), |s| s.to_string());
1366        assert_eq!(v, 5);
1367
1368        let v2 = cache.get_or_insert_with_ref("hello", |_| 999, |s| s.to_string());
1369        assert_eq!(v2, 5);
1370    }
1371
1372    #[test]
1373    fn epoch_concurrent() {
1374        let cache: EpochCache<u64, u64> = EpochCache::new(4096, Default::default());
1375        let n = iters(10_000);
1376
1377        run_concurrent(4, |t| {
1378            for i in 0..n as u64 {
1379                match t {
1380                    0 => {
1381                        cache.insert(i % 50, i);
1382                    }
1383                    1 => {
1384                        let _ = cache.get(&(i % 50));
1385                    }
1386                    2 => {
1387                        if i % 100 == 0 {
1388                            cache.clear();
1389                        }
1390                    }
1391                    _ => {
1392                        let _ = cache.remove(&(i % 50));
1393                    }
1394                }
1395            }
1396        });
1397    }
1398}