1extern crate alloc;
20
21#[cfg(feature = "std")]
22extern crate std;
23
24use alloc::boxed::Box;
25use alloc::vec::Vec;
26use core::borrow::Borrow;
27use core::hash::{BuildHasher, Hash};
28use core::sync::atomic::Ordering;
29use foldhash::fast::FixedState;
30use kovan::{Atomic, RetiredNode, Shared, pin, retire};
31
32const BUCKET_COUNT: usize = 524_288;
38
39struct Backoff {
41 step: u32,
42}
43
44impl Backoff {
45 #[inline(always)]
46 fn new() -> Self {
47 Self { step: 0 }
48 }
49
50 #[inline(always)]
51 fn spin(&mut self) {
52 for _ in 0..(1 << self.step.min(6)) {
53 core::hint::spin_loop();
54 }
55 if self.step <= 6 {
56 self.step += 1;
57 }
58 }
59}
60
61#[repr(C)]
63struct Node<K, V> {
64 retired: RetiredNode,
65 hash: u64,
66 key: K,
67 value: V,
68 next: Atomic<Node<K, V>>,
69}
70
71pub struct HashMap<K: 'static, V: 'static, S = FixedState> {
73 buckets: Box<[Atomic<Node<K, V>>]>,
74 mask: usize,
75 hasher: S,
76}
77
78#[cfg(feature = "std")]
79impl<K, V> HashMap<K, V, FixedState>
80where
81 K: Hash + Eq + Clone + 'static,
82 V: Clone + 'static,
83{
84 pub fn new() -> Self {
86 Self::with_hasher(FixedState::default())
87 }
88}
89
90impl<K, V, S> HashMap<K, V, S>
91where
92 K: Hash + Eq + Clone + 'static,
93 V: Clone + 'static,
94 S: BuildHasher,
95{
96 pub fn with_hasher(hasher: S) -> Self {
98 let mut buckets = Vec::with_capacity(BUCKET_COUNT);
99 for _ in 0..BUCKET_COUNT {
100 buckets.push(Atomic::null());
101 }
102
103 Self {
104 buckets: buckets.into_boxed_slice(),
105 mask: BUCKET_COUNT - 1,
106 hasher,
107 }
108 }
109
110 #[inline(always)]
111 fn get_bucket_idx(&self, hash: u64) -> usize {
112 (hash as usize) & self.mask
113 }
114
115 #[inline(always)]
116 fn get_bucket(&self, idx: usize) -> &Atomic<Node<K, V>> {
117 unsafe { self.buckets.get_unchecked(idx) }
118 }
119
120 pub fn get<Q>(&self, key: &Q) -> Option<V>
122 where
123 K: Borrow<Q>,
124 Q: Hash + Eq + ?Sized,
125 {
126 let hash = self.hasher.hash_one(key);
127 let idx = self.get_bucket_idx(hash);
128 let bucket = self.get_bucket(idx);
129
130 let guard = pin();
131 let mut current = bucket.load(Ordering::Acquire, &guard);
132
133 while !current.is_null() {
134 unsafe {
135 let node = current.deref();
136 if node.hash == hash && node.key.borrow() == key {
138 return Some(node.value.clone());
139 }
140 current = node.next.load(Ordering::Acquire, &guard);
141 }
142 }
143 None
144 }
145
146 pub fn contains_key<Q>(&self, key: &Q) -> bool
148 where
149 K: Borrow<Q>,
150 Q: Hash + Eq + ?Sized,
151 {
152 self.get(key).is_some()
153 }
154
155 pub fn insert(&self, key: K, value: V) -> Option<V> {
157 let hash = self.hasher.hash_one(&key);
158 let idx = self.get_bucket_idx(hash);
159 let bucket = self.get_bucket(idx);
160 let mut backoff = Backoff::new();
161
162 let guard = pin();
163
164 'outer: loop {
165 let mut prev_link = bucket;
167 let mut current = prev_link.load(Ordering::Acquire, &guard);
168
169 while !current.is_null() {
170 unsafe {
171 let node = current.deref();
172
173 if node.hash == hash && node.key == key {
174 let next = node.next.load(Ordering::Relaxed, &guard);
176 let old_value = node.value.clone();
177
178 let new_node = Box::into_raw(Box::new(Node {
180 retired: RetiredNode::new(),
181 hash,
182 key: key.clone(),
183 value: value.clone(),
184 next: Atomic::new(next.as_raw()),
185 }));
186
187 match prev_link.compare_exchange(
188 current,
189 Shared::from_raw(new_node),
190 Ordering::Release,
191 Ordering::Relaxed,
192 &guard,
193 ) {
194 Ok(_) => {
195 retire(current.as_raw());
198 return Some(old_value);
199 }
200 Err(_) => {
201 drop(Box::from_raw(new_node));
203 backoff.spin();
204 continue 'outer;
205 }
206 }
207 }
208
209 prev_link = &node.next;
210 current = node.next.load(Ordering::Acquire, &guard);
211 }
212 }
213
214 let new_node_ptr = Box::into_raw(Box::new(Node {
217 retired: RetiredNode::new(),
218 hash,
219 key: key.clone(),
220 value: value.clone(),
221 next: Atomic::null(),
222 }));
223
224 match prev_link.compare_exchange(
226 unsafe { Shared::from_raw(core::ptr::null_mut()) },
227 unsafe { Shared::from_raw(new_node_ptr) },
228 Ordering::Release,
229 Ordering::Relaxed,
230 &guard,
231 ) {
232 Ok(_) => return None,
233 Err(actual_val) => {
234 unsafe {
238 let actual_node = actual_val.deref();
239 if actual_node.hash == hash && actual_node.key == key {
240 drop(Box::from_raw(new_node_ptr));
242 backoff.spin();
243 continue 'outer;
244 }
245 }
246
247 unsafe {
249 drop(Box::from_raw(new_node_ptr));
250 }
251 backoff.spin();
252 continue 'outer;
253 }
254 }
255 }
256 }
257
258 pub fn insert_if_absent(&self, key: K, value: V) -> Option<V> {
261 let hash = self.hasher.hash_one(&key);
262 let idx = self.get_bucket_idx(hash);
263 let bucket = self.get_bucket(idx);
264 let mut backoff = Backoff::new();
265
266 let guard = pin();
267
268 'outer: loop {
269 let mut prev_link = bucket;
271 let mut current = prev_link.load(Ordering::Acquire, &guard);
272
273 while !current.is_null() {
274 unsafe {
275 let node = current.deref();
276
277 if node.hash == hash && node.key == key {
278 return Some(node.value.clone());
280 }
281
282 prev_link = &node.next;
283 current = node.next.load(Ordering::Acquire, &guard);
284 }
285 }
286
287 let new_node_ptr = Box::into_raw(Box::new(Node {
289 retired: RetiredNode::new(),
290 hash,
291 key: key.clone(),
292 value: value.clone(),
293 next: Atomic::null(),
294 }));
295
296 match prev_link.compare_exchange(
298 unsafe { Shared::from_raw(core::ptr::null_mut()) },
299 unsafe { Shared::from_raw(new_node_ptr) },
300 Ordering::Release,
301 Ordering::Relaxed,
302 &guard,
303 ) {
304 Ok(_) => return None,
305 Err(actual_val) => {
306 unsafe {
308 let actual_node = actual_val.deref();
309 if actual_node.hash == hash && actual_node.key == key {
310 drop(Box::from_raw(new_node_ptr));
312 return Some(actual_node.value.clone());
313 }
314 }
315
316 unsafe {
318 drop(Box::from_raw(new_node_ptr));
319 }
320 backoff.spin();
321 continue 'outer;
322 }
323 }
324 }
325 }
326
327 pub fn get_or_insert(&self, key: K, value: V) -> V {
333 match self.insert_if_absent(key, value.clone()) {
336 Some(existing) => existing,
337 None => value,
338 }
339 }
340
341 pub fn remove<Q>(&self, key: &Q) -> Option<V>
343 where
344 K: Borrow<Q>,
345 Q: Hash + Eq + ?Sized,
346 {
347 let hash = self.hasher.hash_one(key);
348 let idx = self.get_bucket_idx(hash);
349 let bucket = self.get_bucket(idx);
350 let mut backoff = Backoff::new();
351
352 let guard = pin();
353
354 loop {
355 let mut prev_link = bucket;
356 let mut current = prev_link.load(Ordering::Acquire, &guard);
357
358 while !current.is_null() {
359 unsafe {
360 let node = current.deref();
361
362 if node.hash == hash && node.key.borrow() == key {
363 let next = node.next.load(Ordering::Acquire, &guard);
364 let old_value = node.value.clone();
365
366 match prev_link.compare_exchange(
367 current,
368 next,
369 Ordering::Release,
370 Ordering::Relaxed,
371 &guard,
372 ) {
373 Ok(_) => {
374 retire(current.as_raw());
377 return Some(old_value);
378 }
379 Err(_) => {
380 backoff.spin();
381 break; }
383 }
384 }
385
386 prev_link = &node.next;
387 current = node.next.load(Ordering::Acquire, &guard);
388 }
389 }
390
391 if current.is_null() {
392 return None;
393 }
394 }
395 }
396
397 pub fn clear(&self) {
399 let guard = pin();
400
401 for bucket in self.buckets.iter() {
402 loop {
403 let head = bucket.load(Ordering::Acquire, &guard);
404 if head.is_null() {
405 break;
406 }
407
408 match bucket.compare_exchange(
410 head,
411 unsafe { Shared::from_raw(core::ptr::null_mut()) },
412 Ordering::Release,
413 Ordering::Relaxed,
414 &guard,
415 ) {
416 Ok(_) => {
417 unsafe {
419 let mut current = head;
420 while !current.is_null() {
421 let node = current.deref();
422 let next = node.next.load(Ordering::Relaxed, &guard);
423 retire(current.as_raw());
426 current = next;
427 }
428 }
429 break;
430 }
431 Err(_) => {
432 continue;
434 }
435 }
436 }
437 }
438 }
439
440 pub fn is_empty(&self) -> bool {
442 self.len() == 0
443 }
444
445 pub fn len(&self) -> usize {
448 let mut count = 0;
449 let guard = pin();
450 for bucket in self.buckets.iter() {
451 let mut current = bucket.load(Ordering::Acquire, &guard);
452 while !current.is_null() {
453 unsafe {
454 let node = current.deref();
455 count += 1;
456 current = node.next.load(Ordering::Acquire, &guard);
457 }
458 }
459 }
460 count
461 }
462
463 pub fn iter(&self) -> Iter<'_, K, V, S> {
466 Iter {
467 map: self,
468 bucket_idx: 0,
469 guard: pin(),
470 current: core::ptr::null(),
471 }
472 }
473
474 pub fn keys(&self) -> Keys<'_, K, V, S> {
477 Keys { iter: self.iter() }
478 }
479
480 pub fn hasher(&self) -> &S {
482 &self.hasher
483 }
484}
485
486pub struct Iter<'a, K: 'static, V: 'static, S> {
494 map: &'a HashMap<K, V, S>,
495 bucket_idx: usize,
496 guard: kovan::Guard,
497 current: *const Node<K, V>,
498}
499
500impl<'a, K, V, S> Iterator for Iter<'a, K, V, S>
501where
502 K: Clone,
503 V: Clone,
504{
505 type Item = (K, V);
506
507 fn next(&mut self) -> Option<Self::Item> {
508 loop {
509 if !self.current.is_null() {
510 unsafe {
511 let node = &*self.current;
512 self.current = node.next.load(Ordering::Acquire, &self.guard).as_raw();
514 return Some((node.key.clone(), node.value.clone()));
515 }
516 }
517
518 if self.bucket_idx >= self.map.buckets.len() {
520 return None;
521 }
522
523 let bucket = unsafe { self.map.buckets.get_unchecked(self.bucket_idx) };
524 self.bucket_idx += 1;
525 self.current = bucket.load(Ordering::Acquire, &self.guard).as_raw();
526 }
527 }
528}
529
530pub struct Keys<'a, K: 'static, V: 'static, S> {
532 iter: Iter<'a, K, V, S>,
533}
534
535impl<'a, K, V, S> Iterator for Keys<'a, K, V, S>
536where
537 K: Clone,
538 V: Clone,
539{
540 type Item = K;
541
542 fn next(&mut self) -> Option<Self::Item> {
543 self.iter.next().map(|(k, _)| k)
544 }
545}
546
547impl<'a, K, V, S> IntoIterator for &'a HashMap<K, V, S>
548where
549 K: Hash + Eq + Clone + 'static,
550 V: Clone + 'static,
551 S: BuildHasher,
552{
553 type Item = (K, V);
554 type IntoIter = Iter<'a, K, V, S>;
555
556 fn into_iter(self) -> Self::IntoIter {
557 self.iter()
558 }
559}
560
561#[cfg(feature = "std")]
562impl<K, V> Default for HashMap<K, V, FixedState>
563where
564 K: Hash + Eq + Clone + 'static,
565 V: Clone + 'static,
566{
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572unsafe impl<K: Send, V: Send, S: Send> Send for HashMap<K, V, S> {}
578unsafe impl<K: Send + Sync, V: Send + Sync, S: Send + Sync> Sync for HashMap<K, V, S> {}
579
580impl<K, V, S> Drop for HashMap<K, V, S> {
581 fn drop(&mut self) {
582 let guard = pin();
588
589 for bucket in self.buckets.iter() {
590 let mut current = bucket.load(Ordering::Acquire, &guard);
591
592 unsafe {
593 while !current.is_null() {
594 let node = current.deref();
595 let next = node.next.load(Ordering::Relaxed, &guard);
596 drop(Box::from_raw(current.as_raw()));
597 current = next;
598 }
599 }
600 }
601
602 drop(guard);
604 kovan::flush();
605 }
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611
612 #[test]
613 fn test_insert_and_get() {
614 let map = HashMap::new();
615 assert_eq!(map.insert(1, 100), None);
616 assert_eq!(map.get(&1), Some(100));
617 assert_eq!(map.get(&2), None);
618 }
619
620 #[test]
621 fn test_insert_replace() {
622 let map = HashMap::new();
623 assert_eq!(map.insert(1, 100), None);
624 assert_eq!(map.insert(1, 200), Some(100));
625 assert_eq!(map.get(&1), Some(200));
626 }
627
628 #[test]
629 fn test_concurrent_inserts() {
630 use alloc::sync::Arc;
631 extern crate std;
632 use std::thread;
633
634 let map = Arc::new(HashMap::new());
635 let mut handles = alloc::vec::Vec::new();
636
637 for thread_id in 0..4 {
638 let map_clone = Arc::clone(&map);
639 let handle = thread::spawn(move || {
640 for i in 0..1000 {
641 let key = thread_id * 1000 + i;
642 map_clone.insert(key, key * 2);
643 }
644 });
645 handles.push(handle);
646 }
647
648 for handle in handles {
649 handle.join().unwrap();
650 }
651
652 for thread_id in 0..4 {
653 for i in 0..1000 {
654 let key = thread_id * 1000 + i;
655 assert_eq!(map.get(&key), Some(key * 2));
656 }
657 }
658 }
659}