use super::SetError;
use super::SetHasher as Hasher;
use core::hint::unreachable_unchecked;
use core::{fmt, mem, ptr};
use hashbrown::HashTable;
use libdd_alloc::{Allocator, ChainAllocator, VirtualAllocator};
use std::ffi::c_void;
use std::hash::{BuildHasher, Hash};
pub const SET_MIN_CAPACITY: usize = 14;
#[repr(transparent)]
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct SetId<T>(pub(crate) ptr::NonNull<T>);
impl<T> SetId<T> {
#[inline]
#[must_use]
pub fn cast<U>(self) -> SetId<U> {
SetId(self.0.cast())
}
pub fn into_raw(self) -> ptr::NonNull<c_void> {
self.0.cast()
}
pub unsafe fn from_raw(raw: ptr::NonNull<c_void>) -> Self {
Self(raw.cast::<T>())
}
}
impl<T> Clone for SetId<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for SetId<T> {}
impl<T: Hash + Eq + 'static> fmt::Debug for Set<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Set").field("table", &self.table).finish()
}
}
pub struct Set<T: Hash + Eq + 'static> {
pub(crate) arena: ChainAllocator<VirtualAllocator>,
pub(crate) table: HashTable<ptr::NonNull<T>>,
}
impl<T: Eq + Hash + 'static> Set<T> {
const SIZE_HINT: usize = 1024 * 1024;
pub fn try_new() -> Result<Self, SetError> {
Self::try_with_capacity(SET_MIN_CAPACITY)
}
#[inline]
pub(crate) fn allocate_one(&mut self, value: T) -> Result<ptr::NonNull<T>, SetError> {
let layout = core::alloc::Layout::new::<T>();
let obj = self.arena.allocate(layout)?; let raw_slice_ptr: *mut [u8] = obj.as_ptr();
let raw = raw_slice_ptr as *mut u8 as *mut T;
unsafe { ptr::write(raw, value) };
Ok(unsafe { ptr::NonNull::new_unchecked(raw) })
}
pub fn try_insert(&mut self, value: T) -> Result<SetId<T>, SetError> {
let hash = Hasher::default().hash_one(&value);
if let Some(existing) = unsafe { self.find_with_hash(hash, &value) } {
return Ok(existing);
}
unsafe { self.insert_unique_uncontended_with_hash(hash, value) }
}
pub fn len(&self) -> usize {
self.table.len()
}
pub fn is_empty(&self) -> bool {
self.table.is_empty()
}
pub fn capacity(&self) -> usize {
self.table.capacity()
}
pub fn find(&self, value: &T) -> Option<SetId<T>> {
let hash = Hasher::default().hash_one(value);
unsafe { self.find_with_hash(hash, value) }
}
pub unsafe fn get(&self, id: SetId<T>) -> &T {
unsafe { id.0.as_ref() }
}
}
impl<T: Hash + Eq + 'static> Drop for Set<T> {
fn drop(&mut self) {
if mem::needs_drop::<T>() {
for nn in self.table.iter() {
unsafe { ptr::drop_in_place(nn.as_ptr()) };
}
}
}
}
impl<T: Hash + Eq + 'static> Set<T> {
fn try_with_capacity(capacity: usize) -> Result<Self, SetError> {
let arena = ChainAllocator::new_in(Self::SIZE_HINT, VirtualAllocator {});
let mut table = HashTable::new();
table.try_reserve(capacity, |_| unsafe { unreachable_unchecked() })?;
Ok(Self { arena, table })
}
unsafe fn find_with_hash(&self, hash: u64, key: &T) -> Option<SetId<T>> {
let found = self
.table
.find(hash, |nn| unsafe { nn.as_ref() == key })?;
Some(SetId(*found))
}
unsafe fn insert_unique_uncontended_with_hash(
&mut self,
hash: u64,
value: T,
) -> Result<SetId<T>, SetError> {
self.table
.try_reserve(1, |nnv| Hasher::default().hash_one(unsafe { nnv.as_ref() }))?;
let nn = self.allocate_one(value)?;
self.table
.insert_unique(hash, nn, |_| unsafe { unreachable_unchecked() });
Ok(SetId(nn))
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::collections::HashSet as StdHashSet;
use std::sync::{Arc, Weak};
proptest! {
#![proptest_config(ProptestConfig {
cases: if cfg!(miri) { 4 } else { 64 },
.. ProptestConfig::default()
})]
#[test]
fn proptest_matches_std_hashset(values in proptest::collection::vec(any::<u64>(), 0..if cfg!(miri) { 32 } else { 512 })) {
let mut set = Set::<u64>::try_new().unwrap();
let mut shadow = StdHashSet::<u64>::new();
for v in &values {
shadow.insert(*v);
let _ = set.try_insert(*v).unwrap();
}
prop_assert_eq!(set.len(), shadow.len());
for &v in &shadow {
let id = set.find(&v).unwrap();
let fetched = unsafe { set.get(id) };
prop_assert_eq!(*fetched, v);
}
}
}
#[test]
fn set_drops_elements_on_drop() {
let mut set = Set::<Arc<u64>>::try_new().unwrap();
let mut weaks: Vec<Weak<u64>> = Vec::new();
let total = if cfg!(miri) { 8 } else { 64 };
for i in 0..total {
let arc = Arc::new(i as u64);
weaks.push(Arc::downgrade(&arc));
let _ = set.try_insert(arc).unwrap();
}
drop(set);
for (idx, w) in weaks.iter().enumerate() {
assert!(w.upgrade().is_none(), "weak at {idx} still alive");
}
}
}