1use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rostl_primitives::{
6 cmov_body, cxchg_body, impl_cmov_for_generic_pod,
7 ooption::OOption,
8 traits::{Cmov, _Cmovbase},
9};
10use rostl_sort::{
11 bitonic::{bitonic_payload_sort, bitonic_sort},
12 compaction::{compact, compact_payload, distribute_payload},
13};
14
15use crate::map::{OHash, UnsortedMap};
16use kanal::{bounded, unbounded, Receiver, Sender};
17use std::{
19 io,
20 sync::{Arc, Barrier},
21 thread,
22};
23const P: usize = 15;
27
28enum Reply<K, V>
30where
31 K: OHash + Pod + Default + std::fmt::Debug + Ord,
32 V: Cmov + Pod + Default + std::fmt::Debug,
33 BatchBlock<K, V>: Ord + Send,
34{
35 Blocks { pid: usize, blocks: Vec<BatchBlock<K, V>> },
36 Unit(()),
37}
38
39enum Replyv2<V>
40where
41 V: Cmov + Pod + Default + std::fmt::Debug,
42{
43 Blocks { pid: usize, offset: usize, values: Vec<OOption<V>> },
44 Unit(()),
45}
46
47enum Cmd<K, V>
50where
51 K: OHash + Pod + Default + std::fmt::Debug + Ord,
52 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
53 BatchBlock<K, V>: Ord + Send,
54{
55 Get {
57 blocks: Vec<BatchBlock<K, V>>,
58 ret_tx: Sender<Reply<K, V>>,
59 },
60 Insert {
62 blocks: Vec<BatchBlock<K, V>>,
63 ret_tx: Sender<Reply<K, V>>,
64 },
65 Getv2 {
66 offset: usize,
67 blocks: Vec<K>,
68 },
69 Insertv2 {
70 blocks: Vec<KeyWithPartValue<K, V>>,
71 },
72 Shutdown,
74}
75
76#[derive(Debug)]
79struct Worker<K, V>
80where
81 K: OHash + Pod + Default + std::fmt::Debug + Ord,
82 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
83 BatchBlock<K, V>: Ord + Send,
84{
85 tx: Sender<Cmd<K, V>>,
86 join_handle: Option<thread::JoinHandle<()>>,
87}
88
89#[allow(unused)]
90fn pin_current_thread_to(cpu: usize) -> io::Result<()> {
91 unsafe {
92 let mut set: libc::cpu_set_t = std::mem::zeroed();
93 libc::CPU_ZERO(&mut set);
94 libc::CPU_SET(cpu, &mut set);
95 let ret = libc::pthread_setaffinity_np(
96 libc::pthread_self(),
97 std::mem::size_of::<libc::cpu_set_t>(),
98 &raw const set,
99 );
100 if ret != 0 {
101 return Err(io::Error::from_raw_os_error(ret));
102 }
103 }
104 Ok(())
105}
106
107fn set_current_thread_rt(priority: i32) -> io::Result<()> {
108 unsafe {
109 let ret = libc::setpriority(libc::PRIO_PROCESS, 0, priority);
111 if ret != 0 {
112 eprintln!("setpriority failed: {}", std::io::Error::last_os_error());
113 }
114 }
115 Ok(())
116}
117
118impl<K, V> Worker<K, V>
119where
120 K: OHash + Pod + Default + std::fmt::Debug + Ord + Send,
121 V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
122 BatchBlock<K, V>: Ord + Send,
123{
124 fn new(
127 n: usize,
128 pid: usize,
129 startup_barrier: Arc<Barrier>,
130 reply_channel: Sender<Replyv2<V>>,
131 ) -> Self {
132 let (tx, rx): (Sender<Cmd<_, _>>, Receiver<_>) = unbounded();
133
134 let handler = thread::Builder::new()
135 .name(format!("partition-{pid}"))
136 .spawn(move || {
137 set_current_thread_rt(0).expect("failed to set thread to real-time priority");
140
141 startup_barrier.wait();
143
144 let mut map = UnsortedMap::<K, V>::new(n);
148
149 loop {
150 let cmd = match rx.recv() {
151 Ok(cmd) => cmd,
152 Err(_) => {
153 panic!("worker thread command channel disconnected unexpectedly");
154 }
155 };
156
157 match cmd {
158 Cmd::Get { mut blocks, ret_tx } => {
159 for blk in &mut blocks {
160 blk.v = OOption::new(Default::default(), true);
161 blk.v.is_some = map.get(blk.k, &mut blk.v.value);
162 }
163 let _ = ret_tx.send(Reply::Blocks { pid, blocks }); }
165 Cmd::Insert { blocks, ret_tx } => {
166 for blk in &blocks {
167 map.insert_cond(blk.k, blk.v.value, blk.v.is_some);
168 }
169 let _ = ret_tx.send(Reply::Unit(()));
170 }
171 Cmd::Getv2 { offset, blocks } => {
172 let mut values = vec![OOption::<V>::default(); blocks.len()];
173 for (i, k) in blocks.iter().enumerate() {
174 values[i].is_some = map.get(*k, &mut values[i].value);
175 }
176 let _ = reply_channel.send(Replyv2::Blocks { pid, offset, values });
177 }
179 Cmd::Insertv2 { blocks } => {
180 for blk in &blocks {
181 let real = blk.partition == pid;
182 map.insert_cond(blk.key, blk.value, real);
183 }
184 let _ = reply_channel.send(Replyv2::Unit(()));
185 }
186 Cmd::Shutdown => break,
187 }
188 }
189 })
190 .expect("failed to spawn worker thread");
191
192 Self { tx, join_handle: Some(handler) }
193 }
194}
195
196impl<K, V> Drop for Worker<K, V>
197where
198 K: OHash + Pod + Default + std::fmt::Debug + Ord,
199 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
200 BatchBlock<K, V>: Ord + Send,
201{
202 fn drop(&mut self) {
203 let _ = self.tx.send(Cmd::Shutdown);
205 match self.join_handle.take() {
207 Some(handle) => {
208 let _ = handle.join();
209 }
210 None => {
211 panic!("Exception while dropping worker thread, handler was already taken");
212 }
213 }
214 }
215}
216
217#[derive(Debug)]
226pub struct ShardedMap<K, V>
227where
228 K: OHash + Pod + Default + std::fmt::Debug + Send + Ord,
229 V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
230 BatchBlock<K, V>: Ord + Send,
231{
232 size: usize,
234 capacity: usize,
236 workers: [Worker<K, V>; P],
238 random_state: RandomState,
240 response_channel: Receiver<Replyv2<V>>,
242}
243
244#[repr(C)]
246#[derive(Default, Debug, Clone, Copy, Zeroable, PartialEq, Eq, PartialOrd, Ord)]
247pub struct BatchBlock<K, V>
248where
249 K: OHash + Pod + Default + std::fmt::Debug + Ord,
250 V: Cmov + Pod + Default + std::fmt::Debug,
251{
252 index: usize,
253 k: K,
254 v: OOption<V>,
255}
256unsafe impl<K, V> Pod for BatchBlock<K, V>
257where
258 K: OHash + Pod + Default + std::fmt::Debug + Ord,
259 V: Cmov + Pod + Default + std::fmt::Debug,
260{
261}
262impl_cmov_for_generic_pod!(BatchBlock<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord, V: Cmov + Pod + Default + std::fmt::Debug);
263
264#[repr(C)]
265#[derive(Debug, Clone, Copy, Zeroable, PartialEq, Eq)]
266struct KeyWithPart<K>
267where
268 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
269{
270 partition: usize,
271 key: K,
272}
273unsafe impl<K> Pod for KeyWithPart<K> where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized {}
274impl_cmov_for_generic_pod!(KeyWithPart<K>; where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized);
275
276impl<K> KeyWithPart<K>
277where
278 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
279{
280 fn cmp_ct(&self, other: &Self) -> std::cmp::Ordering {
281 let part = self.partition.cmp(&other.partition) as i8;
282 let key = self.key.cmp(&other.key) as i8;
283
284 let mut res = part;
285 res.cmov(&key, part == 0);
286
287 res.cmp(&0)
288 }
289}
290
291impl<K> PartialOrd for KeyWithPart<K>
292where
293 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
294{
295 #[allow(clippy::non_canonical_partial_ord_impl)]
296 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
297 Some(self.cmp_ct(other))
298 }
299}
300
301impl<K> Ord for KeyWithPart<K>
302where
303 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
304{
305 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
306 self.cmp_ct(other)
307 }
308}
309
310#[repr(C)]
311#[derive(Debug, Clone, Copy, Zeroable, PartialEq, Eq)]
312struct KeyWithPartValue<K, V>
313where
314 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
315 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
316{
317 partition: usize,
318 key: K,
319 value: V,
320}
321unsafe impl<K, V> Pod for KeyWithPartValue<K, V>
322where
323 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
324 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
325{
326}
327impl_cmov_for_generic_pod!(KeyWithPartValue<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized, V: Cmov + Pod + Default + std::fmt::Debug +Eq);
328
329impl<K, V> KeyWithPartValue<K, V>
330where
331 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
332 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
333{
334 fn cmp_ct(&self, other: &Self) -> std::cmp::Ordering {
335 let part = self.partition.cmp(&other.partition) as i8;
336 let key = self.key.cmp(&other.key) as i8;
337
338 let mut res = part;
339 res.cmov(&key, part == 0);
340
341 res.cmp(&0)
342 }
343}
344
345impl<K, V> PartialOrd for KeyWithPartValue<K, V>
346where
347 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
348 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
349{
350 #[allow(clippy::non_canonical_partial_ord_impl)]
351 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
352 Some(self.cmp_ct(other))
353 }
354}
355
356impl<K, V> Ord for KeyWithPartValue<K, V>
357where
358 K: OHash + Pod + Default + std::fmt::Debug + Ord + Sized,
359 V: Cmov + Pod + Default + std::fmt::Debug + Eq,
360{
361 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
362 self.cmp_ct(other)
363 }
364}
365
366impl<K, V> ShardedMap<K, V>
367where
368 K: OHash + Default + std::fmt::Debug + Send + Ord + Pod + Sized,
369 V: Cmov + Pod + Default + std::fmt::Debug + Send + Eq,
370 BatchBlock<K, V>: Ord + Send,
371{
372 pub fn new(capacity: usize) -> Self {
374 let per_part = capacity.div_ceil(P);
375 let startup = Arc::new(Barrier::new(P + 1));
376
377 let (reply_tx, response_channel) = unbounded::<Replyv2<V>>();
378
379 let workers =
380 std::array::from_fn(|i| Worker::new(per_part, i, startup.clone(), reply_tx.clone()));
381
382 startup.wait();
384
385 Self {
386 size: 0,
387 capacity: per_part * P,
388 workers,
389 random_state: RandomState::new(),
390 response_channel,
391 }
392 }
393
394 #[inline(always)]
395 fn get_partition(&self, key: &K) -> usize {
396 (self.random_state.hash_one(key) % P as u64) as usize
397 }
398
399 pub const fn compute_safe_batch_size(&self, n: usize) -> usize {
403 let a = n.div_ceil(P) + (n * (P.ilog2() as usize + 1)).div_ceil(P).isqrt() + 20; if a < n {
405 a
406 } else {
407 n
408 }
409 }
410
411 pub fn get_batch_distinct(&mut self, keys: &[K], b: usize) -> Vec<OOption<V>> {
416 let n: usize = keys.len();
417 assert!(b <= n, "batch size b must be <= number of keys");
418
419 let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
423 std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
424
425 const INVALID_ID: usize = usize::MAX;
426
427 for (i, k) in keys.iter().enumerate() {
430 let target_p = self.get_partition(k);
431 for (p, partition) in per_p.iter_mut().enumerate() {
432 partition[i].k = *k;
433 partition[i].index = i;
434 partition[i].index.cmov(&INVALID_ID, target_p != p);
435 }
436 }
437
438 for partition in &mut per_p {
440 let cnt = compact(partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
441 assert!(cnt <= b);
443 }
444
445 let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
446
447 for (p, partition) in per_p.iter_mut().enumerate() {
449 let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
450 self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
451 }
452
453 let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * b];
455
456 for _ in 0..P {
457 match done_rx.recv().unwrap() {
458 Reply::Blocks { pid, blocks } => {
459 for i in 0..b {
460 merged[pid * b + i] = blocks[i];
461 }
462 }
463 _ => panic!("unexpected reply from worker thread (probably early termination?)"),
464 }
465 }
466
467 bitonic_sort(&mut merged);
469
470 let mut ret: Vec<OOption<V>> = vec![OOption::default(); n];
472
473 for i in 0..n {
474 ret[i] = merged[i].v;
475 }
476
477 ret
478 }
479
480 pub fn get_batch(&mut self, keys: &[K], b: usize) -> Vec<OOption<V>> {
484 let n: usize = keys.len();
487 assert!(n > 0, "get_batch requires at least one key");
488 assert!(b <= n, "batch size b must be <= number of keys");
489 let bp = b * P;
490 assert!(b >= n.div_ceil(P), "batch size b must be >= n/P to avoid overflow");
491 const SUBTASK_SIZE: usize = 32;
492
493 let mut keyinfo = vec![KeyWithPart { partition: P, key: K::default() }; bp];
495 for i in 0..n {
496 keyinfo[i].key = keys[i];
497 keyinfo[i].partition = self.get_partition(&keys[i]);
498 }
499
500 let mut index_map_1 = (0..n).collect::<Vec<usize>>();
501 bitonic_payload_sort::<KeyWithPart<K>, [KeyWithPart<K>], usize>(
502 &mut keyinfo[..n],
503 &mut index_map_1,
504 );
505
506 let mut par_load = [0; P];
508 let mut prefix_sum_1 = vec![0; n + 1];
509
510 prefix_sum_1[1] = 1;
511 for (j, load) in par_load.iter_mut().enumerate() {
512 let cond = keyinfo[0].partition == j;
513 load.cmov(&1, cond);
514 }
515 for i in 1..n {
516 let new_key = keyinfo[i].key != keyinfo[i - 1].key;
517 prefix_sum_1[i + 1] = prefix_sum_1[i];
518
519 let alt = prefix_sum_1[i] + 1;
520 prefix_sum_1[i + 1].cmov(&alt, new_key);
521
522 for (j, load) in par_load.iter_mut().enumerate() {
523 let cond = keyinfo[i].partition == j;
524 let alt = *load + 1;
525 load.cmov(&alt, cond & new_key);
526 }
527 }
528 for (j, load) in par_load.iter().enumerate() {
533 assert!(*load <= b, "Too many distinct keys in partition {j}: {}, increase b", *load);
534 }
535
536 compact_payload(&mut keyinfo[..n], &prefix_sum_1);
537
538 let mut par_load_ps = [0; P + 1];
540 for j in 0..P {
541 par_load_ps[j + 1] = par_load_ps[j] + par_load[j];
542 }
543 let mut prefix_sum_2 = vec![0; bp + 1];
544 for j in 0..P {
545 for i in 0..b {
546 let mut rank_in_part = i + 1;
547 rank_in_part.cmov(&par_load[j], rank_in_part > par_load[j]);
548 prefix_sum_2[j * b + i + 1] = par_load_ps[j] + rank_in_part;
549 }
550 }
551 distribute_payload(&mut keyinfo, &prefix_sum_2);
552
553 let mut sent_count = 0;
557 for j in 0..P {
559 for k in 0..b.div_ceil(SUBTASK_SIZE) {
560 let offset = k * SUBTASK_SIZE;
561 let low = j * b + offset;
562 let high = (low + SUBTASK_SIZE).min((j + 1) * b);
563 let blocks: Vec<K> = keyinfo[low..high].iter().map(|x| x.key).collect();
564 self.workers[j].tx.send(Cmd::Getv2 { offset, blocks }).unwrap();
565 sent_count += 1;
566 }
567 }
568
569 let mut res = vec![OOption::<V>::default(); bp];
570
571 for _ in 0..sent_count {
572 match self.response_channel.recv().unwrap() {
573 Replyv2::Blocks { pid, offset, values } => {
574 for (val, res) in
575 values.iter().zip(res.iter_mut().skip(pid * b + offset)).take(SUBTASK_SIZE)
576 {
577 *res = *val;
578 }
579 }
580 _ => panic!("unexpected reply from worker thread (probably early termination?)"),
581 }
582 }
583 compact_payload(&mut res, &prefix_sum_2);
588 distribute_payload(&mut res[..n], &prefix_sum_1);
589
590 for i in 1..n {
591 let cond = prefix_sum_1[i] == prefix_sum_1[i - 1];
592 let copy = res[i - 1];
593 res[i].cmov(©, cond);
594 }
595
596 res.truncate(n);
597 bitonic_payload_sort(&mut index_map_1[..n], &mut res);
598
599 res
602 }
603
604 #[deprecated(
608 note = "This function is unsafe because it can potentially leak information about keys to partition mapping. Use get_batch_distinct instead."
609 )]
610 pub unsafe fn get_batch_leaky(&mut self, keys: &[K]) -> Vec<OOption<V>> {
611 let n: usize = keys.len();
612 let mut b = 0;
613
614 let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
618 std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
619
620 const INVALID_ID: usize = usize::MAX;
621
622 for (i, k) in keys.iter().enumerate() {
625 let target_p = self.get_partition(k);
626 for (p, partition) in per_p.iter_mut().enumerate() {
627 partition[i].k = *k;
628 partition[i].index = i;
629 partition[i].index.cmov(&INVALID_ID, target_p != p);
630 }
631 }
632
633 for partition in &mut per_p {
635 let cnt = compact(partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
636 b = b.max(cnt);
637 }
638
639 let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
640
641 for (p, partition) in per_p.iter_mut().enumerate() {
643 let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
644 self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
645 }
646
647 let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * b];
649
650 for _ in 0..P {
651 match done_rx.recv().unwrap() {
652 Reply::Blocks { pid, blocks } => {
653 for i in 0..b {
654 merged[pid * b + i] = blocks[i];
655 }
656 }
657 _ => panic!("unexpected reply from worker thread (probably early termination?)"),
658 }
659 }
660
661 bitonic_sort(&mut merged);
663
664 let mut ret: Vec<OOption<V>> = vec![OOption::default(); n];
666
667 for i in 0..n {
668 ret[i] = merged[i].v;
669 }
670
671 ret
672 }
673
674 pub fn insert_batch_distinct(&mut self, keys: &[K], values: &[V], b: usize) {
682 let n = keys.len();
683 assert!(n == values.len(), "Invalid input: keys and values must have the same length");
684 assert!(self.size + n <= self.capacity, "Map is full, cannot insert more elements.");
685 assert!(b <= n, "batch size b must be <= number of keys");
686
687 let mut per_p: [Box<Vec<BatchBlock<K, V>>>; P] =
689 std::array::from_fn(|_| Box::new(vec![BatchBlock::default(); n]));
690
691 const INVALID_ID: usize = usize::MAX;
692
693 for (i, k) in keys.iter().enumerate() {
696 let target_p = self.get_partition(k);
697 for (p, partition) in per_p.iter_mut().enumerate() {
698 partition[i].k = *k;
699 partition[i].v = OOption::new(values[i], true);
700 partition[i].index = i;
701 partition[i].index.cmov(&INVALID_ID, target_p != p);
702 }
703 }
704
705 for partition in &mut per_p {
707 let cnt = compact(partition, |x| x.index == INVALID_ID);
708
709 assert!(cnt <= b);
711 }
712
713 let (done_tx, done_rx) = bounded::<Reply<K, V>>(P);
714
715 for (p, partition) in per_p.iter_mut().enumerate() {
717 let blocks: Vec<BatchBlock<K, V>> = partition[..b].to_vec();
718 self.workers[p].tx.send(Cmd::Insert { blocks, ret_tx: done_tx.clone() }).unwrap();
719 }
720
721 for _i in 0..P {
723 match done_rx.recv().unwrap() {
724 Reply::Unit(()) => {}
725 _ => {
726 panic!("unexpected reply from worker thread (probably early termination?)");
727 }
728 }
729 }
730
731 self.size += n;
733 }
734
735 pub fn insert_batch(&mut self, keys: &[K], values: &[V], b: usize) {
741 let n: usize = keys.len();
742 assert!(n > 0, "get_batch requires at least one key");
743 assert!(b <= n, "batch size b must be <= number of keys");
744 let bp = b * P;
745 assert!(b >= n.div_ceil(P), "batch size b must be >= n/P to avoid overflow");
746
747 let mut keyinfo =
749 vec![KeyWithPartValue { partition: P, key: K::default(), value: V::default() }; bp];
750 for i in 0..n {
751 keyinfo[i].key = keys[i];
752 keyinfo[i].value = values[i];
753 keyinfo[i].partition = self.get_partition(&keys[i]);
754 }
755
756 let mut index_map_1 = (0..n).collect::<Vec<usize>>();
757 bitonic_payload_sort::<KeyWithPartValue<K, V>, [KeyWithPartValue<K, V>], usize>(
758 &mut keyinfo[..n],
759 &mut index_map_1,
760 );
761
762 let mut par_load = [0; P];
764 let mut prefix_sum_1 = vec![0; n + 1];
765
766 prefix_sum_1[1] = 1;
767 for (j, load) in par_load.iter_mut().enumerate() {
768 let cond = keyinfo[0].partition == j;
769 load.cmov(&1, cond);
770 }
771 for i in 1..n {
772 let new_key = keyinfo[i].key != keyinfo[i - 1].key;
773 prefix_sum_1[i + 1] = prefix_sum_1[i];
774
775 let alt = prefix_sum_1[i] + 1;
776 prefix_sum_1[i + 1].cmov(&alt, new_key);
777 keyinfo[i].partition.cmov(&P, !new_key); for (j, load) in par_load.iter_mut().enumerate() {
780 let cond = keyinfo[i].partition == j;
781 let alt = *load + 1;
782 load.cmov(&alt, cond & new_key);
783 }
784 }
785 for (j, load) in par_load.iter().enumerate() {
790 assert!(*load <= b, "Too many distinct keys in partition {j}: {}, increase b", *load);
791 }
792
793 compact_payload(&mut keyinfo[..n], &prefix_sum_1);
794
795 let mut par_load_ps = [0; P + 1];
797 for j in 0..P {
798 par_load_ps[j + 1] = par_load_ps[j] + par_load[j];
799 }
800 self.size += par_load_ps[P];
801
802 let mut prefix_sum_2 = vec![0; bp + 1];
803 for j in 0..P {
804 for i in 0..b {
805 let mut rank_in_part = i + 1;
806 rank_in_part.cmov(&par_load[j], rank_in_part > par_load[j]);
807 prefix_sum_2[j * b + i + 1] = par_load_ps[j] + rank_in_part;
808 }
809 }
810 distribute_payload(&mut keyinfo, &prefix_sum_2);
811
812 const SUBTASK_SIZE: usize = 32;
814 let mut sent_count = 0;
815 for j in 0..P {
816 for k in 0..b.div_ceil(SUBTASK_SIZE) {
817 let low = j * b + k * SUBTASK_SIZE;
818 let high = (low + SUBTASK_SIZE).min((j + 1) * b);
819 let blocks: Vec<KeyWithPartValue<K, V>> = keyinfo[low..high].to_vec();
820 self.workers[j].tx.send(Cmd::Insertv2 { blocks }).unwrap();
821 sent_count += 1;
822 }
823 }
824
825 for _ in 0..sent_count {
826 match self.response_channel.recv().unwrap() {
827 Replyv2::Unit(()) => {}
828 _ => panic!("unexpected reply from worker thread (probably early termination?)"),
829 }
830 }
831 }
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 const N: usize = 4;
841
842 #[test]
843 fn new_map_rounds_capacity_and_starts_empty() {
844 let requested = 100;
845 let map: ShardedMap<u64, u64> = ShardedMap::new(requested);
846
847 let per_part = requested.div_ceil(P);
849 assert_eq!(map.capacity, per_part * P); assert_eq!(map.size, 0);
851 }
852
853 #[test]
854 fn insert_batch_then_get_batch_returns_expected_values() {
855 let mut map: ShardedMap<u64, u64> = ShardedMap::new(32);
856
857 let keys: [u64; N] = [1, 2, 3, 4];
858 let values: [u64; N] = [10, 20, 30, 40];
859
860 map.insert_batch_distinct(&keys, &values, N);
861
862 let results = map.get_batch_distinct(&keys, N);
863 for i in 0..N {
864 assert!(results[i].is_some(), "key {} missing", keys[i]);
865 assert_eq!(results[i].unwrap(), values[i]);
866 }
867
868 #[allow(deprecated)]
869 {
870 let results = unsafe { map.get_batch_leaky(&keys) };
871 for i in 0..N {
872 assert!(results[i].is_some(), "key {} missing", keys[i]);
873 assert_eq!(results[i].unwrap(), values[i]);
874 }
875 }
876
877 let results = map.get_batch(&keys, N);
878 for i in 0..N {
879 assert!(results[i].is_some(), "key {} missing", keys[i]);
880 assert_eq!(results[i].unwrap(), values[i]);
881 }
882 }
883
884 #[test]
885 fn querying_absent_keys_returns_none() {
886 let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
887
888 let absent: [u64; N] = [100, 200, 300, 400];
889 let results = map.get_batch_distinct(&absent, N);
890
891 for r in &results {
892 assert!(!r.is_some());
893 }
894
895 #[allow(deprecated)]
896 {
897 let results = unsafe { map.get_batch_leaky(&absent) };
898 for r in &results {
899 assert!(!r.is_some());
900 }
901 }
902
903 let absent: [u64; N] = [100, 200, 300, 400];
904 let results = map.get_batch(&absent, N);
905
906 for r in &results {
907 assert!(!r.is_some());
908 }
909 }
910
911 #[test]
912 fn size_updates_after_insert() {
913 let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
914
915 let keys: [u64; N] = [11, 22, 33, 44];
916 let values: [u64; N] = [111, 222, 333, 444];
917
918 map.insert_batch_distinct(&keys, &values, N);
919 assert_eq!(map.size, N);
920
921 let mut map: ShardedMap<u64, u64> = ShardedMap::new(16);
922 map.insert_batch(&keys, &values, N);
923 assert_eq!(map.size, N);
924 }
925
926 #[test]
927 fn compute_safe_batch_size_works() {
928 let map: ShardedMap<u64, u64> = ShardedMap::new(16);
929
930 assert_eq!(P, 15);
932
933 let n = 100;
934 let b = map.compute_safe_batch_size(n);
935 assert!(b >= n.div_ceil(P));
936 assert_eq!(b, 32);
937
938 let n = 1000;
939 let b = map.compute_safe_batch_size(n);
940 assert!(b >= n.div_ceil(P));
941 assert_eq!(b, 103);
942
943 let n = 4096;
944 let b = map.compute_safe_batch_size(n);
945 assert!(b >= n.div_ceil(P));
946 assert_eq!(b, 327);
947
948 let n = 8192;
949 let b = map.compute_safe_batch_size(n);
950 assert!(b >= n.div_ceil(P));
951 assert_eq!(b, 613);
952
953 for i in 1..100 {
954 for j in 1..100 {
955 let n = i * j * P;
956 let b = map.compute_safe_batch_size(n);
957 assert!(b >= n.div_ceil(P));
958 assert!(b <= n);
959 }
960 }
961 }
962}