Skip to main content

tantivy_stacker/
shared_arena_hashmap.rs

1use std::iter::{Cloned, Filter};
2use std::mem;
3
4use super::{Addr, MemoryArena};
5use crate::fastcpy::fast_short_slice_copy;
6use crate::memory_arena::store;
7
8/// Returns the actual memory size in bytes
9/// required to create a table with a given capacity.
10/// required to create a table of size
11pub fn compute_table_memory_size(capacity: usize) -> usize {
12    capacity * mem::size_of::<KeyValue>()
13}
14
15#[cfg(not(feature = "compare_hash_only"))]
16type HashType = u32;
17
18#[cfg(feature = "compare_hash_only")]
19type HashType = u64;
20
21/// `KeyValue` is the item stored in the hash table.
22/// The key is actually a `BytesRef` object stored in an external memory arena.
23/// The `value_addr` also points to an address in the memory arena.
24#[derive(Copy, Clone)]
25struct KeyValue {
26    key_value_addr: Addr,
27    hash: HashType,
28}
29
30impl Default for KeyValue {
31    fn default() -> Self {
32        KeyValue {
33            key_value_addr: Addr::null_pointer(),
34            hash: 0,
35        }
36    }
37}
38
39impl KeyValue {
40    #[inline]
41    fn is_empty(&self) -> bool {
42        self.key_value_addr.is_null()
43    }
44    #[inline]
45    fn is_not_empty_ref(&self) -> bool {
46        !self.key_value_addr.is_null()
47    }
48}
49
50/// Customized `HashMap` with `&[u8]` keys
51///
52/// Its main particularity is that rather than storing its
53/// keys in the heap, keys are stored in a memory arena
54/// inline with the values.
55///
56/// The quirky API has the benefit of avoiding
57/// the computation of the hash of the key twice,
58/// or copying the key as long as there is no insert.
59///
60/// SharedArenaHashMap is like ArenaHashMap but gets the memory arena
61/// passed as an argument to the methods.
62/// So one MemoryArena can be shared with multiple SharedArenaHashMap.
63pub struct SharedArenaHashMap {
64    table: Vec<KeyValue>,
65    mask: usize,
66    len: usize,
67}
68
69struct LinearProbing {
70    pos: usize,
71    mask: usize,
72}
73
74impl LinearProbing {
75    #[inline]
76    fn compute(hash: HashType, mask: usize) -> LinearProbing {
77        LinearProbing {
78            pos: hash as usize,
79            mask,
80        }
81    }
82
83    #[inline]
84    fn next_probe(&mut self) -> usize {
85        // Not saving the masked version removes a dependency.
86        self.pos = self.pos.wrapping_add(1);
87        self.pos & self.mask
88    }
89}
90
91type IterNonEmpty<'a> = Filter<Cloned<std::slice::Iter<'a, KeyValue>>, fn(&KeyValue) -> bool>;
92
93pub struct Iter<'a> {
94    hashmap: &'a SharedArenaHashMap,
95    memory_arena: &'a MemoryArena,
96    inner: IterNonEmpty<'a>,
97}
98
99impl<'a> Iterator for Iter<'a> {
100    type Item = (&'a [u8], Addr);
101
102    fn next(&mut self) -> Option<Self::Item> {
103        self.inner.next().map(move |kv| {
104            let (key, offset): (&'a [u8], Addr) = self
105                .hashmap
106                .get_key_value(kv.key_value_addr, self.memory_arena);
107            (key, offset)
108        })
109    }
110}
111
112/// Returns the greatest power of two lower or equal to `n`.
113/// Except if n == 0, in that case, return 1.
114///
115/// # Panics if n == 0
116fn compute_previous_power_of_two(n: usize) -> usize {
117    assert!(n > 0);
118    let msb = (63u32 - (n as u64).leading_zeros()) as u8;
119    1 << msb
120}
121
122impl Default for SharedArenaHashMap {
123    fn default() -> Self {
124        SharedArenaHashMap::with_capacity(4)
125    }
126}
127
128impl SharedArenaHashMap {
129    pub fn with_capacity(table_size: usize) -> SharedArenaHashMap {
130        let table_size_power_of_2 = compute_previous_power_of_two(table_size);
131        let table = vec![KeyValue::default(); table_size_power_of_2];
132
133        SharedArenaHashMap {
134            table,
135            mask: table_size_power_of_2 - 1,
136            len: 0,
137        }
138    }
139
140    #[inline]
141    #[cfg(not(feature = "compare_hash_only"))]
142    fn get_hash(&self, key: &[u8]) -> HashType {
143        murmurhash32::murmurhash2(key)
144    }
145
146    #[inline]
147    #[cfg(feature = "compare_hash_only")]
148    fn get_hash(&self, key: &[u8]) -> HashType {
149        /// Since we compare only the hash we need a high quality hash.
150        use std::hash::Hasher;
151        let mut hasher = ahash::AHasher::default();
152        hasher.write(key);
153        hasher.finish() as HashType
154    }
155
156    #[inline]
157    fn probe(&self, hash: HashType) -> LinearProbing {
158        LinearProbing::compute(hash, self.mask)
159    }
160
161    #[inline]
162    pub fn mem_usage(&self) -> usize {
163        self.table.len() * mem::size_of::<KeyValue>()
164    }
165
166    #[inline]
167    fn is_saturated(&self) -> bool {
168        self.table.len() <= self.len * 2
169    }
170
171    #[inline]
172    fn get_key_value<'a>(&'a self, addr: Addr, memory_arena: &'a MemoryArena) -> (&'a [u8], Addr) {
173        let data = memory_arena.slice_from(addr);
174        let key_bytes_len_bytes = unsafe { data.get_unchecked(..2) };
175        let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
176        let key_bytes: &[u8] = unsafe { data.get_unchecked(2..2 + key_bytes_len as usize) };
177        (key_bytes, addr.offset(2 + key_bytes_len as u32))
178    }
179
180    #[inline]
181    #[cfg(not(feature = "compare_hash_only"))]
182    fn get_value_addr_if_key_match(
183        &self,
184        target_key: &[u8],
185        addr: Addr,
186        memory_arena: &MemoryArena,
187    ) -> Option<Addr> {
188        use crate::fastcmp::fast_short_slice_compare;
189
190        let (stored_key, value_addr) = self.get_key_value(addr, memory_arena);
191        if fast_short_slice_compare(stored_key, target_key) {
192            Some(value_addr)
193        } else {
194            None
195        }
196    }
197    #[inline]
198    #[cfg(feature = "compare_hash_only")]
199    fn get_value_addr_if_key_match(
200        &self,
201        _target_key: &[u8],
202        addr: Addr,
203        memory_arena: &MemoryArena,
204    ) -> Option<Addr> {
205        // For the compare_hash_only feature, it would make sense to store the keys at a different
206        // memory location. Here they will just pollute the cache.
207        let data = memory_arena.slice_from(addr);
208        let key_bytes_len_bytes = &data[..2];
209        let key_bytes_len = u16::from_le_bytes(key_bytes_len_bytes.try_into().unwrap());
210        let value_addr = addr.offset(2 + key_bytes_len as u32);
211
212        Some(value_addr)
213    }
214
215    #[inline]
216    fn set_bucket(&mut self, hash: HashType, key_value_addr: Addr, bucket: usize) {
217        self.len += 1;
218
219        self.table[bucket] = KeyValue {
220            key_value_addr,
221            hash,
222        };
223    }
224
225    #[inline]
226    pub fn is_empty(&self) -> bool {
227        self.len() == 0
228    }
229
230    #[inline]
231    pub fn len(&self) -> usize {
232        self.len
233    }
234
235    #[inline]
236    pub fn iter<'a>(&'a self, memory_arena: &'a MemoryArena) -> Iter<'a> {
237        Iter {
238            inner: self
239                .table
240                .iter()
241                .cloned()
242                .filter(KeyValue::is_not_empty_ref),
243            hashmap: self,
244            memory_arena,
245        }
246    }
247
248    fn resize(&mut self) {
249        let new_len = (self.table.len() * 2).max(1 << 3);
250        let mask = new_len - 1;
251        self.mask = mask;
252        let new_table = vec![KeyValue::default(); new_len];
253        let old_table = mem::replace(&mut self.table, new_table);
254        for key_value in old_table.into_iter().filter(KeyValue::is_not_empty_ref) {
255            let mut probe = LinearProbing::compute(key_value.hash, mask);
256            loop {
257                let bucket = probe.next_probe();
258                if self.table[bucket].is_empty() {
259                    self.table[bucket] = key_value;
260                    break;
261                }
262            }
263        }
264    }
265
266    /// Get a value associated to a key.
267    #[inline]
268    pub fn get<V>(&self, key: &[u8], memory_arena: &MemoryArena) -> Option<V>
269    where V: Copy + 'static {
270        let hash = self.get_hash(key);
271        let mut probe = self.probe(hash);
272        loop {
273            let bucket = probe.next_probe();
274            let kv: KeyValue = self.table[bucket];
275            if kv.is_empty() {
276                return None;
277            } else if kv.hash == hash
278                && let Some(val_addr) =
279                    self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
280            {
281                let v = memory_arena.read(val_addr);
282                return Some(v);
283            }
284        }
285    }
286
287    /// `update` create a new entry for a given key if it does not exist
288    /// or updates the existing entry.
289    ///
290    /// The actual logic for this update is define in the `updater`
291    /// argument.
292    ///
293    /// If the key is not present, `updater` will receive `None` and
294    /// will be in charge of returning a default value.
295    /// If the key already as an associated value, then it will be passed
296    /// `Some(previous_value)`.
297    ///
298    /// The key will be truncated to u16::MAX bytes.
299    #[inline]
300    pub fn mutate_or_create<V>(
301        &mut self,
302        key: &[u8],
303        memory_arena: &mut MemoryArena,
304        mut updater: impl FnMut(Option<V>) -> V,
305    ) -> V
306    where
307        V: Copy + 'static,
308    {
309        if self.is_saturated() {
310            self.resize();
311        }
312        // Limit the key size to u16::MAX
313        let key = &key[..std::cmp::min(key.len(), u16::MAX as usize)];
314        let hash = self.get_hash(key);
315        let mut probe = self.probe(hash);
316        let mut bucket = probe.next_probe();
317        let mut kv: KeyValue = self.table[bucket];
318        loop {
319            if kv.is_empty() {
320                // The key does not exist yet.
321                let val = updater(None);
322                let num_bytes = std::mem::size_of::<u16>() + key.len() + std::mem::size_of::<V>();
323                let key_addr = memory_arena.allocate_space(num_bytes);
324                {
325                    let data = memory_arena.slice_mut(key_addr, num_bytes);
326                    let key_len_bytes: [u8; 2] = (key.len() as u16).to_le_bytes();
327                    data[..2].copy_from_slice(&key_len_bytes);
328                    let stop = 2 + key.len();
329                    fast_short_slice_copy(key, &mut data[2..stop]);
330                    store(&mut data[stop..], val);
331                }
332
333                self.set_bucket(hash, key_addr, bucket);
334                return val;
335            }
336            if kv.hash == hash
337                && let Some(val_addr) =
338                    self.get_value_addr_if_key_match(key, kv.key_value_addr, memory_arena)
339            {
340                let v = memory_arena.read(val_addr);
341                let new_v = updater(Some(v));
342                memory_arena.write_at(val_addr, new_v);
343                return new_v;
344            }
345            // This allows fetching the next bucket before the loop jmp
346            bucket = probe.next_probe();
347            kv = self.table[bucket];
348        }
349    }
350}
351
352#[cfg(test)]
353mod tests {
354
355    use std::collections::HashMap;
356
357    use super::{SharedArenaHashMap, compute_previous_power_of_two};
358    use crate::MemoryArena;
359
360    #[test]
361    fn test_hash_map() {
362        let mut memory_arena = MemoryArena::default();
363        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
364        hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
365            assert_eq!(opt_val, None);
366            3u32
367        });
368        hash_map.mutate_or_create(b"abcd", &mut memory_arena, |opt_val: Option<u32>| {
369            assert_eq!(opt_val, None);
370            4u32
371        });
372        hash_map.mutate_or_create(b"abc", &mut memory_arena, |opt_val: Option<u32>| {
373            assert_eq!(opt_val, Some(3u32));
374            5u32
375        });
376        let mut vanilla_hash_map = HashMap::new();
377        let iter_values = hash_map.iter(&memory_arena);
378        for (key, addr) in iter_values {
379            let val: u32 = memory_arena.read(addr);
380            vanilla_hash_map.insert(key.to_owned(), val);
381        }
382        assert_eq!(vanilla_hash_map.len(), 2);
383    }
384
385    #[test]
386    fn test_long_key_truncation() {
387        // Keys longer than u16::MAX are truncated.
388        let mut memory_arena = MemoryArena::default();
389        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
390        let key1 = (0..u16::MAX as usize).map(|i| i as u8).collect::<Vec<_>>();
391        hash_map.mutate_or_create(&key1, &mut memory_arena, |opt_val: Option<u32>| {
392            assert_eq!(opt_val, None);
393            4u32
394        });
395        // Due to truncation, this key is the same as key1
396        let key2 = (0..u16::MAX as usize + 1)
397            .map(|i| i as u8)
398            .collect::<Vec<_>>();
399        hash_map.mutate_or_create(&key2, &mut memory_arena, |opt_val: Option<u32>| {
400            assert_eq!(opt_val, Some(4));
401            3u32
402        });
403        let mut vanilla_hash_map = HashMap::new();
404        let iter_values = hash_map.iter(&memory_arena);
405        for (key, addr) in iter_values {
406            let val: u32 = memory_arena.read(addr);
407            vanilla_hash_map.insert(key.to_owned(), val);
408            assert_eq!(key.len(), key1[..].len());
409            assert_eq!(key, &key1[..])
410        }
411        assert_eq!(vanilla_hash_map.len(), 1); // Both map to the same key
412    }
413
414    #[test]
415    fn test_empty_hashmap() {
416        let memory_arena = MemoryArena::default();
417        let hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
418        assert_eq!(hash_map.get::<u32>(b"abc", &memory_arena), None);
419    }
420
421    #[test]
422    fn test_compute_previous_power_of_two() {
423        assert_eq!(compute_previous_power_of_two(8), 8);
424        assert_eq!(compute_previous_power_of_two(9), 8);
425        assert_eq!(compute_previous_power_of_two(7), 4);
426        assert_eq!(compute_previous_power_of_two(u64::MAX as usize), 1 << 63);
427    }
428
429    #[test]
430    fn test_many_terms() {
431        let mut memory_arena = MemoryArena::default();
432        let mut terms: Vec<String> = (0..20_000).map(|val| val.to_string()).collect();
433        let mut hash_map: SharedArenaHashMap = SharedArenaHashMap::default();
434        for term in terms.iter() {
435            hash_map.mutate_or_create(
436                term.as_bytes(),
437                &mut memory_arena,
438                |_opt_val: Option<u32>| 5u32,
439            );
440        }
441        let mut terms_back: Vec<String> = hash_map
442            .iter(&memory_arena)
443            .map(|(bytes, _)| String::from_utf8(bytes.to_vec()).unwrap())
444            .collect();
445        terms_back.sort();
446        terms.sort();
447
448        for pos in 0..terms.len() {
449            assert_eq!(terms[pos], terms_back[pos]);
450        }
451    }
452}