1use std::borrow::Borrow;
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::sync::{
4 atomic::{AtomicUsize, Ordering},
5 Arc,
6};
7
8pub use hashbrown::hash_map::DefaultHashBuilder;
9use hashbrown::hash_map::{HashMap, RawEntryMut};
10
11use tokio::sync::{OwnedRwLockMappedWriteGuard, OwnedRwLockReadGuard, OwnedRwLockWriteGuard, RwLock};
12
13pub mod lru;
14
15#[derive(Debug)]
16pub struct CHashMap<K, T, S = DefaultHashBuilder> {
17 hash_builder: S,
18 shards: Vec<Arc<RwLock<HashMap<K, T, S>>>>,
19 size: AtomicUsize,
20}
21
22impl<K, T> CHashMap<K, T, DefaultHashBuilder> {
23 pub fn new(num_shards: usize) -> Self {
24 Self::with_hasher(num_shards, DefaultHashBuilder::default())
25 }
26}
27
28impl<K, T> Default for CHashMap<K, T, DefaultHashBuilder> {
29 fn default() -> Self {
30 Self::new(num_cpus::get())
31 }
32}
33
34#[doc(hidden)]
35pub trait Erased {}
36impl<T> Erased for T {}
37
38pub type ReadHandle<T, U> = OwnedRwLockReadGuard<T, U>;
39pub type WriteHandle<T, U> = OwnedRwLockMappedWriteGuard<T, U>;
40
41pub type Shard<K, T, S> = HashMap<K, T, S>;
42
43impl<K, T, S> CHashMap<K, T, S>
44where
45 S: Clone,
46{
47 pub fn with_hasher(num_shards: usize, hash_builder: S) -> Self {
48 CHashMap {
49 shards: (0..num_shards)
50 .into_iter()
51 .map(|_| Arc::new(RwLock::new(HashMap::with_hasher(hash_builder.clone()))))
52 .collect(),
53 hash_builder,
54 size: AtomicUsize::new(0),
55 }
56 }
57}
58
59impl<K, T, S> CHashMap<K, T, S>
60where
61 K: Clone,
62 T: Clone,
63 S: Clone,
64{
65 pub async fn duplicate(&self) -> Self {
67 let mut shards = Vec::with_capacity(self.shards.len());
68 let mut size = 0;
69
70 for shard in &self.shards {
71 let shard = shard.read().await.clone();
72 size += shard.len();
73 shards.push(Arc::new(RwLock::new(shard)));
74 }
75
76 CHashMap {
77 shards,
78 hash_builder: self.hash_builder.clone(),
79 size: AtomicUsize::new(size),
80 }
81 }
82}
83
84impl<K, T, S> CHashMap<K, T, S>
85where
86 K: Hash + Eq,
87 S: BuildHasher,
88{
89 pub fn hash_builder(&self) -> &S {
90 &self.hash_builder
91 }
92
93 #[inline]
94 fn hash_and_shard<Q: ?Sized>(&self, key: &Q) -> (u64, usize)
95 where
96 Q: Hash + Eq,
97 {
98 let mut hasher = self.hash_builder.build_hasher();
99 key.hash(&mut hasher);
100 let hash = hasher.finish();
101 (hash, hash as usize % self.shards.len())
102 }
103
104 pub async fn clear(&self) {
105 for shard in &self.shards {
106 let mut shard = shard.write().await;
107
108 let len = shard.len();
109 shard.clear();
110
111 self.size.fetch_sub(len, Ordering::SeqCst);
112 }
113 }
114
115 pub async fn retain<F>(&self, f: F)
116 where
117 F: Fn(&K, &mut T) -> bool,
118 {
119 for shard in &self.shards {
120 let mut shard = shard.write().await;
121
122 let len = shard.len();
123 shard.retain(&f);
124
125 self.size.fetch_sub(len - shard.len(), Ordering::SeqCst);
126 }
127 }
128
129 pub fn iter_shards<'a>(&'a self) -> impl Iterator<Item = &'a RwLock<Shard<K, T, S>>> {
130 self.shards.iter().map(|s| &**s)
131 }
132
133 pub fn size(&self) -> usize {
134 self.size.load(Ordering::SeqCst)
135 }
136
137 pub fn num_shards(&self) -> usize {
138 self.shards.len()
139 }
140
141 pub fn try_maybe_contains_hash(&self, hash: u64) -> bool {
142 let shard_idx = hash as usize % self.shards.len();
143 let shard = unsafe { self.shards.get_unchecked(shard_idx) };
144
145 if let Ok(shard) = shard.try_read() {
146 shard.raw_entry().from_hash(hash, |_| true).is_some()
147 } else {
148 false
149 }
150 }
151
152 pub async fn contains_hash(&self, hash: u64) -> bool {
153 let shard_idx = hash as usize % self.shards.len();
154 let shard = unsafe { self.shards.get_unchecked(shard_idx) };
155
156 shard.read().await.raw_entry().from_hash(hash, |_| true).is_some()
157 }
158
159 pub async fn contains<Q: ?Sized>(&self, key: &Q) -> bool
160 where
161 K: Borrow<Q>,
162 Q: Hash + Eq,
163 {
164 self.contains_hash(self.hash_and_shard(key).0).await
165 }
166
167 pub async fn remove<Q: ?Sized>(&self, key: &Q) -> Option<T>
168 where
169 K: Borrow<Q>,
170 Q: Hash + Eq,
171 {
172 let (hash, shard_idx) = self.hash_and_shard(&key);
173 let mut shard = unsafe { self.shards.get_unchecked(shard_idx).write().await };
174
175 match shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
176 RawEntryMut::Occupied(occupied) => {
177 let value = occupied.remove();
178 self.size.fetch_sub(1, Ordering::SeqCst);
179 Some(value)
180 }
181 RawEntryMut::Vacant(_) => None,
182 }
183 }
184
185 pub async fn insert(&self, key: K, value: T) -> Option<T> {
186 let (hash, shard_idx) = self.hash_and_shard(&key);
187 let mut shard = unsafe { self.shards.get_unchecked(shard_idx).write().await };
188
189 match shard.raw_entry_mut().from_key_hashed_nocheck(hash, &key) {
190 RawEntryMut::Occupied(mut occupied) => Some(occupied.insert(value)),
191 RawEntryMut::Vacant(vacant) => {
192 self.size.fetch_add(1, Ordering::SeqCst);
193 vacant.insert_hashed_nocheck(hash, key, value);
194 None
195 }
196 }
197 }
198
199 pub async fn get<Q: ?Sized>(&self, key: &Q) -> Option<ReadHandle<impl Erased, T>>
200 where
201 K: Borrow<Q>,
202 Q: Hash + Eq,
203 {
204 let (hash, shard_idx) = self.hash_and_shard(key);
205 let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().read_owned().await };
206
207 OwnedRwLockReadGuard::try_map(shard, |shard| {
208 match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
209 Some((_, value)) => Some(value),
210 None => None,
211 }
212 })
213 .ok()
214 }
215
216 pub async fn get_cloned<Q: ?Sized>(&self, key: &Q) -> Option<T>
217 where
218 K: Borrow<Q>,
219 Q: Hash + Eq,
220 T: Clone,
221 {
222 let (hash, shard_idx) = self.hash_and_shard(key);
223 let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().read_owned().await };
224
225 match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
226 Some((_, value)) => Some(value.clone()),
227 None => None,
228 }
229 }
230
231 pub async fn get_mut<Q: ?Sized>(&self, key: &Q) -> Option<WriteHandle<impl Erased, T>>
232 where
233 K: Borrow<Q>,
234 Q: Hash + Eq,
235 {
236 let (hash, shard_idx) = self.hash_and_shard(key);
237 let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
238
239 OwnedRwLockWriteGuard::try_map(shard, |shard| {
240 match shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
241 RawEntryMut::Occupied(occupied) => Some(occupied.into_mut()),
242 RawEntryMut::Vacant(_) => None,
243 }
244 })
245 .ok()
246 }
247
248 pub async fn get_or_insert(&self, key: &K, on_insert: impl FnOnce() -> T) -> ReadHandle<impl Erased, T>
249 where
250 K: Clone,
251 {
252 let (hash, shard_idx) = self.hash_and_shard(key);
253 let mut shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
254
255 if let RawEntryMut::Vacant(vacant) = shard.raw_entry_mut().from_key_hashed_nocheck(hash, key) {
256 self.size.fetch_add(1, Ordering::SeqCst);
257
258 vacant.insert_hashed_nocheck(hash, key.clone(), on_insert());
259 }
260
261 OwnedRwLockReadGuard::map(OwnedRwLockWriteGuard::downgrade(shard), |shard| {
263 match shard.raw_entry().from_key_hashed_nocheck(hash, key) {
264 Some((_, value)) => value,
265 None => unreachable!(),
266 }
267 })
268 }
269
270 pub async fn get_mut_or_insert(
271 &self,
272 key: &K,
273 on_insert: impl FnOnce() -> T,
274 ) -> WriteHandle<impl Erased, T>
275 where
276 K: Clone,
277 {
278 let (hash, shard_idx) = self.hash_and_shard(key);
279 let shard = unsafe { self.shards.get_unchecked(shard_idx).clone().write_owned().await };
280
281 OwnedRwLockWriteGuard::map(shard, |shard| {
282 shard
283 .raw_entry_mut()
284 .from_key_hashed_nocheck(hash, key)
285 .or_insert_with(|| {
286 self.size.fetch_add(1, Ordering::SeqCst);
287
288 (key.clone(), on_insert())
289 })
290 .1
291 })
292 }
293
294 pub async fn get_or_default(&self, key: &K) -> ReadHandle<impl Erased, T>
295 where
296 K: Clone,
297 T: Default,
298 {
299 self.get_or_insert(key, Default::default).await
300 }
301
302 pub async fn get_mut_or_default(&self, key: &K) -> WriteHandle<impl Erased, T>
303 where
304 K: Clone,
305 T: Default,
306 {
307 self.get_mut_or_insert(key, Default::default).await
308 }
309
310 pub async fn batch_read<'a, Q: 'a + ?Sized, I, F>(
339 &self,
340 keys: I,
341 cache: Option<&mut Vec<(&'a Q, u64, usize)>>,
342 mut f: F,
343 ) where
344 K: Borrow<Q>,
345 Q: Hash + Eq,
346 I: IntoIterator<Item = &'a Q>,
347 F: FnMut(&'a Q, Option<(&K, &T)>),
348 {
349 let mut own_cache = Vec::new();
350 let cache = match cache {
351 Some(cache) => {
352 cache.clear();
353 cache
354 }
355 None => &mut own_cache,
356 };
357
358 cache.extend(keys.into_iter().map(|key| {
359 let (hash, shard) = self.hash_and_shard(key);
360 (key, hash, shard)
361 }));
362
363 if cache.is_empty() {
364 return;
365 }
366
367 cache.sort_unstable_by_key(|(_, _, shard)| *shard);
368
369 let mut i = 0;
370 'outer: loop {
371 let current_shard = cache[i].2;
372 let shard = unsafe { self.shards.get_unchecked(current_shard).read().await };
373
374 while cache[i].2 == current_shard {
375 f(
376 cache[i].0,
377 shard.raw_entry().from_key_hashed_nocheck(cache[i].1, cache[i].0),
378 );
379 i += 1;
380
381 if i >= cache.len() {
382 break 'outer;
383 }
384 }
385 }
386
387 cache.clear();
388 }
389
390 pub async fn batch_write<'a, Q: 'a + ?Sized, I, F>(
393 &self,
394 keys: I,
395 cache: Option<&mut Vec<(&'a Q, u64, usize)>>,
396 mut f: F,
397 ) where
398 K: Borrow<Q>,
399 Q: Hash + Eq,
400 I: IntoIterator<Item = &'a Q>,
401 F: FnMut(&'a Q, hashbrown::hash_map::RawEntryMut<K, T, S>),
402 {
403 let mut own_cache = Vec::new();
404 let cache = match cache {
405 Some(cache) => {
406 cache.clear();
407 cache
408 }
409 None => &mut own_cache,
410 };
411
412 cache.extend(keys.into_iter().map(|key| {
413 let (hash, shard) = self.hash_and_shard(key);
414 (key, hash, shard)
415 }));
416
417 if cache.is_empty() {
418 return;
419 }
420
421 cache.sort_unstable_by_key(|(_, _, shard)| *shard);
422
423 let mut i = 0;
424 'outer: loop {
425 let current_shard = cache[i].2;
426 let mut shard = unsafe { self.shards.get_unchecked(current_shard).write().await };
427
428 while cache[i].2 == current_shard {
429 f(
430 cache[i].0,
431 shard
432 .raw_entry_mut()
433 .from_key_hashed_nocheck(cache[i].1, cache[i].0),
434 );
435 i += 1;
436
437 if i >= cache.len() {
438 break 'outer;
439 }
440 }
441 }
442
443 cache.clear();
444 }
445}