rostl_datastructures/
map.rs

1//! Implements map related data structures.
2
3use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_primitives::{
7  cmov_body, cxchg_body, impl_cmov_for_generic_pod,
8  traits::{Cmov, _Cmovbase},
9};
10
11use seq_macro::seq;
12
13use crate::{array::MultiWayArray, queue::ShortQueue};
14
15// Size of the insertion queue for deamortized insertions that failed.
16const INSERTION_QUEUE_MAX_SIZE: usize = 10;
17// Number of deamortized insertions to perform per insertion.
18const DEAMORTIZED_INSERTIONS: usize = 2;
19// Number of elements in each map bucket.
20const BUCKET_SIZE: usize = 4;
21
22use std::hash::Hash;
23
24/// Utility trait for types that can be uses as keys in the map.
25pub trait OHash: Cmov + Pod + Hash + PartialEq {}
26// UNDONE(git-20): Hashmaps with boolean key aren't supported until boolean implements pod.s
27impl<K> OHash for K where K: Cmov + Pod + Hash + PartialEq {}
28
29/// An element in the map.
30#[repr(align(8))]
31#[derive(Debug, Default, Clone, Copy, Zeroable)]
32pub struct InlineElement<K, V>
33where
34  K: OHash,
35  V: Cmov + Pod,
36{
37  key: K,
38  value: V,
39}
40unsafe impl<K: OHash, V: Cmov + Pod> Pod for InlineElement<K, V> {}
41impl_cmov_for_generic_pod!(InlineElement<K,V>; where K: OHash, V: Cmov + Pod);
42
43/// A struct that represents an element in a bucket.
44#[derive(Debug, Default, Clone, Copy, Zeroable)]
45pub struct BucketElement<K, V>
46where
47  K: OHash,
48  V: Cmov + Pod,
49{
50  is_valid: bool,
51  element: InlineElement<K, V>,
52}
53unsafe impl<K: OHash, V: Cmov + Pod> Pod for BucketElement<K, V> {}
54impl_cmov_for_generic_pod!(BucketElement<K,V>; where K: OHash, V: Cmov + Pod);
55
56impl<K, V> BucketElement<K, V>
57where
58  K: OHash,
59  V: Cmov + Pod,
60{
61  #[inline(always)]
62  const fn is_empty(&self) -> bool {
63    !self.is_valid
64  }
65}
66
67/// A bucket in the map.
68/// The bucket has `BUCKET_SIZE` elements.
69/// # Invariants
70/// * The elements in the bucket that have `is_valid == true` are non empty.
71/// * The elements in the bucket that have `is_valid == false` are empty.
72/// * No two valid elements have the same key.
73#[derive(Debug, Default, Clone, Copy, Zeroable)]
74struct Bucket<K, V>
75where
76  K: OHash,
77  V: Cmov + Pod,
78{
79  elements: [BucketElement<K, V>; BUCKET_SIZE],
80}
81unsafe impl<K: OHash, V: Cmov + Pod> Pod for Bucket<K, V> {}
82impl_cmov_for_generic_pod!(Bucket<K,V>; where K: OHash, V: Cmov + Pod);
83
84impl<K, V> Bucket<K, V>
85where
86  K: OHash,
87  V: Cmov + Pod,
88{
89  /// Replaces the value of a Key, if:
90  ///  1) `real`
91  ///  2) the bucket has a valid element with the same key as `element`
92  /// # Returns
93  /// * `true` - if a replacement happened (including if it was with the same value)
94  /// * `false` - otherwise (including if the element is empty)
95  fn update_if_exists(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
96    let mut updated = false;
97    for i in 0..BUCKET_SIZE {
98      let element_i = &mut self.elements[i];
99      let choice = real & !element_i.is_empty() & (element_i.element.key == element.key);
100      element_i.element.value.cmov(&element.value, choice);
101      updated.cmov(&true, choice);
102    }
103    updated
104  }
105
106  fn read_if_exists(&self, key: K, ret: &mut V) -> bool {
107    let mut found = false;
108    for i in 0..BUCKET_SIZE {
109      let element_i = &self.elements[i];
110      let choice = !element_i.is_empty() & (element_i.element.key == key);
111      ret.cmov(&element_i.element.value, choice);
112      found.cmov(&true, choice);
113    }
114    found
115  }
116
117  /// Insert an element into the bucket uf it has an empty slot, otherwise does nothing.
118  /// # Preconditions
119  /// * real ==> The same key isn't in the bucket.
120  /// # Returns
121  /// * `true` - if the element was inserted or if `real == false`
122  /// * `false` - otherwise
123  fn insert_if_available(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
124    let mut inserted = !real;
125    for i in 0..BUCKET_SIZE {
126      let element_i = &mut self.elements[i];
127      let choice = !inserted & element_i.is_empty();
128      element_i.is_valid.cmov(&true, choice);
129      element_i.element.cmov(&element, choice);
130      inserted.cmov(&true, choice);
131    }
132    inserted
133  }
134}
135
136/// An unsorted map that is oblivious to the access pattern.
137/// The map uses cuckoo hashing with size-2 buckets, two tables and a deamortization queue.
138/// `INSERTION_QUEUE_MAX_SIZE` is the maximum size of the deamortization queue.
139/// `DEAMORTIZED_INSERTIONS` is the number of deamortized insertions to perform per insert call.
140/// # Invariants
141/// * A key appears at most once in a valid element in between the two tables and the insertion queue.
142/// * The two tables have the same capacity.
143/// * The two tables have a different hash functions (different keys on the keyed hash function).
144#[derive(Debug)]
145pub struct UnsortedMap<K, V>
146where
147  K: OHash + Default + std::fmt::Debug,
148  V: Cmov + Pod + Default + std::fmt::Debug,
149{
150  /// Number of elements in the map
151  size: usize,
152  /// Maximum number of elements for perfect load `(max_size / load_factor)`
153  _capacity: usize,
154  /// Maximum number of entries in each table `(buckets / load_factor)`
155  table_size: usize,
156  /// The two tables
157  table: MultiWayArray<Bucket<K, V>, 2>,
158  /// The hasher used to hash keys
159  hash_builders: [RandomState; 2],
160  /// The insertion queue
161  insertion_queue: ShortQueue<InlineElement<K, V>, INSERTION_QUEUE_MAX_SIZE>,
162  /// Random source for random indices
163  rng: ThreadRng,
164}
165
166impl<K, V> UnsortedMap<K, V>
167where
168  K: OHash + Default + std::fmt::Debug,
169  V: Cmov + Pod + Default + std::fmt::Debug,
170{
171  /// Creates a new `UnsortedMap` with the given capacity `n`.
172  pub fn new(capacity: usize) -> Self {
173    debug_assert!(capacity > 0);
174    // For load factor of 0.8: cap / (0.8 * BUCKET_SIZE) = cap * 5 / (4 * BUCKET_SIZE)
175    let table_size = (capacity * 5).div_ceil(4 * BUCKET_SIZE).max(2);
176    Self {
177      size: 0,
178      _capacity: capacity,
179      table_size,
180      table: MultiWayArray::new(table_size),
181      hash_builders: [RandomState::new(), RandomState::new()],
182      insertion_queue: ShortQueue::new(),
183      rng: rand::rng(),
184    }
185  }
186
187  #[inline(always)]
188  fn hash_key<const TABLE: usize>(&self, key: &K) -> usize {
189    (self.hash_builders[TABLE].hash_one(key) % self.table_size as u64) as usize
190  }
191
192  /// Tries to get a value from the map.
193  /// # Returns
194  /// * `true` if the key was found
195  /// * `false` if the key wasn't found
196  /// # Postconditions
197  /// * If the key was found, the value is written to `ret`
198  /// * If the key wasn't found, `ret` is not modified
199  pub fn get(&mut self, key: K, ret: &mut V) -> bool {
200    let mut found = false;
201    let mut tmp: Bucket<K, V> = Default::default();
202
203    // Tries to get the element from each table:
204    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
205    seq!(INDEX in 0..2 {
206      let hash = self.hash_key::<INDEX>(&key);
207      self.table.read(INDEX, hash, &mut tmp);
208      let found_local = tmp.read_if_exists(key, ret);
209      found.cmov(&true, found_local);
210    });
211
212    // Tries to get the element from the deamortization queue:
213    for i in 0..self.insertion_queue.size {
214      let element = self.insertion_queue.elements.data[i];
215      let found_local = !element.is_empty() & (element.value.key == key);
216      ret.cmov(&element.value.value, found_local);
217      found.cmov(&true, found_local);
218    }
219
220    found
221  }
222
223  /// Tries to insert an element into some of the hash tables, in case of collisions, the element is replaced.
224  /// # Returns
225  /// * `true` if it was possible to insert into an empty slot or `real == false`.
226  /// * `false` if the element was replaced and therefore the new element value needs to be inserted into the insertion queue.
227  fn try_insert_entry(&mut self, real: bool, element: &mut InlineElement<K, V>) -> bool {
228    let mut done = !real;
229
230    seq!(INDEX_REV in 0..2 {{
231      #[allow(clippy::identity_op, clippy::eq_op)] // False positives due to the seq! macro.
232      const INDEX: usize = 1 - INDEX_REV;
233
234      let hash = self.hash_key::<INDEX>(&element.key);
235      self.table.update(INDEX, hash, |bucket| {
236        let choice = !done;
237        let inserted = bucket.insert_if_available(choice, *element);
238        done.cmov(&true, inserted);
239        let randidx = self.rng.random_range(0..BUCKET_SIZE);
240        bucket.elements[randidx].element.cxchg(element, !done);
241      });
242    }});
243    done
244  }
245
246  /// Deamortizes the insertion queue by trying to insert elements into the tables.
247  pub fn deamortize_insertion_queue(&mut self) {
248    for _ in 0..DEAMORTIZED_INSERTIONS {
249      let mut element = InlineElement::default();
250      let real = self.insertion_queue.size > 0;
251
252      // Use FIFO order so we don't get stuck in a loop in the random graph of cuckoo hashing.
253      self.insertion_queue.maybe_pop(real, &mut element);
254      let has_pending_element = !self.try_insert_entry(real, &mut element);
255      self.insertion_queue.maybe_push(has_pending_element, element);
256    }
257  }
258
259  /// Inserts an elementn into the map. If the insertion doesn't finish, the removed element is inserted into the insertion queue.
260  /// # Preconditions
261  /// * The key is not in the map already.
262  pub fn insert(&mut self, key: K, value: V) {
263    // UNDONE(git-32): Recover in case the insertion queue is full.
264    assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
265    self.insertion_queue.maybe_push(true, InlineElement { key, value });
266    self.deamortize_insertion_queue();
267    self.size.cmov(&(self.size + 1), true);
268  }
269
270  /// Updates a value that is already in the map.
271  /// # Preconditions
272  /// * The key is in the map.
273  /// # Returns
274  /// * `true` if the value was updated
275  pub fn write(&mut self, key: K, value: V) {
276    let mut updated = false;
277
278    // Tries to get the element from each table:
279    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
280    seq!(INDEX in 0..2 {
281      let hash = self.hash_key::<INDEX>(&key);
282      self.table.update(INDEX, hash, |bucket| {
283        let choice = !updated;
284        let updated_local = bucket.update_if_exists(choice, InlineElement { key, value });
285        updated.cmov(&true, updated_local);
286      });
287    });
288
289    // Tries to get the element from the deamortization queue:
290    for i in 0..self.insertion_queue.size {
291      let element = &mut self.insertion_queue.elements.data[i];
292      let choice = !updated & !element.is_empty() & (element.value.key == key);
293      element.value.value.cmov(&value, choice);
294      updated.cmov(&true, choice);
295    }
296
297    assert!(updated);
298  }
299
300  // UNDONE(git-33): add efficient upsert function (inserts if not present, updates if present)
301  // UNDONE(git-33): add delete function
302}
303
304// UNDONE(git-35): Add benchmarks for the map.
305
306#[cfg(test)]
307mod tests {
308  use super::*;
309
310  #[test]
311  fn test_unsorted_map() {
312    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(2);
313    assert_eq!(map.size, 0);
314    let mut value = 0;
315    assert!(!map.get(1, &mut value));
316    map.insert(1, 2);
317    assert_eq!(map.size, 1);
318    assert!(map.get(1, &mut value));
319    assert_eq!(value, 2);
320    map.write(1, 3);
321    assert!(map.get(1, &mut value));
322    assert_eq!(value, 3);
323  }
324
325  #[test]
326  fn test_full_map() {
327    const SZ: usize = 1024;
328    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(SZ);
329    assert_eq!(map.size, 0);
330    for i in 0..SZ as u32 {
331      map.insert(i, i * 2);
332      let mut value = 0;
333      assert!(map.get(i, &mut value));
334      assert_eq!(value, i * 2);
335      assert_eq!(map.size, (i + 1) as usize);
336      map.write(i, i * 3);
337      assert!(map.get(i, &mut value));
338      assert_eq!(value, i * 3);
339      assert_eq!(map.size, (i + 1) as usize);
340    }
341  }
342
343  fn test_map_subtypes<
344    K: OHash + Default + std::fmt::Debug,
345    V: Cmov + Pod + Default + std::fmt::Debug,
346  >() {
347    const SZ: usize = 1024;
348    let mut map: UnsortedMap<K, V> = UnsortedMap::new(SZ);
349    assert_eq!(map.size, 0);
350    let mut value = V::default();
351    assert!(!map.get(K::default(), &mut value));
352    map.insert(K::default(), V::default());
353    assert_eq!(map.size, 1);
354    assert!(map.get(K::default(), &mut value));
355  }
356
357  #[test]
358  fn test_map_multiple_types() {
359    test_map_subtypes::<u32, u32>();
360    test_map_subtypes::<u64, u64>();
361    test_map_subtypes::<u128, u128>();
362    test_map_subtypes::<i32, i32>();
363    test_map_subtypes::<i64, i64>();
364    test_map_subtypes::<i128, i128>();
365  }
366
367  // UNDONE(git-34): Add further tests for the map.
368}