use crate::hash::RobinHoodKey;
use crate::memory;
use core::marker::PhantomData;
use core::mem;
use core::ptr::{self, NonNull};
pub struct RawTable<K, V> {
base_ptr: NonNull<u8>,
pub(crate) meta_ptr: *mut u8,
pub(crate) keys_ptr: *mut K,
pub(crate) vals_ptr: *mut V,
pub(crate) capacity: usize,
size: usize,
mask: usize,
keys_offset: usize,
vals_offset: usize,
_marker: PhantomData<(K, V)>,
}
unsafe impl<K: Send, V: Send> Send for RawTable<K, V> {}
unsafe impl<K: Sync, V: Sync> Sync for RawTable<K, V> {}
impl<K: RobinHoodKey + PartialEq, V> RawTable<K, V> {
#[inline]
#[must_use]
pub fn new() -> Self {
Self::with_capacity(16)
}
#[must_use]
pub fn with_capacity(mut capacity: usize) -> Self {
if capacity == 0 {
capacity = 1;
}
let capacity = capacity.next_power_of_two();
let k_size = mem::size_of::<K>();
let v_size = mem::size_of::<V>();
let base = unsafe { memory::allocate_block(capacity, k_size, v_size) };
let (_, keys_offset, vals_offset) = memory::calc_layout(capacity, k_size, v_size);
let (meta_ptr, keys_ptr, vals_ptr) =
unsafe { memory::get_array_ptrs::<K, V>(base, keys_offset, vals_offset) };
Self {
base_ptr: base,
meta_ptr,
keys_ptr,
vals_ptr,
capacity,
size: 0,
mask: capacity - 1,
keys_offset,
vals_offset,
_marker: PhantomData,
}
}
#[inline]
pub(crate) fn size(&self) -> usize {
self.size
}
#[inline]
pub(crate) fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub(crate) fn is_empty(&self) -> bool {
self.size == 0
}
#[inline]
pub fn insert(&mut self, key: K, value: V) -> bool {
if self.size * 5 >= self.capacity * 4 {
self.resize();
}
let hash = key.hash_xxh3();
let mut idx = usize::try_from(hash).unwrap() & self.mask;
let mut dfh: u8 = 0;
let mut current_key = key;
let mut current_val = value;
unsafe {
loop {
let meta_val = *self.meta_ptr.add(idx);
if meta_val == 0 {
ptr::write(self.keys_ptr.add(idx), current_key);
ptr::write(self.vals_ptr.add(idx), current_val);
*self.meta_ptr.add(idx) = dfh + 1;
self.size += 1;
return true;
}
if meta_val < dfh + 1 {
let old_key = ptr::replace(self.keys_ptr.add(idx), current_key);
let old_val = ptr::replace(self.vals_ptr.add(idx), current_val);
current_key = old_key;
current_val = old_val;
*self.meta_ptr.add(idx) = dfh + 1;
dfh = meta_val - 1;
} else if meta_val == dfh + 1 && *self.keys_ptr.add(idx) == current_key {
drop(current_key);
drop(current_val);
return false;
}
idx = (idx + 1) & self.mask;
dfh = dfh.checked_add(1).expect(
"DfH overflow: probe chain exceeded 255. \
This indicates a hash distribution problem or adversarial input.",
);
}
}
}
#[inline]
pub fn get(&self, key: &K) -> Option<&V> {
let hash = key.hash_xxh3();
let mut idx = usize::try_from(hash).unwrap() & self.mask;
let mut dfh: u8 = 0;
unsafe {
loop {
let meta_val = *self.meta_ptr.add(idx);
if meta_val == 0 || meta_val < dfh.saturating_add(1) {
return None;
}
if *self.keys_ptr.add(idx) == *key {
return Some(&*self.vals_ptr.add(idx));
}
idx = (idx + 1) & self.mask;
dfh = dfh.checked_add(1).expect(
"DfH overflow: probe chain exceeded 255. \
This indicates a hash distribution problem or adversarial input.",
);
}
}
}
#[inline]
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
let hash = key.hash_xxh3();
let mut idx = usize::try_from(hash).unwrap() & self.mask;
let mut dfh: u8 = 0;
unsafe {
loop {
let meta_val = *self.meta_ptr.add(idx);
if meta_val == 0 || meta_val < dfh.saturating_add(1) {
return None;
}
if *self.keys_ptr.add(idx) == *key {
return Some(&mut *self.vals_ptr.add(idx));
}
idx = (idx + 1) & self.mask;
dfh = dfh.checked_add(1).expect(
"DfH overflow: probe chain exceeded 255. \
This indicates a hash distribution problem or adversarial input.",
);
}
}
}
#[inline]
pub fn remove(&mut self, key: &K) -> bool {
let hash = key.hash_xxh3();
let mut idx = usize::try_from(hash).unwrap() & self.mask;
let mut dfh: u8 = 0;
unsafe {
loop {
let meta_val = *self.meta_ptr.add(idx);
if meta_val == 0 || meta_val < dfh {
return false;
}
if *self.keys_ptr.add(idx) == *key {
break;
}
idx = (idx + 1) & self.mask;
dfh = dfh.checked_add(1).expect(
"DfH overflow: probe chain exceeded 255. \
This indicates a hash distribution problem or adversarial input.",
);
}
ptr::drop_in_place(self.keys_ptr.add(idx));
ptr::drop_in_place(self.vals_ptr.add(idx));
let mut shift_idx = idx;
loop {
let next = (shift_idx + 1) & self.mask;
let next_meta = *self.meta_ptr.add(next);
if next_meta <= 1 {
break;
}
ptr::copy(self.keys_ptr.add(next), self.keys_ptr.add(shift_idx), 1);
ptr::copy(self.vals_ptr.add(next), self.vals_ptr.add(shift_idx), 1);
*self.meta_ptr.add(shift_idx) = next_meta - 1;
shift_idx = next;
}
*self.meta_ptr.add(shift_idx) = 0;
self.size -= 1;
true
}
}
#[inline]
pub fn clear(&mut self) {
unsafe {
for i in 0..self.capacity {
if *self.meta_ptr.add(i) != 0 {
ptr::drop_in_place(self.keys_ptr.add(i));
ptr::drop_in_place(self.vals_ptr.add(i));
*self.meta_ptr.add(i) = 0;
}
}
}
self.size = 0;
}
pub fn reserve(&mut self, additional: usize) {
let target = self.size.saturating_add(additional);
if target * 5 >= self.capacity * 4 {
let new_cap = (target * 5 / 4).next_power_of_two();
if new_cap > self.capacity {
self.resize_to(new_cap);
}
}
}
fn resize(&mut self) {
self.resize_to(self.capacity * 2);
}
fn resize_to(&mut self, new_capacity: usize) {
let old_cap = self.capacity;
let old_base = self.base_ptr;
let old_meta = self.meta_ptr;
let old_keys = self.keys_ptr;
let old_vals = self.vals_ptr;
let k_size = mem::size_of::<K>();
let v_size = mem::size_of::<V>();
let new_base = unsafe { memory::allocate_block(new_capacity, k_size, v_size) };
let (new_keys_offset, new_vals_offset) = {
let (_, k_off, v_off) = memory::calc_layout(new_capacity, k_size, v_size);
(k_off, v_off)
};
let (new_meta, new_keys, new_vals) =
unsafe { memory::get_array_ptrs::<K, V>(new_base, new_keys_offset, new_vals_offset) };
let mut new_size = 0;
let new_mask = new_capacity - 1;
unsafe {
for i in 0..old_cap {
if *old_meta.add(i) != 0 {
let k = ptr::read(old_keys.add(i));
let v = ptr::read(old_vals.add(i));
let hash = k.hash_xxh3();
let mut idx = usize::try_from(hash).unwrap() & new_mask;
let mut dfh: u8 = 0;
let mut cur_k = k;
let mut cur_v = v;
loop {
let meta_val = *new_meta.add(idx);
if meta_val == 0 {
ptr::write(new_keys.add(idx), cur_k);
ptr::write(new_vals.add(idx), cur_v);
*new_meta.add(idx) = dfh + 1;
new_size += 1;
break;
}
if meta_val < dfh + 1 {
let tk = ptr::replace(new_keys.add(idx), cur_k);
let tv = ptr::replace(new_vals.add(idx), cur_v);
cur_k = tk;
cur_v = tv;
*new_meta.add(idx) = dfh + 1;
dfh = meta_val - 1;
}
idx = (idx + 1) & new_mask;
dfh = dfh.checked_add(1).expect(
"DfH overflow: probe chain exceeded 255. \
This indicates a hash distribution problem or adversarial input.",
);
}
}
}
}
debug_assert_eq!(new_size, self.size, "Migration lost elements!");
unsafe {
memory::deallocate_block(old_base, old_cap, k_size, v_size);
}
self.base_ptr = new_base;
self.meta_ptr = new_meta;
self.keys_ptr = new_keys;
self.vals_ptr = new_vals;
self.capacity = new_capacity;
self.mask = new_mask;
self.size = new_size;
self.keys_offset = new_keys_offset;
self.vals_offset = new_vals_offset;
}
#[inline]
pub(crate) fn decrement_size(&mut self) {
self.size = self.size.saturating_sub(1);
}
}
impl<K: RobinHoodKey + PartialEq, V> Default for RawTable<K, V> {
#[inline]
fn default() -> Self {
Self::new()
}
}
impl<K, V> Drop for RawTable<K, V> {
fn drop(&mut self) {
unsafe {
for i in 0..self.capacity {
if *self.meta_ptr.add(i) != 0 {
ptr::drop_in_place(self.keys_ptr.add(i));
ptr::drop_in_place(self.vals_ptr.add(i));
}
}
memory::deallocate_block(
self.base_ptr,
self.capacity,
mem::size_of::<K>(),
mem::size_of::<V>(),
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::String;
#[test]
fn test_raw_table_basic_ops() {
let mut table: RawTable<u32, String> = RawTable::new();
assert!(table.is_empty());
assert!(table.insert(1, String::from("one")));
assert_eq!(table.size(), 1);
assert_eq!(table.get(&1), Some(&String::from("one")));
assert!(!table.insert(1, String::from("updated"))); assert!(table.remove(&1));
assert_eq!(table.get(&1), None);
assert!(table.is_empty());
}
#[test]
fn test_raw_table_resize_chain() {
let mut table: RawTable<usize, usize> = RawTable::with_capacity(4);
for i in 0..20 {
assert!(table.insert(i, i * 10));
}
assert!(table.capacity() >= 25);
for i in 0..20 {
assert_eq!(table.get(&i), Some(&(i * 10)));
}
}
#[test]
fn test_raw_table_clear_retains_capacity() {
let mut table: RawTable<u64, u64> = RawTable::with_capacity(32);
for i in 0..10 {
table.insert(i, i);
}
let cap_before = table.capacity();
table.clear();
assert_eq!(table.size(), 0);
assert_eq!(table.capacity(), cap_before);
}
}