1use std::{fmt::Debug, hash::Hash};
2
3use ahash::RandomState;
4use hashbrown::HashMap;
5use parking_lot::RwLock;
6
7pub struct ConcurrentMap<K, V> {
8 shards: Box<[RwLock<HashMap<K, V, RandomState>>]>,
9 num_shards: usize,
10 hasher: RandomState,
11}
12
13impl<K: Eq + Hash, V: Clone> Default for ConcurrentMap<K, V> {
14 fn default() -> Self {
15 Self::new()
16 }
17}
18
19impl<K: Debug + Clone + Eq + Hash, V: Debug + Clone> Debug for ConcurrentMap<K, V> {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 let debug_map = self
22 .shards
23 .iter()
24 .flat_map(|shard| shard.read().clone())
25 .map(|(k, v)| (k, v))
26 .collect::<HashMap<_, _>>();
27
28 f.debug_map().entries(debug_map.iter()).finish()
29 }
30}
31
32impl<K: Eq + Hash, V: Clone> ConcurrentMap<K, V> {
49 pub fn new() -> Self {
50 let num_shards =
51 (std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two();
52
53 Self {
54 shards: (0..num_shards)
55 .map(|_| RwLock::new(HashMap::default()))
56 .collect::<Box<_>>(),
57 num_shards,
58 hasher: RandomState::default(),
59 }
60 }
61
62 fn hash(&self, key: &K) -> usize {
63 self.hasher.hash_one(key) as usize
64 }
65
66 fn determine_shard(&self, hash: usize) -> usize {
67 hash % self.num_shards
68 }
69
70 unsafe fn get_read_shard(
71 &self,
72 idx: usize,
73 ) -> parking_lot::lock_api::RwLockReadGuard<parking_lot::RawRwLock, HashMap<K, V, RandomState>>
74 {
75 self.shards.get_unchecked(idx).read()
76 }
77
78 unsafe fn get_write_shard(
79 &self,
80 idx: usize,
81 ) -> parking_lot::lock_api::RwLockWriteGuard<parking_lot::RawRwLock, HashMap<K, V, RandomState>>
82 {
83 self.shards.get_unchecked(idx).write()
84 }
85
86 pub fn get(&self, key: &K) -> Option<V> {
96 let hash = self.hash(key);
97 let idx = self.determine_shard(hash);
98
99 let shard = unsafe { self.get_read_shard(idx) };
100
101 shard.get(key).cloned()
102 }
103
104 pub fn get_or_insert<F: FnOnce() -> V>(&self, key: K, value: F) -> V {
105 let hash = self.hash(&key);
106 let idx = self.determine_shard(hash);
107
108 let result = {
110 let shard = unsafe { self.get_read_shard(idx) };
111 shard.get(&key).cloned()
112 };
113
114 if let Some(result) = result {
116 return result;
117 }
118
119 let mut shard = unsafe { self.get_write_shard(idx) };
121 let result = shard.get(&key);
122
123 if let Some(result) = result {
125 return result.clone();
126 }
127
128 let result = value();
130 shard.insert(key, result.clone());
131 result
132 }
133}