fixed_cache/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![allow(clippy::new_without_default)]
4
5use core::{
6    cell::UnsafeCell,
7    hash::{BuildHasher, Hash},
8    mem::MaybeUninit,
9    sync::atomic::{AtomicUsize, Ordering},
10};
11use equivalent::Equivalent;
12
13const LOCKED_BIT: usize = 0x0000_8000;
14
15#[cfg(feature = "rapidhash")]
16type DefaultBuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
17#[cfg(not(feature = "rapidhash"))]
18type DefaultBuildHasher = std::hash::RandomState;
19
20/// A concurrent, fixed-size, set-associative cache.
21///
22/// This cache maps keys to values using a fixed number of buckets. Each key hashes to exactly
23/// one bucket, and collisions are resolved by eviction (the new value replaces the old one).
24///
25/// # Thread Safety
26///
27/// The cache is safe to share across threads (`Send + Sync`). All operations use atomic
28/// instructions and never block, making it suitable for high-contention scenarios.
29///
30/// # Limitations
31///
32/// - **No `Drop` support**: Key and value types must not implement `Drop`. Use `Copy` types,
33///   primitives, or `&'static` references.
34/// - **Eviction on collision**: When two keys hash to the same bucket, the older entry is lost.
35/// - **No iteration or removal**: Individual entries cannot be enumerated or explicitly removed.
36///
37/// # Type Parameters
38///
39/// - `K`: The key type. Must implement [`Hash`] + [`Eq`] and must not implement [`Drop`].
40/// - `V`: The value type. Must implement [`Clone`] and must not implement [`Drop`].
41/// - `S`: The hash builder type. Must implement [`BuildHasher`]. Defaults to [`RandomState`] or
42///   [`rapidhash`] if the `rapidhash` feature is enabled.
43///
44/// # Example
45///
46/// ```
47/// use fixed_cache::Cache;
48///
49/// let cache: Cache<u64, u64> = Cache::new(256, Default::default());
50///
51/// // Insert a value
52/// cache.insert(42, 100);
53/// assert_eq!(cache.get(&42), Some(100));
54///
55/// // Get or compute a value
56/// let value = cache.get_or_insert_with(123, |&k| k * 2);
57/// assert_eq!(value, 246);
58/// ```
59///
60/// [`Hash`]: core::hash::Hash
61/// [`Eq`]: core::cmp::Eq
62/// [`Clone`]: core::clone::Clone
63/// [`Drop`]: core::ops::Drop
64/// [`BuildHasher`]: core::hash::BuildHasher
65/// [`RandomState`]: std::hash::RandomState
66/// [`rapidhash`]: https://crates.io/crates/rapidhash
67pub struct Cache<K, V, S = DefaultBuildHasher> {
68    entries: *const [Bucket<(K, V)>],
69    build_hasher: S,
70    drop: bool,
71}
72
73impl<K, V, S> core::fmt::Debug for Cache<K, V, S> {
74    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75        f.debug_struct("Cache").finish_non_exhaustive()
76    }
77}
78
79// SAFETY: `Cache` is safe to share across threads because `Bucket` uses atomic operations.
80unsafe impl<K: Send, V: Send, S: Send> Send for Cache<K, V, S> {}
81unsafe impl<K: Send, V: Send, S: Sync> Sync for Cache<K, V, S> {}
82
83impl<K, V, S> Cache<K, V, S>
84where
85    K: Hash + Eq,
86    S: BuildHasher,
87{
88    /// Create a new cache with the specified number of entries and hasher.
89    ///
90    /// Dynamically allocates memory for the cache entries.
91    ///
92    /// # Panics
93    ///
94    /// Panics if `num` is not a power of two.
95    pub fn new(num: usize, build_hasher: S) -> Self {
96        assert!(num.is_power_of_two(), "capacity must be a power of two");
97        let entries =
98            Box::into_raw((0..num).map(|_| Bucket::new()).collect::<Vec<_>>().into_boxed_slice());
99        Self::new_inner(entries, build_hasher, true)
100    }
101
102    /// Creates a new cache with the specified entries and hasher.
103    ///
104    /// # Panics
105    ///
106    /// Panics if `entries.len()` is not a power of two.
107    #[inline]
108    pub const fn new_static(entries: &'static [Bucket<(K, V)>], build_hasher: S) -> Self {
109        Self::new_inner(entries, build_hasher, false)
110    }
111
112    #[inline]
113    const fn new_inner(entries: *const [Bucket<(K, V)>], build_hasher: S, drop: bool) -> Self {
114        const {
115            assert!(!std::mem::needs_drop::<K>(), "dropping keys is not supported yet");
116            assert!(!std::mem::needs_drop::<V>(), "dropping values is not supported yet");
117        }
118        assert!(entries.len().is_power_of_two());
119        Self { entries, build_hasher, drop }
120    }
121
122    #[inline]
123    const fn index_mask(&self) -> usize {
124        let n = self.capacity();
125        unsafe { core::hint::assert_unchecked(n.is_power_of_two()) };
126        n - 1
127    }
128
129    #[inline]
130    const fn tag_mask(&self) -> usize {
131        !self.index_mask()
132    }
133
134    /// Returns the hash builder used by this cache.
135    #[inline]
136    pub const fn hasher(&self) -> &S {
137        &self.build_hasher
138    }
139
140    /// Returns the number of entries in this cache.
141    #[inline]
142    pub const fn capacity(&self) -> usize {
143        self.entries.len()
144    }
145}
146
147impl<K, V, S> Cache<K, V, S>
148where
149    K: Hash + Eq,
150    V: Clone,
151    S: BuildHasher,
152{
153    /// Get an entry from the cache.
154    pub fn get<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> Option<V> {
155        let (bucket, tag) = self.calc(key);
156        self.get_inner(key, bucket, tag)
157    }
158
159    #[inline]
160    fn get_inner<Q: ?Sized + Hash + Equivalent<K>>(
161        &self,
162        key: &Q,
163        bucket: &Bucket<(K, V)>,
164        tag: usize,
165    ) -> Option<V> {
166        if bucket.try_lock(Some(tag)) {
167            // SAFETY: We hold the lock, so we have exclusive access.
168            let (ck, v) = unsafe { (*bucket.data.get()).assume_init_ref() };
169            if key.equivalent(ck) {
170                let v = v.clone();
171                bucket.unlock(tag);
172                return Some(v);
173            }
174            bucket.unlock(tag);
175            // Hash collision: same hash but different key.
176        }
177
178        None
179    }
180
181    /// Insert an entry into the cache.
182    pub fn insert(&self, key: K, value: V) {
183        let (bucket, tag) = self.calc(&key);
184        self.insert_inner(|| key, || value, bucket, tag);
185    }
186
187    #[inline]
188    fn insert_inner(
189        &self,
190        make_key: impl FnOnce() -> K,
191        make_value: impl FnOnce() -> V,
192        bucket: &Bucket<(K, V)>,
193        tag: usize,
194    ) {
195        if bucket.try_lock(None) {
196            // SAFETY: We hold the lock, so we have exclusive access.
197            unsafe {
198                let data = (&mut *bucket.data.get()).as_mut_ptr();
199                (&raw mut (*data).0).write(make_key());
200                (&raw mut (*data).1).write(make_value());
201            }
202            bucket.unlock(tag);
203        }
204    }
205
206    /// Gets a value from the cache, or inserts one computed by `f` if not present.
207    ///
208    /// If the key is found in the cache, returns a clone of the cached value.
209    /// Otherwise, calls `f` to compute the value, attempts to insert it, and returns it.
210    #[inline]
211    pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
212    where
213        F: FnOnce(&K) -> V,
214    {
215        let mut key = std::mem::ManuallyDrop::new(key);
216        let mut read = false;
217        let r = self.get_or_insert_with_ref(&*key, f, |k| {
218            read = true;
219            unsafe { std::ptr::read(k) }
220        });
221        if !read {
222            unsafe { std::mem::ManuallyDrop::drop(&mut key) }
223        }
224        r
225    }
226
227    /// Gets a value from the cache, or inserts one computed by `f` if not present.
228    ///
229    /// If the key is found in the cache, returns a clone of the cached value.
230    /// Otherwise, calls `f` to compute the value, attempts to insert it, and returns it.
231    ///
232    /// This is the same as [`get_or_insert_with`], but takes a reference to the key, and a function
233    /// to get the key reference to an owned key.
234    ///
235    /// [`get_or_insert_with`]: Self::get_or_insert_with
236    #[inline]
237    pub fn get_or_insert_with_ref<'a, Q, F, Cvt>(&self, key: &'a Q, f: F, cvt: Cvt) -> V
238    where
239        Q: ?Sized + Hash + Equivalent<K>,
240        F: FnOnce(&'a Q) -> V,
241        Cvt: FnOnce(&'a Q) -> K,
242    {
243        let (bucket, tag) = self.calc(key);
244        if let Some(v) = self.get_inner(key, bucket, tag) {
245            return v;
246        }
247        let value = f(key);
248        self.insert_inner(|| cvt(key), || value.clone(), bucket, tag);
249        value
250    }
251
252    #[inline]
253    fn calc<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> (&Bucket<(K, V)>, usize) {
254        let hash = self.hash_key(key);
255        // SAFETY: index is masked to be within bounds.
256        let bucket = unsafe { (&*self.entries).get_unchecked(hash & self.index_mask()) };
257        let tag = hash & self.tag_mask();
258        (bucket, tag)
259    }
260
261    #[inline]
262    fn hash_key<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> usize {
263        let hash = self.build_hasher.hash_one(key);
264
265        if cfg!(target_pointer_width = "32") {
266            ((hash >> 32) as usize) ^ (hash as usize)
267        } else {
268            hash as usize
269        }
270    }
271}
272
273impl<K, V, S> Drop for Cache<K, V, S> {
274    fn drop(&mut self) {
275        if self.drop {
276            drop(unsafe { Box::from_raw(self.entries.cast_mut()) });
277        }
278    }
279}
280
281/// A single cache bucket that holds one key-value pair.
282///
283/// Buckets are aligned to 128 bytes to avoid false sharing between cache lines.
284/// Each bucket contains an atomic tag for lock-free synchronization and uninitialized
285/// storage for the data.
286///
287/// This type is public to allow use with the [`static_cache!`] macro for compile-time
288/// cache initialization. You typically don't need to interact with it directly.
289#[repr(C, align(128))]
290#[doc(hidden)]
291pub struct Bucket<T> {
292    tag: AtomicUsize,
293    data: UnsafeCell<MaybeUninit<T>>,
294}
295
296impl<T> Bucket<T> {
297    /// Creates a new zeroed bucket.
298    #[inline]
299    pub const fn new() -> Self {
300        Self { tag: AtomicUsize::new(0), data: UnsafeCell::new(MaybeUninit::zeroed()) }
301    }
302
303    #[inline]
304    fn try_lock(&self, expected: Option<usize>) -> bool {
305        let state = self.tag.load(Ordering::Relaxed);
306        if let Some(expected) = expected {
307            if state != expected {
308                return false;
309            }
310        } else if state & LOCKED_BIT != 0 {
311            return false;
312        }
313        self.tag
314            .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
315            .is_ok()
316    }
317
318    #[inline]
319    fn unlock(&self, tag: usize) {
320        self.tag.store(tag, Ordering::Release);
321    }
322}
323
324// SAFETY: `Bucket` is a specialized `Mutex<T>` that never blocks.
325unsafe impl<T: Send> Send for Bucket<T> {}
326unsafe impl<T: Send> Sync for Bucket<T> {}
327
328/// Declares a static cache with the given name, key type, value type, and size.
329///
330/// The size must be a power of two.
331///
332/// # Example
333///
334/// ```
335/// # #[cfg(feature = "rapidhash")] {
336/// use fixed_cache::{Cache, static_cache};
337///
338/// type BuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
339///
340/// static MY_CACHE: Cache<u64, &'static str, BuildHasher> =
341///     static_cache!(u64, &'static str, 1024, BuildHasher::new());
342///
343/// let value = MY_CACHE.get_or_insert_with(&42, |_k| "hi");
344/// assert_eq!(value, "hi");
345///
346/// let new_value = MY_CACHE.get_or_insert_with(&42, |_k| "not hi");
347/// assert_eq!(new_value, "not hi");
348/// # }
349/// ```
350#[macro_export]
351macro_rules! static_cache {
352    ($K:ty, $V:ty, $size:expr) => {
353        $crate::static_cache!($K, $V, $size, Default::default())
354    };
355    ($K:ty, $V:ty, $size:expr, $hasher:expr) => {{
356        static ENTRIES: [$crate::Bucket<($K, $V)>; $size] =
357            [const { $crate::Bucket::new() }; $size];
358        $crate::Cache::new_static(&ENTRIES, $hasher)
359    }};
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use std::thread;
366
367    const fn iters(n: usize) -> usize {
368        if cfg!(miri) { n / 10 } else { n }
369    }
370
371    type BH = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
372    type Cache<K, V> = super::Cache<K, V, BH>;
373
374    fn new_cache<K: Hash + Eq, V: Clone>(size: usize) -> Cache<K, V> {
375        Cache::new(size, Default::default())
376    }
377
378    #[test]
379    fn test_basic_get_or_insert() {
380        let cache = new_cache(1024);
381
382        let mut computed = false;
383        let value = cache.get_or_insert_with(42, |&k| {
384            computed = true;
385            k * 2
386        });
387        assert!(computed);
388        assert_eq!(value, 84);
389
390        computed = false;
391        let value = cache.get_or_insert_with(42, |&k| {
392            computed = true;
393            k * 2
394        });
395        assert!(!computed);
396        assert_eq!(value, 84);
397    }
398
399    #[test]
400    fn test_different_keys() {
401        let cache: Cache<&'static str, usize> = static_cache!(&'static str, usize, 1024);
402
403        let v1 = cache.get_or_insert_with("hello", |s| s.len());
404        let v2 = cache.get_or_insert_with("world!", |s| s.len());
405
406        assert_eq!(v1, 5);
407        assert_eq!(v2, 6);
408    }
409
410    #[test]
411    fn test_new_dynamic_allocation() {
412        let cache: Cache<u32, u32> = new_cache(64);
413        assert_eq!(cache.capacity(), 64);
414
415        cache.insert(1, 100);
416        assert_eq!(cache.get(&1), Some(100));
417    }
418
419    #[test]
420    fn test_get_miss() {
421        let cache = new_cache::<u64, u64>(64);
422        assert_eq!(cache.get(&999), None);
423    }
424
425    #[test]
426    fn test_insert_and_get() {
427        let cache: Cache<u64, &'static str> = new_cache(64);
428
429        cache.insert(1, "one");
430        cache.insert(2, "two");
431        cache.insert(3, "three");
432
433        assert_eq!(cache.get(&1), Some("one"));
434        assert_eq!(cache.get(&2), Some("two"));
435        assert_eq!(cache.get(&3), Some("three"));
436        assert_eq!(cache.get(&4), None);
437    }
438
439    #[test]
440    fn test_insert_twice() {
441        let cache = new_cache(64);
442
443        cache.insert(42, 1);
444        assert_eq!(cache.get(&42), Some(1));
445
446        cache.insert(42, 2);
447        let v = cache.get(&42);
448        assert!(v == Some(1) || v == Some(2));
449    }
450
451    #[test]
452    fn test_get_or_insert_with_ref() {
453        let cache: Cache<&'static str, usize> = new_cache(64);
454
455        let key = "hello";
456        let value = cache.get_or_insert_with_ref(key, |s| s.len(), |s| s);
457        assert_eq!(value, 5);
458
459        let value2 = cache.get_or_insert_with_ref(key, |_| 999, |s| s);
460        assert_eq!(value2, 5);
461    }
462
463    #[test]
464    fn test_get_or_insert_with_ref_different_keys() {
465        let cache: Cache<&'static str, usize> = new_cache(1024);
466
467        let v1 = cache.get_or_insert_with_ref("foo", |s| s.len(), |s| s);
468        let v2 = cache.get_or_insert_with_ref("barbaz", |s| s.len(), |s| s);
469
470        assert_eq!(v1, 3);
471        assert_eq!(v2, 6);
472    }
473
474    #[test]
475    fn test_capacity() {
476        let cache = new_cache::<u64, u64>(256);
477        assert_eq!(cache.capacity(), 256);
478
479        let cache2 = new_cache::<u64, u64>(128);
480        assert_eq!(cache2.capacity(), 128);
481    }
482
483    #[test]
484    fn test_hasher() {
485        let cache = new_cache::<u64, u64>(64);
486        let _ = cache.hasher();
487    }
488
489    #[test]
490    fn test_debug_impl() {
491        let cache = new_cache::<u64, u64>(64);
492        let debug_str = format!("{:?}", cache);
493        assert!(debug_str.contains("Cache"));
494    }
495
496    #[test]
497    fn test_bucket_new() {
498        let bucket: Bucket<(u64, u64)> = Bucket::new();
499        assert_eq!(bucket.tag.load(Ordering::Relaxed), 0);
500    }
501
502    #[test]
503    fn test_many_entries() {
504        let cache: Cache<u64, u64> = new_cache(1024);
505        let n = iters(500);
506
507        for i in 0..n as u64 {
508            cache.insert(i, i * 2);
509        }
510
511        let mut hits = 0;
512        for i in 0..n as u64 {
513            if cache.get(&i) == Some(i * 2) {
514                hits += 1;
515            }
516        }
517        assert!(hits > 0);
518    }
519
520    #[test]
521    fn test_string_keys() {
522        let cache: Cache<&'static str, i32> = new_cache(1024);
523
524        cache.insert("alpha", 1);
525        cache.insert("beta", 2);
526        cache.insert("gamma", 3);
527
528        assert_eq!(cache.get(&"alpha"), Some(1));
529        assert_eq!(cache.get(&"beta"), Some(2));
530        assert_eq!(cache.get(&"gamma"), Some(3));
531    }
532
533    #[test]
534    fn test_zero_values() {
535        let cache: Cache<u64, u64> = new_cache(64);
536
537        cache.insert(0, 0);
538        assert_eq!(cache.get(&0), Some(0));
539
540        cache.insert(1, 0);
541        assert_eq!(cache.get(&1), Some(0));
542    }
543
544    #[test]
545    fn test_clone_value() {
546        #[derive(Clone, PartialEq, Debug)]
547        struct MyValue(u64);
548
549        let cache: Cache<u64, MyValue> = new_cache(64);
550
551        cache.insert(1, MyValue(123));
552        let v = cache.get(&1);
553        assert_eq!(v, Some(MyValue(123)));
554    }
555
556    fn run_concurrent<F>(num_threads: usize, f: F)
557    where
558        F: Fn(usize) + Send + Sync,
559    {
560        thread::scope(|s| {
561            for t in 0..num_threads {
562                let f = &f;
563                s.spawn(move || f(t));
564            }
565        });
566    }
567
568    #[test]
569    fn test_concurrent_reads() {
570        let cache: Cache<u64, u64> = new_cache(1024);
571        let n = iters(100);
572
573        for i in 0..n as u64 {
574            cache.insert(i, i * 10);
575        }
576
577        run_concurrent(4, |_| {
578            for i in 0..n as u64 {
579                let _ = cache.get(&i);
580            }
581        });
582    }
583
584    #[test]
585    fn test_concurrent_writes() {
586        let cache: Cache<u64, u64> = new_cache(1024);
587        let n = iters(100);
588
589        run_concurrent(4, |t| {
590            for i in 0..n {
591                cache.insert((t * 1000 + i) as u64, i as u64);
592            }
593        });
594    }
595
596    #[test]
597    fn test_concurrent_read_write() {
598        let cache: Cache<u64, u64> = new_cache(256);
599        let n = iters(1000);
600
601        run_concurrent(2, |t| {
602            for i in 0..n as u64 {
603                if t == 0 {
604                    cache.insert(i % 100, i);
605                } else {
606                    let _ = cache.get(&(i % 100));
607                }
608            }
609        });
610    }
611
612    #[test]
613    fn test_concurrent_get_or_insert() {
614        let cache: Cache<u64, u64> = new_cache(1024);
615        let n = iters(100);
616
617        run_concurrent(8, |_| {
618            for i in 0..n as u64 {
619                let _ = cache.get_or_insert_with(i, |&k| k * 2);
620            }
621        });
622
623        for i in 0..n as u64 {
624            if let Some(v) = cache.get(&i) {
625                assert_eq!(v, i * 2);
626            }
627        }
628    }
629
630    #[test]
631    #[should_panic]
632    fn test_non_power_of_two_panics() {
633        let _ = new_cache::<u64, u64>(100);
634    }
635
636    #[test]
637    fn test_power_of_two_sizes() {
638        for shift in 1..10 {
639            let size = 1 << shift;
640            let cache = new_cache::<u64, u64>(size);
641            assert_eq!(cache.capacity(), size);
642        }
643    }
644
645    #[test]
646    fn test_small_cache() {
647        let cache = new_cache(2);
648        assert_eq!(cache.capacity(), 2);
649
650        cache.insert(1, 10);
651        cache.insert(2, 20);
652        cache.insert(3, 30);
653
654        let count = [1, 2, 3].iter().filter(|&&k| cache.get(&k).is_some()).count();
655        assert!(count <= 2);
656    }
657
658    #[test]
659    fn test_equivalent_key_lookup() {
660        let cache = new_cache(64);
661
662        cache.insert("hello", 42);
663
664        assert_eq!(cache.get(&"hello"), Some(42));
665    }
666
667    #[test]
668    fn test_large_values() {
669        let cache: Cache<u64, [u8; 1000]> = new_cache(64);
670
671        let large_value = [42u8; 1000];
672        cache.insert(1, large_value);
673
674        assert_eq!(cache.get(&1), Some(large_value));
675    }
676
677    #[test]
678    fn test_send_sync() {
679        fn assert_send<T: Send>() {}
680        fn assert_sync<T: Sync>() {}
681
682        assert_send::<Cache<u64, u64>>();
683        assert_sync::<Cache<u64, u64>>();
684        assert_send::<Bucket<(u64, u64)>>();
685        assert_sync::<Bucket<(u64, u64)>>();
686    }
687}