use std::{alloc::Layout, ptr::NonNull};
use crate::{Allocator, BufferArena, chunk::Chunk};
#[derive(Debug)]
pub struct StackArena {
store: Vec<Chunk>,
stack: Vec<NonNull<[u8]>>,
current: BufferArena,
default_chunk_size: usize,
}
impl Default for StackArena {
fn default() -> Self {
let default_chunk_size = 1 << 12;
Self {
store: Vec::with_capacity(16),
stack: Vec::with_capacity(256),
current: Default::default(),
default_chunk_size,
}
}
}
impl StackArena {
#[inline(always)]
pub fn new() -> Self {
Self::with_chunk_size(4096)
}
#[inline]
pub fn with_chunk_size(chunk_size: usize) -> Self {
let current = BufferArena::with_capacity(chunk_size);
Self {
store: Vec::with_capacity(4),
stack: Vec::with_capacity(32),
current,
default_chunk_size: chunk_size,
}
}
#[inline]
pub fn len(&self) -> usize {
self.stack.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn pop(&mut self) -> Option<NonNull<[u8]>> {
let ptr = *self.stack.last().unwrap();
let layout = Layout::for_value(unsafe { ptr.as_ref() });
unsafe { self.deallocate(ptr.cast(), layout) };
Some(ptr)
}
#[inline]
pub fn top(&mut self) -> Option<NonNull<[u8]>> {
self.stack.last().map(|&ptr| ptr)
}
#[inline]
pub fn rollback(&mut self, data: &[u8]) {
let data = data.as_ref();
unsafe {
self.deallocate(
NonNull::new_unchecked(data.as_ptr().cast_mut()),
Layout::for_value(data),
)
};
}
}
impl Allocator for StackArena {
unsafe fn allocate(
&mut self,
layout: std::alloc::Layout,
) -> Result<std::ptr::NonNull<[u8]>, crate::AllocError> {
if !self.current.sufficient_for(layout) {
let capacity = layout.size().max(self.default_chunk_size);
let prev = std::mem::replace(&mut self.current, BufferArena::with_capacity(capacity));
self.store.push(prev.into());
}
unsafe {
let object = self.current.allocate(layout)?;
self.stack.push(object);
Ok(object)
}
}
#[inline]
unsafe fn deallocate(&mut self, ptr: std::ptr::NonNull<u8>, layout: std::alloc::Layout) {
let object = NonNull::slice_from_raw_parts(ptr, layout.size());
let pos = self
.stack
.iter()
.rposition(|item| std::ptr::eq(item.as_ptr(), object.as_ptr()))
.unwrap();
self.stack.truncate(pos);
if !self.current.contains(ptr) {
let pos = self
.store
.iter()
.rposition(|item| item.contains(ptr))
.unwrap();
std::mem::swap(&mut self.current.store, &mut self.store[pos]);
self.store.truncate(pos);
}
debug_assert!(self.current.contains(ptr));
unsafe { self.current.deallocate(ptr, layout) };
}
unsafe fn grow(
&mut self,
ptr: NonNull<u8>,
old_layout: std::alloc::Layout,
new_layout: std::alloc::Layout,
) -> Result<NonNull<[u8]>, crate::AllocError> {
let top = self.stack.pop().unwrap();
debug_assert_eq!(
top.cast().as_ptr(),
ptr.as_ptr(),
"this operation is only supported for the last allocation"
);
let object = if self.current.remaining() >= new_layout.size() - old_layout.size() {
unsafe { self.current.grow(ptr, old_layout, new_layout) }?
} else {
let capacity = new_layout.size().max(self.default_chunk_size);
let prev = std::mem::replace(&mut self.current, BufferArena::with_capacity(capacity));
self.store.push(prev.into());
unsafe {
let object = self.current.allocate(new_layout)?;
object
.cast()
.copy_from_nonoverlapping(ptr, old_layout.size());
object
}
};
self.stack.push(object);
Ok(object)
}
#[inline]
unsafe fn shrink(
&mut self,
ptr: NonNull<u8>,
old_layout: Layout,
new_layout: Layout,
) -> Result<NonNull<[u8]>, crate::AllocError> {
let top = self.stack.pop().unwrap();
debug_assert_eq!(
top.cast().as_ptr(),
ptr.as_ptr(),
"this operation is only supported for the last allocation"
);
debug_assert_eq!(top.cast().as_ptr(), ptr.as_ptr());
debug_assert_eq!(self.current.ptr, unsafe { ptr.add(old_layout.size()) });
let object = unsafe { self.current.shrink(ptr, old_layout, new_layout) }?;
self.stack.push(object);
Ok(object)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let stack = StackArena::new();
assert_eq!(stack.len(), 0);
assert!(stack.is_empty());
}
#[test]
fn test_allocate_deallocate() {
let mut stack = StackArena::new();
let layout = Layout::from_size_align(16, 8).unwrap();
let ptr = unsafe { stack.allocate(layout) }.unwrap();
unsafe { std::ptr::write_bytes(ptr.as_ptr() as *mut u8, 0xAA, 16) };
let data = unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u8, 16) };
assert_eq!(data, [0xAA; 16].as_slice());
unsafe { stack.deallocate(ptr.cast(), layout) };
assert_eq!(stack.len(), 0);
}
#[test]
fn test_multi_allocations() {
let mut stack = StackArena::new();
let n = 10;
let mut allocations = Vec::with_capacity(n);
for i in 0..n {
let layout = Layout::array::<u8>(i + 1).unwrap();
let ptr = unsafe { stack.allocate(layout) }.unwrap();
unsafe { std::ptr::write_bytes(ptr.as_ptr() as *mut u8, i as u8, i + 1) };
allocations.push((ptr, layout));
}
assert_eq!(stack.len(), n);
for (i, (ptr, _)) in allocations.iter().enumerate() {
let data = unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u8, i + 1) };
assert_eq!(data, vec![i as u8; i + 1].as_slice());
}
for (ptr, layout) in allocations.into_iter().rev() {
unsafe { stack.deallocate(ptr.cast(), layout) };
}
assert_eq!(stack.len(), 0);
}
#[test]
fn test_grow_shrink() {
let mut stack = StackArena::new();
let initial_layout = Layout::from_size_align(8, 8).unwrap();
let ptr = unsafe { stack.allocate(initial_layout) }.unwrap();
unsafe { std::ptr::write_bytes(ptr.as_ptr() as *mut u8, 0xAA, 8) };
let new_layout = Layout::from_size_align(16, 8).unwrap();
let grown_ptr = unsafe { stack.grow(ptr.cast(), initial_layout, new_layout) }.unwrap();
let data = unsafe { std::slice::from_raw_parts(grown_ptr.as_ptr() as *const u8, 8) };
assert_eq!(data, [0xAA; 8].as_slice());
unsafe { std::ptr::write_bytes((grown_ptr.as_ptr() as *mut u8).add(8), 0xBB, 8) };
let shrunk_ptr =
unsafe { stack.shrink(grown_ptr.cast(), new_layout, initial_layout) }.unwrap();
let final_data = unsafe { std::slice::from_raw_parts(shrunk_ptr.as_ptr() as *const u8, 8) };
assert_eq!(final_data, [0xAA; 8].as_slice());
unsafe { stack.deallocate(shrunk_ptr.cast(), initial_layout) };
assert_eq!(stack.len(), 0);
}
#[inline]
fn allocate_and_write<A: Allocator>(allocator: &mut A, size: usize) -> (NonNull<[u8]>, Layout) {
let layout = Layout::from_size_align(size, 8).unwrap();
let ptr = unsafe { allocator.allocate(layout).unwrap() };
unsafe { std::ptr::write_bytes(ptr.as_ptr() as *mut u8, 0xAA, size) };
(ptr, layout)
}
#[inline]
fn perform_consecutive_allocations<A: Allocator>(
allocator: &mut A,
count: usize,
) -> Vec<(NonNull<[u8]>, Layout)> {
let mut ptrs = Vec::with_capacity(count);
for i in 0..count {
let size = (i % 8) + 1; let (ptr, layout) = allocate_and_write(allocator, size);
ptrs.push((ptr, layout));
}
ptrs
}
#[test]
fn test_consecutive_allocations() {
let mut stack = StackArena::new();
let n = 10000;
let allocations = perform_consecutive_allocations(&mut stack, n);
assert_eq!(stack.len(), n);
for (ptr, layout) in allocations.into_iter().rev() {
unsafe { stack.deallocate(ptr.cast(), layout) };
}
assert_eq!(stack.len(), 0);
}
#[test]
fn test_custom_chunk_size() {
let mut stack = StackArena::with_chunk_size(256);
let small_layout = Layout::from_size_align(64, 8).unwrap();
let small_ptr = unsafe { stack.allocate(small_layout) }.unwrap();
let medium_layout = Layout::from_size_align(128, 8).unwrap();
let medium_ptr = unsafe { stack.allocate(medium_layout) }.unwrap();
let large_layout = Layout::from_size_align(512, 8).unwrap();
let large_ptr = unsafe { stack.allocate(large_layout) }.unwrap();
assert_eq!(stack.len(), 3);
unsafe {
stack.deallocate(large_ptr.cast(), large_layout);
stack.deallocate(medium_ptr.cast(), medium_layout);
stack.deallocate(small_ptr.cast(), small_layout);
}
assert_eq!(stack.len(), 0);
}
#[test]
fn test_cross_chunk_allocation_deallocation() {
let mut stack = StackArena::with_chunk_size(32);
let mut allocations = Vec::new();
let specs = [
(16, 8, 0xAA),
(8, 8, 0xBB),
(16, 8, 0xCC),
(8, 8, 0xDD),
(39, 4, 0xEE),
];
for (size, align, fill) in specs {
let layout = unsafe { Layout::from_size_align_unchecked(size, align) };
let ptr = unsafe { stack.allocate(layout).unwrap() };
unsafe { ptr.cast::<u8>().write_bytes(fill, layout.size()) };
allocations.push((ptr, layout));
}
assert_eq!(stack.len(), specs.len());
for ((size, _align, fill), (ptr, _layout)) in std::iter::zip(specs, &allocations) {
let data = unsafe { ptr.as_ref() };
let expected = vec![fill; size];
assert_eq!(data, &expected);
}
for (ptr, layout) in allocations.into_iter().rev() {
unsafe { stack.deallocate(ptr.cast(), layout) };
}
assert_eq!(stack.len(), 0);
let layout5 = Layout::from_size_align(24, 8).unwrap();
let ptr5 = unsafe { stack.allocate(layout5) }.unwrap();
unsafe { ptr5.cast::<u8>().write_bytes(0xFF, layout5.size()) };
assert_eq!(stack.len(), 1);
let data5 =
unsafe { std::slice::from_raw_parts(ptr5.as_ptr() as *const u8, layout5.size()) };
assert_eq!(data5, [0xFF; 24].as_slice());
}
#[test]
fn test_stack_arena_grow_with_new_chunk() {
let mut stack = StackArena::with_chunk_size(32);
let layout1 = Layout::from_size_align(8, 8).unwrap();
let ptr1 = unsafe { stack.allocate(layout1) }.unwrap();
unsafe { ptr1.cast::<u8>().write_bytes(0xAA, layout1.size()) };
let layout2 = Layout::from_size_align(16, 8).unwrap();
let ptr2 = unsafe { stack.allocate(layout2) }.unwrap();
unsafe { ptr2.cast::<u8>().write_bytes(0xBB, layout2.size()) };
let data1 =
unsafe { std::slice::from_raw_parts(ptr1.as_ptr() as *const u8, layout1.size()) };
assert_eq!(data1, [0xAA; 8].as_slice());
let data2 =
unsafe { std::slice::from_raw_parts(ptr2.as_ptr() as *const u8, layout2.size()) };
assert_eq!(data2, [0xBB; 16].as_slice());
unsafe { stack.deallocate(ptr2.cast(), layout2) };
let layout1_grown = Layout::from_size_align(24, 8).unwrap();
let ptr1_grown = unsafe { stack.grow(ptr1.cast(), layout1, layout1_grown) }.unwrap();
unsafe {
ptr1_grown
.cast::<u8>()
.add(layout1.size())
.write_bytes(0xCC, layout1_grown.size() - layout1.size())
};
let data1_grown = unsafe {
std::slice::from_raw_parts(ptr1_grown.as_ptr() as *const u8, layout1_grown.size())
};
let mut expected1_grown = vec![0xAA; layout1.size()];
expected1_grown.extend_from_slice(&vec![0xCC; layout1_grown.size() - layout1.size()]);
assert_eq!(data1_grown, expected1_grown.as_slice());
unsafe { stack.deallocate(ptr1_grown.cast(), layout1_grown) };
assert_eq!(stack.len(), 0);
}
#[test]
fn test_ptr_eq() {
let a: NonNull<[u8]> = NonNull::slice_from_raw_parts(NonNull::dangling(), 2);
let b: NonNull<[u8]> = NonNull::slice_from_raw_parts(NonNull::dangling(), 2);
assert!(std::ptr::eq(a.as_ptr(), b.as_ptr()));
}
}