rostl_datastructures/
sharded_map.rs

1//! Implements map related data structures.
2
3use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rostl_primitives::{
6  cmov_body, cxchg_body, impl_cmov_for_generic_pod,
7  ooption::OOption,
8  traits::{Cmov, _Cmovbase},
9};
10use rostl_sort::{
11  bitonic::{bitonic_payload_sort, bitonic_sort},
12  compaction::{compact, compact_payload, distribute_payload},
13};
14
15use crate::map::{OHash, UnsortedMap};
16use kanal::{bounded, unbounded, Receiver, Sender};
17// use crossbeam::channel::{bounded, unbounded, Receiver, Sender};
18use std::{
19  io,
20  sync::{Arc, Barrier},
21  thread,
22};
23// use tracing::info;
24
25/// Number of partitions in the map.
26const P: usize = 15;
27
28/// The replies from the worker thread to the main thread.
29enum Reply<K, V>
30where
31  K: OHash + Pod + Default + std::fmt::Debug + Ord,
32  V: Cmov + Pod + Default + std::fmt::Debug,
33  BatchBlock<K, V>: Ord + Send,
34{
35  Blocks { pid: usize, blocks: Vec<BatchBlock<K, V>> },
36  Unit(()),
37}
38
39enum Replyv2<V>
40where
41  V: Cmov + Pod + Default + std::fmt::Debug,
42{
43  Blocks { pid: usize, offset: usize, values: Vec<OOption<V>> },
44  Unit(()),
45}
46
47/// The command sent to the worker thread to perform a batch operation.
48///
49enum Cmd<K, V>
50where
51  K: OHash + Pod + Default + std::fmt::Debug + Ord,
52  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
53  BatchBlock<K, V>: Ord + Send,
54{
55  /// Get a batch of blocks from the map.
56  Get {
57    blocks: Vec<BatchBlock<K, V>>,
58    ret_tx: Sender<Reply<K, V>>,
59  },
60  /// Insert a batch of blocks into the map.
61  Insert {
62    blocks: Vec<BatchBlock<K, V>>,
63    ret_tx: Sender<Reply<K, V>>,
64  },
65  Getv2 {
66    offset: usize,
67    blocks: Vec<K>,
68  },
69  Insertv2 {
70    blocks: Vec<KeyWithPartValue<K, V>>,
71  },
72  // Shutdown the worker thread.
73  Shutdown,
74}
75
76/// A worker is the thread that manages a partition of the map.
77/// Worker threads are kept hot while while there are new queries to process.
78#[derive(Debug)]
79struct Worker<K, V>
80where
81  K: OHash + Pod + Default + std::fmt::Debug + Ord,
82  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
83  BatchBlock<K, V>: Ord + Send,
84{
85  tx: Sender<Cmd<K, V>>,
86  join_handle: Option<thread::JoinHandle<()>>,
87}
88
89#[allow(unused)]
90fn pin_current_thread_to(cpu: usize) -> io::Result<()> {
91  unsafe {
92    let mut set: libc::cpu_set_t = std::mem::zeroed();
93    libc::CPU_ZERO(&mut set);
94    libc::CPU_SET(cpu, &mut set);
95    let ret = libc::pthread_setaffinity_np(
96      libc::pthread_self(),
97      std::mem::size_of::<libc::cpu_set_t>(),
98      &raw const set,
99    );
100    if ret != 0 {
101      return Err(io::Error::from_raw_os_error(ret));
102    }
103  }
104  Ok(())
105}
106
107fn set_current_thread_rt(priority: i32) -> io::Result<()> {
108  unsafe {
109    // Check range with sched_get_priority_min/max(SCHED_FIFO)
110    let ret = libc::setpriority(libc::PRIO_PROCESS, 0, priority);
111    if ret != 0 {
112      eprintln!("setpriority failed: {}", std::io::Error::last_os_error());
113    }
114  }
115  Ok(())
116}
117
118impl<K, V> Worker<K, V>
119where
120  K: OHash + Pod + Default + std::fmt::Debug + Ord + Send,
121  V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
122  BatchBlock<K, V>: Ord + Send,
123{
124  /// Creates a new worker partition `pid`, with max size `n`.
125  ///
126  fn new(
127    n: usize,
128    pid: usize,
129    startup_barrier: Arc<Barrier>,
130    reply_channel: Sender<Replyv2<V>>,
131  ) -> Self {
132    let (tx, rx): (Sender<Cmd<_, _>>, Receiver<_>) = unbounded();
133
134    let handler = thread::Builder::new()
135      .name(format!("partition-{pid}"))
136      .spawn(move || {
137        // pin thread to CPU
138        // pin_current_thread_to(pid).expect("failed to pin thread to CPU");
139        set_current_thread_rt(0).expect("failed to set thread to real-time priority");
140
141        // block until all workers are running
142        startup_barrier.wait();
143
144        // Thread local variables:
145        // Thread-local map for this worker:
146        //
147        let mut map = UnsortedMap::<K, V>::new(n);
148
149        loop {
150          let cmd = match rx.recv() {
151            Ok(cmd) => cmd,
152            Err(_) => {
153              panic!("worker thread command channel disconnected unexpectedly");
154            }
155          };
156
157          match cmd {
158            Cmd::Get { mut blocks, ret_tx } => {
159              for blk in &mut blocks {
160                blk.v = OOption::new(Default::default(), true);
161                blk.v.is_some = map.get(blk.k, &mut blk.v.value);
162              }
163              let _ = ret_tx.send(Reply::Blocks { pid, blocks }); // move blocks back
164            }
165            Cmd::Insert { blocks, ret_tx } => {
166              for blk in &blocks {
167                map.insert_cond(blk.k, blk.v.value, blk.v.is_some);
168              }
169              let _ = ret_tx.send(Reply::Unit(()));
170            }
171            Cmd::Getv2 { offset, blocks } => {
172              let mut values = vec![OOption::<V>::default(); blocks.len()];
173              for (i, k) in blocks.iter().enumerate() {
174                values[i].is_some = map.get(*k, &mut values[i].value);
175              }
176              let _ = reply_channel.send(Replyv2::Blocks { pid, offset, values });
177              // move blocks back
178            }
179            Cmd::Insertv2 { blocks } => {
180              for blk in &blocks {
181                let real = blk.partition == pid;
182                map.insert_cond(blk.key, blk.value, real);
183              }
184              let _ = reply_channel.send(Replyv2::Unit(()));
185            }
186            Cmd::Shutdown => break,
187          }
188        }
189      })
190      .expect("failed to spawn worker thread");
191
192    Self { tx, join_handle: Some(handler) }
193  }
194}
195
196impl<K, V> Drop for Worker<K, V>
197where
198  K: OHash + Pod + Default + std::fmt::Debug + Ord,
199  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
200  BatchBlock<K, V>: Ord + Send,
201{
202  fn drop(&mut self) {
203    // Send a shutdown command to the worker thread.
204    let _ = self.tx.send(Cmd::Shutdown);
205    // Wait for the worker thread to finish.
206    match self.join_handle.take() {
207      Some(handle) => {
208        let _ = handle.join();
209      }
210      None => {
211        panic!("Exception while dropping worker thread, handler was already taken");
212      }
213    }
214  }
215}
216
217/// A sharded hashmap implementation.
218/// The map is split across multiple partitions and each partition is a separate hashmap.
219/// Queries are resolved in batches, to not leak the number of queries that go to each partition.
220/// # Parameters
221/// * `K`: The type of the keys in the map.
222/// * `V`: The type of the values in the map.
223/// * `P`: The number of partitions in the map.
224/// * `B`: The maximum number of non-distinct keys in any partition in a batch.
225#[derive(Debug)]
226pub struct ShardedMap<K, V>
227where
228  K: OHash + Pod + Default + std::fmt::Debug + Send + Ord,
229  V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
230  BatchBlock<K, V>: Ord + Send,
231{
232  /// Number of elements in the map
233  size: usize,
234  /// capacity
235  capacity: usize,
236  /// The partitions of the map.
237  workers: [Worker<K, V>; P],
238  /// The random state used for hashing.
239  random_state: RandomState,
240  /// Channel for quickly receiving replies from worker threads.
241  response_channel: Receiver<Replyv2<V>>,
242}
243
244/// A block in a batch, that contains the key, the value and the index of the block in the original full batch.
245#[repr(C)]
246#[derive(Default, Debug, Clone, Copy, Zeroable, PartialEq, Eq, PartialOrd, Ord)]
247pub struct BatchBlock<K, V>
248where
249  K: OHash + Pod + Default + std::fmt::Debug + Ord,
250  V: Cmov + Pod + Default + std::fmt::Debug,
251{
252  index: usize,
253  k: K,
254  v: OOption<V>,
255}
256unsafe impl<K, V> Pod for BatchBlock<K, V>
257where
258  K: OHash + Pod + Default + std::fmt::Debug + Ord,
259  V: Cmov + Pod + Default + std::fmt::Debug,
260{
261}
262impl_cmov_for_generic_pod!(BatchBlock<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord, V: Cmov + Pod + Default + std::fmt::Debug);
263
264#[repr(C)]
265#[derive(Debug, Clone, Copy, Zeroable, PartialEq, Eq)]
266struct KeyWithPart<K>
267where
268  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
269{
270  partition: usize,
271  key: K,
272}
273unsafe impl<K> Pod for KeyWithPart<K> where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized {}
274impl_cmov_for_generic_pod!(KeyWithPart<K>; where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized);
275
276impl<K> KeyWithPart<K>
277where
278  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
279{
280  fn cmp_ct(&self, other: &Self) -> std::cmp::Ordering {
281    let part = self.partition.cmp(&other.partition) as i8;
282    let key = self.key.cmp(&other.key) as i8;
283
284    let mut res = part;
285    res.cmov(&key, part == 0);
286
287    res.cmp(&0)
288  }
289}
290
291impl<K> PartialOrd for KeyWithPart<K>
292where
293  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
294{
295  #[allow(clippy::non_canonical_partial_ord_impl)]
296  fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297    Some(self.cmp_ct(other))
298  }
299}
300
301impl<K> Ord for KeyWithPart<K>
302where
303  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
304{
305  fn cmp(&self, other: &Self) -> std::cmp::Ordering {
306    self.cmp_ct(other)
307  }
308}
309
310#[repr(C)]
311#[derive(Debug, Clone, Copy, Zeroable, PartialEq, Eq)]
312struct KeyWithPartValue<K, V>
313where
314  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
315  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
316{
317  partition: usize,
318  key: K,
319  value: V,
320}
321unsafe impl<K, V> Pod for KeyWithPartValue<K, V>
322where
323  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
324  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
325{
326}
327impl_cmov_for_generic_pod!(KeyWithPartValue<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized, V: Cmov + Pod + Default + std::fmt::Debug +Eq);
328
329impl<K, V> KeyWithPartValue<K, V>
330where
331  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
332  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
333{
334  fn cmp_ct(&self, other: &Self) -> std::cmp::Ordering {
335    let part = self.partition.cmp(&other.partition) as i8;
336    let key = self.key.cmp(&other.key) as i8;
337
338    let mut res = part;
339    res.cmov(&key, part == 0);
340
341    res.cmp(&0)
342  }
343}
344
345impl<K, V> PartialOrd for KeyWithPartValue<K, V>
346where
347  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
348  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
349{
350  #[allow(clippy::non_canonical_partial_ord_impl)]
351  fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
352    Some(self.cmp_ct(other))
353  }
354}
355
356impl<K, V> Ord for KeyWithPartValue<K, V>
357where
358  K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
359  V: Cmov + Pod + Default + std::fmt::Debug + Eq,
360{
361  fn cmp(&self, other: &Self) -> std::cmp::Ordering {
362    self.cmp_ct(other)
363  }
364}
365
366impl<K, V> ShardedMap<K, V>
367where
368  K: OHash + Default + std::fmt::Debug + Send + Ord + Pod + Sized,
369  V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
370  BatchBlock<K, V>: Ord + Send,
371{
372  /// Creates a new `ShardedMap` with the given number of partitions.
373  pub fn new(capacity: usize) -> Self {
374    let per_part = capacity.div_ceil(P);
375    let startup = Arc::new(Barrier::new(P + 1));
376
377    let (reply_tx, response_channel) = unbounded::<Replyv2<V>>();
378
379    let workers =
380      std::array::from_fn(|i| Worker::new(per_part, i, startup.clone(), reply_tx.clone()));
381
382    // wait until all workers have reached their barrier
383    startup.wait();
384
385    Self {
386      size: 0,
387      capacity: per_part * P,
388      workers,
389      random_state: RandomState::new(),
390      response_channel,
391    }
392  }
393
394  #[inline(always)]
395  fn get_partition(&self, key: &K) -> usize {
396    (self.random_state.hash_one(key) % P as u64) as usize
397  }
398
399  /// Computes a safe batch size B for a given number of distinct queries N.
400  /// # Preconditions
401  /// * N >= P log P
402  pub const fn compute_safe_batch_size(&self, n: usize) -> usize {
403    let a = n.div_ceil(P) + (n * (P.ilog2() as usize + 1)).div_ceil(P).isqrt() + 20; // Safety margin for small N
404    if a < n {
405      a
406    } else {
407      n
408    }
409  }
410
411  /// Reads N values from the map, leaking only `N` and `B`, but not any information about the keys (doesn't leak the number of keys to each partition).
412  /// # Preconditions
413  /// * No repeated keys in the input array.
414  /// * There are at most `b` queries to each partition.
415  pub fn get_batch_distinct(&mut self, keys: &[K], b: usize) -> Vec<OOption<V>> {
416    let n: usize = keys.len();
417    assert!(b <= n, "batch size b must be <= number of keys");
418
419    // 1. Create P arrays of size N.
420    // let mut per_p: [[BatchBlock<K, V>; N]; P] =
421    //   [unsafe { std::mem::MaybeUninit::<[BatchBlock<K, V>; N]>::uninit().assume_init() }; P];
422    let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
423      std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
424
425    const INVALID_ID: usize = usize::MAX;
426
427    // 2. Map each key at index i to a partition: to p[h(keys[i])][i],
428    // UNDONE(git-65): this is O(P*n), we could do n log^2 n
429    for (i, k) in keys.iter().enumerate() {
430      let target_p = self.get_partition(k);
431      for (p, partition) in per_p.iter_mut().enumerate() {
432        partition[i].k = *k;
433        partition[i].index = i;
434        partition[i].index.cmov(&INVALID_ID, target_p != p);
435      }
436    }
437
438    // 3. Apply oblivious compaction to each partition.
439    for partition in &mut per_p {
440      let cnt = compact(partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
441      // UNDONE(git-64): deal with overflow.
442      assert!(cnt <= b);
443    }
444
445    let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
446
447    // 4. Read the first B values from each partition in the corresponding partition.
448    for (p, partition) in per_p.iter_mut().enumerate() {
449      let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
450      self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
451    }
452
453    // 5. Collect the first B values from each partition into the results array.
454    let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * b];
455
456    for _ in 0..P {
457      match done_rx.recv().unwrap() {
458        Reply::Blocks { pid, blocks } => {
459          for i in 0..b {
460            merged[pid * b + i] = blocks[i];
461          }
462        }
463        _ => panic!("unexpected reply from worker thread (probably early termination?)"),
464      }
465    }
466
467    // 6. Oblivious sort according to the index (we actually have P sorted arrays already, so we just need to merge them).
468    bitonic_sort(&mut merged);
469
470    // 7. Return the first n values from the results array.
471    let mut ret: Vec<OOption<V>> = vec![OOption::default(); n];
472
473    for i in 0..n {
474      ret[i] = merged[i].v;
475    }
476
477    ret
478  }
479
480  /// Reads N values from the map, leaking only `N` and `B`, but not any information about the keys (doesn't leak the number of keys to each partition).
481  /// # Preconditions
482  /// * There are at most `b` queries to each partition (statistically likely).
483  pub fn get_batch(&mut self, keys: &[K], b: usize) -> Vec<OOption<V>> {
484    // info!("get_batch called with n = {}, b = {}", keys.len(), b);
485    // let now = std::time::Instant::now();
486    let n: usize = keys.len();
487    assert!(n > 0, "get_batch requires at least one key");
488    assert!(b <= n, "batch size b must be <= number of keys");
489    let bp = b * P;
490    assert!(b >= n.div_ceil(P), "batch size b must be >= n/P to avoid overflow");
491    const SUBTASK_SIZE: usize = 32;
492
493    // 1. Sort the keys by partition.
494    let mut keyinfo = vec![KeyWithPart { partition: P, key: K::default() }; bp];
495    for i in 0..n {
496      keyinfo[i].key = keys[i];
497      keyinfo[i].partition = self.get_partition(&keys[i]);
498    }
499
500    let mut index_map_1 = (0..n).collect::<Vec<usize>>();
501    bitonic_payload_sort::<KeyWithPart<K>, [KeyWithPart<K>], usize>(
502      &mut keyinfo[..n],
503      &mut index_map_1,
504    );
505
506    // 2. Compute unique keys for each partition and remove duplicates to the end.
507    let mut par_load = [0; P];
508    let mut prefix_sum_1 = vec![0; n + 1];
509
510    prefix_sum_1[1] = 1;
511    for (j, load) in par_load.iter_mut().enumerate() {
512      let cond = keyinfo[0].partition == j;
513      load.cmov(&1, cond);
514    }
515    for i in 1..n {
516      let new_key = keyinfo[i].key != keyinfo[i - 1].key;
517      prefix_sum_1[i + 1] = prefix_sum_1[i];
518
519      let alt = prefix_sum_1[i] + 1;
520      prefix_sum_1[i + 1].cmov(&alt, new_key);
521
522      for (j, load) in par_load.iter_mut().enumerate() {
523        let cond = keyinfo[i].partition == j;
524        let alt = *load + 1;
525        load.cmov(&alt, cond & new_key);
526      }
527    }
528    // for i in n..(np + 1) {
529    //   prefix_sum_1[i] = prefix_sum_1[n];
530    // }
531
532    for (j, load) in par_load.iter().enumerate() {
533      assert!(*load <= b, "Too many distinct keys in partition {j}: {}, increase b", *load);
534    }
535
536    compact_payload(&mut keyinfo[..n], &prefix_sum_1);
537
538    // 3. Create a distribution of the unique keys to partitions.
539    let mut par_load_ps = [0; P + 1];
540    for j in 0..P {
541      par_load_ps[j + 1] = par_load_ps[j] + par_load[j];
542    }
543    let mut prefix_sum_2 = vec![0; bp + 1];
544    for j in 0..P {
545      for i in 0..b {
546        let mut rank_in_part = i + 1;
547        rank_in_part.cmov(&par_load[j], rank_in_part > par_load[j]);
548        prefix_sum_2[j * b + i + 1] = par_load_ps[j] + rank_in_part;
549      }
550    }
551    distribute_payload(&mut keyinfo, &prefix_sum_2);
552
553    // info!("get_batch preprocessing took {:?}", now.elapsed());
554    // let now = std::time::Instant::now();
555
556    let mut sent_count = 0;
557    // 4. Read the first B values from each partition in the corresponding partition.
558    for j in 0..P {
559      for k in 0..b.div_ceil(SUBTASK_SIZE) {
560        let offset = k * SUBTASK_SIZE;
561        let low = j * b + offset;
562        let high = (low + SUBTASK_SIZE).min((j + 1) * b);
563        let blocks: Vec<K> = keyinfo[low..high].iter().map(|x| x.key).collect();
564        self.workers[j].tx.send(Cmd::Getv2 { offset, blocks }).unwrap();
565        sent_count += 1;
566      }
567    }
568
569    let mut res = vec![OOption::<V>::default(); bp];
570
571    for _ in 0..sent_count {
572      match self.response_channel.recv().unwrap() {
573        Replyv2::Blocks { pid, offset, values } => {
574          for (val, res) in
575            values.iter().zip(res.iter_mut().skip(pid * b + offset)).take(SUBTASK_SIZE)
576          {
577            *res = *val;
578          }
579        }
580        _ => panic!("unexpected reply from worker thread (probably early termination?)"),
581      }
582    }
583    // info!("get_batch querying took {:?}", now.elapsed());
584    // let now = std::time::Instant::now();
585
586    // 5. Undo compaction and distribution of the results.
587    compact_payload(&mut res, &prefix_sum_2);
588    distribute_payload(&mut res[..n], &prefix_sum_1);
589
590    for i in 1..n {
591      let cond = prefix_sum_1[i] == prefix_sum_1[i - 1];
592      let copy = res[i - 1];
593      res[i].cmov(&copy, cond);
594    }
595
596    res.truncate(n);
597    bitonic_payload_sort(&mut index_map_1[..n], &mut res);
598
599    // info!("get_batch postprocessing took {:?}", now.elapsed());
600
601    res
602  }
603
604  /// Leaky version of `get_batch_distinct`, which will return values for repeated keys and leak the size of the largest partition.
605  /// # Safety
606  /// * This function will leak the size of the largest partition, which with repeated queries can be used to infer the mapping of keys to partitions.
607  #[deprecated(
608    note = "This function is unsafe because it can potentially leak information about keys to partition mapping. Use get_batch_distinct instead."
609  )]
610  pub unsafe fn get_batch_leaky(&mut self, keys: &[K]) -> Vec<OOption<V>> {
611    let n: usize = keys.len();
612    let mut b = 0;
613
614    // 1. Create P arrays of size N.
615    // let mut per_p: [[BatchBlock<K, V>; N]; P] =
616    //   [unsafe { std::mem::MaybeUninit::<[BatchBlock<K, V>; N]>::uninit().assume_init() }; P];
617    let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
618      std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
619
620    const INVALID_ID: usize = usize::MAX;
621
622    // 2. Map each key at index i to a partition: to p[h(keys[i])][i],
623    // UNDONE(git-65): this is O(P*n), we could do n log^2 n
624    for (i, k) in keys.iter().enumerate() {
625      let target_p = self.get_partition(k);
626      for (p, partition) in per_p.iter_mut().enumerate() {
627        partition[i].k = *k;
628        partition[i].index = i;
629        partition[i].index.cmov(&INVALID_ID, target_p != p);
630      }
631    }
632
633    // 3. Apply oblivious compaction to each partition.
634    for partition in &mut per_p {
635      let cnt = compact(partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
636      b = b.max(cnt);
637    }
638
639    let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
640
641    // 4. Read the first B values from each partition in the corresponding partition.
642    for (p, partition) in per_p.iter_mut().enumerate() {
643      let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
644      self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
645    }
646
647    // 5. Collect the first B values from each partition into the results array.
648    let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * b];
649
650    for _ in 0..P {
651      match done_rx.recv().unwrap() {
652        Reply::Blocks { pid, blocks } => {
653          for i in 0..b {
654            merged[pid * b + i] = blocks[i];
655          }
656        }
657        _ => panic!("unexpected reply from worker thread (probably early termination?)"),
658      }
659    }
660
661    // 6. Oblivious sort according to the index (we actually have P sorted arrays already, so we just need to merge them).
662    bitonic_sort(&mut merged);
663
664    // 7. Return the first n values from the results array.
665    let mut ret: Vec<OOption<V>> = vec![OOption::default(); n];
666
667    for i in 0..n {
668      ret[i] = merged[i].v;
669    }
670
671    ret
672  }
673
674  /// Inserts a batch of N distinct key-value pairs into the map, distributing them across partitions.
675  ///
676  /// # Preconditions
677  /// * No repeated keys in the input array.
678  /// * All of the inserted keys are not already present in the map.
679  /// * There is enough space in the map to insert all `N` keys.
680  /// * There are at most `b` queries to each partition.
681  pub fn insert_batch_distinct(&mut self, keys: &[K], values: &[V], b: usize) {
682    let n = keys.len();
683    assert!(n == values.len(), "Invalid input: keys and values must have the same length");
684    assert!(self.size + n <= self.capacity, "Map is full, cannot insert more elements.");
685    assert!(b <= n, "batch size b must be <= number of keys");
686
687    // 1. Create P arrays of size N.
688    let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
689      std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
690
691    const INVALID_ID: usize = usize::MAX;
692
693    // 2. Map each key at index i to a partition: to p[h(keys[i])][i],
694    // UNDONE(git-65): this is O(P*N), we could do N log^2 N
695    for (i, k) in keys.iter().enumerate() {
696      let target_p = self.get_partition(k);
697      for (p, partition) in per_p.iter_mut().enumerate() {
698        partition[i].k = *k;
699        partition[i].v = OOption::new(values[i], true);
700        partition[i].index = i;
701        partition[i].index.cmov(&INVALID_ID, target_p != p);
702      }
703    }
704
705    // 3. Apply oblivious compaction to each partition.
706    for partition in &mut per_p {
707      let cnt = compact(partition, |x| x.index == INVALID_ID);
708
709      // UNDONE(git-64): deal with overflow.
710      assert!(cnt <= b);
711    }
712
713    let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
714
715    // 4. Insert the first b values from each partition in the corresponding partition.
716    for (p, partition) in per_p.iter_mut().enumerate() {
717      let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
718      self.workers[p].tx.send(Cmd::Insert { blocks, ret_tx: done_tx.clone() }).unwrap();
719    }
720
721    // 5. Receive the write receipts from the worker threads.
722    for _i in 0..P {
723      match done_rx.recv().unwrap() {
724        Reply::Unit(()) => {}
725        _ => {
726          panic!("unexpected reply from worker thread (probably early termination?)");
727        }
728      }
729    }
730
731    // 6. Update the size of the map.
732    self.size += n;
733  }
734
735  /// Reads N values from the map, leaking only `N` and `B`, but not any information about the keys (doesn't leak the number of keys to each partition).
736  /// # Preconditions
737  /// * There are at most `b` queries to each partition (statistically likely).
738  /// # Behavior
739  /// * If a key appears multiple times in the input array, only the value corresponding to its first occurrence is used.
740  pub fn insert_batch(&mut self, keys: &[K], values: &[V], b: usize) {
741    let n: usize = keys.len();
742    assert!(n > 0, "get_batch requires at least one key");
743    assert!(b <= n, "batch size b must be <= number of keys");
744    let bp = b * P;
745    assert!(b >= n.div_ceil(P), "batch size b must be >= n/P to avoid overflow");
746
747    // 1. Sort the keys by partition.
748    let mut keyinfo =
749      vec![KeyWithPartValue { partition: P, key: K::default(), value: V::default() }; bp];
750    for i in 0..n {
751      keyinfo[i].key = keys[i];
752      keyinfo[i].value = values[i];
753      keyinfo[i].partition = self.get_partition(&keys[i]);
754    }
755
756    let mut index_map_1 = (0..n).collect::<Vec<usize>>();
757    bitonic_payload_sort::<KeyWithPartValue<K, V>, [KeyWithPartValue<K, V>], usize>(
758      &mut keyinfo[..n],
759      &mut index_map_1,
760    );
761
762    // 2. Compute unique keys for each partition and remove duplicates to the end.
763    let mut par_load = [0; P];
764    let mut prefix_sum_1 = vec![0; n + 1];
765
766    prefix_sum_1[1] = 1;
767    for (j, load) in par_load.iter_mut().enumerate() {
768      let cond = keyinfo[0].partition == j;
769      load.cmov(&1, cond);
770    }
771    for i in 1..n {
772      let new_key = keyinfo[i].key != keyinfo[i - 1].key;
773      prefix_sum_1[i + 1] = prefix_sum_1[i];
774
775      let alt = prefix_sum_1[i] + 1;
776      prefix_sum_1[i + 1].cmov(&alt, new_key);
777      keyinfo[i].partition.cmov(&P, !new_key); // Mark duplicate keys as belonging to an invalid partition
778
779      for (j, load) in par_load.iter_mut().enumerate() {
780        let cond = keyinfo[i].partition == j;
781        let alt = *load + 1;
782        load.cmov(&alt, cond & new_key);
783      }
784    }
785    // for i in n..(np + 1) {
786    //   prefix_sum_1[i] = prefix_sum_1[n];
787    // }
788
789    for (j, load) in par_load.iter().enumerate() {
790      assert!(*load <= b, "Too many distinct keys in partition {j}: {}, increase b", *load);
791    }
792
793    compact_payload(&mut keyinfo[..n], &prefix_sum_1);
794
795    // 3. Create a distribution of the unique keys to partitions.
796    let mut par_load_ps = [0; P + 1];
797    for j in 0..P {
798      par_load_ps[j + 1] = par_load_ps[j] + par_load[j];
799    }
800    self.size += par_load_ps[P];
801
802    let mut prefix_sum_2 = vec![0; bp + 1];
803    for j in 0..P {
804      for i in 0..b {
805        let mut rank_in_part = i + 1;
806        rank_in_part.cmov(&par_load[j], rank_in_part > par_load[j]);
807        prefix_sum_2[j * b + i + 1] = par_load_ps[j] + rank_in_part;
808      }
809    }
810    distribute_payload(&mut keyinfo, &prefix_sum_2);
811
812    // 4. Read the first B values from each partition in the corresponding partition.
813    const SUBTASK_SIZE: usize = 32;
814    let mut sent_count = 0;
815    for j in 0..P {
816      for k in 0..b.div_ceil(SUBTASK_SIZE) {
817        let low = j * b + k * SUBTASK_SIZE;
818        let high = (low + SUBTASK_SIZE).min((j + 1) * b);
819        let blocks: Vec<KeyWithPartValue<K, V>> = keyinfo[low..high].to_vec();
820        self.workers[j].tx.send(Cmd::Insertv2 { blocks }).unwrap();
821        sent_count += 1;
822      }
823    }
824
825    for _ in 0..sent_count {
826      match self.response_channel.recv().unwrap() {
827        Replyv2::Unit(()) => {}
828        _ => panic!("unexpected reply from worker thread (probably early termination?)"),
829      }
830    }
831  }
832}
833
834#[cfg(test)]
835mod tests {
836  use super::*;
837
838  // For all the tests below we keep b == N so that
839  // the per‑partition overflow assert! in the map never fires.
840  const N: usize = 4;
841
842  #[test]
843  fn new_map_rounds_capacity_and_starts_empty() {
844    let requested = 100;
845    let map: ShardedMap<u64, u64> = ShardedMap::new(requested);
846
847    // Inside the same module we can see private fields.
848    let per_part = requested.div_ceil(P);
849    assert_eq!(map.capacity, per_part * P); // rounded up
850    assert_eq!(map.size, 0);
851  }
852
853  #[test]
854  fn insert_batch_then_get_batch_returns_expected_values() {
855    let mut map: ShardedMap<u64, u64> = ShardedMap::new(32);
856
857    let keys: [u64; N] = [1, 2, 3, 4];
858    let values: [u64; N] = [10, 20, 30, 40];
859
860    map.insert_batch_distinct(&keys, &values, N);
861
862    let results = map.get_batch_distinct(&keys, N);
863    for i in 0..N {
864      assert!(results[i].is_some(), "key {} missing", keys[i]);
865      assert_eq!(results[i].unwrap(), values[i]);
866    }
867
868    #[allow(deprecated)]
869    {
870      let results = unsafe { map.get_batch_leaky(&keys) };
871      for i in 0..N {
872        assert!(results[i].is_some(), "key {} missing", keys[i]);
873        assert_eq!(results[i].unwrap(), values[i]);
874      }
875    }
876
877    let results = map.get_batch(&keys, N);
878    for i in 0..N {
879      assert!(results[i].is_some(), "key {} missing", keys[i]);
880      assert_eq!(results[i].unwrap(), values[i]);
881    }
882  }
883
884  #[test]
885  fn querying_absent_keys_returns_none() {
886    let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
887
888    let absent: [u64; N] = [100, 200, 300, 400];
889    let results = map.get_batch_distinct(&absent, N);
890
891    for r in &results {
892      assert!(!r.is_some());
893    }
894
895    #[allow(deprecated)]
896    {
897      let results = unsafe { map.get_batch_leaky(&absent) };
898      for r in &results {
899        assert!(!r.is_some());
900      }
901    }
902
903    let absent: [u64; N] = [100, 200, 300, 400];
904    let results = map.get_batch(&absent, N);
905
906    for r in &results {
907      assert!(!r.is_some());
908    }
909  }
910
911  #[test]
912  fn size_updates_after_insert() {
913    let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
914
915    let keys: [u64; N] = [11, 22, 33, 44];
916    let values: [u64; N] = [111, 222, 333, 444];
917
918    map.insert_batch_distinct(&keys, &values, N);
919    assert_eq!(map.size, N);
920
921    let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
922    map.insert_batch(&keys, &values, N);
923    assert_eq!(map.size, N);
924  }
925
926  #[test]
927  fn compute_safe_batch_size_works() {
928    let map: ShardedMap<u64, u64> = ShardedMap::new(16);
929
930    // N >= P log P
931    assert_eq!(P, 15);
932
933    let n = 100;
934    let b = map.compute_safe_batch_size(n);
935    assert!(b >= n.div_ceil(P));
936    assert_eq!(b, 32);
937
938    let n = 1000;
939    let b = map.compute_safe_batch_size(n);
940    assert!(b >= n.div_ceil(P));
941    assert_eq!(b, 103);
942
943    let n = 4096;
944    let b = map.compute_safe_batch_size(n);
945    assert!(b >= n.div_ceil(P));
946    assert_eq!(b, 327);
947
948    let n = 8192;
949    let b = map.compute_safe_batch_size(n);
950    assert!(b >= n.div_ceil(P));
951    assert_eq!(b, 613);
952
953    for i in 1..100 {
954      for j in 1..100 {
955        let n = i * j * P;
956        let b = map.compute_safe_batch_size(n);
957        assert!(b >= n.div_ceil(P));
958        assert!(b <= n);
959      }
960    }
961  }
962}