1use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rostl_primitives::{
6 cmov_body, cxchg_body, impl_cmov_for_generic_pod,
7 ooption::OOption,
8 traits::{Cmov, _Cmovbase},
9};
10use rostl_sort::{bitonic::bitonic_sort, compaction::compact};
11
12use crate::map::{OHash, UnsortedMap};
13use kanal::{bounded, Receiver, Sender};
14use std::{
15 sync::{Arc, Barrier},
16 thread,
17};
18
19const P: usize = 15;
21
22enum Reply<K, V, const B: usize>
24where
25 K: OHash + Pod + Default + std::fmt::Debug + Ord,
26 V: Cmov + Pod + Default + std::fmt::Debug,
27 BatchBlock<K, V>: Ord + Send,
28{
29 Blocks { pid: usize, blocks: Box<[BatchBlock<K, V>; B]> },
30 Unit(()),
31}
32
33enum Cmd<K, V, const B: usize>
36where
37 K: OHash + Pod + Default + std::fmt::Debug + Ord,
38 V: Cmov + Pod + Default + std::fmt::Debug,
39 BatchBlock<K, V>: Ord + Send,
40{
41 Get {
43 blocks: Box<[BatchBlock<K, V>; B]>,
44 ret_tx: Sender<Reply<K, V, B>>,
45 },
46 Insert {
48 blocks: Box<[BatchBlock<K, V>; B]>,
49 ret_tx: Sender<Reply<K, V, B>>,
50 },
51 Shutdown,
53}
54
55#[derive(Debug)]
58struct Worker<K, V, const B: usize>
59where
60 K: OHash + Pod + Default + std::fmt::Debug + Ord,
61 V: Cmov + Pod + Default + std::fmt::Debug,
62 BatchBlock<K, V>: Ord + Send,
63{
64 tx: Sender<Cmd<K, V, B>>,
65 join_handle: Option<thread::JoinHandle<()>>,
66}
67
68impl<K, V, const B: usize> Worker<K, V, B>
69where
70 K: OHash + Pod + Default + std::fmt::Debug + Ord + Send,
71 V: Cmov + Pod + Default + std::fmt::Debug + Send,
72 BatchBlock<K, V>: Ord + Send,
73{
74 fn new(n: usize, pid: usize, startup_barrier: Arc<Barrier>) -> Self {
77 let (tx, rx): (Sender<Cmd<_, _, B>>, Receiver<_>) = bounded(10);
79
80 let handler = thread::Builder::new()
81 .name(format!("partition-{pid}"))
82 .spawn(move || {
83 startup_barrier.wait();
85
86 let mut map = UnsortedMap::<K, V>::new(n);
90
91 loop {
92 let cmd = match rx.recv() {
93 Ok(cmd) => cmd,
94 Err(_) => {
95 panic!("worker thread command channel disconnected unexpectedly");
96 }
97 };
98
99 match cmd {
100 Cmd::Get { mut blocks, ret_tx } => {
101 for blk in blocks.iter_mut() {
103 blk.v = OOption::new(Default::default(), true);
104 blk.v.is_some = map.get(blk.k, &mut blk.v.value);
105 }
106 let _ = ret_tx.send(Reply::Blocks { pid, blocks }); }
108 Cmd::Insert { blocks, ret_tx } => {
109 for blk in blocks.iter() {
111 map.insert(blk.k, blk.v.unwrap());
112 }
113 let _ = ret_tx.send(Reply::Unit(()));
114 }
115 Cmd::Shutdown => {
116 break;
119 }
120 }
121 }
122 })
123 .expect("failed to spawn worker thread");
124
125 Self { tx, join_handle: Some(handler) }
126 }
127}
128
129impl<K, V, const B: usize> Drop for Worker<K, V, B>
130where
131 K: OHash + Pod + Default + std::fmt::Debug + Ord,
132 V: Cmov + Pod + Default + std::fmt::Debug,
133 BatchBlock<K, V>: Ord + Send,
134{
135 fn drop(&mut self) {
136 let _ = self.tx.send(Cmd::Shutdown);
138 match self.join_handle.take() {
140 Some(handle) => {
141 let _ = handle.join();
142 }
143 None => {
144 panic!("Exception while dropping worker thread, handler was already taken");
145 }
146 }
147 }
148}
149
150#[derive(Debug)]
159pub struct ShardedMap<K, V, const B: usize>
160where
161 K: OHash + Pod + Default + std::fmt::Debug + Send + Ord,
162 V: Cmov + Pod + Default + std::fmt::Debug + Send,
163 BatchBlock<K, V>: Ord + Send,
164{
165 size: usize,
167 capacity: usize,
169 workers: [Worker<K, V, B>; P],
171 random_state: RandomState,
173}
174
175#[repr(C)]
177#[derive(Default, Debug, Clone, Copy, Zeroable, PartialEq, Eq, PartialOrd, Ord)]
178pub struct BatchBlock<K, V>
179where
180 K: OHash + Pod + Default + std::fmt::Debug + Ord,
181 V: Cmov + Pod + Default + std::fmt::Debug,
182{
183 index: usize,
184 k: K,
185 v: OOption<V>,
186}
187unsafe impl<K, V> Pod for BatchBlock<K, V>
188where
189 K: OHash + Pod + Default + std::fmt::Debug + Ord,
190 V: Cmov + Pod + Default + std::fmt::Debug,
191{
192}
193impl_cmov_for_generic_pod!(BatchBlock<K, V>; where K: OHash + Pod + Default + std::fmt::Debug + Ord, V: Cmov + Pod + Default + std::fmt::Debug);
194
195impl<K, V, const B: usize> ShardedMap<K, V, B>
196where
197 K: OHash + Default + std::fmt::Debug + Send + Ord,
198 V: Cmov + Pod + Default + std::fmt::Debug + Send,
199 BatchBlock<K, V>: Ord + Send,
200{
201 pub fn new(capacity: usize) -> Self {
203 let per_part = capacity.div_ceil(P);
204 let startup = Arc::new(Barrier::new(P + 1));
205
206 let workers = std::array::from_fn(|i| Worker::new(per_part, i, startup.clone()));
207
208 startup.wait();
210
211 Self { size: 0, capacity: per_part * P, workers, random_state: RandomState::new() }
212 }
213
214 #[inline(always)]
215 fn get_partition(&self, key: &K) -> usize {
216 (self.random_state.hash_one(key) % P as u64) as usize
217 }
218
219 pub fn get_batch_distinct<const N: usize>(&mut self, keys: &[K; N]) -> [OOption<V>; N] {
223 let mut per_p: [Box<[BatchBlock<K, V>; N]>; P] =
227 std::array::from_fn(|_| Box::new([BatchBlock::default(); N]));
228
229 const INVALID_ID: usize = usize::MAX;
230
231 for (i, k) in keys.iter().enumerate() {
234 let target_p = self.get_partition(k);
235 for (p, partition) in per_p.iter_mut().enumerate() {
236 partition[i].k = *k;
237 partition[i].index = i;
238 partition[i].index.cmov(&INVALID_ID, target_p != p);
239 }
240 }
241
242 for partition in &mut per_p {
244 let cnt = compact(&mut **partition, |x: &BatchBlock<K, V>| x.index == INVALID_ID);
245 assert!(cnt <= B);
247 }
248
249 let (done_tx, done_rx) = bounded::<Reply<K, V, B>>(P);
250
251 for (p, partition) in per_p.iter_mut().enumerate() {
253 let blocks = {
254 let mut new_blocks = [BatchBlock::default(); B];
255 new_blocks.copy_from_slice(&partition[..B]);
256 Box::new(new_blocks)
257 };
258 self.workers[p].tx.send(Cmd::Get { blocks, ret_tx: done_tx.clone() }).unwrap();
259 }
260
261 let mut merged: Vec<BatchBlock<K, V>> = vec![BatchBlock::default(); P * B];
263
264 for _ in 0..P {
265 match done_rx.recv().unwrap() {
266 Reply::Blocks { pid, blocks } => {
267 for b in 0..B {
268 merged[pid * B + b] = blocks[b];
269 }
270 }
271 _ => {
272 panic!("unexpected reply from worker thread (probably early termination?)");
273 }
274 }
275 }
276
277 bitonic_sort(&mut merged);
279
280 let mut ret: [OOption<V>; N] = [(); N].map(|_| OOption::default());
282 for i in 0..N {
283 ret[i] = merged[i].v;
284 }
285
286 ret
287 }
288
289 pub fn insert_batch_distinct<const N: usize>(&mut self, keys: &[K; N], values: &[V; N]) {
296 assert!(self.size + N <= self.capacity, "Map is full, cannot insert more elements.");
297 let mut per_p: [Box<[BatchBlock<K, V>; B]>; P] =
299 std::array::from_fn(|_| Box::new([BatchBlock::default(); B]));
300
301 const INVALID_ID: usize = usize::MAX;
302
303 for (i, k) in keys.iter().enumerate() {
306 let target_p = self.get_partition(k);
307 for (p, partition) in per_p.iter_mut().enumerate() {
308 partition[i].k = *k;
309 partition[i].v = OOption::new(values[i], true);
310 partition[i].index = i;
311 partition[i].index.cmov(&INVALID_ID, target_p != p);
312 }
313 }
314
315 for partition in &mut per_p {
317 let cnt = compact(&mut **partition, |x| x.index == INVALID_ID);
318 assert!(cnt <= B);
320 }
321
322 let (done_tx, done_rx) = bounded::<Reply<K, V, B>>(P);
323
324 for (p, partition) in per_p.iter_mut().enumerate() {
326 let blocks = std::mem::replace(partition, Box::new([BatchBlock::default(); B]));
327 self.workers[p].tx.send(Cmd::Insert { blocks, ret_tx: done_tx.clone() }).unwrap();
328 }
329
330 for _i in 0..P {
332 match done_rx.recv().unwrap() {
333 Reply::Unit(()) => {}
334 _ => {
335 panic!("unexpected reply from worker thread (probably early termination?)");
336 }
337 }
338 }
339
340 self.size += N;
342 }
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 const N: usize = 4;
352 const B: usize = N;
353
354 #[test]
355 fn new_map_rounds_capacity_and_starts_empty() {
356 let requested = 100;
357 let map: ShardedMap<u64, u64, B> = ShardedMap::new(requested);
358
359 let per_part = requested.div_ceil(P);
361 assert_eq!(map.capacity, per_part * P); assert_eq!(map.size, 0);
363 }
364
365 #[test]
366 fn insert_batch_then_get_batch_returns_expected_values() {
367 let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(32);
368
369 let keys: [u64; N] = [1, 2, 3, 4];
370 let values: [u64; N] = [10, 20, 30, 40];
371
372 map.insert_batch_distinct::<N>(&keys, &values);
373
374 let results = map.get_batch_distinct::<N>(&keys);
375 for i in 0..N {
376 assert!(results[i].is_some(), "key {} missing", keys[i]);
377 assert_eq!(results[i].unwrap(), values[i]);
378 }
379 }
380
381 #[test]
382 fn querying_absent_keys_returns_none() {
383 let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(16);
384
385 let absent: [u64; N] = [100, 200, 300, 400];
386 let results = map.get_batch_distinct::<N>(&absent);
387
388 for r in &results {
389 assert!(!r.is_some());
390 }
391 }
392
393 #[test]
394 fn size_updates_after_insert() {
395 let mut map: ShardedMap<u64, u64, B> = ShardedMap::new(16);
396
397 let keys: [u64; N] = [11, 22, 33, 44];
398 let values: [u64; N] = [111, 222, 333, 444];
399
400 map.insert_batch_distinct::<N>(&keys, &values);
401 assert_eq!(map.size, N);
402 }
403}