query_graph/
map.rs

1use std::{fmt::Debug, hash::Hash};
2
3use ahash::RandomState;
4use hashbrown::HashMap;
5use parking_lot::RwLock;
6
7pub struct ConcurrentMap<K, V> {
8    shards: Box<[RwLock<HashMap<K, V, RandomState>>]>,
9    num_shards: usize,
10    hasher: RandomState,
11}
12
13impl<K: Eq + Hash, V: Clone> Default for ConcurrentMap<K, V> {
14    fn default() -> Self {
15        Self::new()
16    }
17}
18
19impl<K: Debug + Clone + Eq + Hash, V: Debug + Clone> Debug for ConcurrentMap<K, V> {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        let debug_map = self
22            .shards
23            .iter()
24            .flat_map(|shard| shard.read().clone())
25            .map(|(k, v)| (k, v))
26            .collect::<HashMap<_, _>>();
27
28        f.debug_map().entries(debug_map.iter()).finish()
29    }
30}
31
32// impl<K: Clone, V: Clone> Clone for ConcurrentMap<K, V> {
33//     fn clone(&self) -> Self {
34//         let shards = self
35//             .shards
36//             .iter()
37//             .map(|shard| RwLock::new(shard.read().clone()))
38//             .collect::<Box<_>>();
39
40//         Self {
41//             shards,
42//             num_shards: self.num_shards,
43//             hasher: self.hasher.clone(),
44//         }
45//     }
46// }
47
48impl<K: Eq + Hash, V: Clone> ConcurrentMap<K, V> {
49    pub fn new() -> Self {
50        let num_shards =
51            (std::thread::available_parallelism().map_or(1, usize::from) * 4).next_power_of_two();
52
53        Self {
54            shards: (0..num_shards)
55                .map(|_| RwLock::new(HashMap::default()))
56                .collect::<Box<_>>(),
57            num_shards,
58            hasher: RandomState::default(),
59        }
60    }
61
62    fn hash(&self, key: &K) -> usize {
63        self.hasher.hash_one(key) as usize
64    }
65
66    fn determine_shard(&self, hash: usize) -> usize {
67        hash % self.num_shards
68    }
69
70    unsafe fn get_read_shard(
71        &self,
72        idx: usize,
73    ) -> parking_lot::lock_api::RwLockReadGuard<parking_lot::RawRwLock, HashMap<K, V, RandomState>>
74    {
75        self.shards.get_unchecked(idx).read()
76    }
77
78    unsafe fn get_write_shard(
79        &self,
80        idx: usize,
81    ) -> parking_lot::lock_api::RwLockWriteGuard<parking_lot::RawRwLock, HashMap<K, V, RandomState>>
82    {
83        self.shards.get_unchecked(idx).write()
84    }
85
86    // pub fn insert(&self, key: K, value: V) {
87    //     let hash = self.hash(&key);
88    //     let idx = self.determine_shard(hash);
89
90    //     let mut shard = unsafe { self.get_write_shard(idx) };
91
92    //     shard.insert(key, value);
93    // }
94
95    pub fn get(&self, key: &K) -> Option<V> {
96        let hash = self.hash(key);
97        let idx = self.determine_shard(hash);
98
99        let shard = unsafe { self.get_read_shard(idx) };
100
101        shard.get(key).cloned()
102    }
103
104    pub fn get_or_insert<F: FnOnce() -> V>(&self, key: K, value: F) -> V {
105        let hash = self.hash(&key);
106        let idx = self.determine_shard(hash);
107
108        // First, read the shard with just a read-lock.
109        let result = {
110            let shard = unsafe { self.get_read_shard(idx) };
111            shard.get(&key).cloned()
112        };
113
114        // If the result is some, return it.
115        if let Some(result) = result {
116            return result;
117        }
118
119        // Getting the value failed with a read lock, so we will try with a write-lock.
120        let mut shard = unsafe { self.get_write_shard(idx) };
121        let result = shard.get(&key);
122
123        // We check that the result is some, this means another thread won and wrote first.
124        if let Some(result) = result {
125            return result.clone();
126        }
127
128        // If this thread won, we get the value and insert it.
129        let result = value();
130        shard.insert(key, result.clone());
131        result
132    }
133}