rostl_datastructures/
sharded_map.rs

1//! Implements map related data structures.
2
3use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rostl_primitives::{
6  cmov_body, cxchg_body, impl_cmov_for_generic_pod,
7  ooption::OOption,
8  traits::{Cmov, _Cmovbase},
9};
10use rostl_sort::{bitonic::bitonic_sort, compaction::compact};
11
12use crate::map::{OHash, UnsortedMap};
13use kanal::{bounded, Receiver, Sender};
14use std::{
15  sync::{Arc, Barrier},
16  thread,
17};
18
19// Number of partitions in the map.
20const P: usize = 15;
21
22/// The replies from the worker thread to the main thread.
23enum Reply<K, V, const B: usize>
24where
25  K: OHash + Pod + Default + std::fmt::Debug + Ord,
26  V: Cmov + Pod + Default + std::fmt::Debug,
27  BatchBlock<K, V>: Ord + Send,
28{
29  Blocks { pid: usize, blocks: Box<[BatchBlock<K, V>; B]> },
30  Unit(()),
31}
32
33/// The command sent to the worker thread to perform a batch operation.
34///
35enum Cmd<K, V, const B: usize>
36where
37  K: OHash + Pod + Default + std::fmt::Debug + Ord,
38  V: Cmov + Pod + Default + std::fmt::Debug,
39  BatchBlock<K, V>: Ord + Send,
40{
41  /// Get a batch of blocks from the map.
42  Get {
43    blocks: Box<[BatchBlock<K, V>; B]>,
44    ret_tx: Sender<Reply<K, V, B>>,
45  },
46  /// Insert a batch of blocks into the map.
47  Insert {
48    blocks: Box<[BatchBlock<K, V>; B]>,
49    ret_tx: Sender<Reply<K, V, B>>,
50  },
51  // Shutdown the worker thread.
52  Shutdown,
53}
54
55/// A worker is the thread that manages a partition of the map.
56/// Worker threads are kept hot while while there are new queries to process.
57#[derive(Debug)]
58struct Worker<K, V, const B: usize>
59where
60  K: OHash + Pod + Default + std::fmt::Debug + Ord,
61  V: Cmov + Pod + Default + std::fmt::Debug,
62  BatchBlock<K, V>: Ord + Send,
63{
64  tx: Sender<Cmd<K, V, B>>,
65  join_handle: Option<thread::JoinHandle<()>>,
66}
67
68impl<K, V, const B: usize> Worker<K, V, B>
69where
70  K: OHash + Pod + Default + std::fmt::Debug + Ord + Send,
71  V: Cmov + Pod + Default + std::fmt::Debug + Send,
72  BatchBlock<K, V>: Ord + Send,
73{
74  /// Creates a new worker partition `pid`, with max size `n`.
75  ///
76  fn new(n: usize, pid: usize, startup_barrier: Arc<Barrier>) -> Self {
77    // Note: this bound is a bit arbitrary, 2 is enough for the map as it is implemented now.
78    let (tx, rx): (Sender<Cmd<_, _, B>>, Receiver<_>) = bounded(10);
79
80    let handler = thread::Builder::new()
81      .name(format!("partition-{pid}"))
82      .spawn(move || {
83        // block until all workers are running
84        startup_barrier.wait();
85
86        // Thread local variables:
87        // Thread-local map for this worker:
88        //
89        let mut map = UnsortedMap::<K, V>::new(n);
90
91        loop {
92          let cmd = match rx.recv() {
93            Ok(cmd) => cmd,
94            Err(_) => {
95              panic!("worker thread command channel disconnected unexpectedly");
96            }
97          };
98
99          match cmd {
100            Cmd::Get { mut blocks, ret_tx } => {
101              // println!("worker {pid} received get command with {} blocks", blocks.len());
102              for blk in blocks.iter_mut() {
103                blk.v = OOption::new(Default::default(), true);
104                blk.v.is_some = map.get(blk.k, &mut blk.v.value);
105              }
106              let _ = ret_tx.send(Reply::Blocks { pid, blocks }); // move blocks back
107            }
108            Cmd::Insert { blocks, ret_tx } => {
109              // println!("worker {pid} received insert command with {} blocks", blocks.len());
110              for blk in blocks.iter() {
111                map.insert(blk.k, blk.v.unwrap());
112              }
113              let _ = ret_tx.send(Reply::Unit(()));
114            }
115            Cmd::Shutdown => {
116              // We don't need to do anything here, the worker will exit.
117              // println!("worker {pid} received shutdown command, exiting");
118              break;
119            }
120          }
121        }
122      })
123      .expect("failed to spawn worker thread");
124
125    Self { tx, join_handle: Some(handler) }
126  }
127}
128
129impl<K, V, const B: usize> Drop for Worker<K, V, B>
130where
131  K: OHash + Pod + Default + std::fmt::Debug + Ord,
132  V: Cmov + Pod + Default + std::fmt::Debug,
133  BatchBlock<K, V>: Ord + Send,
134{
135  fn drop(&mut self) {
136    // Send a shutdown command to the worker thread.
137    let _ = self.tx.send(Cmd::Shutdown);
138    // Wait for the worker thread to finish.
139    match self.join_handle.take() {
140      Some(handle) => {
141        let _ = handle.join();
142      }
143      None => {
144        panic!("Exception while dropping worker thread, handler was already taken");
145      }
146    }
147  }
148}
149
150/// A sharded hashmap implementation.
151/// The map is split across multiple partitions and each partition is a separate hashmap.
152/// Queries are resolved in batches, to not leak the number of queries that go to each partition.
153/// # Parameters
154/// * `K`: The type of the keys in the map.
155/// * `V`: The type of the values in the map.
156/// * `P`: The number of partitions in the map.
157/// * `B`: The maximum number of non-distinct keys in any partition in a batch.
158#[derive(Debug)]
159pub struct ShardedMap<K, V, const B: usize>
160where
161  K: OHash + Pod + Default + std::fmt::Debug + Send + Ord,
162  V: Cmov + Pod + Default + std::fmt::Debug + Send,
163  BatchBlock<K, V>: Ord + Send,
164{
165  /// Number of elements in the map
166  size: usize,
167  /// capacity
168  capacity: usize,
169  /// The partitions of the map.
170  workers: [Worker<K, V, B>; P],
171  /// The random state used for hashing.
172  random_state: RandomState,
173}
174
175/// A block in a batch, that contains the key, the value and the index of the block in the original full batch.
176#[repr(C)]
177#[derive(Default, Debug, Clone, Copy, Zeroable, PartialEq, Eq, PartialOrd, Ord)]
178pub struct BatchBlock<K, V>
179where
180  K: OHash + Pod + Default + std::fmt::Debug + Ord,
181  V: Cmov + Pod + Default + std::fmt::Debug,
182{
183  index: usize,
184  k: K,
185  v: OOption<V>,
186}
187unsafe impl<K, V> Pod for BatchBlock<K, V>
188where
189  K: OHash + Pod + Default + std::fmt::Debug + Ord,
190  V: Cmov + Pod + Default + std::fmt::Debug,
191{
192}
193impl_cmov_for_generic_pod!(BatchBlock<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord, V: Cmov + Pod + Default + std::fmt::Debug);
194
195impl<K, V, const B: usize> ShardedMap<K, V, B>
196where
197  K: OHash + Default + std::fmt::Debug + Send + Ord,
198  V: Cmov + Pod + Default + std::fmt::Debug + Send,
199  BatchBlock<K, V>: Ord + Send,
200{
201  /// Creates a new `ShardedMap` with the given number of partitions.
202  pub fn new(capacity: usize) -> Self {
203    let per_part = capacity.div_ceil(P);
204    let startup = Arc::new(Barrier::new(P + 1));
205
206    let workers = std::array::from_fn(|i| Worker::new(per_part, i, startup.clone()));
207
208    // wait until all workers have reached their barrier
209    startup.wait();
210
211    Self { size: 0, capacity: per_part * P, workers, random_state: RandomState::new() }
212  }
213
214  #[inline(always)]
215  fn get_partition(&self, key: &K) -> usize {
216    (self.random_state.hash_one(key) % P as u64) as usize
217  }
218
219  /// Reads N values from the map, leaking only `N` and `B`, but not any information about the keys (doesn't leak the number of keys to each partition).
220  /// # Preconditions
221  /// * No repeated keys in the input array.
222  pub fn get_batch_distinct<const N: usize>(&mut self, keys: &[K; N]) -> [OOption<V>; N] {
223    // 1. Create P arrays of size N.
224    // let mut per_p: [[BatchBlock<K, V>; N]; P] =
225    //   [unsafe { std::mem::MaybeUninit::<[BatchBlock<K, V>; N]>::uninit().assume_init() }; P];
226    let mut per_p: [Box<[BatchBlock<K, V>; N]>; P] =
227      std::array::from_fn(|_| Box::new([BatchBlock::default(); N]));
228
229    const INVALID_ID: usize = usize::MAX;
230
231    // 2. Map each key at index i to a partition: to p[h(keys[i])][i],
232    // UNDONE(git-65): this is O(P*N), we could do N log^2 N
233    for (i, k) in keys.iter().enumerate() {
234      let target_p = self.get_partition(k);
235      for (p, partition) in per_p.iter_mut().enumerate() {
236        partition[i].k = *k;
237        partition[i].index = i;
238        partition[i].index.cmov(&INVALID_ID, target_p != p);
239      }
240    }
241
242    // 3. Apply oblivious compaction to each partition.
243    for partition in &mut per_p {
244      let cnt = compact(&mut **partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
245      // UNDONE(git-64): deal with overflow.
246      assert!(cnt <= B);
247    }
248
249    let (done_tx, done_rx) = bounded::<Reply<K, V, B>>(P);
250
251    // 4. Read the first B values from each partition in the corresponding partition.
252    for (p, partition) in per_p.iter_mut().enumerate() {
253      let blocks = {
254        let mut new_blocks = [BatchBlock::default(); B];
255        new_blocks.copy_from_slice(&partition[..B]);
256        Box::new(new_blocks)
257      };
258      self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
259    }
260
261    // 5. Collect the first B values from each partition into the results array.
262    let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * B];
263
264    for _ in 0..P {
265      match done_rx.recv().unwrap() {
266        Reply::Blocks { pid, blocks } => {
267          for b in 0..B {
268            merged[pid * B + b] = blocks[b];
269          }
270        }
271        _ => {
272          panic!("unexpected reply from worker thread (probably early termination?)");
273        }
274      }
275    }
276
277    // 6. Oblivious sort according to the index (we actually have P sorted arrays already, so we just need to merge them).
278    bitonic_sort(&mut merged);
279
280    // 7. Return the first N values from the results array.
281    let mut ret: [OOption<V>; N] = [(); N].map(|_| OOption::default());
282    for i in 0..N {
283      ret[i] = merged[i].v;
284    }
285
286    ret
287  }
288
289  /// Inserts a batch of N distinct key-value pairs into the map, distributing them across partitions.
290  ///
291  /// # Preconditions
292  /// * No repeated keys in the input array.
293  /// * All of the inserted keys are not already present in the map.
294  /// * There is enough space in the map to insert all N keys.
295  pub fn insert_batch_distinct<const N: usize>(&mut self, keys: &[K; N], values: &[V; N]) {
296    assert!(self.size + N <= self.capacity, "Map is full, cannot insert more elements.");
297    // 1. Create P arrays of size N.
298    let mut per_p: [Box<[BatchBlock<K, V>; B]>; P] =
299      std::array::from_fn(|_| Box::new([BatchBlock::default(); B]));
300
301    const INVALID_ID: usize = usize::MAX;
302
303    // 2. Map each key at index i to a partition: to p[h(keys[i])][i],
304    // UNDONE(git-65): this is O(P*N), we could do N log^2 N
305    for (i, k) in keys.iter().enumerate() {
306      let target_p = self.get_partition(k);
307      for (p, partition) in per_p.iter_mut().enumerate() {
308        partition[i].k = *k;
309        partition[i].v = OOption::new(values[i], true);
310        partition[i].index = i;
311        partition[i].index.cmov(&INVALID_ID, target_p != p);
312      }
313    }
314
315    // 3. Apply oblivious compaction to each partition.
316    for partition in &mut per_p {
317      let cnt = compact(&mut **partition, |x| x.index == INVALID_ID);
318      // UNDONE(git-64): deal with overflow.
319      assert!(cnt <= B);
320    }
321
322    let (done_tx, done_rx) = bounded::<Reply<K, V, B>>(P);
323
324    // 4. Insert the first B values from each partition in the corresponding partition.
325    for (p, partition) in per_p.iter_mut().enumerate() {
326      let blocks = std::mem::replace(partition, Box::new([BatchBlock::default(); B]));
327      self.workers[p].tx.send(Cmd::Insert { blocks, ret_tx: done_tx.clone() }).unwrap();
328    }
329
330    // 5. Receive the write receipts from the worker threads.
331    for _i in 0..P {
332      match done_rx.recv().unwrap() {
333        Reply::Unit(()) => {}
334        _ => {
335          panic!("unexpected reply from worker thread (probably early termination?)");
336        }
337      }
338    }
339
340    // 6. Update the size of the map.
341    self.size += N;
342  }
343}
344
345#[cfg(test)]
346mod tests {
347  use super::*;
348
349  // For all the tests below we keep B == N so that
350  // the per‑partition overflow assert! in the map never fires.
351  const N: usize = 4;
352  const B: usize = N;
353
354  #[test]
355  fn new_map_rounds_capacity_and_starts_empty() {
356    let requested = 100;
357    let map: ShardedMap<u64, u64, B> = ShardedMap::new(requested);
358
359    // Inside the same module we can see private fields.
360    let per_part = requested.div_ceil(P);
361    assert_eq!(map.capacity, per_part * P); // rounded up
362    assert_eq!(map.size, 0);
363  }
364
365  #[test]
366  fn insert_batch_then_get_batch_returns_expected_values() {
367    let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(32);
368
369    let keys: [u64; N] = [1, 2, 3, 4];
370    let values: [u64; N] = [10, 20, 30, 40];
371
372    map.insert_batch_distinct::<N>(&keys, &values);
373
374    let results = map.get_batch_distinct::<N>(&keys);
375    for i in 0..N {
376      assert!(results[i].is_some(), "key {} missing", keys[i]);
377      assert_eq!(results[i].unwrap(), values[i]);
378    }
379  }
380
381  #[test]
382  fn querying_absent_keys_returns_none() {
383    let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(16);
384
385    let absent: [u64; N] = [100, 200, 300, 400];
386    let results = map.get_batch_distinct::<N>(&absent);
387
388    for r in &results {
389      assert!(!r.is_some());
390    }
391  }
392
393  #[test]
394  fn size_updates_after_insert() {
395    let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(16);
396
397    let keys: [u64; N] = [11, 22, 33, 44];
398    let values: [u64; N] = [111, 222, 333, 444];
399
400    map.insert_batch_distinct::<N>(&keys, &values);
401    assert_eq!(map.size, N);
402  }
403}