scherben_map/
lib.rs

1use fnv::FnvHasher;
2use hashbrown::raw::RawTable;
3use std::{
4    borrow::Borrow,
5    fmt,
6    hash::{Hash, Hasher},
7    sync::Arc,
8};
9
10type RwLock<T> = parking_lot_utils::RwLock<T>;
11
12/// Key is trait for interval key.
13pub trait Key<K>: Hash + Eq {
14    fn key(&self) -> &K;
15}
16
17/// IKey is a trait which keys must satisfy.
18pub trait IKey<K> {
19    fn as_bytes(&self) -> &[u8];
20}
21
22/// RwKey is read/write key.
23struct RwKey<'a, K: IKey<K>> {
24    key: &'a K,
25}
26
27/// IntoIter creates an iterator for shards.
28pub struct IntoIter<K, V> {
29    shards: std::vec::IntoIter<Arc<RwLock<Shard<K, V>>>>,
30    item: Option<Arc<RwLock<Shard<K, V>>>>,
31}
32
33/// HashMap is a sharded hashmap which uses `N` buckets,
34/// each protected with an Read/Write lock. `N` gets
35/// rounded to nearest power of two.
36pub struct HashMap<K, V, const N: usize> {
37    shards: [Arc<RwLock<Shard<K, V>>>; N],
38    shards_size: u64,
39}
40
41/// Shard embeds a table that contains key/value pair.
42pub struct Shard<K, V> {
43    pub table: RawTable<(K, V)>,
44}
45
46impl<K: Clone + Send + Sync, V: Clone + Send + Sync, const N: usize> Default for HashMap<K, V, N> {
47    fn default() -> Self {
48        Self::with_shard(N.next_power_of_two())
49    }
50}
51
52impl<K: Clone + Send + Sync, V: Clone + Send + Sync> Shard<K, V> {
53    /// get fetches a option containing a reference to the
54    /// value associated with the given key.
55    pub fn get<'a>(&'a self, key: &K, hash: u64) -> Option<&'a V>
56    where
57        K: Hash + Eq + IKey<K>,
58    {
59        match self.table.get(hash, move |x| key.eq(x.0.borrow())) {
60            Some(&(_, ref value)) => Some(value),
61            None => None,
62        }
63    }
64
65    /// insert inserts the given key/value pair into the table.
66    pub fn insert(&mut self, key: &K, hash: u64, value: V)
67    where
68        K: Hash + Eq + IKey<K> + Clone,
69        V: Clone,
70    {
71        if let Some((_, item)) = self.table.get_mut(hash, move |x| key.eq(x.0.borrow())) {
72            _ = std::mem::replace(item, value);
73        } else {
74            self.table.insert(hash, (key.clone(), value), |x| {
75                make_hash(x.0.borrow().as_bytes())
76            });
77        }
78    }
79
80    /// remove remove the entry associated with `key` and `hash`
81    /// from the table.
82    pub fn remove(&mut self, key: &K, hash: u64) -> Option<V>
83    where
84        K: Hash + Eq + IKey<K> + Clone,
85        V: Clone,
86    {
87        match self.table.remove_entry(hash, move |x| key.eq(x.0.borrow())) {
88            Some((_, value)) => Some(value),
89            None => None,
90        }
91    }
92
93    /// fil_pairs_into fills the `buffer` with cloned entries
94    /// of key/value pairs.
95    pub fn fill_pairs_into(&self, buffer: &mut Vec<(K, V)>) {
96        unsafe {
97            for entry in self.table.iter() {
98                let value = entry.as_ref().clone();
99                buffer.push((value.0.clone(), value.1.clone()));
100            }
101        }
102    }
103
104    /// len returns the length of the table.
105    #[inline(always)]
106    pub fn len(&self) -> usize {
107        self.table.len()
108    }
109}
110
111impl<'a, K: Clone + Send + Sync, V: Clone + Send + Sync, const N: usize> HashMap<K, V, N> {
112    /// new returns an instance with `N` shards.
113    /// `N` gets rounded to nearest power of two.
114    pub fn new() -> Self {
115        Self::with_shard(N.next_power_of_two())
116    }
117
118    /// new_arc returns a new instance contained in
119    /// an arc pointer.
120    pub fn new_arc() -> std::sync::Arc<Self> {
121        std::sync::Arc::new(Self::new())
122    }
123
124    /// with_shard returns a new instance with `N` shards.
125    /// `N` gets rounded to nearest power of two.
126    pub fn with_shard(shards_size: usize) -> Self {
127        let shards = std::iter::repeat(|| RawTable::with_capacity(shards_size))
128            .map(|f| f())
129            .take(shards_size)
130            .map(|table| Arc::new(RwLock::new(Shard { table })))
131            .collect::<Vec<_>>()
132            .try_into()
133            .unwrap();
134
135        Self {
136            shards,
137            shards_size: shards_size as u64,
138        }
139    }
140
141    /// get_shard returns the shard lock along with
142    /// its hash.
143    pub fn get_shard(&'a self, key: K) -> (&'a Arc<RwLock<Shard<K, V>>>, u64)
144    where
145        K: Hash + Eq + IKey<K>,
146        V: Clone,
147    {
148        let hash: u64 = make_hash(key.as_bytes());
149        let bin = hash % self.shards_size;
150        match self.shards.get(bin as usize) {
151            Some(lock) => (lock, hash),
152            None => panic!("index out of bounds"),
153        }
154    }
155
156    /// get_owned returns the cloned value associated
157    /// with the given `key`.
158    pub fn get_owned(&'a self, key: K) -> Option<V>
159    where
160        K: Hash + Eq + IKey<K>,
161        V: Clone,
162    {
163        let hash: u64 = make_hash(key.as_bytes());
164        let bin = hash % self.shards_size;
165        let shard = match self.shards.get(bin as usize) {
166            Some(lock) => lock.read(),
167            None => panic!("index out of bounds"),
168        };
169        match shard.get(&key, hash) {
170            Some(result) => Some(result.clone()),
171            None => None,
172        }
173    }
174
175    /// insert inserts the given `key`, `value`
176    /// pair into a shard based on the hash value
177    /// of `key`.
178    pub fn insert(&self, key: K, value: V)
179    where
180        K: Hash + Eq + IKey<K>,
181    {
182        let hash: u64 = make_hash(key.as_bytes());
183        let bin = hash % self.shards_size;
184        let mut shard = match self.shards.get(bin as usize) {
185            Some(lock) => lock.write(),
186            None => panic!("index out of bounds"),
187        };
188        shard.insert(&key, hash, value);
189    }
190
191    /// remove removes the entry associated with `key`
192    /// from the corresponding shard.
193    pub fn remove(&self, key: K) -> Option<V>
194    where
195        K: Hash + Eq + IKey<K>,
196    {
197        let hash: u64 = make_hash(key.as_bytes());
198        let bin = hash % self.shards_size;
199        let mut shard = match self.shards.get(bin as usize) {
200            Some(lock) => lock.write(),
201            None => panic!("index out of bounds"),
202        };
203
204        shard.remove(&key, hash)
205    }
206
207    /// contains returns whether any entry is
208    /// associated with the given `key`.
209    pub fn contains(&self, key: K) -> bool
210    where
211        K: Hash + Eq + IKey<K>,
212    {
213        let hash: u64 = make_hash(key.as_bytes());
214        let bin = hash % self.shards_size;
215        let shard = match self.shards.get(bin as usize) {
216            Some(lock) => lock.read(),
217            None => panic!("index out of bounds"),
218        };
219        match shard.get(&key, hash) {
220            Some(_) => true,
221            None => false,
222        }
223    }
224
225    /// len returns the sum of total number of all entries
226    /// stored in each shard table.
227    pub fn len(&self) -> usize {
228        self.shards.iter().map(|x| x.read().len()).sum()
229    }
230
231    /// into_iter creates an iterator for shards.
232    pub fn into_iter(&self) -> IntoIter<K, V> {
233        let mut shards: Vec<Arc<RwLock<Shard<K, V>>>> =
234            Vec::with_capacity(self.shards_size as usize);
235        for i in 0..self.shards.len() {
236            shards.push(self.shards[i].clone());
237        }
238
239        IntoIter {
240            shards: shards.into_iter(),
241            item: None,
242        }
243    }
244}
245
246impl<K, V> Iterator for IntoIter<K, V> {
247    type Item = Arc<RwLock<Shard<K, V>>>;
248    fn size_hint(&self) -> (usize, Option<usize>) {
249        (self.shards.size_hint().0, None)
250    }
251
252    fn next(&mut self) -> Option<Arc<RwLock<Shard<K, V>>>> {
253        match self.shards.next() {
254            Some(ref result) => {
255                self.item = Some(result.clone());
256                return self.item.clone();
257            }
258            None => {
259                self.item = None;
260                return None;
261            }
262        }
263    }
264}
265
266impl<K, V> fmt::Debug for Shard<K, V> {
267    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> fmt::Result {
268        write!(fmt, "Shard{{table: [{}]}}", self.table.len())
269    }
270}
271
272impl<'a, K: IKey<K> + std::cmp::PartialEq> PartialEq for RwKey<'a, K> {
273    fn eq(&self, other: &Self) -> bool {
274        *self.key == *other.key
275    }
276}
277
278impl<'a, K: IKey<K> + std::cmp::PartialEq> Eq for RwKey<'a, K> {}
279
280impl<K> IKey<K> for String {
281    fn as_bytes(&self) -> &[u8] {
282        self.as_bytes()
283    }
284}
285
286impl<K> IKey<K> for &str {
287    fn as_bytes(&self) -> &[u8] {
288        (*self).as_bytes()
289    }
290}
291
292impl<'a, K: IKey<K>> Hash for RwKey<'a, K> {
293    fn hash<H: Hasher>(&self, state: &mut H) {
294        state.write_u64(make_hash(self.key.as_bytes()));
295        state.finish();
296    }
297}
298
299impl<'a, K: IKey<K> + Eq> Key<K> for RwKey<'a, K> {
300    fn key(&self) -> &K {
301        &self.key
302    }
303}
304
305/// make_hash hashes the `key`.
306fn make_hash(key: &[u8]) -> u64 {
307    let mut hasher: Box<dyn Hasher> = Box::new(FnvHasher::default());
308    hasher.write(key);
309    hasher.finish()
310}
311
312#[cfg(test)]
313mod test {
314    use super::*;
315
316    #[test]
317    fn two_threads_performing_read_write() {
318        let map: std::sync::Arc<HashMap<String, std::sync::Arc<std::sync::Mutex<String>>, 16>> =
319            HashMap::<String, std::sync::Arc<std::sync::Mutex<String>>, 16>::new_arc();
320        map.insert(
321            "test".to_string(),
322            std::sync::Arc::new(std::sync::Mutex::new("result".to_string())),
323        );
324        {
325            let map_a = map.clone();
326            let map_b = map.clone();
327            let handle_a = std::thread::spawn(move || {
328                match map_a.get_owned("test".to_string()) {
329                    Some(result) => {
330                        let mut value = result.lock().unwrap();
331                        value.push_str(" + mutation");
332                        println!("result: {}", &value);
333                    }
334                    None => println!("found nothing ( from thread A )"),
335                };
336            });
337
338            let handle_b = std::thread::spawn(move || match map_b.get_owned("test".to_string()) {
339                Some(result) => {
340                    let value = result.lock().unwrap();
341                    println!("result: {}", &value);
342                }
343                None => println!("found nothing ( from thread B )"),
344            });
345
346            handle_a.join().unwrap();
347            handle_b.join().unwrap();
348        }
349
350        assert_eq!(
351            *map.get_owned("test".to_string()).unwrap().lock().unwrap(),
352            "result + mutation".to_string()
353        )
354    }
355
356    #[test]
357    fn write_and_remove() {
358        let map: HashMap<&str, i64, 8> = Default::default();
359        map.insert("test", 6);
360        assert_eq!(map.get_owned("test").unwrap(), 6);
361
362        map.remove("test");
363        assert_eq!(map.get_owned("test"), None);
364    }
365
366    #[test]
367    fn map_iterate_shards() {
368        let map: HashMap<String, i64, 8> = Default::default();
369        for i in 0..1000 {
370            let item = format!("Hallo, Welt {}!", i);
371            map.insert(item, i);
372        }
373        for _shard_guard in map.into_iter() {}
374        for _shard_guard in map.into_iter() {}
375        for _shard_guard in map.into_iter() {}
376        for _shard_guard in map.into_iter() {}
377    }
378
379    #[test]
380    fn map_replace_item() {
381        let map: HashMap<String, i64, 8> = Default::default();
382        map.insert("Test".to_string(), 0);
383
384        let mut value = map.get_owned("Test".to_string());
385        assert_eq!(value.is_some(), true);
386        assert_eq!(value.unwrap(), 0);
387
388        map.insert("Test".to_string(), 64);
389        value = map.get_owned("Test".to_string());
390
391        assert_eq!(value.is_some(), true);
392        assert_eq!(value.unwrap(), 64);
393    }
394}