1use ahash::RandomState;
4use bytemuck::{Pod, Zeroable};
5use rand::{rngs::ThreadRng, Rng};
6use rostl_primitives::{
7 cmov_body, cxchg_body, impl_cmov_for_generic_pod,
8 traits::{Cmov, _Cmovbase},
9};
10
11use seq_macro::seq;
12
13use crate::{array::MultiWayArray, queue::ShortQueue};
14
15const INSERTION_QUEUE_MAX_SIZE: usize = 10;
17const DEAMORTIZED_INSERTIONS: usize = 2;
19const BUCKET_SIZE: usize = 4;
21
22use std::hash::Hash;
23
24pub trait OHash: Cmov + Pod + Hash + PartialEq {}
26impl<K> OHash for K where K: Cmov + Pod + Hash + PartialEq {}
28
29#[repr(align(8))]
31#[repr(C)]
32#[derive(Debug, Default, Clone, Copy, Zeroable)]
33pub struct InlineElement<K, V>
34where
35 K: OHash,
36 V: Cmov + Pod,
37{
38 key: K,
39 value: V,
40}
41unsafe impl<K: OHash, V: Cmov + Pod> Pod for InlineElement<K, V> {}
42impl_cmov_for_generic_pod!(InlineElement<K,V>; where K: OHash, V: Cmov + Pod);
43
44#[derive(Debug, Default, Clone, Copy, Zeroable)]
51#[repr(C)]
52struct Bucket<K, V>
53where
54 K: OHash,
55 V: Cmov + Pod,
56{
57 is_valid: [bool; BUCKET_SIZE],
58 elements: [InlineElement<K, V>; BUCKET_SIZE],
59}
60unsafe impl<K: OHash, V: Cmov + Pod> Pod for Bucket<K, V> {}
61impl_cmov_for_generic_pod!(Bucket<K,V>; where K: OHash, V: Cmov + Pod);
62
63impl<K, V> Bucket<K, V>
64where
65 K: OHash,
66 V: Cmov + Pod,
67{
68 const fn is_empty(&self, i: usize) -> bool {
69 !self.is_valid[i]
70 }
71
72 fn update_if_exists(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
79 let mut updated = false;
80 for i in 0..BUCKET_SIZE {
81 let choice = real & !self.is_empty(i) & (self.elements[i].key == element.key);
82 self.elements[i].value.cmov(&element.value, choice);
83 updated.cmov(&true, choice);
84 }
85 updated
86 }
87
88 fn read_if_exists(&self, key: K, ret: &mut V) -> bool {
89 let mut found = false;
90 for i in 0..BUCKET_SIZE {
91 let choice = !self.is_empty(i) & (self.elements[i].key == key);
92 ret.cmov(&self.elements[i].value, choice);
93 found.cmov(&true, choice);
94 }
95 found
96 }
97
98 fn insert_if_available(&mut self, real: bool, element: InlineElement<K, V>) -> bool {
105 let mut inserted = !real;
106 for i in 0..BUCKET_SIZE {
107 let choice = !inserted & self.is_empty(i);
108 self.is_valid[i].cmov(&true, choice);
109 self.elements[i].cmov(&element, choice);
110 inserted.cmov(&true, choice);
111 }
112 inserted
113 }
114}
115
116#[derive(Debug)]
125pub struct UnsortedMap<K, V>
126where
127 K: OHash + Default + std::fmt::Debug,
128 V: Cmov + Pod + Default + std::fmt::Debug,
129{
130 size: usize,
132 _capacity: usize,
134 table_size: usize,
136 table: MultiWayArray<Bucket<K, V>, 2>,
138 hash_builders: [RandomState; 2],
140 insertion_queue: ShortQueue<InlineElement<K, V>, INSERTION_QUEUE_MAX_SIZE>,
142 rng: ThreadRng,
144}
145
146impl<K, V> UnsortedMap<K, V>
147where
148 K: OHash + Default + std::fmt::Debug,
149 V: Cmov + Pod + Default + std::fmt::Debug,
150{
151 pub fn new(capacity: usize) -> Self {
153 debug_assert!(capacity > 0);
154 let table_size = (capacity * 10).div_ceil(9 * BUCKET_SIZE).max(2);
157 Self {
158 size: 0,
159 _capacity: capacity,
160 table_size,
161 table: MultiWayArray::new(table_size),
162 hash_builders: [RandomState::new(), RandomState::new()],
163 insertion_queue: ShortQueue::new(),
164 rng: rand::rng(),
165 }
166 }
167
168 #[inline(always)]
169 fn hash_key<const TABLE: usize>(&self, key: &K) -> usize {
170 (self.hash_builders[TABLE].hash_one(key) % self.table_size as u64) as usize
171 }
172
173 pub fn get(&mut self, key: K, ret: &mut V) -> bool {
181 let mut found = false;
182 let mut tmp: Bucket<K, V> = Default::default();
183
184 seq!(INDEX in 0..2 {
187 let hash = self.hash_key::<INDEX>(&key);
188 self.table.read(INDEX, hash, &mut tmp);
189 let found_local = tmp.read_if_exists(key, ret);
190 found.cmov(&true, found_local);
191 });
192
193 for i in 0..self.insertion_queue.size {
195 let element = self.insertion_queue.elements.data[i];
196 let found_local = !element.is_empty() & (element.value.key == key);
197 ret.cmov(&element.value.value, found_local);
198 found.cmov(&true, found_local);
199 }
200
201 found
202 }
203
204 fn try_insert_entry(&mut self, real: bool, element: &mut InlineElement<K, V>) -> bool {
209 let mut done = !real;
210
211 seq!(INDEX_REV in 0..2 {{
212 #[allow(clippy::identity_op, clippy::eq_op)] const INDEX: usize = 1 - INDEX_REV;
214
215 let hash = self.hash_key::<INDEX>(&element.key);
216 self.table.update(INDEX, hash, |bucket| {
217 let choice = !done;
218 let inserted = bucket.insert_if_available(choice, *element);
219 done.cmov(&true, inserted);
220 let randidx = self.rng.random_range(0..BUCKET_SIZE);
221 bucket.elements[randidx].cxchg(element, !done);
222 });
223 }});
224 done
225 }
226
227 pub fn deamortize_insertion_queue(&mut self) {
229 for _ in 0..DEAMORTIZED_INSERTIONS {
230 let mut element = InlineElement::default();
231 let real = self.insertion_queue.size > 0;
232
233 self.insertion_queue.maybe_pop(real, &mut element);
235 let has_pending_element = !self.try_insert_entry(real, &mut element);
236 self.insertion_queue.maybe_push(has_pending_element, element);
237 }
238 }
239
240 pub fn insert(&mut self, key: K, value: V) {
244 assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
246 self.insertion_queue.maybe_push(true, InlineElement { key, value });
247 self.deamortize_insertion_queue();
248 self.size.cmov(&(self.size + 1), true);
249 }
250
251 pub fn insert_cond(&mut self, key: K, value: V, real: bool) {
258 assert!(self.insertion_queue.size < INSERTION_QUEUE_MAX_SIZE);
260 self.insertion_queue.maybe_push(real, InlineElement { key, value });
261 self.deamortize_insertion_queue();
262 self.size.cmov(&(self.size + 1), real);
263 }
264
265 pub fn write(&mut self, key: K, value: V) {
271 let mut updated = false;
272
273 seq!(INDEX in 0..2 {
276 let hash = self.hash_key::<INDEX>(&key);
277 self.table.update(INDEX, hash, |bucket| {
278 let choice = !updated;
279 let updated_local = bucket.update_if_exists(choice, InlineElement { key, value });
280 updated.cmov(&true, updated_local);
281 });
282 });
283
284 for i in 0..self.insertion_queue.size {
286 let element = &mut self.insertion_queue.elements.data[i];
287 let choice = !updated & !element.is_empty() & (element.value.key == key);
288 element.value.value.cmov(&value, choice);
289 updated.cmov(&true, choice);
290 }
291
292 assert!(updated);
293 }
294
295 }
298
299#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_unsorted_map() {
307 let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(2);
308 assert_eq!(map.size, 0);
309 let mut value = 0;
310 assert!(!map.get(1, &mut value));
311 map.insert(1, 2);
312 assert_eq!(map.size, 1);
313 assert!(map.get(1, &mut value));
314 assert_eq!(value, 2);
315 map.write(1, 3);
316 assert!(map.get(1, &mut value));
317 assert_eq!(value, 3);
318 }
319
320 #[test]
321 fn test_full_map() {
322 const SZ: usize = 1024;
323 let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(SZ);
324 assert_eq!(map.size, 0);
325 for i in 0..SZ as u32 {
326 map.insert(i, i * 2);
327 let mut value = 0;
328 assert!(map.get(i, &mut value));
329 assert_eq!(value, i * 2);
330 assert_eq!(map.size, (i + 1) as usize);
331 map.write(i, i * 3);
332 assert!(map.get(i, &mut value));
333 assert_eq!(value, i * 3);
334 assert_eq!(map.size, (i + 1) as usize);
335 }
336 }
337
338 #[test]
339 fn test_insert_cond() {
340 let mut map: UnsortedMap<u32, u32> = UnsortedMap::new(8);
342 assert_eq!(map.size, 0);
343
344 map.insert_cond(10, 100, false);
346 assert_eq!(map.size, 0);
347 let mut value = 0;
348 assert!(!map.get(10, &mut value));
349
350 map.insert_cond(10, 200, true);
352 assert_eq!(map.size, 1);
353 assert!(map.get(10, &mut value));
354 assert_eq!(value, 200);
355
356 map.insert_cond(10, 300, false);
358 assert_eq!(map.size, 1);
359 assert!(map.get(10, &mut value));
360 assert_eq!(value, 200);
361 }
362
363 fn test_map_subtypes<
364 K: OHash + Default + std::fmt::Debug,
365 V: Cmov + Pod + Default + std::fmt::Debug,
366 >() {
367 const SZ: usize = 1024;
368 let mut map: UnsortedMap<K, V> = UnsortedMap::new(SZ);
369 assert_eq!(map.size, 0);
370 let mut value = V::default();
371 assert!(!map.get(K::default(), &mut value));
372 map.insert(K::default(), V::default());
373 assert_eq!(map.size, 1);
374 assert!(map.get(K::default(), &mut value));
375 }
376
377 #[test]
378 fn test_map_multiple_types() {
379 test_map_subtypes::<u32, u32>();
380 test_map_subtypes::<u64, u64>();
381 test_map_subtypes::<u128, u128>();
382 test_map_subtypes::<i32, i32>();
383 test_map_subtypes::<i64, i64>();
384 test_map_subtypes::<i128, i128>();
385 }
386
387 }