use bytemuck::{Pod, Zeroable};
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
const SENTINEL: u32 = 0;
#[derive(Copy, Clone)]
enum Register {
Bucket,
Next,
}
enum Field {
Size,
Capacity,
FreeListHead,
Sequence,
}
macro_rules! bucket_node {
( $array:expr, $index:expr ) => {
$array[$index as usize]
};
}
macro_rules! node {
( $array:expr, $index:expr ) => {
$array[($index - 1) as usize]
};
}
macro_rules! readonly_impl {
( $name:tt ) => {
impl<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> $name<'a, V> {
pub const fn data_len(capacity: usize) -> usize {
std::mem::size_of::<Allocator>() + (capacity * std::mem::size_of::<Node<V>>())
}
pub fn capacity(&self) -> usize {
self.allocator.get_field(Field::Capacity) as usize
}
pub fn size(&self) -> usize {
self.allocator.get_field(Field::Size) as usize
}
pub fn is_full(&self) -> bool {
self.allocator.get_field(Field::Size) >= self.allocator.get_field(Field::Capacity)
}
pub fn is_empty(&self) -> bool {
self.allocator.get_field(Field::Size) == 0
}
pub fn contains(&self, value: &V) -> bool {
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
let index = hasher.finish() as u32 % self.allocator.get_field(Field::Capacity);
let head = bucket_node!(self.nodes, index).get_register(Register::Bucket);
let mut current = head;
while current != SENTINEL {
let node = node!(self.nodes, current);
if &node.value == value {
return true;
}
current = node.get_register(Register::Next);
}
false
}
}
};
}
pub struct HashSet<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> {
allocator: &'a Allocator,
nodes: &'a [Node<V>],
}
readonly_impl!(HashSet);
impl<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> HashSet<'a, V> {
pub fn from_bytes(bytes: &'a [u8]) -> Self {
let (allocator, nodes) = bytes.split_at(std::mem::size_of::<Allocator>());
let allocator = bytemuck::from_bytes::<Allocator>(allocator);
let nodes = bytemuck::cast_slice(nodes);
Self { allocator, nodes }
}
pub fn iter(&self) -> HashSetIterator<'_, V> {
HashSetIterator::<V> {
hash_set: self,
bucket: SENTINEL,
node: SENTINEL,
}
}
}
pub struct HashSetIterator<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> {
hash_set: &'a HashSet<'a, V>,
bucket: u32,
node: u32,
}
impl<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> Iterator
for HashSetIterator<'a, V>
{
type Item = &'a V;
fn next(&mut self) -> Option<Self::Item> {
if self.bucket <= self.hash_set.capacity() as u32 {
while self.node == SENTINEL {
self.bucket += 1;
if self.bucket > self.hash_set.capacity() as u32 {
return None;
}
self.node = node!(self.hash_set.nodes, self.bucket).get_register(Register::Bucket);
}
let node = &node!(self.hash_set.nodes, self.node);
self.node = node.get_register(Register::Next);
Some(&node.value)
} else {
None
}
}
}
pub struct HashSetMut<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> {
allocator: &'a mut Allocator,
nodes: &'a mut [Node<V>],
}
readonly_impl!(HashSetMut);
impl<'a, V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> HashSetMut<'a, V> {
pub fn from_bytes_mut(bytes: &'a mut [u8]) -> Self {
let (allocator, nodes) = bytes.split_at_mut(std::mem::size_of::<Allocator>());
let allocator = bytemuck::from_bytes_mut::<Allocator>(allocator);
let nodes = bytemuck::cast_slice_mut(nodes);
Self { allocator, nodes }
}
pub fn initialize(&mut self, capacity: u32) {
self.allocator.initialize(capacity)
}
pub fn insert(&mut self, value: V) -> bool {
if self.size() == self.capacity() {
return false;
}
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
let index = hasher.finish() as u32 % self.allocator.get_field(Field::Capacity);
let head = bucket_node!(self.nodes, index).get_register(Register::Bucket);
let mut current = head;
while current != SENTINEL {
let node = node!(self.nodes, current);
if node.value == value {
return false;
}
current = node.get_register(Register::Next);
}
let node = self.add_node(value);
bucket_node!(self.nodes, index).set_register(Register::Bucket, node);
node!(self.nodes, node).set_register(Register::Next, head);
true
}
pub fn remove(&mut self, value: &V) -> bool {
if self.is_empty() {
return false;
}
let mut hasher = DefaultHasher::new();
value.hash(&mut hasher);
let index = hasher.finish() as u32 % self.allocator.get_field(Field::Capacity);
let head = bucket_node!(self.nodes, index).get_register(Register::Bucket);
let mut current = head;
let mut previous = SENTINEL;
while current != SENTINEL {
let node = node!(self.nodes, current);
if &node.value == value {
if previous == SENTINEL {
bucket_node!(self.nodes, index).set_register(Register::Bucket, SENTINEL);
} else {
node!(self.nodes, previous)
.set_register(Register::Next, node.get_register(Register::Next));
}
return self.remove_node(current).is_some();
}
previous = current;
current = node.get_register(Register::Next);
}
false
}
fn add_node(&mut self, value: V) -> u32 {
let free_node = self.allocator.get_field(Field::FreeListHead);
let sequence = self.allocator.get_field(Field::Sequence);
if free_node == sequence {
if (sequence - 1) == self.allocator.get_field(Field::Capacity) {
panic!(
"set is full ({} nodes)",
self.allocator.get_field(Field::Size)
);
}
self.allocator.set_field(Field::Sequence, sequence + 1);
self.allocator.set_field(Field::FreeListHead, sequence + 1);
} else {
self.allocator.set_field(
Field::FreeListHead,
node!(self.nodes, free_node).get_register(Register::Next),
);
}
let entry = &mut node!(self.nodes, free_node);
entry.value = value;
entry.set_register(Register::Next, SENTINEL);
self.allocator
.set_field(Field::Size, self.allocator.get_field(Field::Size) + 1);
free_node
}
fn remove_node(&mut self, index: u32) -> Option<V> {
if index == SENTINEL {
return None;
}
let node = &mut node!(self.nodes, index);
let value = node.value;
node.value = V::default();
let free_list_head = self.allocator.get_field(Field::FreeListHead);
node.set_register(Register::Next, free_list_head);
self.allocator.set_field(Field::FreeListHead, index);
self.allocator
.set_field(Field::Size, self.allocator.get_field(Field::Size) - 1);
Some(value)
}
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
pub struct Allocator {
fields: [u32; 4],
}
impl Allocator {
pub fn initialize(&mut self, capacity: u32) {
self.fields = [0, capacity, 1, 1];
}
#[inline(always)]
fn get_field(&self, field: Field) -> u32 {
self.fields[field as usize]
}
#[inline(always)]
fn set_field(&mut self, field: Field, value: u32) {
self.fields[field as usize] = value;
}
}
#[repr(C)]
#[derive(Clone, Copy, Default)]
pub struct Node<V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> {
registers: [u32; 2],
value: V,
}
impl<V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> Node<V> {
#[inline(always)]
fn get_register(&self, register: Register) -> u32 {
self.registers[register as usize]
}
#[inline(always)]
fn set_register(&mut self, register: Register, value: u32) {
self.registers[register as usize] = value;
}
}
unsafe impl<V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> Zeroable for Node<V> {}
unsafe impl<V: Default + Copy + Clone + Hash + PartialEq + Pod + Zeroable> Pod for Node<V> {}
#[cfg(test)]
mod tests {
use crate::collections::HashSetMut;
#[test]
fn test_insert() {
const CAPACITY: usize = 10;
let mut data = [0u8; HashSetMut::<u64>::data_len(CAPACITY)];
let mut set = HashSetMut::<u64>::from_bytes_mut(&mut data);
set.allocator.initialize(CAPACITY as u32);
assert_eq!(set.capacity(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.contains(&value));
}
}
#[test]
fn test_large_insert() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; HashSetMut::<u64>::data_len(CAPACITY)];
let mut set = HashSetMut::<u64>::from_bytes_mut(&mut data);
set.allocator.initialize(CAPACITY as u32);
assert_eq!(set.capacity(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.contains(&value));
}
}
#[test]
fn test_large_remove() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; HashSetMut::<u64>::data_len(CAPACITY)];
let mut set = HashSetMut::<u64>::from_bytes_mut(&mut data);
set.allocator.initialize(CAPACITY as u32);
assert_eq!(set.capacity(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.remove(&value));
}
assert_eq!(set.size(), 0);
}
#[test]
fn test_large_remove_insert() {
const CAPACITY: usize = 10_000;
let mut data = [0u8; HashSetMut::<u64>::data_len(CAPACITY)];
let mut set = HashSetMut::<u64>::from_bytes_mut(&mut data);
set.allocator.initialize(CAPACITY as u32);
assert_eq!(set.capacity(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.remove(&value));
}
assert_eq!(set.size(), 0);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.contains(&value));
}
}
#[test]
fn test_insert_when_full() {
const CAPACITY: usize = 10;
let mut data = [0u8; HashSetMut::<u64>::data_len(CAPACITY)];
let mut set = HashSetMut::<u64>::from_bytes_mut(&mut data);
set.allocator.initialize(CAPACITY as u32);
assert_eq!(set.capacity(), CAPACITY);
for i in 0..CAPACITY {
let value = (i + 1) as u64;
assert!(set.insert(value));
}
assert_eq!(set.size(), CAPACITY);
assert!(set.is_full());
assert!(!set.insert(10));
assert!(set.remove(&1));
assert!(set.insert(11));
assert!(set.is_full());
assert!(!set.insert(20));
}
}