1use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_primitives::{
7 cmov_body, cxchg_body, impl_cmov_for_generic_pod,
8 traits::{Cmov, _Cmovbase},
9};
10
11use seq_macro::seq;
12
13use crate::{array::MultiWayArray, queue::ShortQueue};
14
15const INSERTION_QUEUE_MAX_SIZE: usize = 10;
17const DEAMORTIZED_INSERTIONS: usize = 2;
19const BUCKET_SIZE: usize = 4;
21
22use std::hash::Hash;
23
24pub trait OHash: Cmov + Pod + Hash + PartialEq {}
26impl<K> OHash for K where K: Cmov + Pod + Hash + PartialEq {}
28
29#[repr(align(8))]
31#[derive(Debug, Default, Clone, Copy, Zeroable)]
32pub struct InlineElement<K, V>
33where
34 K: OHash,
35 V: Cmov + Pod,
36{
37 key: K,
38 value: V,
39}
40unsafe impl<K: OHash, V: Cmov + Pod> Pod for InlineElement<K, V> {}
41impl_cmov_for_generic_pod!(InlineElement<K,V>; where K: OHash, V: Cmov + Pod);
42
43#[derive(Debug, Default, Clone, Copy, Zeroable)]
45pub struct BucketElement<K, V>
46where
47 K: OHash,
48 V: Cmov + Pod,
49{
50 is_valid: bool,
51 element: InlineElement<K, V>,
52}
53unsafe impl<K: OHash, V: Cmov + Pod> Pod for BucketElement<K, V> {}
54impl_cmov_for_generic_pod!(BucketElement<K,V>; where K: OHash, V: Cmov + Pod);
55
56impl<K, V> BucketElement<K, V>
57where
58 K: OHash,
59 V: Cmov + Pod,
60{
61 #[inline(always)]
62 const fn is_empty(&self) -> bool {
63 !self.is_valid
64 }
65}
66
67#[derive(Debug, Default, Clone, Copy, Zeroable)]
74struct Bucket<K, V>
75where
76 K: OHash,
77 V: Cmov + Pod,
78{
79 elements: [BucketElement<K, V>; BUCKET_SIZE],
80}
81unsafe impl<K: OHash, V: Cmov + Pod> Pod for Bucket<K, V> {}
82impl_cmov_for_generic_pod!(Bucket<K,V>; where K: OHash, V: Cmov + Pod);
83
84impl<K, V> Bucket<K, V>
85where
86 K: OHash,
87 V: Cmov + Pod,
88{
89 fn update_if_exists(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
96 let mut updated = false;
97 for i in 0..BUCKET_SIZE {
98 let element_i = &mut self.elements[i];
99 let choice = real & !element_i.is_empty() & (element_i.element.key == element.key);
100 element_i.element.value.cmov(&element.value, choice);
101 updated.cmov(&true, choice);
102 }
103 updated
104 }
105
106 fn read_if_exists(&self, key: K, ret: &mut V) -> bool {
107 let mut found = false;
108 for i in 0..BUCKET_SIZE {
109 let element_i = &self.elements[i];
110 let choice = !element_i.is_empty() & (element_i.element.key == key);
111 ret.cmov(&element_i.element.value, choice);
112 found.cmov(&true, choice);
113 }
114 found
115 }
116
117 fn insert_if_available(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
124 let mut inserted = !real;
125 for i in 0..BUCKET_SIZE {
126 let element_i = &mut self.elements[i];
127 let choice = !inserted & element_i.is_empty();
128 element_i.is_valid.cmov(&true, choice);
129 element_i.element.cmov(&element, choice);
130 inserted.cmov(&true, choice);
131 }
132 inserted
133 }
134}
135
136#[derive(Debug)]
145pub struct UnsortedMap<K, V>
146where
147 K: OHash + Default + std::fmt::Debug,
148 V: Cmov + Pod + Default + std::fmt::Debug,
149{
150 size: usize,
152 _capacity: usize,
154 table_size: usize,
156 table: MultiWayArray<Bucket<K, V>, 2>,
158 hash_builders: [RandomState; 2],
160 insertion_queue: ShortQueue<InlineElement<K, V>, INSERTION_QUEUE_MAX_SIZE>,
162 rng: ThreadRng,
164}
165
166impl<K, V> UnsortedMap<K, V>
167where
168 K: OHash + Default + std::fmt::Debug,
169 V: Cmov + Pod + Default + std::fmt::Debug,
170{
171 pub fn new(capacity: usize) -> Self {
173 debug_assert!(capacity > 0);
174 let table_size = (capacity * 5).div_ceil(4 * BUCKET_SIZE).max(2);
176 Self {
177 size: 0,
178 _capacity: capacity,
179 table_size,
180 table: MultiWayArray::new(table_size),
181 hash_builders: [RandomState::new(), RandomState::new()],
182 insertion_queue: ShortQueue::new(),
183 rng: rand::rng(),
184 }
185 }
186
187 #[inline(always)]
188 fn hash_key<const TABLE: usize>(&self, key: &K) -> usize {
189 (self.hash_builders[TABLE].hash_one(key) % self.table_size as u64) as usize
190 }
191
192 pub fn get(&mut self, key: K, ret: &mut V) -> bool {
200 let mut found = false;
201 let mut tmp: Bucket<K, V> = Default::default();
202
203 seq!(INDEX in 0..2 {
206 let hash = self.hash_key::<INDEX>(&key);
207 self.table.read(INDEX, hash, &mut tmp);
208 let found_local = tmp.read_if_exists(key, ret);
209 found.cmov(&true, found_local);
210 });
211
212 for i in 0..self.insertion_queue.size {
214 let element = self.insertion_queue.elements.data[i];
215 let found_local = !element.is_empty() & (element.value.key == key);
216 ret.cmov(&element.value.value, found_local);
217 found.cmov(&true, found_local);
218 }
219
220 found
221 }
222
223 fn try_insert_entry(&mut self, real: bool, element: &mut InlineElement<K, V>) -> bool {
228 let mut done = !real;
229
230 seq!(INDEX_REV in 0..2 {{
231 #[allow(clippy::identity_op, clippy::eq_op)] const INDEX: usize = 1 - INDEX_REV;
233
234 let hash = self.hash_key::<INDEX>(&element.key);
235 self.table.update(INDEX, hash, |bucket| {
236 let choice = !done;
237 let inserted = bucket.insert_if_available(choice, *element);
238 done.cmov(&true, inserted);
239 let randidx = self.rng.random_range(0..BUCKET_SIZE);
240 bucket.elements[randidx].element.cxchg(element, !done);
241 });
242 }});
243 done
244 }
245
246 pub fn deamortize_insertion_queue(&mut self) {
248 for _ in 0..DEAMORTIZED_INSERTIONS {
249 let mut element = InlineElement::default();
250 let real = self.insertion_queue.size > 0;
251
252 self.insertion_queue.maybe_pop(real, &mut element);
254 let has_pending_element = !self.try_insert_entry(real, &mut element);
255 self.insertion_queue.maybe_push(has_pending_element, element);
256 }
257 }
258
259 pub fn insert(&mut self, key: K, value: V) {
263 assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
265 self.insertion_queue.maybe_push(true, InlineElement { key, value });
266 self.deamortize_insertion_queue();
267 self.size.cmov(&(self.size + 1), true);
268 }
269
270 pub fn write(&mut self, key: K, value: V) {
276 let mut updated = false;
277
278 seq!(INDEX in 0..2 {
281 let hash = self.hash_key::<INDEX>(&key);
282 self.table.update(INDEX, hash, |bucket| {
283 let choice = !updated;
284 let updated_local = bucket.update_if_exists(choice, InlineElement { key, value });
285 updated.cmov(&true, updated_local);
286 });
287 });
288
289 for i in 0..self.insertion_queue.size {
291 let element = &mut self.insertion_queue.elements.data[i];
292 let choice = !updated & !element.is_empty() & (element.value.key == key);
293 element.value.value.cmov(&value, choice);
294 updated.cmov(&true, choice);
295 }
296
297 assert!(updated);
298 }
299
300 }
303
304#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_unsorted_map() {
312 let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(2);
313 assert_eq!(map.size, 0);
314 let mut value = 0;
315 assert!(!map.get(1, &mut value));
316 map.insert(1, 2);
317 assert_eq!(map.size, 1);
318 assert!(map.get(1, &mut value));
319 assert_eq!(value, 2);
320 map.write(1, 3);
321 assert!(map.get(1, &mut value));
322 assert_eq!(value, 3);
323 }
324
325 #[test]
326 fn test_full_map() {
327 const SZ: usize = 1024;
328 let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(SZ);
329 assert_eq!(map.size, 0);
330 for i in 0..SZ as u32 {
331 map.insert(i, i * 2);
332 let mut value = 0;
333 assert!(map.get(i, &mut value));
334 assert_eq!(value, i * 2);
335 assert_eq!(map.size, (i + 1) as usize);
336 map.write(i, i * 3);
337 assert!(map.get(i, &mut value));
338 assert_eq!(value, i * 3);
339 assert_eq!(map.size, (i + 1) as usize);
340 }
341 }
342
343 fn test_map_subtypes<
344 K: OHash + Default + std::fmt::Debug,
345 V: Cmov + Pod + Default + std::fmt::Debug,
346 >() {
347 const SZ: usize = 1024;
348 let mut map: UnsortedMap<K, V> = UnsortedMap::new(SZ);
349 assert_eq!(map.size, 0);
350 let mut value = V::default();
351 assert!(!map.get(K::default(), &mut value));
352 map.insert(K::default(), V::default());
353 assert_eq!(map.size, 1);
354 assert!(map.get(K::default(), &mut value));
355 }
356
357 #[test]
358 fn test_map_multiple_types() {
359 test_map_subtypes::<u32, u32>();
360 test_map_subtypes::<u64, u64>();
361 test_map_subtypes::<u128, u128>();
362 test_map_subtypes::<i32, i32>();
363 test_map_subtypes::<i64, i64>();
364 test_map_subtypes::<i128, i128>();
365 }
366
367 }