use std::alloc::{alloc, dealloc, Layout};
use std::cell::Cell;
use std::marker::PhantomData;
use std::ptr::NonNull;
pub struct MkFastArena {
base: NonNull<u8>,
end: *const u8,
ptr: Cell<*mut u8>,
_marker: PhantomData<*mut u8>,
}
impl MkFastArena {
#[inline]
pub fn new(size: usize) -> Self {
assert!(size > 0, "Arena size must be > 0");
let layout = Layout::from_size_align(size, 4096)
.expect("Invalid arena size");
let base = unsafe {
let ptr = alloc(layout);
NonNull::new(ptr).expect("Failed to allocate arena memory")
};
let end = unsafe { base.as_ptr().add(size) };
Self {
base,
end,
ptr: Cell::new(base.as_ptr()),
_marker: PhantomData,
}
}
#[inline]
pub fn with_capacity_mb(mb: usize) -> Self {
Self::new(mb * 1024 * 1024)
}
#[allow(clippy::mut_from_ref)]
#[inline(always)]
pub unsafe fn alloc<T>(&self, value: T) -> Option<&mut T> {
let ptr = self.alloc_raw::<T>()?;
unsafe {
ptr.write(value);
Some(&mut *ptr)
}
}
#[inline(always)]
pub fn alloc_raw<T>(&self) -> Option<*mut T> {
let layout = Layout::new::<T>();
self.alloc_layout(layout).map(|p| p as *mut T)
}
#[inline(always)]
pub fn alloc_slice_raw<T>(&self, len: usize) -> Option<*mut T> {
if len == 0 {
return Some(std::ptr::NonNull::dangling().as_ptr());
}
let layout = Layout::array::<T>(len).ok()?;
self.alloc_layout(layout).map(|p| p as *mut T)
}
#[allow(clippy::mut_from_ref)]
#[inline(always)]
pub unsafe fn alloc_slice_fill<T: Clone>(&self, len: usize, value: T) -> Option<&mut [T]> {
let ptr = self.alloc_slice_raw::<T>(len)?;
unsafe {
for i in 0..len {
ptr.add(i).write(value.clone());
}
Some(std::slice::from_raw_parts_mut(ptr, len))
}
}
#[allow(clippy::mut_from_ref)]
#[inline(always)]
pub unsafe fn alloc_slice_default<T: Default>(&self, len: usize) -> Option<&mut [T]> {
let ptr = self.alloc_slice_raw::<T>(len)?;
unsafe {
for i in 0..len {
ptr.add(i).write(T::default());
}
Some(std::slice::from_raw_parts_mut(ptr, len))
}
}
#[inline(always)]
fn alloc_layout(&self, layout: Layout) -> Option<*mut u8> {
let current = self.ptr.get();
let align = layout.align();
let aligned = ((current as usize + align - 1) & !(align - 1)) as *mut u8;
let new_ptr = unsafe { aligned.add(layout.size()) };
if new_ptr > self.end as *mut u8 {
return None;
}
self.ptr.set(new_ptr);
Some(aligned)
}
#[inline(always)]
pub fn reset(&self) {
self.ptr.set(self.base.as_ptr());
}
#[inline(always)]
pub fn checkpoint(&self) -> usize {
self.ptr.get() as usize - self.base.as_ptr() as usize
}
#[inline(always)]
pub fn reset_to(&self, checkpoint: usize) {
let new_ptr = unsafe { self.base.as_ptr().add(checkpoint) };
debug_assert!(new_ptr <= self.end as *mut u8);
self.ptr.set(new_ptr);
}
#[inline]
pub fn allocated(&self) -> usize {
self.ptr.get() as usize - self.base.as_ptr() as usize
}
#[inline]
pub fn remaining(&self) -> usize {
self.end as usize - self.ptr.get() as usize
}
#[inline]
pub fn capacity(&self) -> usize {
self.end as usize - self.base.as_ptr() as usize
}
}
impl Drop for MkFastArena {
fn drop(&mut self) {
unsafe {
let size = self.end as usize - self.base.as_ptr() as usize;
let layout = Layout::from_size_align_unchecked(size, 4096);
dealloc(self.base.as_ptr(), layout);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_alloc() {
let arena = MkFastArena::new(4096);
unsafe {
let x = arena.alloc(42u64).unwrap();
assert_eq!(*x, 42);
let y = arena.alloc(123u32).unwrap();
assert_eq!(*y, 123);
*x = 100;
assert_eq!(*x, 100);
}
}
#[test]
fn test_slice_alloc() {
let arena = MkFastArena::new(4096);
unsafe {
let slice = arena.alloc_slice_fill(100, 42u64).unwrap();
assert_eq!(slice.len(), 100);
assert_eq!(slice[0], 42);
assert_eq!(slice[99], 42);
}
}
#[test]
fn test_reset() {
let arena = MkFastArena::new(4096);
unsafe {
arena.alloc(42u64).unwrap();
arena.alloc(123u64).unwrap();
}
assert!(arena.allocated() > 0);
arena.reset();
assert_eq!(arena.allocated(), 0);
}
#[test]
fn test_checkpoint() {
let arena = MkFastArena::new(4096);
unsafe {
arena.alloc(42u64).unwrap();
let cp = arena.checkpoint();
arena.alloc(123u64).unwrap();
arena.alloc(456u64).unwrap();
assert!(arena.allocated() > cp);
arena.reset_to(cp);
assert_eq!(arena.checkpoint(), cp);
}
}
#[test]
fn test_full() {
let arena = MkFastArena::new(64);
let _ = arena.alloc_slice_raw::<u8>(64);
unsafe {
assert!(arena.alloc(42u64).is_none());
}
}
}