use ::alloc::alloc::{self, alloc_zeroed, handle_alloc_error};
use ::alloc::boxed::Box;
use ::alloc::vec::{self, Vec};
use core::cmp;
use core::fmt::{self, Debug, Formatter};
use core::hint::unreachable_unchecked;
use core::iter::FusedIterator;
use core::mem::size_of;
use core::num::NonZeroUsize;
use core::panic::{RefUnwindSafe, UnwindSafe};
use core::ptr;
use core::slice;
use core::sync::atomic;
use crate::loom::atomic::AtomicPtr;
use crate::loom::AtomicMut as _;
pub struct Buckets<T, const BUCKETS: usize> {
buckets: [AtomicPtr<T>; BUCKETS],
}
unsafe impl<T: Send, const BUCKETS: usize> Send for Buckets<T, BUCKETS> {}
unsafe impl<T: Sync, const BUCKETS: usize> Sync for Buckets<T, BUCKETS> {}
impl<T: UnwindSafe, const BUCKETS: usize> UnwindSafe for Buckets<T, BUCKETS> {}
impl<T: RefUnwindSafe, const BUCKETS: usize> RefUnwindSafe for Buckets<T, BUCKETS> {}
impl<T, const BUCKETS: usize> Buckets<T, BUCKETS> {
#[cfg(not(loom))]
#[allow(clippy::declare_interior_mutable_const)]
const NULL_PTR: AtomicPtr<T> = AtomicPtr::new(ptr::null_mut());
#[cfg(not(loom))]
pub const fn new() -> Self {
Self {
buckets: [Self::NULL_PTR; BUCKETS],
}
}
#[cfg(loom)]
pub fn new() -> Self {
Self {
buckets: [(); BUCKETS].map(|_| AtomicPtr::new(ptr::null_mut())),
}
}
fn bucket(&self, i: BucketIndex<BUCKETS>) -> &AtomicPtr<T> {
unsafe { self.buckets.get_unchecked(i.0) }
}
fn bucket_mut(&mut self, i: BucketIndex<BUCKETS>) -> &mut AtomicPtr<T> {
unsafe { self.buckets.get_unchecked_mut(i.0) }
}
fn take_bucket(&mut self, i: BucketIndex<BUCKETS>) -> Option<Box<[T]>> {
let bucket = self.bucket_mut(i);
let ptr = bucket.read_mut();
bucket.write_mut(ptr::null_mut());
ptr::NonNull::new(ptr)?;
Some(unsafe { Box::from_raw(ptr::slice_from_raw_parts_mut(ptr, i.len().get())) })
}
pub fn get(&self, index: Index<BUCKETS>) -> Option<&T> {
let location = index.location();
let bucket = self.bucket(location.bucket);
let ptr = bucket.load(atomic::Ordering::Acquire);
if ptr.is_null() {
return None;
}
Some(unsafe { &*ptr.add(location.entry) })
}
pub fn get_mut(&mut self, index: Index<BUCKETS>) -> Option<&mut T> {
let location = index.location();
let bucket = self.bucket_mut(location.bucket);
let ptr = bucket.read_mut();
if ptr.is_null() {
return None;
}
Some(unsafe { &mut *ptr.add(location.entry) })
}
pub unsafe fn get_unchecked(&self, index: Index<BUCKETS>) -> &T {
let location = index.location();
let bucket = self.bucket(location.bucket);
let ptr = bucket.load(atomic::Ordering::Relaxed);
unsafe { &*ptr.add(location.entry) }
}
pub unsafe fn get_unchecked_mut(&mut self, index: Index<BUCKETS>) -> &mut T {
let location = index.location();
let bucket = self.bucket_mut(location.bucket);
let ptr = bucket.read_mut();
unsafe { &mut *ptr.add(location.entry) }
}
pub fn get_or_alloc(&self, index: Index<BUCKETS>) -> &T
where
T: MaybeZeroable,
{
let location = index.location();
if location.entry == (location.bucket_len.get() - (location.bucket_len.get() >> 3)) {
self.alloc_bucket_after(index);
}
let bucket = self.bucket(location.bucket);
let mut ptr = bucket.load(atomic::Ordering::Acquire) as *const T;
if ptr.is_null() {
ptr = allocate_race_and_get(bucket, location.bucket_len);
}
unsafe { &*ptr.add(location.entry) }
}
#[cold]
#[inline(never)]
fn alloc_bucket_after(&self, index: Index<BUCKETS>)
where
T: MaybeZeroable,
{
if let Some(new_index) = index.after_bucket().advance() {
allocate_race(self.bucket(new_index), new_index.len());
}
}
pub fn get_or_alloc_mut(&mut self, index: Index<BUCKETS>) -> &mut T
where
T: MaybeZeroable,
{
let location = index.location();
let bucket = self.bucket_mut(location.bucket);
let mut ptr = bucket.read_mut();
if ptr.is_null() {
ptr = Box::into_raw(allocate_slice::<T>(location.bucket_len)).cast::<T>();
bucket.write_mut(ptr);
}
unsafe { &mut *ptr.add(location.entry) }
}
pub fn reserve(&self, index: Index<BUCKETS>)
where
T: MaybeZeroable,
{
let mut cursor = index.after_bucket();
while let Some(index) = cursor.retreat() {
let bucket = self.bucket(index);
if !bucket.load(atomic::Ordering::Relaxed).is_null() {
break;
}
allocate_race(bucket, index.len());
}
}
pub fn reserve_mut(&mut self, index: Index<BUCKETS>)
where
T: MaybeZeroable,
{
let mut cursor = index.after_bucket();
while let Some(index) = cursor.retreat() {
let bucket = self.bucket_mut(index);
if !bucket.read_mut().is_null() {
break;
}
let ptr = Box::into_raw(allocate_slice::<T>(index.len()));
bucket.write_mut(ptr.cast::<T>());
}
}
pub fn truncate(&mut self, n: Index<BUCKETS>) {
let mut cursor = n.after_lower_buckets();
while let Some(bucket) = cursor.advance() {
self.take_bucket(bucket);
}
}
pub fn iter(&self) -> Iter<'_, T, BUCKETS> {
self.into_iter()
}
pub fn iter_mut(&mut self) -> IterMut<'_, T, BUCKETS> {
self.into_iter()
}
}
impl<T, const BUCKETS: usize> Drop for Buckets<T, BUCKETS> {
fn drop(&mut self) {
self.truncate(Index::new(0).unwrap());
}
}
impl<T, const BUCKETS: usize> Default for Buckets<T, BUCKETS> {
fn default() -> Self {
Self::new()
}
}
impl<T: Debug, const BUCKETS: usize> Debug for Buckets<T, BUCKETS> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_map().entries(self.iter()).finish()
}
}
pub unsafe trait MaybeZeroable: Default {
fn zeroable() -> bool;
}
unsafe impl MaybeZeroable for u8 {
fn zeroable() -> bool {
true
}
}
unsafe impl MaybeZeroable for u16 {
fn zeroable() -> bool {
true
}
}
#[cold]
#[inline(never)]
#[must_use]
fn allocate_race_and_get<T: MaybeZeroable>(bucket: &AtomicPtr<T>, len: NonZeroUsize) -> *const T {
let ptr = Box::into_raw(allocate_slice::<T>(len));
match bucket.compare_exchange(
ptr::null_mut(),
ptr.cast::<T>(),
atomic::Ordering::Release,
atomic::Ordering::Acquire,
) {
Ok(_) => ptr.cast::<T>(),
Err(new_ptr) => {
drop(unsafe { Box::from_raw(ptr) });
new_ptr
}
}
}
fn allocate_race<T: MaybeZeroable>(bucket: &AtomicPtr<T>, len: NonZeroUsize) {
let ptr = Box::into_raw(allocate_slice::<T>(len));
match bucket.compare_exchange(
ptr::null_mut(),
ptr.cast::<T>(),
atomic::Ordering::Release,
atomic::Ordering::Relaxed,
) {
Ok(_) => {}
Err(_) => drop(unsafe { Box::from_raw(ptr) }),
}
}
fn allocate_slice<T: MaybeZeroable>(len: NonZeroUsize) -> Box<[T]> {
if size_of::<T>() == 0 {
return Box::new([]);
}
if T::zeroable() {
let layout = alloc::Layout::array::<T>(len.get()).unwrap();
let ptr = unsafe { alloc_zeroed(layout) }.cast::<T>();
if ptr.is_null() {
handle_alloc_error(layout);
}
unsafe { Box::from_raw(ptr::slice_from_raw_parts_mut(ptr, len.get())) }
} else {
let mut vec = Vec::new();
vec.resize_with(len.get(), T::default);
vec.into_boxed_slice()
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct Index<const BUCKETS: usize> {
inner: NonZeroUsize,
}
const SKIPPED_BUCKETS: usize = 5;
const SKIPPED_ENTRIES: usize = 2_usize.pow(SKIPPED_BUCKETS as u32) - 1;
impl<const BUCKETS: usize> Index<BUCKETS> {
const ENTRIES_WITH_SKIPPED: usize = {
let mut total = 0;
let mut i = 0;
while i < BUCKETS + SKIPPED_BUCKETS {
total += 2_usize.pow(i as u32);
i += 1;
}
total
};
const ENTRIES: usize = Self::ENTRIES_WITH_SKIPPED - SKIPPED_ENTRIES;
pub const fn new(i: usize) -> Option<Self> {
if i < Self::ENTRIES {
Some(unsafe { Self::new_unchecked(i) })
} else {
None
}
}
pub const unsafe fn new_unchecked(i: usize) -> Self {
if i >= Self::ENTRIES {
unsafe { unreachable_unchecked() };
}
let Some(inner) = i.checked_add(SKIPPED_ENTRIES + 1) else {
unreachable!()
};
let Some(inner) = NonZeroUsize::new(inner) else {
unreachable!()
};
Self { inner }
}
pub fn new_saturating(i: usize) -> Self {
Self::new(cmp::min(i, Self::ENTRIES - 1)).unwrap()
}
pub const fn get(self) -> usize {
self.inner.get() - (SKIPPED_ENTRIES + 1)
}
pub const fn into_raw(self) -> NonZeroUsize {
debug_assert!(SKIPPED_ENTRIES < self.inner.get());
debug_assert!(self.inner.get() <= Self::ENTRIES_WITH_SKIPPED);
self.inner
}
pub const unsafe fn from_raw_unchecked(inner: usize) -> Self {
debug_assert!(SKIPPED_ENTRIES < inner);
debug_assert!(inner <= Self::ENTRIES_WITH_SKIPPED);
let inner = unsafe { NonZeroUsize::new_unchecked(inner) };
Self { inner }
}
pub const unsafe fn from_raw_checked_above(inner: usize) -> Option<Self> {
if inner <= Self::ENTRIES_WITH_SKIPPED {
Some(unsafe { Self::from_raw_unchecked(inner) })
} else {
None
}
}
pub const fn from_raw_checked(inner: usize) -> Option<Self> {
if SKIPPED_ENTRIES < inner {
unsafe { Self::from_raw_checked_above(inner) }
} else {
None
}
}
fn location(self) -> Location<BUCKETS> {
let b = self.inner.ilog2();
let b_usize = usize::try_from(b).unwrap();
let bucket_len = NonZeroUsize::new(1_usize.checked_shl(b).unwrap()).unwrap();
let entry = self.inner.get() - bucket_len.get();
let bucket = BucketIndex(b_usize - SKIPPED_BUCKETS);
Location {
bucket,
bucket_len,
entry,
}
}
pub fn is_first_in_bucket(self) -> bool {
self.location().entry == 0
}
}
impl<const BUCKETS: usize> Debug for Index<BUCKETS> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
<usize as Debug>::fmt(&self.into_raw().get(), f)
}
}
#[derive(Clone)]
struct Location<const BUCKETS: usize> {
bucket: BucketIndex<BUCKETS>,
bucket_len: NonZeroUsize,
entry: usize,
}
#[derive(Clone, Copy)]
struct BucketIndex<const BUCKETS: usize>(usize);
impl<const BUCKETS: usize> BucketIndex<BUCKETS> {
fn first(self) -> Index<BUCKETS> {
if self.0 >= BUCKETS {
unsafe { unreachable_unchecked() };
}
let b = self.0.checked_add(SKIPPED_BUCKETS).unwrap();
let inner = 1_usize.checked_shl(b.try_into().unwrap()).unwrap();
Index::from_raw_checked(inner).unwrap()
}
fn len(self) -> NonZeroUsize {
self.first().into_raw()
}
}
pub const fn buckets_for_index_bits(bits: u32) -> usize {
assert!(bits <= usize::BITS);
(bits as usize).saturating_sub(SKIPPED_BUCKETS)
}
impl<const BUCKETS: usize> Index<BUCKETS> {
fn after_bucket(self) -> BucketCursor<BUCKETS> {
let location = self.location();
BucketCursor(location.bucket.0 + 1)
}
fn after_lower_buckets(self) -> BucketCursor<BUCKETS> {
let inner_minus_one = self.into_raw().get().checked_sub(1).unwrap();
let bucket = usize::try_from(inner_minus_one.ilog2()).unwrap() + 1 - SKIPPED_BUCKETS;
BucketCursor(bucket)
}
}
#[derive(Clone, Default)]
struct BucketCursor<const BUCKETS: usize>(usize);
impl<const BUCKETS: usize> BucketCursor<BUCKETS> {
fn advance(&mut self) -> Option<BucketIndex<BUCKETS>> {
if self.0 >= BUCKETS {
return None;
}
let index = BucketIndex(self.0);
self.0 += 1;
Some(index)
}
fn retreat(&mut self) -> Option<BucketIndex<BUCKETS>> {
self.0 = self.0.checked_sub(1)?;
Some(BucketIndex(self.0))
}
}
impl<'a, T, const BUCKETS: usize> IntoIterator for &'a Buckets<T, BUCKETS> {
type Item = (Index<BUCKETS>, &'a T);
type IntoIter = Iter<'a, T, BUCKETS>;
fn into_iter(self) -> Self::IntoIter {
Iter {
buckets: self,
bucket: BucketCursor::default(),
iter: [].iter(),
index: 0,
}
}
}
#[must_use]
pub struct Iter<'a, T, const BUCKETS: usize> {
buckets: &'a Buckets<T, BUCKETS>,
bucket: BucketCursor<BUCKETS>,
iter: slice::Iter<'a, T>,
index: usize,
}
impl<'a, T, const BUCKETS: usize> Iterator for Iter<'a, T, BUCKETS> {
type Item = (Index<BUCKETS>, &'a T);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(item) = self.iter.next() {
let index = unsafe { Index::from_raw_unchecked(self.index) };
self.index = self.index.wrapping_add(1);
return Some((index, item));
}
let bucket_index = self.bucket.advance()?;
let bucket = self.buckets.bucket(bucket_index);
let ptr = bucket.load(atomic::Ordering::Acquire);
if !ptr.is_null() {
let slice = unsafe { slice::from_raw_parts(ptr, bucket_index.len().get()) };
self.iter = slice.iter();
self.index = bucket_index.first().into_raw().get();
}
}
}
}
impl<T, const BUCKETS: usize> FusedIterator for Iter<'_, T, BUCKETS> {}
impl<T, const BUCKETS: usize> Clone for Iter<'_, T, BUCKETS> {
fn clone(&self) -> Self {
Self {
buckets: self.buckets,
bucket: self.bucket.clone(),
iter: self.iter.clone(),
index: self.index,
}
}
}
impl<'a, T, const BUCKETS: usize> IntoIterator for &'a mut Buckets<T, BUCKETS> {
type Item = (Index<BUCKETS>, &'a mut T);
type IntoIter = IterMut<'a, T, BUCKETS>;
fn into_iter(self) -> Self::IntoIter {
IterMut {
buckets: self,
bucket: BucketCursor::default(),
iter: [].iter_mut(),
index: 0,
}
}
}
#[must_use]
pub struct IterMut<'a, T, const BUCKETS: usize> {
buckets: &'a mut Buckets<T, BUCKETS>,
bucket: BucketCursor<BUCKETS>,
iter: slice::IterMut<'a, T>,
index: usize,
}
impl<'a, T, const BUCKETS: usize> Iterator for IterMut<'a, T, BUCKETS> {
type Item = (Index<BUCKETS>, &'a mut T);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(item) = self.iter.next() {
let index = unsafe { Index::from_raw_unchecked(self.index) };
self.index = self.index.wrapping_add(1);
return Some((index, item));
}
let bucket_index = self.bucket.advance()?;
let bucket = self.buckets.bucket_mut(bucket_index);
let ptr = bucket.read_mut();
if !ptr.is_null() {
let slice = unsafe { slice::from_raw_parts_mut(ptr, bucket_index.len().get()) };
self.iter = slice.iter_mut();
self.index = bucket_index.first().into_raw().get();
}
}
}
}
impl<T, const BUCKETS: usize> FusedIterator for IterMut<'_, T, BUCKETS> {}
impl<T, const BUCKETS: usize> IntoIterator for Buckets<T, BUCKETS> {
type Item = (Index<BUCKETS>, T);
type IntoIter = IntoIter<T, BUCKETS>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
buckets: self,
bucket: BucketCursor::default(),
iter: Vec::new().into_iter(),
index: 0,
}
}
}
#[must_use]
pub struct IntoIter<T, const BUCKETS: usize> {
buckets: Buckets<T, BUCKETS>,
bucket: BucketCursor<BUCKETS>,
iter: vec::IntoIter<T>,
index: usize,
}
impl<T, const BUCKETS: usize> Iterator for IntoIter<T, BUCKETS> {
type Item = (Index<BUCKETS>, T);
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(item) = self.iter.next() {
let index = unsafe { Index::from_raw_unchecked(self.index) };
self.index = self.index.wrapping_add(1);
return Some((index, item));
}
let bucket_index = self.bucket.advance()?;
if let Some(bucket) = self.buckets.take_bucket(bucket_index) {
self.iter = Vec::from(bucket).into_iter();
self.index = bucket_index.first().into_raw().get();
}
}
}
}
impl<T, const BUCKETS: usize> FusedIterator for IntoIter<T, BUCKETS> {}
#[cfg(test)]
mod tests {
use super::buckets_for_index_bits;
use super::Buckets;
use super::Index;
use super::MaybeZeroable;
use crate::buckets::SKIPPED_BUCKETS;
use crate::buckets::SKIPPED_ENTRIES;
use alloc::vec::Vec;
use core::cell::Cell;
std::thread_local!(static COUNTER: Cell<usize> = const { Cell::new(0) });
#[derive(Default)]
struct Helper;
unsafe impl MaybeZeroable for Helper {
fn zeroable() -> bool {
true
}
}
impl Drop for Helper {
fn drop(&mut self) {
COUNTER.with(|c| c.set(c.get() + 1));
}
}
fn drops_with(f: impl FnOnce()) -> usize {
COUNTER.with(|c| c.set(0));
f();
COUNTER.with(|c| c.get())
}
fn drops<const BUCKETS: usize>(buckets: Buckets<Helper, BUCKETS>) -> usize {
drops_with(|| drop(buckets))
}
#[test]
fn new() {
assert_eq!(drops(<Buckets<_, 1>>::new()), 0);
assert_eq!(
drops(<Buckets<_, { buckets_for_index_bits(usize::BITS) }>>::new()),
0
);
}
#[test]
fn reserve() {
let buckets = <Buckets<Helper, 8>>::new();
buckets.reserve(Index::new(0).unwrap());
assert_eq!(drops(buckets), SKIPPED_ENTRIES + 1);
let buckets = <Buckets<Helper, 8>>::new();
buckets.reserve(Index::new(SKIPPED_ENTRIES).unwrap());
assert_eq!(drops(buckets), SKIPPED_ENTRIES + 1);
let buckets = <Buckets<Helper, 8>>::new();
buckets.reserve(Index::new(SKIPPED_ENTRIES + 1).unwrap());
assert_eq!(drops(buckets), (SKIPPED_ENTRIES + 1) * 3);
let buckets = <Buckets<Helper, 5>>::new();
let total = (1 << (5 + SKIPPED_BUCKETS)) - SKIPPED_ENTRIES - 1;
assert_eq!(<Index<5>>::new(total), None);
buckets.reserve(Index::new(total - 1).unwrap());
assert_eq!(drops(buckets), total);
}
#[test]
fn truncate_exact() {
let mut buckets = <Buckets<Helper, 5>>::new();
let first_in_second_bucket = Index::new(SKIPPED_ENTRIES + 1).unwrap();
buckets.reserve(first_in_second_bucket);
assert_eq!(
drops_with(|| buckets.truncate(Index::new(SKIPPED_ENTRIES + 2).unwrap())),
0,
);
assert_eq!(
drops_with(|| buckets.truncate(first_in_second_bucket)),
(SKIPPED_ENTRIES + 1) * 2
);
assert!(buckets.get(first_in_second_bucket).is_none());
assert_eq!(drops(buckets), SKIPPED_ENTRIES + 1);
}
#[test]
fn get_or_alloc() {
let buckets = <Buckets<Helper, 8>>::new();
buckets.get_or_alloc(Index::new(48).unwrap());
assert_eq!(drops(buckets), 64);
}
#[test]
fn iter() {
let mut buckets = <Buckets<u16, 4>>::new();
buckets.get_or_alloc(Index::new(48).unwrap());
buckets.get_or_alloc(Index::new(225).unwrap());
let indices = buckets.iter().map(|(i, _)| i).collect::<Vec<_>>();
assert_eq!(indices.len(), 320);
for &i in &indices {
assert_eq!(*buckets.get(i).unwrap(), 0);
}
for (i, (index, val)) in buckets.iter_mut().enumerate() {
assert_eq!(indices[i], index);
*val = i as u16;
}
for (i, (index, val)) in buckets.into_iter().enumerate() {
assert_eq!(indices[i], index);
assert_eq!(val, i as u16);
}
}
#[test]
fn location() {
let index = <Index<12>>::new(0).unwrap();
assert_eq!(index.location().bucket.0, 0);
assert_eq!(index.location().bucket_len.get(), 32);
assert_eq!(index.location().entry, 0);
let index = <Index<12>>::new(31).unwrap();
assert_eq!(index.location().bucket.0, 0);
assert_eq!(index.location().bucket_len.get(), 32);
assert_eq!(index.location().entry, 31);
let index = <Index<12>>::new(34).unwrap();
assert_eq!(index.location().bucket.0, 1);
assert_eq!(index.location().bucket_len.get(), 64);
assert_eq!(index.location().entry, 2);
let max = usize::MAX - SKIPPED_ENTRIES - 1;
assert_eq!(
<Index<{ buckets_for_index_bits(usize::BITS) }>>::new(max + 1),
None
);
let index = <Index<{ buckets_for_index_bits(usize::BITS) }>>::new(max).unwrap();
assert_eq!(
index.location().bucket.0,
(usize::BITS as usize) - SKIPPED_BUCKETS - 1
);
assert_eq!(index.location().bucket_len.get(), usize::MAX / 2 + 1);
assert_eq!(index.location().entry, usize::MAX / 2);
}
}