use core::mem::{MaybeUninit, align_of, size_of};
use super::aligned_vec::AlignedVec;
use super::alloc::*;
pub struct Arena<const ALIGN: usize = DEFAULT_ALIGN> {
buffer: AlignedVec<u8, ALIGN>,
offset: core::cell::Cell<usize>,
high_water_mark: core::cell::Cell<usize>,
}
impl<const ALIGN: usize> Arena<ALIGN> {
pub fn with_capacity(capacity: usize) -> Self {
Arena {
buffer: AlignedVec::zeros(capacity),
offset: core::cell::Cell::new(0),
high_water_mark: core::cell::Cell::new(0),
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.buffer.len()
}
#[inline]
pub fn used(&self) -> usize {
self.offset.get()
}
#[inline]
pub fn remaining(&self) -> usize {
self.buffer.len().saturating_sub(self.offset.get())
}
#[inline]
pub fn high_water_mark(&self) -> usize {
self.high_water_mark.get()
}
#[inline]
pub fn reset(&self) {
self.offset.set(0);
}
#[allow(clippy::mut_from_ref)]
pub fn alloc<T>(&self, count: usize) -> &mut [MaybeUninit<T>] {
let align = align_of::<T>().max(ALIGN);
let current_offset = self.offset.get();
let aligned_offset = round_up_pow2(current_offset, align);
let size = count * size_of::<T>();
let new_offset = aligned_offset + size;
assert!(
new_offset <= self.buffer.len(),
"Arena overflow: requested {} bytes but only {} available (capacity: {})",
size,
self.remaining(),
self.capacity()
);
let ptr =
unsafe { (self.buffer.as_ptr() as *mut u8).add(aligned_offset) as *mut MaybeUninit<T> };
self.offset.set(new_offset);
let hwm = self.high_water_mark.get();
if new_offset > hwm {
self.high_water_mark.set(new_offset);
}
unsafe { core::slice::from_raw_parts_mut(ptr, count) }
}
#[allow(clippy::mut_from_ref)]
pub fn alloc_zeroed<T: bytemuck::Zeroable>(&self, count: usize) -> &mut [T] {
let slice = self.alloc::<T>(count);
unsafe {
core::ptr::write_bytes(slice.as_mut_ptr() as *mut u8, 0, count * size_of::<T>());
core::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut T, count)
}
}
pub fn alloc_vec<T: bytemuck::Zeroable>(&self, len: usize) -> ArenaVec<'_, T, ALIGN> {
let slice = self.alloc_zeroed::<T>(len);
ArenaVec {
ptr: slice.as_mut_ptr(),
len,
_marker: core::marker::PhantomData,
}
}
#[allow(clippy::mut_from_ref)]
pub fn try_alloc<T>(&self, count: usize) -> Option<&mut [MaybeUninit<T>]> {
let align = align_of::<T>().max(ALIGN);
let current_offset = self.offset.get();
let aligned_offset = round_up_pow2(current_offset, align);
let size = count * size_of::<T>();
let new_offset = aligned_offset + size;
if new_offset > self.buffer.len() {
return None;
}
let ptr =
unsafe { (self.buffer.as_ptr() as *mut u8).add(aligned_offset) as *mut MaybeUninit<T> };
self.offset.set(new_offset);
let hwm = self.high_water_mark.get();
if new_offset > hwm {
self.high_water_mark.set(new_offset);
}
Some(unsafe { core::slice::from_raw_parts_mut(ptr, count) })
}
#[allow(clippy::mut_from_ref)]
pub fn try_alloc_zeroed<T: bytemuck::Zeroable>(&self, count: usize) -> Option<&mut [T]> {
let slice = self.try_alloc::<T>(count)?;
unsafe {
core::ptr::write_bytes(slice.as_mut_ptr() as *mut u8, 0, count * size_of::<T>());
Some(core::slice::from_raw_parts_mut(
slice.as_mut_ptr() as *mut T,
count,
))
}
}
#[inline]
pub fn save(&self) -> ArenaState {
ArenaState {
offset: self.offset.get(),
}
}
#[inline]
pub fn restore(&self, state: ArenaState) {
let current = self.offset.get();
assert!(
state.offset <= current,
"Invalid arena state: saved offset {} > current offset {}",
state.offset,
current
);
self.offset.set(state.offset);
}
pub fn grow(&mut self, min_capacity: usize) {
if min_capacity <= self.buffer.len() {
return;
}
let new_capacity = min_capacity.max(self.buffer.len() * 2);
let mut new_buffer: AlignedVec<u8, ALIGN> = AlignedVec::zeros(new_capacity);
let current_offset = self.offset.get();
unsafe {
core::ptr::copy_nonoverlapping(
self.buffer.as_ptr(),
new_buffer.as_mut_ptr(),
current_offset,
);
}
self.buffer = new_buffer;
}
}
impl Default for Arena<DEFAULT_ALIGN> {
fn default() -> Self {
Self::with_capacity(16 * 1024 * 1024)
}
}
#[derive(Debug, Clone, Copy)]
pub struct ArenaState {
offset: usize,
}
pub struct ArenaVec<'a, T, const ALIGN: usize = DEFAULT_ALIGN> {
ptr: *mut T,
len: usize,
_marker: core::marker::PhantomData<&'a mut [T]>,
}
impl<'a, T, const ALIGN: usize> ArenaVec<'a, T, ALIGN> {
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn as_ptr(&self) -> *const T {
self.ptr
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { core::slice::from_raw_parts(self.ptr, self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { core::slice::from_raw_parts_mut(self.ptr, self.len) }
}
}
impl<'a, T, const ALIGN: usize> core::ops::Deref for ArenaVec<'a, T, ALIGN> {
type Target = [T];
fn deref(&self) -> &Self::Target {
self.as_slice()
}
}
impl<'a, T, const ALIGN: usize> core::ops::DerefMut for ArenaVec<'a, T, ALIGN> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_slice()
}
}
impl<'a, T, const ALIGN: usize> core::ops::Index<usize> for ArenaVec<'a, T, ALIGN> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.as_slice()[index]
}
}
impl<'a, T, const ALIGN: usize> core::ops::IndexMut<usize> for ArenaVec<'a, T, ALIGN> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.as_mut_slice()[index]
}
}
#[cfg(feature = "std")]
pub fn with_blas_arena<F, R>(f: F) -> R
where
F: FnOnce(&mut Arena) -> R,
{
thread_local! {
static ARENA: std::cell::RefCell<Arena> = std::cell::RefCell::new(Arena::with_capacity(32 * 1024 * 1024)); }
ARENA.with(|cell| {
let mut arena = cell.borrow_mut();
arena.reset();
f(&mut arena)
})
}
#[derive(Debug, Clone, Copy)]
pub struct BlasArenaConfig {
pub capacity: usize,
pub auto_grow: bool,
pub max_capacity: usize,
}
impl Default for BlasArenaConfig {
fn default() -> Self {
BlasArenaConfig {
capacity: 32 * 1024 * 1024, auto_grow: true,
max_capacity: 512 * 1024 * 1024, }
}
}
impl BlasArenaConfig {
pub const fn small() -> Self {
BlasArenaConfig {
capacity: 4 * 1024 * 1024,
auto_grow: true,
max_capacity: 32 * 1024 * 1024,
}
}
pub const fn large() -> Self {
BlasArenaConfig {
capacity: 128 * 1024 * 1024,
auto_grow: true,
max_capacity: 1024 * 1024 * 1024,
}
}
pub const fn gemm_arena_size(m: usize, k: usize, n: usize, elem_size: usize) -> usize {
let mc = if m < 512 { m } else { 512 };
let kc = if k < 256 { k } else { 256 };
let nc = if n < 2048 { n } else { 2048 };
let pack_a_size = mc * kc * elem_size;
let pack_b_size = kc * nc * elem_size;
(pack_a_size + pack_b_size) * 12 / 10
}
}
#[cfg(test)]
mod arena_tests {
use super::*;
#[test]
fn test_arena_basic() {
let arena: Arena = Arena::with_capacity(1024);
{
let slice: &mut [f64] = arena.alloc_zeroed(10);
assert_eq!(slice.len(), 10);
slice[0] = 1.0;
slice[9] = 9.0;
assert_eq!(slice[0], 1.0);
assert_eq!(slice[9], 9.0);
}
assert_eq!(arena.used(), 80); }
#[test]
fn test_arena_reset() {
let arena: Arena = Arena::with_capacity(1024);
{
let _slice1: &mut [f64] = arena.alloc_zeroed(100);
}
assert_eq!(arena.used(), 800);
arena.reset();
assert_eq!(arena.used(), 0);
assert_eq!(arena.high_water_mark(), 800);
{
let _slice2: &mut [f64] = arena.alloc_zeroed(50);
}
assert_eq!(arena.used(), 400);
}
#[test]
fn test_arena_multiple_allocs() {
let arena: Arena = Arena::with_capacity(4096);
let mut buf1 = arena.alloc_vec::<f64>(10);
let mut buf2 = arena.alloc_vec::<f32>(20);
let mut buf3 = arena.alloc_vec::<u8>(100);
buf1[0] = 1.0;
buf2[0] = 2.0;
buf3[0] = 3;
assert_eq!(buf1[0], 1.0);
assert_eq!(buf2[0], 2.0);
assert_eq!(buf3[0], 3);
}
#[test]
fn test_arena_save_restore() {
let arena: Arena = Arena::with_capacity(4096);
{
let _buf1: &mut [f64] = arena.alloc_zeroed(10);
}
let saved_offset = arena.used();
let state = arena.save();
assert!(saved_offset > 0);
{
let _buf2: &mut [f64] = arena.alloc_zeroed(10);
}
let after_second = arena.used();
assert!(after_second > saved_offset);
arena.restore(state);
assert_eq!(arena.used(), saved_offset);
}
#[test]
fn test_arena_try_alloc() {
let arena: Arena = Arena::with_capacity(100);
{
let result: Option<&mut [f64]> = arena.try_alloc_zeroed(10);
assert!(result.is_some());
}
let result: Option<&mut [f64]> = arena.try_alloc_zeroed(100);
assert!(result.is_none());
}
#[test]
fn test_arena_vec() {
let arena: Arena = Arena::with_capacity(1024);
let mut vec = arena.alloc_vec::<f64>(10);
assert_eq!(vec.len(), 10);
assert!(!vec.is_empty());
vec[0] = 1.0;
vec[9] = 9.0;
assert_eq!(vec[0], 1.0);
assert_eq!(vec[9], 9.0);
assert_eq!(vec.as_slice()[0], 1.0);
}
#[test]
fn test_arena_alignment() {
let arena: Arena<128> = Arena::with_capacity(4096);
let buf: &mut [f64] = arena.alloc_zeroed(10);
let ptr = buf.as_ptr() as usize;
assert_eq!(ptr % 128, 0);
}
#[test]
fn test_arena_grow() {
let mut arena: Arena = Arena::with_capacity(100);
let _buf1: &mut [f64] = arena.alloc_zeroed(10);
assert_eq!(arena.capacity(), 100);
arena.grow(500);
assert!(arena.capacity() >= 500);
assert_eq!(arena.used(), 80); }
#[cfg(feature = "std")]
#[test]
fn test_with_blas_arena() {
with_blas_arena(|arena| {
let buf: &mut [f64] = arena.alloc_zeroed(1000);
buf[0] = 42.0;
assert_eq!(buf[0], 42.0);
});
with_blas_arena(|arena| {
assert_eq!(arena.used(), 0);
});
}
#[test]
fn test_blas_arena_config() {
let config = BlasArenaConfig::default();
assert_eq!(config.capacity, 32 * 1024 * 1024);
assert!(config.auto_grow);
let small = BlasArenaConfig::small();
assert_eq!(small.capacity, 4 * 1024 * 1024);
let large = BlasArenaConfig::large();
assert_eq!(large.capacity, 128 * 1024 * 1024);
}
#[test]
fn test_gemm_arena_size() {
let size = BlasArenaConfig::gemm_arena_size(1024, 1024, 1024, 8);
assert!(size > 0);
let expected = ((512 * 256 + 256 * 1024) * 8) * 12 / 10;
assert_eq!(size, expected);
}
#[test]
#[should_panic(expected = "Arena overflow")]
fn test_arena_overflow() {
let arena: Arena = Arena::with_capacity(100);
let _buf: &mut [f64] = arena.alloc_zeroed(100); }
}