1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![allow(clippy::new_without_default)]
4
5use core::{
6 cell::UnsafeCell,
7 hash::{BuildHasher, Hash},
8 mem::MaybeUninit,
9 sync::atomic::{AtomicUsize, Ordering},
10};
11use equivalent::Equivalent;
12use std::convert::Infallible;
13
14const NEEDED_BITS: usize = 2;
15const LOCKED_BIT: usize = 1 << 0;
16const ALIVE_BIT: usize = 1 << 1;
17
18#[cfg(feature = "rapidhash")]
19type DefaultBuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
20#[cfg(not(feature = "rapidhash"))]
21type DefaultBuildHasher = std::hash::RandomState;
22
23pub struct Cache<K, V, S = DefaultBuildHasher> {
69 entries: *const [Bucket<(K, V)>],
70 build_hasher: S,
71 drop: bool,
72}
73
74impl<K, V, S> core::fmt::Debug for Cache<K, V, S> {
75 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76 f.debug_struct("Cache").finish_non_exhaustive()
77 }
78}
79
80unsafe impl<K: Send, V: Send, S: Send> Send for Cache<K, V, S> {}
82unsafe impl<K: Send, V: Send, S: Sync> Sync for Cache<K, V, S> {}
83
84impl<K, V, S> Cache<K, V, S>
85where
86 K: Hash + Eq,
87 S: BuildHasher,
88{
89 pub fn new(num: usize, build_hasher: S) -> Self {
100 Self::len_assertion(num);
101 let entries =
102 Box::into_raw((0..num).map(|_| Bucket::new()).collect::<Vec<_>>().into_boxed_slice());
103 Self::new_inner(entries, build_hasher, true)
104 }
105
106 #[inline]
112 pub const fn new_static(entries: &'static [Bucket<(K, V)>], build_hasher: S) -> Self {
113 Self::len_assertion(entries.len());
114 Self::new_inner(entries, build_hasher, false)
115 }
116
117 #[inline]
118 const fn new_inner(entries: *const [Bucket<(K, V)>], build_hasher: S, drop: bool) -> Self {
119 Self { entries, build_hasher, drop }
120 }
121
122 #[inline]
123 const fn len_assertion(len: usize) {
124 assert!(len.is_power_of_two(), "length must be a power of two");
128 assert!(
129 (len & ((1 << NEEDED_BITS) - 1)) == 0,
130 "len must have its bottom N bits set to zero"
131 );
132 }
133
134 #[inline]
135 const fn index_mask(&self) -> usize {
136 let n = self.capacity();
137 unsafe { core::hint::assert_unchecked(n.is_power_of_two()) };
138 n - 1
139 }
140
141 #[inline]
142 const fn tag_mask(&self) -> usize {
143 !self.index_mask()
144 }
145
146 #[inline]
148 pub const fn hasher(&self) -> &S {
149 &self.build_hasher
150 }
151
152 #[inline]
154 pub const fn capacity(&self) -> usize {
155 self.entries.len()
156 }
157}
158
159impl<K, V, S> Cache<K, V, S>
160where
161 K: Hash + Eq,
162 V: Clone,
163 S: BuildHasher,
164{
165 const NEEDS_DROP: bool = Bucket::<(K, V)>::NEEDS_DROP;
166
167 pub fn get<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> Option<V> {
169 let (bucket, tag) = self.calc(key);
170 self.get_inner(key, bucket, tag)
171 }
172
173 #[inline]
174 fn get_inner<Q: ?Sized + Hash + Equivalent<K>>(
175 &self,
176 key: &Q,
177 bucket: &Bucket<(K, V)>,
178 tag: usize,
179 ) -> Option<V> {
180 if bucket.try_lock(Some(tag)) {
181 let (ck, v) = unsafe { (*bucket.data.get()).assume_init_ref() };
183 if key.equivalent(ck) {
184 let v = v.clone();
185 bucket.unlock(tag);
186 return Some(v);
187 }
188 bucket.unlock(tag);
189 }
191
192 None
193 }
194
195 pub fn insert(&self, key: K, value: V) {
197 let (bucket, tag) = self.calc(&key);
198 self.insert_inner(|| key, || value, bucket, tag);
199 }
200
201 #[inline]
202 fn insert_inner(
203 &self,
204 make_key: impl FnOnce() -> K,
205 make_value: impl FnOnce() -> V,
206 bucket: &Bucket<(K, V)>,
207 tag: usize,
208 ) {
209 if let Some(prev_tag) = bucket.try_lock_ret(None) {
210 unsafe {
212 let data = (&mut *bucket.data.get()).as_mut_ptr();
213 if Self::NEEDS_DROP && (prev_tag & ALIVE_BIT) != 0 {
215 core::ptr::drop_in_place(data);
216 }
217 (&raw mut (*data).0).write(make_key());
218 (&raw mut (*data).1).write(make_value());
219 }
220 bucket.unlock(tag);
221 }
222 }
223
224 #[inline]
229 pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
230 where
231 F: FnOnce(&K) -> V,
232 {
233 self.get_or_try_insert_with(key, |key| Ok::<_, Infallible>(f(key))).unwrap()
234 }
235
236 #[inline]
246 pub fn get_or_insert_with_ref<'a, Q, F, Cvt>(&self, key: &'a Q, f: F, cvt: Cvt) -> V
247 where
248 Q: ?Sized + Hash + Equivalent<K>,
249 F: FnOnce(&'a Q) -> V,
250 Cvt: FnOnce(&'a Q) -> K,
251 {
252 self.get_or_try_insert_with_ref(key, |key| Ok::<_, Infallible>(f(key)), cvt).unwrap()
253 }
254
255 #[inline]
261 pub fn get_or_try_insert_with<F, E>(&self, key: K, f: F) -> Result<V, E>
262 where
263 F: FnOnce(&K) -> Result<V, E>,
264 {
265 let mut key = std::mem::ManuallyDrop::new(key);
266 let mut read = false;
267 let r = self.get_or_try_insert_with_ref(&*key, f, |k| {
268 read = true;
269 unsafe { std::ptr::read(k) }
270 });
271 if !read {
272 unsafe { std::mem::ManuallyDrop::drop(&mut key) }
273 }
274 r
275 }
276
277 #[inline]
286 pub fn get_or_try_insert_with_ref<'a, Q, F, Cvt, E>(
287 &self,
288 key: &'a Q,
289 f: F,
290 cvt: Cvt,
291 ) -> Result<V, E>
292 where
293 Q: ?Sized + Hash + Equivalent<K>,
294 F: FnOnce(&'a Q) -> Result<V, E>,
295 Cvt: FnOnce(&'a Q) -> K,
296 {
297 let (bucket, tag) = self.calc(key);
298 if let Some(v) = self.get_inner(key, bucket, tag) {
299 return Ok(v);
300 }
301 let value = f(key)?;
302 self.insert_inner(|| cvt(key), || value.clone(), bucket, tag);
303 Ok(value)
304 }
305
306 #[inline]
307 fn calc<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> (&Bucket<(K, V)>, usize) {
308 let hash = self.hash_key(key);
309 let bucket = unsafe { (&*self.entries).get_unchecked(hash & self.index_mask()) };
311 let mut tag = hash & self.tag_mask();
312 if Self::NEEDS_DROP {
313 tag |= ALIVE_BIT;
314 }
315 (bucket, tag)
316 }
317
318 #[inline]
319 fn hash_key<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> usize {
320 let hash = self.build_hasher.hash_one(key);
321
322 if cfg!(target_pointer_width = "32") {
323 ((hash >> 32) as usize) ^ (hash as usize)
324 } else {
325 hash as usize
326 }
327 }
328}
329
330impl<K, V, S> Drop for Cache<K, V, S> {
331 fn drop(&mut self) {
332 if self.drop {
333 drop(unsafe { Box::from_raw(self.entries.cast_mut()) });
335 }
336 }
337}
338
339#[repr(C, align(128))]
348#[doc(hidden)]
349pub struct Bucket<T> {
350 tag: AtomicUsize,
351 data: UnsafeCell<MaybeUninit<T>>,
352}
353
354impl<T> Bucket<T> {
355 const NEEDS_DROP: bool = std::mem::needs_drop::<T>();
356
357 #[inline]
359 pub const fn new() -> Self {
360 Self { tag: AtomicUsize::new(0), data: UnsafeCell::new(MaybeUninit::zeroed()) }
361 }
362
363 #[inline]
364 fn try_lock(&self, expected: Option<usize>) -> bool {
365 self.try_lock_ret(expected).is_some()
366 }
367
368 #[inline]
369 fn try_lock_ret(&self, expected: Option<usize>) -> Option<usize> {
370 let state = self.tag.load(Ordering::Relaxed);
371 if let Some(expected) = expected {
372 if state != expected {
373 return None;
374 }
375 } else if state & LOCKED_BIT != 0 {
376 return None;
377 }
378 self.tag
379 .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
380 .ok()
381 }
382
383 #[inline]
384 fn is_alive(&self) -> bool {
385 self.tag.load(Ordering::Relaxed) & ALIVE_BIT != 0
386 }
387
388 #[inline]
389 fn unlock(&self, tag: usize) {
390 self.tag.store(tag, Ordering::Release);
391 }
392}
393
394unsafe impl<T: Send> Send for Bucket<T> {}
396unsafe impl<T: Send> Sync for Bucket<T> {}
397
398impl<T> Drop for Bucket<T> {
399 fn drop(&mut self) {
400 if Self::NEEDS_DROP && self.is_alive() {
401 unsafe { self.data.get_mut().assume_init_drop() };
403 }
404 }
405}
406
407#[macro_export]
430macro_rules! static_cache {
431 ($K:ty, $V:ty, $size:expr) => {
432 $crate::static_cache!($K, $V, $size, Default::default())
433 };
434 ($K:ty, $V:ty, $size:expr, $hasher:expr) => {{
435 static ENTRIES: [$crate::Bucket<($K, $V)>; $size] =
436 [const { $crate::Bucket::new() }; $size];
437 $crate::Cache::new_static(&ENTRIES, $hasher)
438 }};
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use std::thread;
445
446 const fn iters(n: usize) -> usize {
447 if cfg!(miri) { n / 10 } else { n }
448 }
449
450 type BH = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
451 type Cache<K, V> = super::Cache<K, V, BH>;
452
453 fn new_cache<K: Hash + Eq, V: Clone>(size: usize) -> Cache<K, V> {
454 Cache::new(size, Default::default())
455 }
456
457 #[test]
458 fn basic_get_or_insert() {
459 let cache = new_cache(1024);
460
461 let mut computed = false;
462 let value = cache.get_or_insert_with(42, |&k| {
463 computed = true;
464 k * 2
465 });
466 assert!(computed);
467 assert_eq!(value, 84);
468
469 computed = false;
470 let value = cache.get_or_insert_with(42, |&k| {
471 computed = true;
472 k * 2
473 });
474 assert!(!computed);
475 assert_eq!(value, 84);
476 }
477
478 #[test]
479 fn different_keys() {
480 let cache: Cache<String, usize> = new_cache(1024);
481
482 let v1 = cache.get_or_insert_with("hello".to_string(), |s| s.len());
483 let v2 = cache.get_or_insert_with("world!".to_string(), |s| s.len());
484
485 assert_eq!(v1, 5);
486 assert_eq!(v2, 6);
487 }
488
489 #[test]
490 fn new_dynamic_allocation() {
491 let cache: Cache<u32, u32> = new_cache(64);
492 assert_eq!(cache.capacity(), 64);
493
494 cache.insert(1, 100);
495 assert_eq!(cache.get(&1), Some(100));
496 }
497
498 #[test]
499 fn get_miss() {
500 let cache = new_cache::<u64, u64>(64);
501 assert_eq!(cache.get(&999), None);
502 }
503
504 #[test]
505 fn insert_and_get() {
506 let cache: Cache<u64, String> = new_cache(64);
507
508 cache.insert(1, "one".to_string());
509 cache.insert(2, "two".to_string());
510 cache.insert(3, "three".to_string());
511
512 assert_eq!(cache.get(&1), Some("one".to_string()));
513 assert_eq!(cache.get(&2), Some("two".to_string()));
514 assert_eq!(cache.get(&3), Some("three".to_string()));
515 assert_eq!(cache.get(&4), None);
516 }
517
518 #[test]
519 fn insert_twice() {
520 let cache = new_cache(64);
521
522 cache.insert(42, 1);
523 assert_eq!(cache.get(&42), Some(1));
524
525 cache.insert(42, 2);
526 let v = cache.get(&42);
527 assert!(v == Some(1) || v == Some(2));
528 }
529
530 #[test]
531 fn get_or_insert_with_ref() {
532 let cache: Cache<String, usize> = new_cache(64);
533
534 let key = "hello";
535 let value = cache.get_or_insert_with_ref(key, |s| s.len(), |s| s.to_string());
536 assert_eq!(value, 5);
537
538 let value2 = cache.get_or_insert_with_ref(key, |_| 999, |s| s.to_string());
539 assert_eq!(value2, 5);
540 }
541
542 #[test]
543 fn get_or_insert_with_ref_different_keys() {
544 let cache: Cache<String, usize> = new_cache(1024);
545
546 let v1 = cache.get_or_insert_with_ref("foo", |s| s.len(), |s| s.to_string());
547 let v2 = cache.get_or_insert_with_ref("barbaz", |s| s.len(), |s| s.to_string());
548
549 assert_eq!(v1, 3);
550 assert_eq!(v2, 6);
551 }
552
553 #[test]
554 fn capacity() {
555 let cache = new_cache::<u64, u64>(256);
556 assert_eq!(cache.capacity(), 256);
557
558 let cache2 = new_cache::<u64, u64>(128);
559 assert_eq!(cache2.capacity(), 128);
560 }
561
562 #[test]
563 fn hasher() {
564 let cache = new_cache::<u64, u64>(64);
565 let _ = cache.hasher();
566 }
567
568 #[test]
569 fn debug_impl() {
570 let cache = new_cache::<u64, u64>(64);
571 let debug_str = format!("{:?}", cache);
572 assert!(debug_str.contains("Cache"));
573 }
574
575 #[test]
576 fn bucket_new() {
577 let bucket: Bucket<(u64, u64)> = Bucket::new();
578 assert_eq!(bucket.tag.load(Ordering::Relaxed), 0);
579 }
580
581 #[test]
582 fn many_entries() {
583 let cache: Cache<u64, u64> = new_cache(1024);
584 let n = iters(500);
585
586 for i in 0..n as u64 {
587 cache.insert(i, i * 2);
588 }
589
590 let mut hits = 0;
591 for i in 0..n as u64 {
592 if cache.get(&i) == Some(i * 2) {
593 hits += 1;
594 }
595 }
596 assert!(hits > 0);
597 }
598
599 #[test]
600 fn string_keys() {
601 let cache: Cache<String, i32> = new_cache(1024);
602
603 cache.insert("alpha".to_string(), 1);
604 cache.insert("beta".to_string(), 2);
605 cache.insert("gamma".to_string(), 3);
606
607 assert_eq!(cache.get("alpha"), Some(1));
608 assert_eq!(cache.get("beta"), Some(2));
609 assert_eq!(cache.get("gamma"), Some(3));
610 }
611
612 #[test]
613 fn zero_values() {
614 let cache: Cache<u64, u64> = new_cache(64);
615
616 cache.insert(0, 0);
617 assert_eq!(cache.get(&0), Some(0));
618
619 cache.insert(1, 0);
620 assert_eq!(cache.get(&1), Some(0));
621 }
622
623 #[test]
624 fn clone_value() {
625 #[derive(Clone, PartialEq, Debug)]
626 struct MyValue(u64);
627
628 let cache: Cache<u64, MyValue> = new_cache(64);
629
630 cache.insert(1, MyValue(123));
631 let v = cache.get(&1);
632 assert_eq!(v, Some(MyValue(123)));
633 }
634
635 fn run_concurrent<F>(num_threads: usize, f: F)
636 where
637 F: Fn(usize) + Send + Sync,
638 {
639 thread::scope(|s| {
640 for t in 0..num_threads {
641 let f = &f;
642 s.spawn(move || f(t));
643 }
644 });
645 }
646
647 #[test]
648 fn concurrent_reads() {
649 let cache: Cache<u64, u64> = new_cache(1024);
650 let n = iters(100);
651
652 for i in 0..n as u64 {
653 cache.insert(i, i * 10);
654 }
655
656 run_concurrent(4, |_| {
657 for i in 0..n as u64 {
658 let _ = cache.get(&i);
659 }
660 });
661 }
662
663 #[test]
664 fn concurrent_writes() {
665 let cache: Cache<u64, u64> = new_cache(1024);
666 let n = iters(100);
667
668 run_concurrent(4, |t| {
669 for i in 0..n {
670 cache.insert((t * 1000 + i) as u64, i as u64);
671 }
672 });
673 }
674
675 #[test]
676 fn concurrent_read_write() {
677 let cache: Cache<u64, u64> = new_cache(256);
678 let n = iters(1000);
679
680 run_concurrent(2, |t| {
681 for i in 0..n as u64 {
682 if t == 0 {
683 cache.insert(i % 100, i);
684 } else {
685 let _ = cache.get(&(i % 100));
686 }
687 }
688 });
689 }
690
691 #[test]
692 fn concurrent_get_or_insert() {
693 let cache: Cache<u64, u64> = new_cache(1024);
694 let n = iters(100);
695
696 run_concurrent(8, |_| {
697 for i in 0..n as u64 {
698 let _ = cache.get_or_insert_with(i, |&k| k * 2);
699 }
700 });
701
702 for i in 0..n as u64 {
703 if let Some(v) = cache.get(&i) {
704 assert_eq!(v, i * 2);
705 }
706 }
707 }
708
709 #[test]
710 #[should_panic = "power of two"]
711 fn non_power_of_two() {
712 let _ = new_cache::<u64, u64>(100);
713 }
714
715 #[test]
716 #[should_panic = "len must have its bottom N bits set to zero"]
717 fn small_cache() {
718 let _ = new_cache::<u64, u64>(2);
719 }
720
721 #[test]
722 fn power_of_two_sizes() {
723 for shift in 2..10 {
724 let size = 1 << shift;
725 let cache = new_cache::<u64, u64>(size);
726 assert_eq!(cache.capacity(), size);
727 }
728 }
729
730 #[test]
731 fn equivalent_key_lookup() {
732 let cache: Cache<String, i32> = new_cache(64);
733
734 cache.insert("hello".to_string(), 42);
735
736 assert_eq!(cache.get("hello"), Some(42));
737 }
738
739 #[test]
740 fn large_values() {
741 let cache: Cache<u64, [u8; 1000]> = new_cache(64);
742
743 let large_value = [42u8; 1000];
744 cache.insert(1, large_value);
745
746 assert_eq!(cache.get(&1), Some(large_value));
747 }
748
749 #[test]
750 fn send_sync() {
751 fn assert_send<T: Send>() {}
752 fn assert_sync<T: Sync>() {}
753
754 assert_send::<Cache<u64, u64>>();
755 assert_sync::<Cache<u64, u64>>();
756 assert_send::<Bucket<(u64, u64)>>();
757 assert_sync::<Bucket<(u64, u64)>>();
758 }
759
760 #[test]
761 fn get_or_try_insert_with_ok() {
762 let cache = new_cache(1024);
763
764 let mut computed = false;
765 let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |&k| {
766 computed = true;
767 Ok(k * 2)
768 });
769 assert!(computed);
770 assert_eq!(result, Ok(84));
771
772 computed = false;
773 let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |&k| {
774 computed = true;
775 Ok(k * 2)
776 });
777 assert!(!computed);
778 assert_eq!(result, Ok(84));
779 }
780
781 #[test]
782 fn get_or_try_insert_with_err() {
783 let cache: Cache<u64, u64> = new_cache(1024);
784
785 let result: Result<u64, &str> = cache.get_or_try_insert_with(42, |_| Err("failed"));
786 assert_eq!(result, Err("failed"));
787
788 assert_eq!(cache.get(&42), None);
789 }
790
791 #[test]
792 fn get_or_try_insert_with_ref_ok() {
793 let cache: Cache<String, usize> = new_cache(64);
794
795 let key = "hello";
796 let result: Result<usize, &str> =
797 cache.get_or_try_insert_with_ref(key, |s| Ok(s.len()), |s| s.to_string());
798 assert_eq!(result, Ok(5));
799
800 let result2: Result<usize, &str> =
801 cache.get_or_try_insert_with_ref(key, |_| Ok(999), |s| s.to_string());
802 assert_eq!(result2, Ok(5));
803 }
804
805 #[test]
806 fn get_or_try_insert_with_ref_err() {
807 let cache: Cache<String, usize> = new_cache(64);
808
809 let key = "hello";
810 let result: Result<usize, &str> =
811 cache.get_or_try_insert_with_ref(key, |_| Err("failed"), |s| s.to_string());
812 assert_eq!(result, Err("failed"));
813
814 assert_eq!(cache.get(key), None);
815 }
816
817 #[test]
818 fn drop_on_cache_drop() {
819 use std::sync::atomic::{AtomicUsize, Ordering};
820
821 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
822
823 #[derive(Clone, Hash, Eq, PartialEq)]
824 struct DropKey(u64);
825 impl Drop for DropKey {
826 fn drop(&mut self) {
827 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
828 }
829 }
830
831 #[derive(Clone)]
832 struct DropValue(#[allow(dead_code)] u64);
833 impl Drop for DropValue {
834 fn drop(&mut self) {
835 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
836 }
837 }
838
839 DROP_COUNT.store(0, Ordering::SeqCst);
840 {
841 let cache: super::Cache<DropKey, DropValue, BH> =
842 super::Cache::new(64, Default::default());
843 cache.insert(DropKey(1), DropValue(100));
844 cache.insert(DropKey(2), DropValue(200));
845 cache.insert(DropKey(3), DropValue(300));
846 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
847 }
848 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 6);
850 }
851
852 #[test]
853 fn drop_on_eviction() {
854 use std::sync::atomic::{AtomicUsize, Ordering};
855
856 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
857
858 #[derive(Clone, Hash, Eq, PartialEq)]
859 struct DropKey(u64);
860 impl Drop for DropKey {
861 fn drop(&mut self) {
862 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
863 }
864 }
865
866 #[derive(Clone)]
867 struct DropValue(#[allow(dead_code)] u64);
868 impl Drop for DropValue {
869 fn drop(&mut self) {
870 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
871 }
872 }
873
874 DROP_COUNT.store(0, Ordering::SeqCst);
875 {
876 let cache: super::Cache<DropKey, DropValue, BH> =
877 super::Cache::new(64, Default::default());
878 cache.insert(DropKey(1), DropValue(100));
879 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 0);
880 cache.insert(DropKey(1), DropValue(200));
882 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 2);
884 }
885 assert_eq!(DROP_COUNT.load(Ordering::SeqCst), 4);
887 }
888}