rostl_datastructures/
map.rs

1//! Implements map related data structures.
2
3use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_primitives::{
7  cmov_body, cxchg_body, impl_cmov_for_generic_pod,
8  traits::{Cmov, _Cmovbase},
9};
10
11use seq_macro::seq;
12
13use crate::{array::MultiWayArray, queue::ShortQueue};
14
15// Size of the insertion queue for deamortized insertions that failed.
16const INSERTION_QUEUE_MAX_SIZE: usize = 10;
17// Number of deamortized insertions to perform per insertion.
18const DEAMORTIZED_INSERTIONS: usize = 2;
19// Number of elements in each map bucket.
20const BUCKET_SIZE: usize = 4;
21
22use std::hash::Hash;
23
24/// Utility trait for types that can be uses as keys in the map.
25pub trait OHash: Cmov + Pod + Hash + PartialEq {}
26// UNDONE(git-20): Hashmaps with boolean key aren't supported until boolean implements pod.s
27impl<K> OHash for K where K: Cmov + Pod + Hash + PartialEq {}
28
29/// An element in the map.
30#[repr(align(8))]
31#[repr(C)]
32#[derive(Debug, Default, Clone, Copy, Zeroable)]
33pub struct InlineElement<K, V>
34where
35  K: OHash,
36  V: Cmov + Pod,
37{
38  key: K,
39  value: V,
40}
41unsafe impl<K: OHash, V: Cmov + Pod> Pod for InlineElement<K, V> {}
42impl_cmov_for_generic_pod!(InlineElement<K,V>; where K: OHash, V: Cmov + Pod);
43
44/// A bucket in the map.
45/// The bucket has `BUCKET_SIZE` elements.
46/// # Invariants
47/// * The elements in the bucket that have `is_valid == true` are non empty.
48/// * The elements in the bucket that have `is_valid == false` are empty.
49/// * No two valid elements have the same key.
50#[derive(Debug, Default, Clone, Copy, Zeroable)]
51#[repr(C)]
52struct Bucket<K, V>
53where
54  K: OHash,
55  V: Cmov + Pod,
56{
57  is_valid: [bool; BUCKET_SIZE],
58  elements: [InlineElement<K, V>; BUCKET_SIZE],
59}
60unsafe impl<K: OHash, V: Cmov + Pod> Pod for Bucket<K, V> {}
61impl_cmov_for_generic_pod!(Bucket<K,V>; where K: OHash, V: Cmov + Pod);
62
63impl<K, V> Bucket<K, V>
64where
65  K: OHash,
66  V: Cmov + Pod,
67{
68  const fn is_empty(&self, i: usize) -> bool {
69    !self.is_valid[i]
70  }
71
72  /// Replaces the value of a Key, if:
73  ///  1) `real`
74  ///  2) the bucket has a valid element with the same key as `element`
75  /// # Returns
76  /// * `true` - if a replacement happened (including if it was with the same value)
77  /// * `false` - otherwise (including if the element is empty)
78  fn update_if_exists(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
79    let mut updated = false;
80    for i in 0..BUCKET_SIZE {
81      let choice = real & !self.is_empty(i) & (self.elements[i].key == element.key);
82      self.elements[i].value.cmov(&element.value, choice);
83      updated.cmov(&true, choice);
84    }
85    updated
86  }
87
88  fn read_if_exists(&self, key: K, ret: &mut V) -> bool {
89    let mut found = false;
90    for i in 0..BUCKET_SIZE {
91      let choice = !self.is_empty(i) & (self.elements[i].key == key);
92      ret.cmov(&self.elements[i].value, choice);
93      found.cmov(&true, choice);
94    }
95    found
96  }
97
98  /// Insert an element into the bucket uf it has an empty slot, otherwise does nothing.
99  /// # Preconditions
100  /// * real ==> The same key isn't in the bucket.
101  /// # Returns
102  /// * `true` - if the element was inserted or if `real == false`
103  /// * `false` - otherwise
104  fn insert_if_available(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
105    let mut inserted = !real;
106    for i in 0..BUCKET_SIZE {
107      let choice = !inserted & self.is_empty(i);
108      self.is_valid[i].cmov(&true, choice);
109      self.elements[i].cmov(&element, choice);
110      inserted.cmov(&true, choice);
111    }
112    inserted
113  }
114}
115
116/// An unsorted map that is oblivious to the access pattern.
117/// The map uses cuckoo hashing with size-2 buckets, two tables and a deamortization queue.
118/// `INSERTION_QUEUE_MAX_SIZE` is the maximum size of the deamortization queue.
119/// `DEAMORTIZED_INSERTIONS` is the number of deamortized insertions to perform per insert call.
120/// # Invariants
121/// * A key appears at most once in a valid element in between the two tables and the insertion queue.
122/// * The two tables have the same capacity.
123/// * The two tables have a different hash functions (different keys on the keyed hash function).
124#[derive(Debug)]
125pub struct UnsortedMap<K, V>
126where
127  K: OHash + Default + std::fmt::Debug,
128  V: Cmov + Pod + Default + std::fmt::Debug,
129{
130  /// Number of elements in the map
131  size: usize,
132  /// Maximum number of elements for perfect load `(max_size / load_factor)`
133  _capacity: usize,
134  /// Maximum number of entries in each table `(buckets / load_factor)`
135  table_size: usize,
136  /// The two tables
137  table: MultiWayArray<Bucket<K, V>, 2>,
138  /// The hasher used to hash keys
139  hash_builders: [RandomState; 2],
140  /// The insertion queue
141  insertion_queue: ShortQueue<InlineElement<K, V>, INSERTION_QUEUE_MAX_SIZE>,
142  /// Random source for random indices
143  rng: ThreadRng,
144}
145
146impl<K, V> UnsortedMap<K, V>
147where
148  K: OHash + Default + std::fmt::Debug,
149  V: Cmov + Pod + Default + std::fmt::Debug,
150{
151  /// Creates a new `UnsortedMap` with the given capacity `n`.
152  pub fn new(capacity: usize) -> Self {
153    debug_assert!(capacity > 0);
154    // For load factor of 0.8: cap / (0.8 * BUCKET_SIZE) = cap * 5 / (4 * BUCKET_SIZE)
155    // For load factor of 0.9: cap / (0.9 * BUCKET_SIZE) = cap * 10 / (9 * BUCKET_SIZE)
156    let table_size = (capacity * 10).div_ceil(9 * BUCKET_SIZE).max(2);
157    Self {
158      size: 0,
159      _capacity: capacity,
160      table_size,
161      table: MultiWayArray::new(table_size),
162      hash_builders: [RandomState::new(), RandomState::new()],
163      insertion_queue: ShortQueue::new(),
164      rng: rand::rng(),
165    }
166  }
167
168  #[inline(always)]
169  fn hash_key<const TABLE: usize>(&self, key: &K) -> usize {
170    (self.hash_builders[TABLE].hash_one(key) % self.table_size as u64) as usize
171  }
172
173  /// Tries to get a value from the map.
174  /// # Returns
175  /// * `true` if the key was found
176  /// * `false` if the key wasn't found
177  /// # Postconditions
178  /// * If the key was found, the value is written to `ret`
179  /// * If the key wasn't found, `ret` is not modified
180  pub fn get(&mut self, key: K, ret: &mut V) -> bool {
181    let mut found = false;
182    let mut tmp: Bucket<K, V> = Default::default();
183
184    // Tries to get the element from each table:
185    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
186    seq!(INDEX in 0..2 {
187      let hash = self.hash_key::<INDEX>(&key);
188      self.table.read(INDEX, hash, &mut tmp);
189      let found_local = tmp.read_if_exists(key, ret);
190      found.cmov(&true, found_local);
191    });
192
193    // Tries to get the element from the deamortization queue:
194    for i in 0..self.insertion_queue.size {
195      let element = self.insertion_queue.elements.data[i];
196      let found_local = !element.is_empty() & (element.value.key == key);
197      ret.cmov(&element.value.value, found_local);
198      found.cmov(&true, found_local);
199    }
200
201    found
202  }
203
204  /// Tries to insert an element into some of the hash tables, in case of collisions, the element is replaced.
205  /// # Returns
206  /// * `true` if it was possible to insert into an empty slot or `real == false`.
207  /// * `false` if the element was replaced and therefore the new element value needs to be inserted into the insertion queue.
208  fn try_insert_entry(&mut self, real: bool, element: &mut InlineElement<K, V>) -> bool {
209    let mut done = !real;
210
211    seq!(INDEX_REV in 0..2 {{
212      #[allow(clippy::identity_op, clippy::eq_op)] // False positives due to the seq! macro.
213      const INDEX: usize = 1 - INDEX_REV;
214
215      let hash = self.hash_key::<INDEX>(&element.key);
216      self.table.update(INDEX, hash, |bucket| {
217        let choice = !done;
218        let inserted = bucket.insert_if_available(choice, *element);
219        done.cmov(&true, inserted);
220        let randidx = self.rng.random_range(0..BUCKET_SIZE);
221        bucket.elements[randidx].cxchg(element, !done);
222      });
223    }});
224    done
225  }
226
227  /// Deamortizes the insertion queue by trying to insert elements into the tables.
228  pub fn deamortize_insertion_queue(&mut self) {
229    for _ in 0..DEAMORTIZED_INSERTIONS {
230      let mut element = InlineElement::default();
231      let real = self.insertion_queue.size > 0;
232
233      // Use FIFO order so we don't get stuck in a loop in the random graph of cuckoo hashing.
234      self.insertion_queue.maybe_pop(real, &mut element);
235      let has_pending_element = !self.try_insert_entry(real, &mut element);
236      self.insertion_queue.maybe_push(has_pending_element, element);
237    }
238  }
239
240  /// Inserts an elementn into the map. If the insertion doesn't finish, the removed element is inserted into the insertion queue.
241  /// # Preconditions
242  /// * The key is not in the map already.
243  pub fn insert(&mut self, key: K, value: V) {
244    // UNDONE(git-32): Recover in case the insertion queue is full.
245    assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
246    self.insertion_queue.maybe_push(true, InlineElement { key, value });
247    self.deamortize_insertion_queue();
248    self.size.cmov(&(self.size + 1), true);
249  }
250
251  /// Conditionally inserts an element into the map obliviously. If the insertion doesn't finish, the removed element is inserted into the insertion queue.
252  /// # Preconditions
253  /// * If `real` is true, the key is not in the map already.
254  /// # Parameters
255  /// * `real` - if true, the insertion is performed, if false, the
256  ///   insertion is a dummy insertion that doesn't modify the logical map.
257  pub fn insert_cond(&mut self, key: K, value: V, real: bool) {
258    // UNDONE(git-32): Recover in case the insertion queue is full.
259    assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
260    self.insertion_queue.maybe_push(real, InlineElement { key, value });
261    self.deamortize_insertion_queue();
262    self.size.cmov(&(self.size + 1), real);
263  }
264
265  /// Updates a value that is already in the map.
266  /// # Preconditions
267  /// * The key is in the map.
268  /// # Returns
269  /// * `true` if the value was updated
270  pub fn write(&mut self, key: K, value: V) {
271    let mut updated = false;
272
273    // Tries to get the element from each table:
274    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
275    seq!(INDEX in 0..2 {
276      let hash = self.hash_key::<INDEX>(&key);
277      self.table.update(INDEX, hash, |bucket| {
278        let choice = !updated;
279        let updated_local = bucket.update_if_exists(choice, InlineElement { key, value });
280        updated.cmov(&true, updated_local);
281      });
282    });
283
284    // Tries to get the element from the deamortization queue:
285    for i in 0..self.insertion_queue.size {
286      let element = &mut self.insertion_queue.elements.data[i];
287      let choice = !updated & !element.is_empty() & (element.value.key == key);
288      element.value.value.cmov(&value, choice);
289      updated.cmov(&true, choice);
290    }
291
292    assert!(updated);
293  }
294
295  // UNDONE(git-33): add efficient upsert function (inserts if not present, updates if present)
296  // UNDONE(git-33): add delete function
297}
298
299// UNDONE(git-35): Add benchmarks for the map.
300
301#[cfg(test)]
302mod tests {
303  use super::*;
304
305  #[test]
306  fn test_unsorted_map() {
307    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(2);
308    assert_eq!(map.size, 0);
309    let mut value = 0;
310    assert!(!map.get(1, &mut value));
311    map.insert(1, 2);
312    assert_eq!(map.size, 1);
313    assert!(map.get(1, &mut value));
314    assert_eq!(value, 2);
315    map.write(1, 3);
316    assert!(map.get(1, &mut value));
317    assert_eq!(value, 3);
318  }
319
320  #[test]
321  fn test_full_map() {
322    const SZ: usize = 1024;
323    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(SZ);
324    assert_eq!(map.size, 0);
325    for i in 0..SZ as u32 {
326      map.insert(i, i * 2);
327      let mut value = 0;
328      assert!(map.get(i, &mut value));
329      assert_eq!(value, i * 2);
330      assert_eq!(map.size, (i + 1) as usize);
331      map.write(i, i * 3);
332      assert!(map.get(i, &mut value));
333      assert_eq!(value, i * 3);
334      assert_eq!(map.size, (i + 1) as usize);
335    }
336  }
337
338  #[test]
339  fn test_insert_cond() {
340    // Test that conditional insert works when real is true and doesn't when real is false
341    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(8);
342    assert_eq!(map.size, 0);
343
344    // Dummy insertion (real = false) should not increase logical size nor make the key visible
345    map.insert_cond(10, 100, false);
346    assert_eq!(map.size, 0);
347    let mut value = 0;
348    assert!(!map.get(10, &mut value));
349
350    // Real insertion should store the value and increase size
351    map.insert_cond(10, 200, true);
352    assert_eq!(map.size, 1);
353    assert!(map.get(10, &mut value));
354    assert_eq!(value, 200);
355
356    // Another dummy insert with different value should not change stored value
357    map.insert_cond(10, 300, false);
358    assert_eq!(map.size, 1);
359    assert!(map.get(10, &mut value));
360    assert_eq!(value, 200);
361  }
362
363  fn test_map_subtypes<
364    K: OHash + Default + std::fmt::Debug,
365    V: Cmov + Pod + Default + std::fmt::Debug,
366  >() {
367    const SZ: usize = 1024;
368    let mut map: UnsortedMap<K, V> = UnsortedMap::new(SZ);
369    assert_eq!(map.size, 0);
370    let mut value = V::default();
371    assert!(!map.get(K::default(), &mut value));
372    map.insert(K::default(), V::default());
373    assert_eq!(map.size, 1);
374    assert!(map.get(K::default(), &mut value));
375  }
376
377  #[test]
378  fn test_map_multiple_types() {
379    test_map_subtypes::<u32, u32>();
380    test_map_subtypes::<u64, u64>();
381    test_map_subtypes::<u128, u128>();
382    test_map_subtypes::<i32, i32>();
383    test_map_subtypes::<i64, i64>();
384    test_map_subtypes::<i128, i128>();
385  }
386
387  // UNDONE(git-34): Add further tests for the map.
388}