use super::*;
#[derive(Clone, PartialEq, Eq, Default)]
pub struct SeqMap<
K: Hash + PartialEq + Eq + Copy + Clone + Send + Sync + 'static,
V: Copy + Clone + Send + Sync + 'static,
H: Hasher + Default = DefaultHasher,
> {
arr: SeqArray<(K, V)>, _phantom: PhantomData<H>,
}
#[derive(Copy, Clone)]
#[must_use]
enum FindSlotResult<V: Copy + Clone + Send + Sync + 'static> {
Found(usize, V),
Empty(usize),
Resizing,
Resized,
NeedsResize,
}
impl<
K: Hash + PartialEq + Eq + Copy + Clone + Send + Sync + 'static,
V: Copy + Clone + Send + Sync + 'static,
H: Hasher + Default,
> SeqMap<K, V, H>
{
#[inline(always)]
pub fn capacity(&self) -> usize {
self.arr.capacity()
}
#[inline(always)]
pub fn len(&self) -> usize {
self.arr.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.arr.is_empty()
}
#[inline(always)]
fn cluster_size(capacity: usize) -> u8 {
ceil_log2_usize(capacity) as u8
}
#[inline(always)]
fn calculate_hash(key: K) -> u64 {
let mut hasher = H::default();
key.hash(&mut hasher);
hasher.finish()
}
#[inline(always)]
fn find_slot(hash: u64, cap: usize, arr: &SeqArray<(K, V)>, key: K) -> FindSlotResult<V> {
let cluster = Self::cluster_size(cap);
let start = (hash % cap as u64) as usize;
let mut first_empty = None;
for i in 0..cap {
let idx = (start + i) % cap;
let res = arr.get_without_resize(cap, idx);
match res {
Ok((k, v)) if k == key => return FindSlotResult::Found(idx, v),
Err(_) if first_empty.is_none() => first_empty = Some(idx),
Err(Error::Resized) => return FindSlotResult::Resized,
Err(Error::Resizing) => return FindSlotResult::Resizing,
_ => (),
}
}
if arr.capacity() == cap && Self::cluster_size(cap) == cluster {
if let Some(idx) = first_empty {
return FindSlotResult::Empty(idx);
}
}
if cap != arr.capacity() {
return FindSlotResult::Resized;
}
FindSlotResult::NeedsResize
}
#[inline(always)]
pub fn new() -> Self {
SeqMap {
arr: SeqArray::with_capacity(10),
_phantom: PhantomData,
}
}
#[inline(always)]
pub fn with_capacity(size: usize) -> Self {
SeqMap {
arr: SeqArray::with_capacity(size),
_phantom: PhantomData,
}
}
#[inline(always)]
pub fn insert(&self, key: K, value: V)
where
K: core::fmt::Debug,
V: core::fmt::Debug + PartialEq,
{
let hash = Self::calculate_hash(key);
loop {
let cap = self.arr.capacity();
match Self::find_slot(hash, cap, &self.arr, key) {
FindSlotResult::Found(idx, _) => {
if self.arr.set_without_resize(cap, idx, (key, value)).is_ok() {
return;
}
}
FindSlotResult::Empty(idx) => {
if self
.arr
.cas_set_without_resize(cap, idx, (key, value))
.is_ok()
{
return;
}
}
FindSlotResult::Resizing | FindSlotResult::Resized => std::thread::yield_now(),
FindSlotResult::NeedsResize => {
std::thread::yield_now(); if cap != self.arr.capacity() {
continue;
}
self.arr.scale_up();
}
}
}
}
#[inline(always)]
pub fn get(&self, key: K) -> Option<V> {
let hash = Self::calculate_hash(key);
loop {
let cap = self.arr.capacity();
match Self::find_slot(hash, cap, &self.arr, key) {
FindSlotResult::Found(_, v) => return Some(v),
FindSlotResult::Empty(_) => return None,
FindSlotResult::Resizing | FindSlotResult::Resized => {
std::thread::yield_now();
}
FindSlotResult::NeedsResize => {
return None;
}
}
}
}
}
#[inline(always)]
const fn ceil_log2_usize(x: usize) -> usize {
(usize::BITS - (x.wrapping_sub(1)).leading_zeros()) as usize
}
#[cfg(test)]
use rayon::prelude::*;
#[cfg(test)]
use std::collections::HashSet;
#[test]
fn basic_insert_and_get() {
let map = SeqMap::<u32, u32>::with_capacity(16);
for i in 0..16 {
map.insert(i, i * 10);
}
for i in 0..16 {
assert_eq!(map.get(i), Some(i * 10));
}
assert_eq!(map.get(100), None);
}
#[test]
fn concurrent_insert_and_get_unique() {
let map = SeqMap::<u32, u32>::with_capacity(1000);
(0..1000u32).into_par_iter().for_each(|i| {
map.insert(i, i * 2);
});
let mut seen = HashSet::new();
for i in 0..1000u32 {
let v = map.get(i).expect("Missing value");
assert!(seen.insert(v), "Duplicate value {v}");
assert_eq!(v, i * 2);
}
}
#[test]
fn concurrent_update_same_key_contention() {
let map = SeqMap::<u32, u32>::with_capacity(1);
(0..1000u32).into_par_iter().for_each(|i| {
map.insert(0, i);
});
let v = map.get(0).unwrap();
assert!((0..1000).contains(&v), "Unexpected value {v}");
}
#[test]
fn concurrent_resize_and_insert_grow() {
let map = SeqMap::<u32, u32>::with_capacity(4);
let n = 10000u32;
(0..n).into_par_iter().for_each(|i| {
map.insert(i, i + 1);
assert_eq!(map.get(i), Some(i + 1), "Value mismatch at index {i} (A)");
});
for i in 0..n {
assert_eq!(map.get(i), Some(i + 1), "Value mismatch at index {i} (B)");
}
assert_eq!(map.len(), n as usize, "Length mismatch after inserts");
}