1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![allow(clippy::new_without_default)]
4
5use core::{
6 cell::UnsafeCell,
7 hash::{BuildHasher, Hash},
8 mem::MaybeUninit,
9 sync::atomic::{AtomicUsize, Ordering},
10};
11use equivalent::Equivalent;
12
13const LOCKED_BIT: usize = 0x0000_8000;
14
15#[cfg(feature = "rapidhash")]
16type DefaultBuildHasher = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
17#[cfg(not(feature = "rapidhash"))]
18type DefaultBuildHasher = std::hash::RandomState;
19
20pub struct Cache<K, V, S = DefaultBuildHasher> {
68 entries: *const [Bucket<(K, V)>],
69 build_hasher: S,
70 drop: bool,
71}
72
73impl<K, V, S> core::fmt::Debug for Cache<K, V, S> {
74 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75 f.debug_struct("Cache").finish_non_exhaustive()
76 }
77}
78
79unsafe impl<K: Send, V: Send, S: Send> Send for Cache<K, V, S> {}
81unsafe impl<K: Send, V: Send, S: Sync> Sync for Cache<K, V, S> {}
82
83impl<K, V, S> Cache<K, V, S>
84where
85 K: Hash + Eq,
86 S: BuildHasher,
87{
88 pub fn new(num: usize, build_hasher: S) -> Self {
96 assert!(num.is_power_of_two(), "capacity must be a power of two");
97 let entries =
98 Box::into_raw((0..num).map(|_| Bucket::new()).collect::<Vec<_>>().into_boxed_slice());
99 Self::new_inner(entries, build_hasher, true)
100 }
101
102 #[inline]
108 pub const fn new_static(entries: &'static [Bucket<(K, V)>], build_hasher: S) -> Self {
109 Self::new_inner(entries, build_hasher, false)
110 }
111
112 #[inline]
113 const fn new_inner(entries: *const [Bucket<(K, V)>], build_hasher: S, drop: bool) -> Self {
114 const {
115 assert!(!std::mem::needs_drop::<K>(), "dropping keys is not supported yet");
116 assert!(!std::mem::needs_drop::<V>(), "dropping values is not supported yet");
117 }
118 assert!(entries.len().is_power_of_two());
119 Self { entries, build_hasher, drop }
120 }
121
122 #[inline]
123 const fn index_mask(&self) -> usize {
124 let n = self.capacity();
125 unsafe { core::hint::assert_unchecked(n.is_power_of_two()) };
126 n - 1
127 }
128
129 #[inline]
130 const fn tag_mask(&self) -> usize {
131 !self.index_mask()
132 }
133
134 #[inline]
136 pub const fn hasher(&self) -> &S {
137 &self.build_hasher
138 }
139
140 #[inline]
142 pub const fn capacity(&self) -> usize {
143 self.entries.len()
144 }
145}
146
147impl<K, V, S> Cache<K, V, S>
148where
149 K: Hash + Eq,
150 V: Clone,
151 S: BuildHasher,
152{
153 pub fn get<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> Option<V> {
155 let (bucket, tag) = self.calc(key);
156 self.get_inner(key, bucket, tag)
157 }
158
159 #[inline]
160 fn get_inner<Q: ?Sized + Hash + Equivalent<K>>(
161 &self,
162 key: &Q,
163 bucket: &Bucket<(K, V)>,
164 tag: usize,
165 ) -> Option<V> {
166 if bucket.try_lock(Some(tag)) {
167 let (ck, v) = unsafe { (*bucket.data.get()).assume_init_ref() };
169 if key.equivalent(ck) {
170 let v = v.clone();
171 bucket.unlock(tag);
172 return Some(v);
173 }
174 bucket.unlock(tag);
175 }
177
178 None
179 }
180
181 pub fn insert(&self, key: K, value: V) {
183 let (bucket, tag) = self.calc(&key);
184 self.insert_inner(|| key, || value, bucket, tag);
185 }
186
187 #[inline]
188 fn insert_inner(
189 &self,
190 make_key: impl FnOnce() -> K,
191 make_value: impl FnOnce() -> V,
192 bucket: &Bucket<(K, V)>,
193 tag: usize,
194 ) {
195 if bucket.try_lock(None) {
196 unsafe {
198 let data = (&mut *bucket.data.get()).as_mut_ptr();
199 (&raw mut (*data).0).write(make_key());
200 (&raw mut (*data).1).write(make_value());
201 }
202 bucket.unlock(tag);
203 }
204 }
205
206 #[inline]
211 pub fn get_or_insert_with<F>(&self, key: K, f: F) -> V
212 where
213 F: FnOnce(&K) -> V,
214 {
215 let mut key = std::mem::ManuallyDrop::new(key);
216 let mut read = false;
217 let r = self.get_or_insert_with_ref(&*key, f, |k| {
218 read = true;
219 unsafe { std::ptr::read(k) }
220 });
221 if !read {
222 unsafe { std::mem::ManuallyDrop::drop(&mut key) }
223 }
224 r
225 }
226
227 #[inline]
237 pub fn get_or_insert_with_ref<'a, Q, F, Cvt>(&self, key: &'a Q, f: F, cvt: Cvt) -> V
238 where
239 Q: ?Sized + Hash + Equivalent<K>,
240 F: FnOnce(&'a Q) -> V,
241 Cvt: FnOnce(&'a Q) -> K,
242 {
243 let (bucket, tag) = self.calc(key);
244 if let Some(v) = self.get_inner(key, bucket, tag) {
245 return v;
246 }
247 let value = f(key);
248 self.insert_inner(|| cvt(key), || value.clone(), bucket, tag);
249 value
250 }
251
252 #[inline]
253 fn calc<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> (&Bucket<(K, V)>, usize) {
254 let hash = self.hash_key(key);
255 let bucket = unsafe { (&*self.entries).get_unchecked(hash & self.index_mask()) };
257 let tag = hash & self.tag_mask();
258 (bucket, tag)
259 }
260
261 #[inline]
262 fn hash_key<Q: ?Sized + Hash + Equivalent<K>>(&self, key: &Q) -> usize {
263 let hash = self.build_hasher.hash_one(key);
264
265 if cfg!(target_pointer_width = "32") {
266 ((hash >> 32) as usize) ^ (hash as usize)
267 } else {
268 hash as usize
269 }
270 }
271}
272
273impl<K, V, S> Drop for Cache<K, V, S> {
274 fn drop(&mut self) {
275 if self.drop {
276 drop(unsafe { Box::from_raw(self.entries.cast_mut()) });
277 }
278 }
279}
280
281#[repr(C, align(128))]
290#[doc(hidden)]
291pub struct Bucket<T> {
292 tag: AtomicUsize,
293 data: UnsafeCell<MaybeUninit<T>>,
294}
295
296impl<T> Bucket<T> {
297 #[inline]
299 pub const fn new() -> Self {
300 Self { tag: AtomicUsize::new(0), data: UnsafeCell::new(MaybeUninit::zeroed()) }
301 }
302
303 #[inline]
304 fn try_lock(&self, expected: Option<usize>) -> bool {
305 let state = self.tag.load(Ordering::Relaxed);
306 if let Some(expected) = expected {
307 if state != expected {
308 return false;
309 }
310 } else if state & LOCKED_BIT != 0 {
311 return false;
312 }
313 self.tag
314 .compare_exchange(state, state | LOCKED_BIT, Ordering::Acquire, Ordering::Relaxed)
315 .is_ok()
316 }
317
318 #[inline]
319 fn unlock(&self, tag: usize) {
320 self.tag.store(tag, Ordering::Release);
321 }
322}
323
324unsafe impl<T: Send> Send for Bucket<T> {}
326unsafe impl<T: Send> Sync for Bucket<T> {}
327
328#[macro_export]
351macro_rules! static_cache {
352 ($K:ty, $V:ty, $size:expr) => {
353 $crate::static_cache!($K, $V, $size, Default::default())
354 };
355 ($K:ty, $V:ty, $size:expr, $hasher:expr) => {{
356 static ENTRIES: [$crate::Bucket<($K, $V)>; $size] =
357 [const { $crate::Bucket::new() }; $size];
358 $crate::Cache::new_static(&ENTRIES, $hasher)
359 }};
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use std::thread;
366
367 const fn iters(n: usize) -> usize {
368 if cfg!(miri) { n / 10 } else { n }
369 }
370
371 type BH = std::hash::BuildHasherDefault<rapidhash::fast::RapidHasher<'static>>;
372 type Cache<K, V> = super::Cache<K, V, BH>;
373
374 fn new_cache<K: Hash + Eq, V: Clone>(size: usize) -> Cache<K, V> {
375 Cache::new(size, Default::default())
376 }
377
378 #[test]
379 fn test_basic_get_or_insert() {
380 let cache = new_cache(1024);
381
382 let mut computed = false;
383 let value = cache.get_or_insert_with(42, |&k| {
384 computed = true;
385 k * 2
386 });
387 assert!(computed);
388 assert_eq!(value, 84);
389
390 computed = false;
391 let value = cache.get_or_insert_with(42, |&k| {
392 computed = true;
393 k * 2
394 });
395 assert!(!computed);
396 assert_eq!(value, 84);
397 }
398
399 #[test]
400 fn test_different_keys() {
401 let cache: Cache<&'static str, usize> = static_cache!(&'static str, usize, 1024);
402
403 let v1 = cache.get_or_insert_with("hello", |s| s.len());
404 let v2 = cache.get_or_insert_with("world!", |s| s.len());
405
406 assert_eq!(v1, 5);
407 assert_eq!(v2, 6);
408 }
409
410 #[test]
411 fn test_new_dynamic_allocation() {
412 let cache: Cache<u32, u32> = new_cache(64);
413 assert_eq!(cache.capacity(), 64);
414
415 cache.insert(1, 100);
416 assert_eq!(cache.get(&1), Some(100));
417 }
418
419 #[test]
420 fn test_get_miss() {
421 let cache = new_cache::<u64, u64>(64);
422 assert_eq!(cache.get(&999), None);
423 }
424
425 #[test]
426 fn test_insert_and_get() {
427 let cache: Cache<u64, &'static str> = new_cache(64);
428
429 cache.insert(1, "one");
430 cache.insert(2, "two");
431 cache.insert(3, "three");
432
433 assert_eq!(cache.get(&1), Some("one"));
434 assert_eq!(cache.get(&2), Some("two"));
435 assert_eq!(cache.get(&3), Some("three"));
436 assert_eq!(cache.get(&4), None);
437 }
438
439 #[test]
440 fn test_insert_twice() {
441 let cache = new_cache(64);
442
443 cache.insert(42, 1);
444 assert_eq!(cache.get(&42), Some(1));
445
446 cache.insert(42, 2);
447 let v = cache.get(&42);
448 assert!(v == Some(1) || v == Some(2));
449 }
450
451 #[test]
452 fn test_get_or_insert_with_ref() {
453 let cache: Cache<&'static str, usize> = new_cache(64);
454
455 let key = "hello";
456 let value = cache.get_or_insert_with_ref(key, |s| s.len(), |s| s);
457 assert_eq!(value, 5);
458
459 let value2 = cache.get_or_insert_with_ref(key, |_| 999, |s| s);
460 assert_eq!(value2, 5);
461 }
462
463 #[test]
464 fn test_get_or_insert_with_ref_different_keys() {
465 let cache: Cache<&'static str, usize> = new_cache(1024);
466
467 let v1 = cache.get_or_insert_with_ref("foo", |s| s.len(), |s| s);
468 let v2 = cache.get_or_insert_with_ref("barbaz", |s| s.len(), |s| s);
469
470 assert_eq!(v1, 3);
471 assert_eq!(v2, 6);
472 }
473
474 #[test]
475 fn test_capacity() {
476 let cache = new_cache::<u64, u64>(256);
477 assert_eq!(cache.capacity(), 256);
478
479 let cache2 = new_cache::<u64, u64>(128);
480 assert_eq!(cache2.capacity(), 128);
481 }
482
483 #[test]
484 fn test_hasher() {
485 let cache = new_cache::<u64, u64>(64);
486 let _ = cache.hasher();
487 }
488
489 #[test]
490 fn test_debug_impl() {
491 let cache = new_cache::<u64, u64>(64);
492 let debug_str = format!("{:?}", cache);
493 assert!(debug_str.contains("Cache"));
494 }
495
496 #[test]
497 fn test_bucket_new() {
498 let bucket: Bucket<(u64, u64)> = Bucket::new();
499 assert_eq!(bucket.tag.load(Ordering::Relaxed), 0);
500 }
501
502 #[test]
503 fn test_many_entries() {
504 let cache: Cache<u64, u64> = new_cache(1024);
505 let n = iters(500);
506
507 for i in 0..n as u64 {
508 cache.insert(i, i * 2);
509 }
510
511 let mut hits = 0;
512 for i in 0..n as u64 {
513 if cache.get(&i) == Some(i * 2) {
514 hits += 1;
515 }
516 }
517 assert!(hits > 0);
518 }
519
520 #[test]
521 fn test_string_keys() {
522 let cache: Cache<&'static str, i32> = new_cache(1024);
523
524 cache.insert("alpha", 1);
525 cache.insert("beta", 2);
526 cache.insert("gamma", 3);
527
528 assert_eq!(cache.get(&"alpha"), Some(1));
529 assert_eq!(cache.get(&"beta"), Some(2));
530 assert_eq!(cache.get(&"gamma"), Some(3));
531 }
532
533 #[test]
534 fn test_zero_values() {
535 let cache: Cache<u64, u64> = new_cache(64);
536
537 cache.insert(0, 0);
538 assert_eq!(cache.get(&0), Some(0));
539
540 cache.insert(1, 0);
541 assert_eq!(cache.get(&1), Some(0));
542 }
543
544 #[test]
545 fn test_clone_value() {
546 #[derive(Clone, PartialEq, Debug)]
547 struct MyValue(u64);
548
549 let cache: Cache<u64, MyValue> = new_cache(64);
550
551 cache.insert(1, MyValue(123));
552 let v = cache.get(&1);
553 assert_eq!(v, Some(MyValue(123)));
554 }
555
556 fn run_concurrent<F>(num_threads: usize, f: F)
557 where
558 F: Fn(usize) + Send + Sync,
559 {
560 thread::scope(|s| {
561 for t in 0..num_threads {
562 let f = &f;
563 s.spawn(move || f(t));
564 }
565 });
566 }
567
568 #[test]
569 fn test_concurrent_reads() {
570 let cache: Cache<u64, u64> = new_cache(1024);
571 let n = iters(100);
572
573 for i in 0..n as u64 {
574 cache.insert(i, i * 10);
575 }
576
577 run_concurrent(4, |_| {
578 for i in 0..n as u64 {
579 let _ = cache.get(&i);
580 }
581 });
582 }
583
584 #[test]
585 fn test_concurrent_writes() {
586 let cache: Cache<u64, u64> = new_cache(1024);
587 let n = iters(100);
588
589 run_concurrent(4, |t| {
590 for i in 0..n {
591 cache.insert((t * 1000 + i) as u64, i as u64);
592 }
593 });
594 }
595
596 #[test]
597 fn test_concurrent_read_write() {
598 let cache: Cache<u64, u64> = new_cache(256);
599 let n = iters(1000);
600
601 run_concurrent(2, |t| {
602 for i in 0..n as u64 {
603 if t == 0 {
604 cache.insert(i % 100, i);
605 } else {
606 let _ = cache.get(&(i % 100));
607 }
608 }
609 });
610 }
611
612 #[test]
613 fn test_concurrent_get_or_insert() {
614 let cache: Cache<u64, u64> = new_cache(1024);
615 let n = iters(100);
616
617 run_concurrent(8, |_| {
618 for i in 0..n as u64 {
619 let _ = cache.get_or_insert_with(i, |&k| k * 2);
620 }
621 });
622
623 for i in 0..n as u64 {
624 if let Some(v) = cache.get(&i) {
625 assert_eq!(v, i * 2);
626 }
627 }
628 }
629
630 #[test]
631 #[should_panic]
632 fn test_non_power_of_two_panics() {
633 let _ = new_cache::<u64, u64>(100);
634 }
635
636 #[test]
637 fn test_power_of_two_sizes() {
638 for shift in 1..10 {
639 let size = 1 << shift;
640 let cache = new_cache::<u64, u64>(size);
641 assert_eq!(cache.capacity(), size);
642 }
643 }
644
645 #[test]
646 fn test_small_cache() {
647 let cache = new_cache(2);
648 assert_eq!(cache.capacity(), 2);
649
650 cache.insert(1, 10);
651 cache.insert(2, 20);
652 cache.insert(3, 30);
653
654 let count = [1, 2, 3].iter().filter(|&&k| cache.get(&k).is_some()).count();
655 assert!(count <= 2);
656 }
657
658 #[test]
659 fn test_equivalent_key_lookup() {
660 let cache = new_cache(64);
661
662 cache.insert("hello", 42);
663
664 assert_eq!(cache.get(&"hello"), Some(42));
665 }
666
667 #[test]
668 fn test_large_values() {
669 let cache: Cache<u64, [u8; 1000]> = new_cache(64);
670
671 let large_value = [42u8; 1000];
672 cache.insert(1, large_value);
673
674 assert_eq!(cache.get(&1), Some(large_value));
675 }
676
677 #[test]
678 fn test_send_sync() {
679 fn assert_send<T: Send>() {}
680 fn assert_sync<T: Sync>() {}
681
682 assert_send::<Cache<u64, u64>>();
683 assert_sync::<Cache<u64, u64>>();
684 assert_send::<Bucket<(u64, u64)>>();
685 assert_sync::<Bucket<(u64, u64)>>();
686 }
687}