quick_hash_cache/
lib.rs

1use std::borrow::Borrow;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::{
4    atomic::{AtomicUsize, Ordering},
5    Arc,
6};
7
8pub use hashbrown::hash_map::DefaultHashBuilder;
9use hashbrown::hash_map::{HashMap, RawEntryMut};
10
11use tokio::sync::{OwnedRwLockMappedWriteGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
12
13pub mod lru;
14
15#[derive(Debug)]
16pub struct CHashMap<K, T, S = DefaultHashBuilder> {
17    hash_builder: S,
18    shards: Vec<Arc<RwLock<HashMap<K, T, S>>>>,
19    size: AtomicUsize,
20}
21
22impl<K, T> CHashMap<K, T, DefaultHashBuilder> {
23    pub fn new(num_shards: usize) -> Self {
24        Self::with_hasher(num_shards, DefaultHashBuilder::default())
25    }
26}
27
28impl<K, T> Default for CHashMap<K, T, DefaultHashBuilder> {
29    fn default() -> Self {
30        Self::new(num_cpus::get())
31    }
32}
33
34#[doc(hidden)]
35pub trait Erased {}
36impl<T> Erased for T {}
37
38pub type ReadHandle<T, U> = OwnedRwLockReadGuard<T, U>;
39pub type WriteHandle<T, U> = OwnedRwLockMappedWriteGuard<T, U>;
40
41pub type Shard<K, T, S> = HashMap<K, T, S>;
42
43impl<K, T, S> CHashMap<K, T, S>
44where
45    S: Clone,
46{
47    pub fn with_hasher(num_shards: usize, hash_builder: S) -> Self {
48        CHashMap {
49            shards: (0..num_shards)
50                .into_iter()
51                .map(|_| Arc::new(RwLock::new(HashMap::with_hasher(hash_builder.clone()))))
52                .collect(),
53            hash_builder,
54            size: AtomicUsize::new(0),
55        }
56    }
57}
58
59impl<K, T, S> CHashMap<K, T, S>
60where
61    K: Clone,
62    T: Clone,
63    S: Clone,
64{
65    /// Duplicates/Clones the CHashMap. A CHashMap cannot be cloned regularly due to internal async locking.
66    pub async fn duplicate(&self) -> Self {
67        let mut shards = Vec::with_capacity(self.shards.len());
68        let mut size = 0;
69
70        for shard in &self.shards {
71            let shard = shard.read().await.clone();
72            size += shard.len();
73            shards.push(Arc::new(RwLock::new(shard)));
74        }
75
76        CHashMap {
77            shards,
78            hash_builder: self.hash_builder.clone(),
79            size: AtomicUsize::new(size),
80        }
81    }
82}
83
84impl<K, T, S> CHashMap<K, T, S>
85where
86    K: Hash + Eq,
87    S: BuildHasher,
88{
89    pub fn hash_builder(&self) -> &S {
90        &self.hash_builder
91    }
92
93    #[inline]
94    fn hash_and_shard<Q: ?Sized>(&self, key: &Q) -> (u64, usize)
95    where
96        Q: Hash + Eq,
97    {
98        let mut hasher = self.hash_builder.build_hasher();
99        key.hash(&mut hasher);
100        let hash = hasher.finish();
101        (hash, hash as usize % self.shards.len())
102    }
103
104    pub async fn clear(&self) {
105        for shard in &self.shards {
106            let mut shard = shard.write().await;
107
108            let len = shard.len();
109            shard.clear();
110
111            self.size.fetch_sub(len, Ordering::SeqCst);
112        }
113    }
114
115    pub async fn retain<F>(&self, f: F)
116    where
117        F: Fn(&K, &mut T) -> bool,
118    {
119        for shard in &self.shards {
120            let mut shard = shard.write().await;
121
122            let len = shard.len();
123            shard.retain(&f);
124
125            self.size.fetch_sub(len - shard.len(), Ordering::SeqCst);
126        }
127    }
128
129    pub fn iter_shards<'a>(&'a self) -> impl Iterator<Item = &'a RwLock<Shard<K, T, S>>> {
130        self.shards.iter().map(|s| &**s)
131    }
132
133    pub fn size(&self) -> usize {
134        self.size.load(Ordering::SeqCst)
135    }
136
137    pub fn num_shards(&self) -> usize {
138        self.shards.len()
139    }
140
141    pub fn try_maybe_contains_hash(&self, hash: u64) -> bool {
142        let shard_idx = hash as usize % self.shards.len();
143        let shard = unsafe { self.shards.get_unchecked(shard_idx) };
144
145        if let Ok(shard) = shard.try_read() {
146            shard.raw_entry().from_hash(hash, |_| true).is_some()
147        } else {
148            false
149        }
150    }
151
152    pub async fn contains_hash(&self, hash: u64) -> bool {
153        let shard_idx = hash as usize % self.shards.len();
154        let shard = unsafe { self.shards.get_unchecked(shard_idx) };
155
156        shard.read().await.raw_entry().from_hash(hash, |_| true).is_some()
157    }
158
159    pub async fn contains<Q: ?Sized>(&self, key: &Q) -> bool
160    where
161        K: Borrow<Q>,
162        Q: Hash + Eq,
163    {
164        self.contains_hash(self.hash_and_shard(key).0).await
165    }
166
167    pub async fn remove<Q: ?Sized>(&self, key: &Q) -> Option<T>
168    where
169        K: Borrow<Q>,
170        Q: Hash + Eq,
171    {
172        let (hash, shard_idx) = self.hash_and_shard(&key);
173        let mut shard = unsafe { self.shards.get_unchecked(shard_idx).write().await };
174
175        match shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
176            RawEntryMut::Occupied(occupied) => {
177                let value = occupied.remove();
178                self.size.fetch_sub(1, Ordering::SeqCst);
179                Some(value)
180            }
181            RawEntryMut::Vacant(_) => None,
182        }
183    }
184
185    pub async fn insert(&self, key: K, value: T) -> Option<T> {
186        let (hash, shard_idx) = self.hash_and_shard(&key);
187        let mut shard = unsafe { self.shards.get_unchecked(shard_idx).write().await };
188
189        match shard.raw_entry_mut().from_key_hashed_nocheck(hash, &key) {
190            RawEntryMut::Occupied(mut occupied) => Some(occupied.insert(value)),
191            RawEntryMut::Vacant(vacant) => {
192                self.size.fetch_add(1, Ordering::SeqCst);
193                vacant.insert_hashed_nocheck(hash, key, value);
194                None
195            }
196        }
197    }
198
199    pub async fn get<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, T>>
200    where
201        K: Borrow<Q>,
202        Q: Hash + Eq,
203    {
204        let (hash, shard_idx) = self.hash_and_shard(key);
205        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().read_owned().await };
206
207        OwnedRwLockReadGuard::try_map(shard, |shard| {
208            match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
209                Some((_, value)) => Some(value),
210                None => None,
211            }
212        })
213        .ok()
214    }
215
216    pub async fn get_cloned<Q: ?Sized>(&self, key: &Q) -> Option<T>
217    where
218        K: Borrow<Q>,
219        Q: Hash + Eq,
220        T: Clone,
221    {
222        let (hash, shard_idx) = self.hash_and_shard(key);
223        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().read_owned().await };
224
225        match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
226            Some((_, value)) => Some(value.clone()),
227            None => None,
228        }
229    }
230
231    pub async fn get_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, T>>
232    where
233        K: Borrow<Q>,
234        Q: Hash + Eq,
235    {
236        let (hash, shard_idx) = self.hash_and_shard(key);
237        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
238
239        OwnedRwLockWriteGuard::try_map(shard, |shard| {
240            match shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
241                RawEntryMut::Occupied(occupied) => Some(occupied.into_mut()),
242                RawEntryMut::Vacant(_) => None,
243            }
244        })
245        .ok()
246    }
247
248    pub async fn get_or_insert(&self, key: &K, on_insert: impl FnOnce() -> T) -> ReadHandle<impl Erased, T>
249    where
250        K: Clone,
251    {
252        let (hash, shard_idx) = self.hash_and_shard(key);
253        let mut shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
254
255        if let RawEntryMut::Vacant(vacant) = shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
256            self.size.fetch_add(1, Ordering::SeqCst);
257
258            vacant.insert_hashed_nocheck(hash, key.clone(), on_insert());
259        }
260
261        // TODO: Having to do another lookup for a read-reference is wasteful, maybe use an alternate custom ReadHandle?
262        OwnedRwLockReadGuard::map(OwnedRwLockWriteGuard::downgrade(shard), |shard| {
263            match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
264                Some((_, value)) => value,
265                None => unreachable!(),
266            }
267        })
268    }
269
270    pub async fn get_mut_or_insert(
271        &self,
272        key: &K,
273        on_insert: impl FnOnce() -> T,
274    ) -> WriteHandle<impl Erased, T>
275    where
276        K: Clone,
277    {
278        let (hash, shard_idx) = self.hash_and_shard(key);
279        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
280
281        OwnedRwLockWriteGuard::map(shard, |shard| {
282            shard
283                .raw_entry_mut()
284                .from_key_hashed_nocheck(hash, key)
285                .or_insert_with(|| {
286                    self.size.fetch_add(1, Ordering::SeqCst);
287
288                    (key.clone(), on_insert())
289                })
290                .1
291        })
292    }
293
294    pub async fn get_or_default(&self, key: &K) -> ReadHandle<impl Erased, T>
295    where
296        K: Clone,
297        T: Default,
298    {
299        self.get_or_insert(key, Default::default).await
300    }
301
302    pub async fn get_mut_or_default(&self, key: &K) -> WriteHandle<impl Erased, T>
303    where
304        K: Clone,
305        T: Default,
306    {
307        self.get_mut_or_insert(key, Default::default).await
308    }
309
310    /*
311    pub async fn shard_mut<Q: ?Sized>(&self, key: &Q) -> WriteLock<K, T, S, Shard<K, T, S>>
312    where
313        K: Borrow<Q>,
314        Q: Hash + Eq,
315    {
316        let (_, shard_idx) = self.hash_and_shard(key);
317        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
318
319        OwnedRwLockWriteGuard::map(shard, |shard| shard)
320    }
321
322    pub async fn entry<Q: ?Sized>(&self, key: &Q) -> WriteHandle<impl Erased, Entry<'_, K, T, S>>
323    where
324        K: Borrow<Q>,
325        Q: Hash + Eq,
326    {
327        let (hash, shard_idx) = self.hash_and_shard(key);
328        let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
329
330        OwnedRwLockWriteGuard::map(shard, |shard| {
331            shard.raw_entry_mut().from_key_hashed_nocheck(hash, key)
332        })
333    }
334    */
335
336    /// Aggregates all the provided keys and batches together access to the underlying shards,
337    /// reducing locking overhead at the cost of memory to buffer keys/hashes.
338    pub async fn batch_read<'a, Q: 'a + ?Sized, I, F>(
339        &self,
340        keys: I,
341        cache: Option<&mut Vec<(&'a Q, u64, usize)>>,
342        mut f: F,
343    ) where
344        K: Borrow<Q>,
345        Q: Hash + Eq,
346        I: IntoIterator<Item = &'a Q>,
347        F: FnMut(&'a Q, Option<(&K, &T)>),
348    {
349        let mut own_cache = Vec::new();
350        let cache = match cache {
351            Some(cache) => {
352                cache.clear();
353                cache
354            }
355            None => &mut own_cache,
356        };
357
358        cache.extend(keys.into_iter().map(|key| {
359            let (hash, shard) = self.hash_and_shard(key);
360            (key, hash, shard)
361        }));
362
363        if cache.is_empty() {
364            return;
365        }
366
367        cache.sort_unstable_by_key(|(_, _, shard)| *shard);
368
369        let mut i = 0;
370        'outer: loop {
371            let current_shard = cache[i].2;
372            let shard = unsafe { self.shards.get_unchecked(current_shard).read().await };
373
374            while cache[i].2 == current_shard {
375                f(
376                    cache[i].0,
377                    shard.raw_entry().from_key_hashed_nocheck(cache[i].1, cache[i].0),
378                );
379                i += 1;
380
381                if i >= cache.len() {
382                    break 'outer;
383                }
384            }
385        }
386
387        cache.clear();
388    }
389
390    /// Aggregates all the provided keys and batches together access to the underlying shards,
391    /// reducing locking overhead at the cost of memory to buffer keys/hashes
392    pub async fn batch_write<'a, Q: 'a + ?Sized, I, F>(
393        &self,
394        keys: I,
395        cache: Option<&mut Vec<(&'a Q, u64, usize)>>,
396        mut f: F,
397    ) where
398        K: Borrow<Q>,
399        Q: Hash + Eq,
400        I: IntoIterator<Item = &'a Q>,
401        F: FnMut(&'a Q, hashbrown::hash_map::RawEntryMut<K, T, S>),
402    {
403        let mut own_cache = Vec::new();
404        let cache = match cache {
405            Some(cache) => {
406                cache.clear();
407                cache
408            }
409            None => &mut own_cache,
410        };
411
412        cache.extend(keys.into_iter().map(|key| {
413            let (hash, shard) = self.hash_and_shard(key);
414            (key, hash, shard)
415        }));
416
417        if cache.is_empty() {
418            return;
419        }
420
421        cache.sort_unstable_by_key(|(_, _, shard)| *shard);
422
423        let mut i = 0;
424        'outer: loop {
425            let current_shard = cache[i].2;
426            let mut shard = unsafe { self.shards.get_unchecked(current_shard).write().await };
427
428            while cache[i].2 == current_shard {
429                f(
430                    cache[i].0,
431                    shard
432                        .raw_entry_mut()
433                        .from_key_hashed_nocheck(cache[i].1, cache[i].0),
434                );
435                i += 1;
436
437                if i >= cache.len() {
438                    break 'outer;
439                }
440            }
441        }
442
443        cache.clear();
444    }
445}