rostl_datastructures/
map.rs

1//! Implements map related data structures.
2
3use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_primitives::{
7  cmov_body, cxchg_body, impl_cmov_for_generic_pod,
8  traits::{Cmov, _Cmovbase},
9};
10
11use seq_macro::seq;
12
13use crate::{array::DynamicArray, queue::ShortQueue};
14
15// Size of the insertion queue for deamortized insertions that failed.
16const INSERTION_QUEUE_MAX_SIZE: usize = 20;
17// Number of deamortized insertions to perform per insertion.
18const DEAMORTIZED_INSERTIONS: usize = 2;
19// Number of elements in each map bucket.
20const BUCKET_SIZE: usize = 2;
21
22use std::hash::Hash;
23
24/// Utility trait for types that can be uses as keys in the map.
25pub trait OHash: Cmov + Pod + Hash + PartialEq {}
26// UNDONE(git-20): Hashmaps with boolean key aren't supported until boolean implements pod.s
27impl<K> OHash for K where K: Cmov + Pod + Hash + PartialEq {}
28
29/// An element in the map.
30#[repr(align(16))]
31#[derive(Debug, Default, Clone, Copy, Zeroable)]
32pub struct InlineElement<K, V>
33where
34  K: OHash,
35  V: Cmov + Pod,
36{
37  key: K,
38  value: V,
39}
40unsafe impl<K: OHash, V: Cmov + Pod> Pod for InlineElement<K, V> {}
41impl_cmov_for_generic_pod!(InlineElement<K,V>; where K: OHash, V: Cmov + Pod);
42
43#[derive(Debug, Default, Clone, Copy, Zeroable)]
44/// A struct that represents an element in a bucket.
45pub struct BucketElement<K, V>
46where
47  K: OHash,
48  V: Cmov + Pod,
49{
50  is_valid: bool,
51  element: InlineElement<K, V>,
52}
53unsafe impl<K: OHash, V: Cmov + Pod> Pod for BucketElement<K, V> {}
54impl_cmov_for_generic_pod!(BucketElement<K,V>; where K: OHash, V: Cmov + Pod);
55
56impl<K, V> BucketElement<K, V>
57where
58  K: OHash,
59  V: Cmov + Pod,
60{
61  #[inline(always)]
62  const fn is_empty(&self) -> bool {
63    !self.is_valid
64  }
65}
66
67/// A bucket in the map.
68/// The bucket has `BUCKET_SIZE` elements.
69/// # Invariants
70/// * The elements in the bucket that have `is_valid == true` are non empty.
71/// * The elements in the bucket that have `is_valid == false` are empty.
72/// * No two valid elements have the same key.
73#[repr(align(64))]
74#[derive(Debug, Default, Clone, Copy, Zeroable)]
75struct Bucket<K, V>
76where
77  K: OHash,
78  V: Cmov + Pod,
79{
80  elements: [BucketElement<K, V>; BUCKET_SIZE],
81}
82unsafe impl<K: OHash, V: Cmov + Pod> Pod for Bucket<K, V> {}
83impl_cmov_for_generic_pod!(Bucket<K,V>; where K: OHash, V: Cmov + Pod);
84
85impl<K, V> Bucket<K, V>
86where
87  K: OHash,
88  V: Cmov + Pod,
89{
90  /// Replaces the value of a Key, if:
91  ///  1) `real`
92  ///  2) the bucket has a valid element with the same key as `element`
93  /// # Returns
94  /// * `true` - if a replacement happened (including if it was with the same value)
95  /// * `false` - otherwise (including if the element is empty)
96  fn update_if_exists(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
97    let mut updated = false;
98    for i in 0..BUCKET_SIZE {
99      let element_i = &mut self.elements[i];
100      let choice = real & !element_i.is_empty() & (element_i.element.key == element.key);
101      element_i.element.value.cmov(&element.value, choice);
102      updated.cmov(&true, choice);
103    }
104    updated
105  }
106
107  fn read_if_exists(&self, key: K, ret: &mut V) -> bool {
108    let mut found = false;
109    for i in 0..BUCKET_SIZE {
110      let element_i = &self.elements[i];
111      let choice = !element_i.is_empty() & (element_i.element.key == key);
112      ret.cmov(&element_i.element.value, choice);
113      found.cmov(&true, choice);
114    }
115    found
116  }
117
118  /// Insert an element into the bucket uf it has an empty slot, otherwise does nothing.
119  /// # Preconditions
120  /// * real ==> The same key isn't in the bucket.
121  /// # Returns
122  /// * `true` - if the element was inserted or if `real == false`
123  /// * `false` - otherwise
124  fn insert_if_available(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
125    let mut inserted = !real;
126    for i in 0..BUCKET_SIZE {
127      let element_i = &mut self.elements[i];
128      let choice = !inserted & element_i.is_empty();
129      element_i.is_valid.cmov(&true, choice);
130      element_i.element.cmov(&element, choice);
131      inserted.cmov(&true, choice);
132    }
133    inserted
134  }
135}
136
137/// An unsorted map that is oblivious to the access pattern.
138/// The map uses cuckoo hashing with size-2 buckets, two tables and a deamortization queue.
139/// `INSERTION_QUEUE_MAX_SIZE` is the maximum size of the deamortization queue.
140/// `DEAMORTIZED_INSERTIONS` is the number of deamortized insertions to perform per insert call.
141/// # Invariants
142/// * A key appears at most once in a valid element in between the two tables and the insertion queue.
143/// * The two tables have the same capacity.
144/// * The two tables have a different hash functions (different keys on the keyed hash function).
145#[derive(Debug)]
146pub struct UnsortedMap<K, V>
147where
148  K: OHash + Default + std::fmt::Debug,
149  V: Cmov + Pod + Default + std::fmt::Debug,
150{
151  /// Number of elements in the map
152  size: usize,
153  /// capacity
154  capacity: usize,
155  /// The two tables
156  table: [DynamicArray<Bucket<K, V>>; 2],
157  /// The hasher used to hash keys
158  hash_builders: [RandomState; 2],
159  /// The insertion queue
160  insertion_queue: ShortQueue<InlineElement<K, V>, INSERTION_QUEUE_MAX_SIZE>,
161  /// Random source for random indices
162  rng: ThreadRng,
163}
164
165impl<K, V> UnsortedMap<K, V>
166where
167  K: OHash + Default + std::fmt::Debug,
168  V: Cmov + Pod + Default + std::fmt::Debug,
169{
170  /// Creates a new `UnsortedMap` with the given capacity `n`.
171  pub fn new(capacity: usize) -> Self {
172    debug_assert!(capacity > 0);
173    Self {
174      size: 0,
175      capacity,
176      table: [DynamicArray::new(capacity), DynamicArray::new(capacity)],
177      hash_builders: [RandomState::new(), RandomState::new()],
178      insertion_queue: ShortQueue::new(),
179      rng: rand::rng(),
180    }
181  }
182
183  #[inline(always)]
184  fn hash_key<const TABLE: usize>(&self, key: &K) -> usize {
185    (self.hash_builders[TABLE].hash_one(key) % self.capacity as u64) as usize
186  }
187
188  /// Tries to get a value from the map.
189  /// # Returns
190  /// * `true` if the key was found
191  /// * `false` if the key wasn't found
192  /// # Postconditions
193  /// * If the key was found, the value is written to `ret`
194  /// * If the key wasn't found, `ret` is not modified
195  pub fn get(&mut self, key: K, ret: &mut V) -> bool {
196    let mut found = false;
197    let mut tmp: Bucket<K, V> = Default::default();
198
199    // Tries to get the element from each table:
200    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
201    seq!(INDEX in 0..2 {
202      let hash = self.hash_key::<INDEX>(&key);
203      let table = &mut self.table[INDEX];
204      table.read(hash, &mut tmp);
205      let found_local = tmp.read_if_exists(key, ret);
206      found.cmov(&true, found_local);
207    });
208
209    // Tries to get the element from the deamortization queue:
210    for i in 0..self.insertion_queue.size {
211      let element = self.insertion_queue.elements.data[i];
212      let found_local = !element.is_empty() & (element.value.key == key);
213      ret.cmov(&element.value.value, found_local);
214      found.cmov(&true, found_local);
215    }
216
217    found
218  }
219
220  /// Tries to insert an element into some of the hash tables, in case of collisions, the element is replaced.
221  /// # Returns
222  /// * `true` if it was possible to insert into an empty slot or `real == false`.
223  /// * `false` if the element was replaced and therefore the new element value needs to be inserted into the insertion queue.
224  fn try_insert_entry(&mut self, real: bool, element: &mut InlineElement<K, V>) -> bool {
225    let mut done = !real;
226
227    seq!(INDEX_REV in 0..2 {{
228      #[allow(clippy::identity_op, clippy::eq_op)] // False positives due to the seq! macro.
229      const INDEX: usize = 1 - INDEX_REV;
230
231      let hash = self.hash_key::<INDEX>(&element.key);
232      let table = &mut self.table[INDEX];
233      table.update(hash, |bucket| {
234        let choice = !done;
235        let inserted = bucket.insert_if_available(choice, *element);
236        done.cmov(&true, inserted);
237        let randidx = self.rng.random_range(0..BUCKET_SIZE);
238        bucket.elements[randidx].element.cxchg(element, !done);
239      });
240    }});
241    done
242  }
243
244  /// Deamortizes the insertion queue by trying to insert elements into the tables.
245  pub fn deamortize_insertion_queue(&mut self) {
246    for _ in 0..DEAMORTIZED_INSERTIONS {
247      let mut element = InlineElement::default();
248      let real = self.insertion_queue.size > 0;
249
250      // Use FIFO order so we don't get stuck in a loop in the random graph of cuckoo hashing.
251      self.insertion_queue.maybe_pop(real, &mut element);
252      let has_pending_element = !self.try_insert_entry(real, &mut element);
253      self.insertion_queue.maybe_push(has_pending_element, element);
254    }
255  }
256
257  /// Inserts an elementn into the map. If the insertion doesn't finish, the removed element is inserted into the insertion queue.
258  /// # Preconditions
259  /// * The key is not in the map already.
260  pub fn insert(&mut self, key: K, value: V) {
261    // UNDONE(git-32): Recover in case the insertion queue is full.
262    assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
263    self.insertion_queue.maybe_push(true, InlineElement { key, value });
264    self.deamortize_insertion_queue();
265    self.size.cmov(&(self.size + 1), true);
266  }
267
268  /// Updates a value that is already in the map.
269  /// # Preconditions
270  /// * The key is in the map.
271  /// # Returns
272  /// * `true` if the value was updated
273  pub fn write(&mut self, key: K, value: V) {
274    let mut updated = false;
275
276    // Tries to get the element from each table:
277    // seq! does manual loop unrolling in rust. We need it to be able to use the constant INDEX in the hash_key function.
278    seq!(INDEX in 0..2 {
279      let hash = self.hash_key::<INDEX>(&key);
280      let table = &mut self.table[INDEX];
281      table.update(hash, |bucket| {
282        let choice = !updated;
283        let updated_local = bucket.update_if_exists(choice, InlineElement { key, value });
284        updated.cmov(&true, updated_local);
285      });
286    });
287
288    // Tries to get the element from the deamortization queue:
289    for i in 0..self.insertion_queue.size {
290      let element = &mut self.insertion_queue.elements.data[i];
291      let choice = !updated & !element.is_empty() & (element.value.key == key);
292      element.value.value.cmov(&value, choice);
293      updated.cmov(&true, choice);
294    }
295
296    assert!(updated);
297  }
298
299  // UNDONE(git-33): add efficient upsert function (inserts if not present, updates if present)
300  // UNDONE(git-33): add delete function
301}
302
303// UNDONE(git-35): Add benchmarks for the map.
304
305#[cfg(test)]
306mod tests {
307  use super::*;
308
309  #[test]
310  fn test_unsorted_map() {
311    let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(2);
312    assert_eq!(map.size, 0);
313    let mut value = 0;
314    assert!(!map.get(1, &mut value));
315    map.insert(1, 2);
316    assert_eq!(map.size, 1);
317    assert!(map.get(1, &mut value));
318    assert_eq!(value, 2);
319    map.write(1, 3);
320    assert!(map.get(1, &mut value));
321    assert_eq!(value, 3);
322  }
323
324  // UNDONE(git-34): Add further tests for the map.
325}