use std::{alloc::Layout, ptr::NonNull};
use crate::{AllocError, Allocator, chunk::Chunk};
#[derive(Debug)]
pub struct BufferArena {
pub(crate) store: Chunk,
pub(crate) ptr: NonNull<u8>,
}
impl Default for BufferArena {
#[inline(always)]
fn default() -> Self {
let store = Chunk::default();
let ptr = store.base();
Self { store, ptr }
}
}
impl BufferArena {
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
let store = Chunk::with_capacity(capacity);
let ptr = store.base();
Self { store, ptr }
}
#[inline]
pub fn used(&self) -> usize {
unsafe { self.ptr.byte_offset_from_unsigned(self.store.base()) }
}
#[inline]
pub fn remaining(&self) -> usize {
unsafe { self.store.limit().byte_offset_from_unsigned(self.ptr) }
}
#[inline]
pub fn contains(&self, ptr: NonNull<u8>) -> bool {
self.store.contains(ptr)
}
#[inline]
pub fn sufficient_for(&self, layout: Layout) -> bool {
self.ptr.align_offset(layout.align()) + layout.size() <= self.remaining()
}
}
impl From<Chunk> for BufferArena {
#[inline]
fn from(value: Chunk) -> Self {
let ptr = value.base();
Self { store: value, ptr }
}
}
impl Allocator for BufferArena {
#[inline]
unsafe fn allocate(&mut self, layout: Layout) -> Result<NonNull<[u8]>, crate::AllocError> {
debug_assert!(layout.align() > 0);
debug_assert!(layout.align() <= 4096);
if self.sufficient_for(layout) {
let ptr = unsafe { self.ptr.add(self.ptr.align_offset(layout.align())) };
self.ptr = unsafe { ptr.add(layout.size()) };
Ok(NonNull::slice_from_raw_parts(ptr, layout.size()))
} else {
Err(AllocError::CapacityExceeded {
requested: layout.size(),
remaining: self.remaining(),
})
}
}
#[inline]
unsafe fn deallocate(&mut self, ptr: NonNull<u8>, layout: Layout) {
debug_assert!(self.store.contains(ptr.cast()));
debug_assert!(unsafe { ptr.add(layout.size()) } <= self.store.limit());
self.ptr = ptr;
}
#[inline]
unsafe fn grow(
&mut self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, crate::AllocError> {
debug_assert_eq!(
unsafe { ptr.add(old_layout.size()) },
self.ptr,
"last allocation only"
);
debug_assert_eq!(old_layout.align(), new_layout.align());
match old_layout.size().cmp(&new_layout.size()) {
std::cmp::Ordering::Less => {
if unsafe { ptr.add(new_layout.size()) } <= self.store.limit() {
self.ptr = unsafe { ptr.add(new_layout.size()) };
return Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()));
} else {
Err(AllocError::CapacityExceeded {
requested: new_layout.size() - old_layout.size(),
remaining: self.remaining(),
})
}
}
std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
std::cmp::Ordering::Greater => unreachable!("use shrink instead"),
}
}
#[inline]
unsafe fn shrink(
&mut self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, crate::AllocError> {
debug_assert_eq!(
unsafe { ptr.add(old_layout.size()) },
self.ptr,
"last allocation only"
);
debug_assert_eq!(old_layout.align(), new_layout.align());
match old_layout.size().cmp(&new_layout.size()) {
std::cmp::Ordering::Greater => {
self.ptr = unsafe { ptr.add(new_layout.size()) };
return Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size()));
}
std::cmp::Ordering::Equal => Ok(NonNull::slice_from_raw_parts(ptr, new_layout.size())),
std::cmp::Ordering::Less => unreachable!("use grow instead"),
}
}
}
impl Into<Chunk> for BufferArena {
#[inline]
fn into(self) -> Chunk {
self.store
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_buffer_arena_new_and_capacity() {
let cap = 4096;
let arena = BufferArena::with_capacity(cap);
assert_eq!(arena.used(), 0);
assert_eq!(arena.remaining(), cap);
}
#[test]
fn test_chunk_conversion() {
let chunk = Chunk::with_capacity(4096);
let arena: BufferArena = chunk.into();
assert_eq!(arena.used(), 0);
assert_eq!(arena.remaining(), 4096);
let to_chunk: Chunk = arena.into();
assert_eq!(to_chunk.capacity(), 4096);
}
#[test]
fn test_allocate_and_deallocate() {
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(8, 1).unwrap();
let ptr = unsafe { arena.allocate(layout).unwrap() };
assert_eq!(ptr.len(), 8);
assert_eq!(arena.used(), 8);
unsafe { arena.deallocate(ptr.cast(), layout) };
assert_eq!(arena.used(), 0);
}
#[test]
fn test_alignment() {
let mut arena = BufferArena::with_capacity(4096);
unsafe { arena.ptr.write_bytes(0, arena.remaining()) };
let mut prev_end = arena.ptr;
for (i, align) in [1, 2, 4, 8, 16, 32, 4096].into_iter().rev().enumerate() {
let size = i + 1;
let layout = Layout::from_size_align(size, align).unwrap();
let ptr = unsafe { arena.allocate_zeroed(layout).unwrap() };
let addr = ptr.cast::<u8>().as_ptr() as usize;
assert_eq!(addr % align, 0, "addr {ptr:?} not aligned to {align}");
let fill = size as u8;
unsafe { ptr.cast::<u8>().write_bytes(fill, layout.size()) };
let data = unsafe { ptr.as_ref() };
assert_eq!(data, vec![fill; size].as_slice());
assert!(ptr.cast() >= prev_end, "Allocation overlapped previous");
prev_end = unsafe { ptr.cast().add(layout.size()) };
}
assert_eq!(arena.used(), 79);
let written =
unsafe { std::slice::from_raw_parts(arena.store.base().as_ptr(), arena.used()) };
assert_eq!(
written,
[
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0, 0, 0, 0,
4, 4, 4, 4, 5, 5, 5, 5, 5, 0, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7
]
.as_ref()
);
unsafe { arena.deallocate(arena.store.base(), Layout::from_size_align_unchecked(8, 1)) };
assert_eq!(arena.used(), 0);
}
#[test]
fn test_allocate_full_arena() {
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(4096, 1).unwrap();
let ptr = unsafe { arena.allocate(layout).unwrap() };
assert_eq!(ptr.len(), 4096);
assert_eq!(arena.used(), 4096);
let layout2 = Layout::from_size_align(1, 1).unwrap();
assert!(unsafe { arena.allocate(layout2) }.is_err());
}
#[test]
fn test_grow_allocation() {
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(8, 1).unwrap();
let ptr = unsafe { arena.allocate(layout).unwrap() };
let new_layout = Layout::from_size_align(16, 1).unwrap();
let grown = unsafe { arena.grow(ptr.cast(), layout, new_layout).unwrap() };
assert_eq!(grown.len(), 16);
assert_eq!(arena.used(), 16);
}
#[test]
fn test_shrink_allocation() {
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(16, 1).unwrap();
let ptr = unsafe { arena.allocate(layout).unwrap() };
let new_layout = Layout::from_size_align(8, 1).unwrap();
let shrunk = unsafe { arena.shrink(ptr.cast(), layout, new_layout).unwrap() };
assert_eq!(shrunk.len(), 8);
assert_eq!(arena.used(), 8);
}
#[test]
fn test_multiple_allocate_and_deallocate() {
use std::alloc::Layout;
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(8, 1).unwrap();
for _ in 0..5 {
let ptr = unsafe { arena.allocate(layout).unwrap() };
assert_eq!(ptr.len(), 8);
assert_eq!(arena.used(), 8);
unsafe { ptr.cast::<u8>().write_bytes(0xAA, layout.size()) };
assert_eq!(unsafe { ptr.as_ref() }, [0xAA; 8].as_ref());
unsafe { arena.deallocate(ptr.cast(), layout) };
assert_eq!(arena.used(), 0);
}
let mut ptrs = Vec::new();
for _ in 0..4 {
let ptr = unsafe { arena.allocate(layout).unwrap() };
unsafe { ptr.cast::<u8>().write_bytes(0xAA, layout.size()) };
assert_eq!(unsafe { ptr.as_ref() }, [0xAA; 8].as_ref());
ptrs.push(ptr);
}
assert_eq!(arena.used(), 32);
for ptr in ptrs.into_iter().rev() {
unsafe { arena.deallocate(ptr.cast(), layout) };
}
assert_eq!(arena.used(), 0);
}
#[test]
fn test_multi_alloc() {
let mut arena = BufferArena::with_capacity(4096);
let layout = Layout::from_size_align(8, 8).unwrap();
let ptr = unsafe { arena.allocate(layout) }.unwrap();
unsafe { ptr.cast::<u8>().write_bytes(0xAA, layout.size()) };
let data1 = unsafe { ptr.as_ref() };
assert_eq!(data1, [0xAA; 8].as_slice());
let remaining = arena.remaining();
assert_eq!(remaining, 4088);
let new_layout = Layout::from_size_align(layout.size() + 4, layout.align()).unwrap();
let grown_ptr = unsafe { arena.grow(ptr.cast(), layout, new_layout) }.unwrap();
unsafe {
grown_ptr
.cast::<u8>()
.add(layout.size())
.write_bytes(0xBB, 4)
};
assert_eq!(arena.remaining(), 4084);
let grown_data = unsafe {
std::slice::from_raw_parts(grown_ptr.as_ptr() as *const u8, new_layout.size())
};
let mut expected = vec![0xAA; layout.size()];
expected.extend_from_slice(&[0xBB; 4]);
assert_eq!(grown_data, expected.as_slice());
let layout = unsafe { Layout::from_size_align_unchecked(4, 4) };
let ptr = unsafe { arena.allocate(layout).unwrap() };
assert_eq!(ptr.len(), layout.size());
assert_eq!(arena.remaining(), 4080);
}
}