use core::{
alloc::Layout,
cell::Cell,
mem::{align_of, size_of},
ptr::{self, NonNull},
sync::atomic::{AtomicPtr, Ordering},
};
use allocator_api2::{AllocError, Allocator};
#[cfg(feature = "sync")]
use parking_lot::RwLock;
#[inline(always)]
fn is_aligned_to(value: usize, align: usize) -> bool {
debug_assert!(align.is_power_of_two());
let mask = align - 1;
value & mask == 0
}
#[inline(always)]
fn align_up(value: usize, align: usize) -> Option<usize> {
debug_assert!(align.is_power_of_two());
let mask = align - 1;
Some(value.checked_add(mask)? & !mask)
}
#[inline(always)]
fn align_down(value: usize, align: usize) -> usize {
debug_assert!(align.is_power_of_two());
let mask = align - 1;
value & !mask
}
#[inline(always)]
fn layout_sum(layout: &Layout) -> usize {
layout.size() + (layout.align() - 1)
}
pub trait CasPtr {
fn new(value: *mut u8) -> Self;
fn load(&self, order: Ordering) -> *mut u8;
fn set(&mut self, value: *mut u8);
fn compare_exchange(
&self,
old: *mut u8,
new: *mut u8,
success: Ordering,
failure: Ordering,
) -> Result<(), *mut u8>;
fn compare_exchange_weak(
&self,
old: *mut u8,
new: *mut u8,
success: Ordering,
failure: Ordering,
) -> Result<(), *mut u8>;
}
impl CasPtr for Cell<*mut u8> {
#[inline(always)]
fn new(value: *mut u8) -> Self {
Cell::new(value)
}
#[inline(always)]
fn load(&self, _: Ordering) -> *mut u8 {
self.get()
}
#[inline(always)]
fn set(&mut self, value: *mut u8) {
*self.get_mut() = value;
}
#[inline(always)]
fn compare_exchange(
&self,
old: *mut u8,
new: *mut u8,
_: Ordering,
_: Ordering,
) -> Result<(), *mut u8> {
if old == self.get() {
self.set(new);
Ok(())
} else {
Err(self.get())
}
}
#[inline(always)]
fn compare_exchange_weak(
&self,
old: *mut u8,
new: *mut u8,
_: Ordering,
_: Ordering,
) -> Result<(), *mut u8> {
debug_assert_eq!(
old,
self.get(),
"Must be used only in loop where `old` is last loaded value"
);
self.set(new);
Ok(())
}
}
impl CasPtr for AtomicPtr<u8> {
#[inline(always)]
fn new(value: *mut u8) -> Self {
AtomicPtr::new(value)
}
#[inline(always)]
fn load(&self, order: Ordering) -> *mut u8 {
self.load(order)
}
#[inline(always)]
fn set(&mut self, value: *mut u8) {
*self.get_mut() = value;
}
#[inline(always)]
fn compare_exchange(
&self,
old: *mut u8,
new: *mut u8,
success: Ordering,
failure: Ordering,
) -> Result<(), *mut u8> {
self.compare_exchange(old, new, success, failure)?;
Ok(())
}
#[inline(always)]
fn compare_exchange_weak(
&self,
old: *mut u8,
new: *mut u8,
success: Ordering,
failure: Ordering,
) -> Result<(), *mut u8> {
self.compare_exchange_weak(old, new, success, failure)?;
Ok(())
}
}
trait Mut<T> {
fn get(&self) -> T;
fn set(&mut self, value: T);
}
impl<T: Copy> Mut<T> for &mut T {
#[inline(always)]
fn get(&self) -> T {
**self
}
#[inline(always)]
fn set(&mut self, value: T) {
**self = value;
}
}
impl<T: Copy> Mut<T> for &Cell<T> {
#[inline(always)]
fn get(&self) -> T {
Cell::get(*self)
}
#[inline(always)]
fn set(&mut self, value: T) {
Cell::set(*self, value);
}
}
const CHUNK_START_SIZE: usize = 256;
const CHUNK_POWER_OF_TWO_THRESHOLD: usize = 1 << 14;
const CHUNK_PAGE_SIZE_THRESHOLD: usize = 1 << 12;
const CHUNK_MIN_GROW_STEP: usize = 64;
#[repr(C)]
pub struct ChunkHeader<T> {
cursor: T,
end: *mut u8,
prev: Option<NonNull<Self>>,
cumulative_size: usize,
}
impl<T> ChunkHeader<T>
where
T: CasPtr,
{
#[inline]
unsafe fn alloc_chunk(
size: usize,
allocator: &impl Allocator,
prev: Option<NonNull<Self>>,
) -> Result<NonNull<Self>, AllocError> {
let Some(size) = align_up(size, align_of::<Self>()) else {
return Err(AllocError);
};
let layout = unsafe { Layout::from_size_align_unchecked(size, align_of::<Self>()) };
let slice = allocator.allocate(layout)?;
Ok(Self::init_chunk(slice, prev))
}
#[inline]
unsafe fn dealloc_chunk(
chunk: NonNull<Self>,
allocator: &impl Allocator,
) -> Option<NonNull<Self>> {
let me = unsafe { chunk.as_ref() };
let prev = me.prev;
let size = unsafe { me.end.offset_from(chunk.as_ptr().cast()) } as usize;
let layout = unsafe { Layout::from_size_align_unchecked(size, align_of::<Self>()) };
allocator.deallocate(chunk.cast(), layout);
prev
}
#[inline]
unsafe fn init_chunk(slice: NonNull<[u8]>, prev: Option<NonNull<Self>>) -> NonNull<Self> {
let len = slice.len();
let header_ptr = slice.as_ptr().cast::<u8>();
debug_assert!(is_aligned_to(
sptr::Strict::addr(header_ptr),
align_of::<Self>()
));
debug_assert!(len > size_of::<Self>());
let end = header_ptr.add(len);
let header_ptr = header_ptr.cast::<Self>();
let base = header_ptr.add(1).cast::<u8>();
let cumulative_size = match prev {
None => 0,
Some(prev) => {
let prev = unsafe { prev.as_ref() };
prev.cap() + prev.cumulative_size
}
};
ptr::write(
header_ptr,
ChunkHeader {
cursor: T::new(base),
end,
prev,
cumulative_size,
},
);
NonNull::new_unchecked(header_ptr)
}
#[inline(always)]
fn base(&self) -> *const u8 {
unsafe { <*const Self>::add(self, 1).cast() }
}
#[inline(always)]
fn base_mut(&mut self) -> *mut u8 {
unsafe { <*mut Self>::add(self, 1).cast() }
}
#[inline(always)]
unsafe fn offset_from_end(&self, ptr: *const u8) -> usize {
let offset = unsafe { self.end.offset_from(ptr) };
offset as usize
}
#[inline(always)]
unsafe fn offset_from_base(&self, ptr: *const u8) -> usize {
let offset = unsafe { ptr.offset_from(self.base()) };
offset as usize
}
#[inline(always)]
fn cap(&self) -> usize {
unsafe { self.offset_from_end(self.base()) }
}
#[inline(always)]
fn alloc_round(
&self,
cursor: *mut u8,
layout: Layout,
exchange: impl FnOnce(*mut u8) -> Result<(), *mut u8>,
) -> Result<Result<NonNull<[u8]>, *mut u8>, Option<usize>> {
let cursor_addr = sptr::Strict::addr(cursor);
let layout_sum = layout_sum(&layout);
let Some(unaligned) = cursor_addr.checked_add(layout_sum) else {
let used = unsafe { self.offset_from_base(cursor) };
let Some(total_used) = used.checked_add(self.cumulative_size) else {
return Err(None);
};
let Some(next_size) = layout_sum.checked_add(total_used) else {
return Err(None);
};
let Some(min_grow) = self.cap().checked_add(CHUNK_MIN_GROW_STEP) else {
return Err(None);
};
return Err(Some(next_size.max(min_grow)));
};
let aligned_addr = align_down(unaligned - layout.size(), layout.align());
debug_assert!(
aligned_addr >= cursor_addr,
"aligned_addr addr must not be less than cursor"
);
debug_assert!(
(aligned_addr - cursor_addr) < layout.align(),
"Cannot waste space more than alignment size"
);
let next_addr = aligned_addr + layout.size();
let end_addr = sptr::Strict::addr(self.end);
if next_addr > end_addr {
let Some(overused) = (next_addr - end_addr).checked_add(self.cap()) else {
return Err(None);
};
let Some(required) = overused.checked_add(self.cumulative_size) else {
return Err(None);
};
return Err(Some(required));
}
let aligned = unsafe { cursor.add(aligned_addr - cursor_addr) };
let next = unsafe { aligned.add(layout.size()) };
if let Err(updated) = exchange(next) {
return Ok(Err(updated));
};
let len = next_addr - cursor_addr;
debug_assert!(len >= layout.size());
unsafe {
let slice = core::ptr::slice_from_raw_parts_mut(cursor, len);
Ok(Ok(NonNull::new_unchecked(slice)))
}
}
#[inline(always)]
unsafe fn alloc<const ZEROED: bool>(
chunk: NonNull<Self>,
layout: Layout,
) -> Result<NonNull<[u8]>, Option<usize>> {
let me = unsafe { chunk.as_ref() };
let mut cursor = me.cursor.load(Ordering::Relaxed);
loop {
let result = me.alloc_round(cursor, layout, |aligned| {
me.cursor.compare_exchange_weak(
cursor,
aligned,
Ordering::Acquire, Ordering::Relaxed,
)
});
match result {
Ok(Ok(slice)) => {
if ZEROED {
unsafe { ptr::write_bytes(slice.as_ptr().cast::<u8>(), 0, slice.len()) }
}
return Ok(slice);
}
Ok(Err(updated)) => cursor = updated,
Err(next_size) => return Err(next_size),
}
}
}
#[inline]
unsafe fn resize<const ZEROED: bool>(
chunk: NonNull<Self>,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
) -> Result<NonNull<[u8]>, Option<usize>> {
let me = unsafe { chunk.as_ref() };
let old_end = unsafe { ptr.as_ptr().add(old_size) };
let addr = sptr::Strict::addr(ptr.as_ptr());
if new_layout.size() <= old_size {
if let Some(aligned_addr) = align_up(addr, new_layout.align()) {
let max_shift = old_size - new_layout.size();
if addr + max_shift >= aligned_addr {
let aligned = ptr.as_ptr().add(aligned_addr - addr);
memmove(ptr.as_ptr(), aligned, new_layout.size());
let _ = me.cursor.compare_exchange(
old_end,
aligned.add(new_layout.size()),
Ordering::Release, Ordering::Relaxed,
);
let slice = core::ptr::slice_from_raw_parts_mut(aligned, new_layout.size());
return Ok(NonNull::new_unchecked(slice));
}
}
}
if me.cursor.load(Ordering::Relaxed) == old_end {
let Some(unaligned) = addr.checked_add(layout_sum(&new_layout)) else {
let used = unsafe { me.offset_from_base(ptr.as_ptr()) };
let next_size = layout_sum(&new_layout).checked_add(used);
let min_grow = me.cap().checked_add(CHUNK_MIN_GROW_STEP);
return Err(next_size.max(min_grow));
};
let aligned_addr = align_down(unaligned - new_layout.size(), new_layout.align());
debug_assert!(
aligned_addr >= addr,
"aligned_addr addr must not be less than cursor"
);
debug_assert!(
(aligned_addr - addr) < new_layout.align(),
"Cannot waste space more than alignment size"
);
let next_addr = aligned_addr + new_layout.size();
let end_addr = sptr::Strict::addr(me.end);
if next_addr > end_addr {
return Err((next_addr - end_addr).checked_add(me.cap()));
}
let aligned = unsafe { ptr.as_ptr().add(aligned_addr - addr) };
let next = unsafe { aligned.add(new_layout.size()) };
let result = me.cursor.compare_exchange(
old_end,
next,
Ordering::Acquire, Ordering::Relaxed,
);
if let Ok(()) = result {
memmove(ptr.as_ptr(), aligned, new_layout.size().min(old_size));
if ZEROED && old_size < new_layout.size() {
core::ptr::write_bytes(
ptr.as_ptr().add(old_size),
0,
new_layout.size() - old_size,
);
}
let slice = core::ptr::slice_from_raw_parts_mut(aligned, new_layout.size());
return Ok(NonNull::new_unchecked(slice));
}
}
let new_ptr = ChunkHeader::alloc::<false>(chunk, new_layout)?;
core::ptr::copy_nonoverlapping(
ptr.as_ptr(),
new_ptr.as_ptr().cast(),
new_layout.size().min(old_size),
);
if ZEROED && old_size < new_layout.size() {
core::ptr::write_bytes(
new_ptr.as_ptr().cast::<u8>(),
0,
new_layout.size() - old_size,
);
}
Ok(new_ptr)
}
#[inline(always)]
unsafe fn reset(mut chunk: NonNull<Self>) -> Option<NonNull<Self>> {
let me = chunk.as_mut();
let base = me.base_mut();
me.cursor.set(base);
me.cumulative_size = 0;
me.prev.take()
}
#[inline(always)]
unsafe fn dealloc(chunk: NonNull<Self>, ptr: NonNull<u8>, size: usize) {
let me = unsafe { chunk.as_ref() };
let new = unsafe { ptr.as_ptr().add(size) };
let _ = me.cursor.compare_exchange(
ptr.as_ptr(),
new,
Ordering::Release, Ordering::Relaxed,
);
}
}
#[inline(always)]
unsafe fn alloc_fast<T, const ZEROED: bool>(
root: Option<NonNull<ChunkHeader<T>>>,
layout: Layout,
) -> Result<NonNull<[u8]>, Option<usize>>
where
T: CasPtr,
{
match root {
Some(root) => {
unsafe { ChunkHeader::alloc::<ZEROED>(root, layout) }
}
None => {
let Some(min_chunk_size) = layout_sum(&layout).checked_add(size_of::<ChunkHeader<T>>()) else {
return Err(None);
};
Err(Some(min_chunk_size))
}
}
}
#[cold]
unsafe fn alloc_slow<T, A, const ZEROED: bool>(
mut root: impl Mut<Option<NonNull<ChunkHeader<T>>>>,
chunk_size: usize,
layout: Layout,
allocator: &A,
) -> Result<NonNull<[u8]>, AllocError>
where
T: CasPtr,
A: Allocator,
{
let Some(mut chunk_size) = chunk_size.checked_add(size_of::<ChunkHeader<T>>()) else {
return Err(AllocError);
};
if chunk_size < CHUNK_POWER_OF_TWO_THRESHOLD {
chunk_size = chunk_size.next_power_of_two();
} else {
chunk_size = align_up(chunk_size, CHUNK_PAGE_SIZE_THRESHOLD).unwrap_or(chunk_size);
}
let min_chunk_size = CHUNK_START_SIZE.max(size_of::<[ChunkHeader<T>; 16]>());
chunk_size = chunk_size.max(min_chunk_size);
let new_chunk = ChunkHeader::alloc_chunk(chunk_size, allocator, root.get())?;
let res = unsafe { ChunkHeader::alloc::<ZEROED>(new_chunk, layout) };
let Ok(ptr) = res else {
unsafe { unreachable_unchecked() }
};
root.set(Some(new_chunk));
Ok(ptr)
}
#[inline(always)]
unsafe fn resize_fast<T, const ZEROED: bool>(
root: Option<NonNull<ChunkHeader<T>>>,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
) -> Result<NonNull<[u8]>, Option<usize>>
where
T: CasPtr,
{
match root {
Some(root) => {
unsafe { ChunkHeader::resize::<ZEROED>(root, ptr, old_size, new_layout) }
}
None => {
unreachable_unchecked();
}
}
}
#[inline(always)]
unsafe fn resize_slow<T, A, const ZEROED: bool>(
root: impl Mut<Option<NonNull<ChunkHeader<T>>>>,
chunk_size: usize,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
allocator: &A,
) -> Result<NonNull<[u8]>, AllocError>
where
T: CasPtr,
A: Allocator,
{
let new_ptr = alloc_slow::<_, _, false>(root, chunk_size, new_layout, allocator)?;
core::ptr::copy_nonoverlapping(
ptr.as_ptr(),
new_ptr.as_ptr().cast(),
new_layout.size().min(old_size),
);
if ZEROED && old_size < new_layout.size() {
core::ptr::write_bytes(ptr.as_ptr().add(old_size), 0, new_layout.size() - old_size);
}
Ok(new_ptr)
}
#[inline(always)]
unsafe fn dealloc<T>(root: Option<NonNull<ChunkHeader<T>>>, ptr: NonNull<u8>, size: usize)
where
T: CasPtr,
{
if let Some(root) = root {
unsafe {
ChunkHeader::dealloc(root, ptr, size);
}
}
}
#[inline(always)]
pub unsafe fn reset<T, A>(
root: &mut Option<NonNull<ChunkHeader<T>>>,
keep_last: bool,
allocator: &A,
) where
T: CasPtr,
A: Allocator,
{
let mut prev = if keep_last {
let Some(root) = root else {
return;
};
unsafe { ChunkHeader::reset(*root) }
} else {
root.take()
};
while let Some(chunk) = prev {
prev = unsafe { ChunkHeader::dealloc_chunk(chunk, allocator) };
}
}
#[inline(always)]
pub fn reset_leak<T>(root: &mut Option<NonNull<ChunkHeader<T>>>, keep_last: bool)
where
T: CasPtr,
{
if keep_last {
let Some(chunk) = root else {
return;
};
unsafe {
ChunkHeader::reset(*chunk);
}
} else {
*root = None;
};
}
pub trait Arena {
unsafe fn alloc<const ZEROED: bool>(
&self,
layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError>;
unsafe fn resize<const ZEROED: bool>(
&self,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError>;
unsafe fn dealloc(&self, ptr: NonNull<u8>, size: usize);
unsafe fn reset(&mut self, keep_last: bool, allocator: &impl Allocator);
fn reset_leak(&mut self, keep_last: bool);
}
pub struct ArenaLocal {
root: Cell<Option<NonNull<ChunkHeader<Cell<*mut u8>>>>>,
min_chunk_size: Cell<usize>,
}
unsafe impl Send for ArenaLocal {}
impl Drop for ArenaLocal {
#[inline(always)]
fn drop(&mut self) {
debug_assert!(
self.root.get().is_none(),
"Owner must reset `ArenaLocal` with `keep_last` set to `false` before drop"
);
}
}
impl ArenaLocal {
#[inline(always)]
pub const fn new(min_chunk_size: usize) -> Self {
ArenaLocal {
root: Cell::new(None),
min_chunk_size: Cell::new(min_chunk_size),
}
}
#[inline(always)]
#[cfg(feature = "sync")]
pub fn last_chunk_size(&self) -> usize {
match self.root.get() {
None => 0,
Some(root) => {
unsafe { root.as_ref().cap() }
}
}
}
}
impl Arena for ArenaLocal {
unsafe fn alloc<const ZEROED: bool>(
&self,
layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError> {
match alloc_fast::<_, ZEROED>(self.root.get(), layout) {
Ok(ptr) => Ok(ptr),
Err(None) => Err(AllocError),
Err(Some(chunk_size)) => alloc_slow::<_, _, ZEROED>(
&self.root,
chunk_size.max(self.min_chunk_size.get()),
layout,
allocator,
),
}
}
unsafe fn resize<const ZEROED: bool>(
&self,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError> {
match resize_fast::<_, ZEROED>(self.root.get(), ptr, old_size, new_layout) {
Ok(ptr) => Ok(ptr),
Err(None) => Err(AllocError),
Err(Some(chunk_size)) => resize_slow::<_, _, ZEROED>(
&self.root,
chunk_size.max(self.min_chunk_size.get()),
ptr,
old_size,
new_layout,
allocator,
),
}
}
#[inline(always)]
unsafe fn dealloc(&self, ptr: NonNull<u8>, size: usize) {
dealloc(self.root.get(), ptr, size)
}
#[inline(always)]
unsafe fn reset(&mut self, keep_last: bool, allocator: &impl Allocator) {
unsafe { reset(self.root.get_mut(), keep_last, allocator) }
}
#[inline(always)]
fn reset_leak(&mut self, keep_last: bool) {
reset_leak(self.root.get_mut(), keep_last)
}
}
#[cfg(feature = "sync")]
mod sync {
use super::*;
struct Inner {
root: Option<NonNull<ChunkHeader<Cell<*mut u8>>>>,
min_chunk_size: usize,
}
unsafe impl Send for Inner {}
unsafe impl Sync for Inner {}
pub struct ArenaSync {
inner: RwLock<Inner>,
}
impl Drop for ArenaSync {
#[inline(always)]
fn drop(&mut self) {
debug_assert!(
self.inner.get_mut().root.is_none(),
"Owner must reset `ArenaSync` with `keep_last` set to `false` before drop"
);
}
}
impl ArenaSync {
#[inline(always)]
pub const fn new(min_chunk_size: usize) -> Self {
ArenaSync {
inner: RwLock::new(Inner {
root: None,
min_chunk_size,
}),
}
}
}
impl Arena for ArenaSync {
unsafe fn alloc<const ZEROED: bool>(
&self,
layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError> {
let inner = self.inner.read();
match alloc_fast::<_, ZEROED>(inner.root, layout) {
Ok(ptr) => Ok(ptr),
Err(None) => Err(AllocError),
Err(Some(chunk_size)) => {
drop(inner);
let mut guard = self.inner.write();
let inner = &mut *guard;
alloc_slow::<_, _, ZEROED>(
&mut inner.root,
chunk_size.max(inner.min_chunk_size),
layout,
allocator,
)
}
}
}
unsafe fn resize<const ZEROED: bool>(
&self,
ptr: NonNull<u8>,
old_size: usize,
new_layout: Layout,
allocator: &impl Allocator,
) -> Result<NonNull<[u8]>, AllocError> {
let inner = self.inner.read();
match resize_fast::<_, ZEROED>(inner.root, ptr, old_size, new_layout) {
Ok(ptr) => Ok(ptr),
Err(None) => Err(AllocError),
Err(Some(chunk_size)) => {
drop(inner);
let mut guard = self.inner.write();
let inner = &mut *guard;
resize_slow::<_, _, ZEROED>(
&mut inner.root,
chunk_size.max(inner.min_chunk_size),
ptr,
old_size,
new_layout,
allocator,
)
}
}
}
unsafe fn dealloc(&self, ptr: NonNull<u8>, size: usize) {
dealloc(self.inner.read().root, ptr, size)
}
#[inline(always)]
unsafe fn reset(&mut self, keep_last: bool, allocator: &impl Allocator) {
unsafe { reset(&mut self.inner.get_mut().root, keep_last, allocator) }
}
#[inline(always)]
fn reset_leak(&mut self, keep_last: bool) {
reset_leak(&mut self.inner.get_mut().root, keep_last)
}
}
}
#[cfg(feature = "sync")]
pub use self::sync::ArenaSync;
#[cfg(debug_assertions)]
#[track_caller]
unsafe fn unreachable_unchecked() -> ! {
unreachable!()
}
#[cfg(not(debug_assertions))]
unsafe fn unreachable_unchecked() -> ! {
unsafe { core::hint::unreachable_unchecked() }
}
#[inline(always)]
unsafe fn memmove(src: *mut u8, dst: *mut u8, size: usize) {
if src == dst {
return;
}
#[cold]
#[inline(always)]
unsafe fn cold_copy(src: *mut u8, dst: *mut u8, size: usize) {
core::ptr::copy(src, dst, size);
}
cold_copy(src, dst, size)
}