#![deny(missing_docs)]
mod lazy_atomic_cell;
cfg_if::cfg_if! {
if #[cfg(unix)] {
mod pthread_mutex;
use pthread_mutex::PthreadMutex as Mutex;
} else if #[cfg(windows)] {
mod windows_mutex;
use windows_mutex::WindowsMutex as Mutex;
} else {
compile_error!("no mutex implementation for this platform");
}
}
#[doc(hidden)]
pub use lazy_atomic_cell::LazyAtomicCell;
use mem::MaybeUninit;
use rand::{rngs::StdRng, Rng, SeedableRng};
use std::{
alloc::{handle_alloc_error, GlobalAlloc, Layout},
mem, ptr,
sync::atomic::{AtomicPtr, Ordering},
};
const SHUFFLING_ARRAY_SIZE: usize = 256;
struct ShufflingArray<A>
where
A: 'static + GlobalAlloc,
{
elems: [AtomicPtr<u8>; SHUFFLING_ARRAY_SIZE],
size_class: usize,
allocator: &'static A,
}
impl<A> Drop for ShufflingArray<A>
where
A: 'static + GlobalAlloc,
{
fn drop(&mut self) {
let layout =
unsafe { Layout::from_size_align_unchecked(self.size_class, mem::align_of::<usize>()) };
for el in &self.elems {
let p = el.swap(ptr::null_mut(), Ordering::SeqCst);
if !p.is_null() {
unsafe {
self.allocator.dealloc(p, layout);
}
}
}
}
}
impl<A> ShufflingArray<A>
where
A: 'static + GlobalAlloc,
{
fn new(size_class: usize, allocator: &'static A) -> Self {
let elems = unsafe {
let mut elems = MaybeUninit::<[AtomicPtr<u8>; 256]>::uninit();
let elems_ptr: *mut [AtomicPtr<u8>; 256] = elems.as_mut_ptr();
let elems_ptr: *mut AtomicPtr<u8> = elems_ptr.cast();
let layout = Layout::from_size_align_unchecked(size_class, mem::align_of::<usize>());
for i in 0..256 {
let p = allocator.alloc(layout);
if p.is_null() {
handle_alloc_error(layout);
}
ptr::write(elems_ptr.offset(i), AtomicPtr::new(p));
}
elems.assume_init()
};
ShufflingArray {
elems,
size_class,
allocator,
}
}
fn elem_layout(&self) -> Layout {
unsafe {
debug_assert!(
Layout::from_size_align(self.size_class, mem::align_of::<usize>()).is_ok()
);
Layout::from_size_align_unchecked(self.size_class, mem::align_of::<usize>())
}
}
}
struct SizeClasses<A>([LazyAtomicCell<A, ShufflingArray<A>>; NUM_SIZE_CLASSES])
where
A: 'static + GlobalAlloc;
struct SizeClassInfo {
index: usize,
size_class: usize,
}
#[rustfmt::skip]
#[inline]
fn size_class_info(size: usize) -> Option<SizeClassInfo> {
let mut size_class = mem::size_of::<usize>();
let mut stride = mem::size_of::<usize>();
if size <= size_class {
return Some(SizeClassInfo { index: 0, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 1, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 2, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 3, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 4, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 5, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 6, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 7, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 8, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 9, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 10, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 11, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 12, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 13, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 14, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 15, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 16, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 17, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 18, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 19, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 20, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 21, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 22, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 23, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 24, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 25, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 26, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 27, size_class });
}
size_class += stride;
stride = stride * 2;
if size <= size_class {
return Some(SizeClassInfo { index: 28, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 29, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 30, size_class });
}
size_class += stride;
if size <= size_class {
return Some(SizeClassInfo { index: 31, size_class });
}
None
}
const NUM_SIZE_CLASSES: usize = 32;
pub struct ShufflingAllocator<A>
where
A: 'static + GlobalAlloc,
{
#[doc(hidden)]
pub inner: &'static A,
#[doc(hidden)]
pub state: LazyAtomicCell<A, State<A>>,
}
#[doc(hidden)]
pub struct State<A>
where
A: 'static + GlobalAlloc,
{
rng: Mutex<A, StdRng>,
size_classes: LazyAtomicCell<A, SizeClasses<A>>,
}
#[macro_export]
macro_rules! wrap {
($inner:expr) => {
$crate::ShufflingAllocator {
inner: $inner,
state: $crate::LazyAtomicCell {
ptr: ::std::sync::atomic::AtomicPtr::new(::std::ptr::null_mut()),
allocator: $inner,
},
}
};
}
impl<A> ShufflingAllocator<A>
where
A: 'static + GlobalAlloc,
{
#[inline]
fn state(&self) -> &State<A> {
self.state.get_or_create(|| State {
rng: Mutex::new(&self.inner, StdRng::from_entropy()),
size_classes: LazyAtomicCell::new(self.inner),
})
}
#[inline]
fn random_index(&self) -> usize {
let mut rng = self.state().rng.lock();
rng.gen_range(0..SHUFFLING_ARRAY_SIZE)
}
#[inline]
fn size_classes(&self) -> &SizeClasses<A> {
self.state().size_classes.get_or_create(|| {
let mut classes =
MaybeUninit::<[LazyAtomicCell<A, ShufflingArray<A>>; NUM_SIZE_CLASSES]>::uninit();
unsafe {
for i in 0..NUM_SIZE_CLASSES {
ptr::write(
classes
.as_mut_ptr()
.cast::<LazyAtomicCell<A, ShufflingArray<A>>>()
.offset(i as _),
LazyAtomicCell::new(self.inner),
);
}
SizeClasses(classes.assume_init())
}
})
}
#[inline]
fn shuffling_array(&self, size: usize) -> Option<&ShufflingArray<A>> {
let SizeClassInfo { index, size_class } = size_class_info(size)?;
let size_classes = self.size_classes();
Some(size_classes.0[index].get_or_create(|| ShufflingArray::new(size_class, self.inner)))
}
}
unsafe impl<A> GlobalAlloc for ShufflingAllocator<A>
where
A: GlobalAlloc,
{
#[inline]
unsafe fn alloc(&self, layout: std::alloc::Layout) -> *mut u8 {
if layout.align() > mem::align_of::<usize>() {
return self.inner.alloc(layout);
}
match self.shuffling_array(layout.size()) {
None => self.inner.alloc(layout),
Some(array) => {
let replacement_ptr = self.inner.alloc(array.elem_layout());
if replacement_ptr.is_null() {
return ptr::null_mut();
}
let index = self.random_index();
array.elems[index].swap(replacement_ptr, Ordering::SeqCst)
}
}
}
#[inline]
unsafe fn dealloc(&self, ptr: *mut u8, layout: std::alloc::Layout) {
if ptr.is_null() {
return;
}
if layout.align() > mem::align_of::<usize>() {
self.inner.dealloc(ptr, layout);
return;
}
match self.shuffling_array(layout.size()) {
None => self.inner.dealloc(ptr, layout),
Some(array) => {
let index = self.random_index();
let old_ptr = array.elems[index].swap(ptr, Ordering::SeqCst);
self.inner.dealloc(old_ptr, array.elem_layout());
}
}
}
}