use std::cell::UnsafeCell;
use std::mem::{align_of, size_of};
use std::sync::atomic::{AtomicU32, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, Weak};
use crate::error::{Error, Result};
pub use super::{FrameHeader, MAX_PLANES};
pub struct ArenaPool {
inner: Mutex<PoolInner>,
cap_per_arena: usize,
max_arenas: usize,
max_alloc_count_per_arena: u32,
}
struct PoolInner {
idle: Vec<Box<[u8]>>,
total_allocated: usize,
}
impl ArenaPool {
pub fn new(max_arenas: usize, cap_per_arena: usize) -> Arc<Self> {
Self::with_alloc_count_cap(max_arenas, cap_per_arena, 1_000_000)
}
pub fn with_alloc_count_cap(
max_arenas: usize,
cap_per_arena: usize,
max_alloc_count_per_arena: u32,
) -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(PoolInner {
idle: Vec::with_capacity(max_arenas),
total_allocated: 0,
}),
cap_per_arena,
max_arenas,
max_alloc_count_per_arena,
})
}
pub fn cap_per_arena(&self) -> usize {
self.cap_per_arena
}
pub fn max_arenas(&self) -> usize {
self.max_arenas
}
pub fn lease(self: &Arc<Self>) -> Result<Arena> {
let buffer = {
let mut inner = self.inner.lock().expect("ArenaPool mutex poisoned");
if let Some(buf) = inner.idle.pop() {
buf
} else if inner.total_allocated < self.max_arenas {
inner.total_allocated += 1;
vec![0u8; self.cap_per_arena].into_boxed_slice()
} else {
return Err(Error::resource_exhausted(format!(
"ArenaPool exhausted: all {} arenas checked out",
self.max_arenas
)));
}
};
Ok(Arena {
buffer: UnsafeCell::new(buffer),
cursor: AtomicUsize::new(0),
alloc_count: AtomicU32::new(0),
cap: self.cap_per_arena,
alloc_count_cap: self.max_alloc_count_per_arena,
pool: Arc::downgrade(self),
})
}
fn release(&self, buffer: Box<[u8]>) {
if let Ok(mut inner) = self.inner.lock() {
inner.idle.push(buffer);
}
}
}
pub struct Arena {
buffer: UnsafeCell<Box<[u8]>>,
cursor: AtomicUsize,
alloc_count: AtomicU32,
cap: usize,
alloc_count_cap: u32,
pool: Weak<ArenaPool>,
}
unsafe impl Send for Arena {}
unsafe impl Sync for Arena {}
impl Arena {
pub fn capacity(&self) -> usize {
self.cap
}
pub fn used(&self) -> usize {
self.cursor.load(Ordering::Acquire)
}
pub fn alloc_count(&self) -> u32 {
self.alloc_count.load(Ordering::Acquire)
}
pub fn alloc_count_exceeded(&self) -> bool {
self.alloc_count.load(Ordering::Acquire) >= self.alloc_count_cap
}
#[allow(clippy::mut_from_ref)] pub fn alloc<T>(&self, count: usize) -> Result<&mut [T]>
where
T: Copy,
{
let prev_count = self.alloc_count.fetch_add(1, Ordering::AcqRel);
if prev_count >= self.alloc_count_cap {
self.alloc_count.fetch_sub(1, Ordering::AcqRel);
return Err(Error::resource_exhausted(format!(
"Arena alloc-count cap of {} exceeded",
self.alloc_count_cap
)));
}
let elem_size = size_of::<T>();
let elem_align = align_of::<T>();
let bytes = elem_size.checked_mul(count).ok_or_else(|| {
self.alloc_count.fetch_sub(1, Ordering::AcqRel);
Error::resource_exhausted("Arena alloc size overflow".to_string())
})?;
let mut current = self.cursor.load(Ordering::Acquire);
let aligned;
let new_cursor;
loop {
let candidate_aligned = match align_up(current, elem_align) {
Some(a) => a,
None => {
self.alloc_count.fetch_sub(1, Ordering::AcqRel);
return Err(Error::resource_exhausted(
"Arena cursor alignment overflow".to_string(),
));
}
};
let candidate_new = match candidate_aligned.checked_add(bytes) {
Some(n) => n,
None => {
self.alloc_count.fetch_sub(1, Ordering::AcqRel);
return Err(Error::resource_exhausted(
"Arena cursor advance overflow".to_string(),
));
}
};
if candidate_new > self.cap {
self.alloc_count.fetch_sub(1, Ordering::AcqRel);
return Err(Error::resource_exhausted(format!(
"Arena cap of {} bytes exceeded (would consume {} bytes)",
self.cap, candidate_new
)));
}
match self.cursor.compare_exchange_weak(
current,
candidate_new,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(_) => {
aligned = candidate_aligned;
new_cursor = candidate_new;
let _ = new_cursor; break;
}
Err(observed) => {
current = observed;
}
}
}
let slice: &mut [T] = unsafe {
let buf_ptr = (*self.buffer.get()).as_mut_ptr();
let elem_ptr = buf_ptr.add(aligned).cast::<T>();
std::slice::from_raw_parts_mut(elem_ptr, count)
};
Ok(slice)
}
pub fn reset(&mut self) {
self.cursor.store(0, Ordering::Release);
self.alloc_count.store(0, Ordering::Release);
}
}
impl Drop for Arena {
fn drop(&mut self) {
let buffer = std::mem::replace(
unsafe { &mut *self.buffer.get() },
Vec::new().into_boxed_slice(),
);
if let Some(pool) = self.pool.upgrade() {
pool.release(buffer);
}
}
}
fn align_up(n: usize, align: usize) -> Option<usize> {
debug_assert!(align.is_power_of_two(), "alignment must be a power of two");
let mask = align - 1;
n.checked_add(mask).map(|m| m & !mask)
}
pub struct FrameInner {
arena: Arena,
plane_offsets: [(usize, usize); MAX_PLANES],
plane_count: u8,
header: FrameHeader,
}
pub type Frame = Arc<FrameInner>;
impl FrameInner {
pub fn new(arena: Arena, planes: &[(usize, usize)], header: FrameHeader) -> Result<Frame> {
if planes.len() > MAX_PLANES {
return Err(Error::invalid(format!(
"FrameInner supports at most {} planes (got {})",
MAX_PLANES,
planes.len()
)));
}
let used = arena.used();
for (i, (off, len)) in planes.iter().enumerate() {
let end = off
.checked_add(*len)
.ok_or_else(|| Error::invalid(format!("plane {i}: offset+len overflow")))?;
if end > used {
return Err(Error::invalid(format!(
"plane {i}: range {off}..{end} exceeds arena used={used}"
)));
}
}
let mut plane_offsets = [(0usize, 0usize); MAX_PLANES];
for (i, p) in planes.iter().enumerate() {
plane_offsets[i] = *p;
}
Ok(Arc::new(FrameInner {
arena,
plane_offsets,
plane_count: planes.len() as u8,
header,
}))
}
pub fn plane_count(&self) -> usize {
self.plane_count as usize
}
pub fn plane(&self, i: usize) -> Option<&[u8]> {
if i >= self.plane_count as usize {
return None;
}
let (off, len) = self.plane_offsets[i];
let buf: &[u8] = unsafe {
let buf_ref = &*self.arena.buffer.get();
&(**buf_ref)[off..off + len]
};
Some(buf)
}
pub fn header(&self) -> &FrameHeader {
&self.header
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::format::PixelFormat;
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn types_are_send_sync() {
assert_send_sync::<ArenaPool>();
assert_send_sync::<Arc<ArenaPool>>();
assert_send_sync::<Arena>();
assert_send_sync::<FrameInner>();
assert_send_sync::<Frame>();
}
fn small_pool(slots: usize, cap: usize) -> Arc<ArenaPool> {
ArenaPool::new(slots, cap)
}
#[test]
fn pool_lease_returns_err_when_exhausted() {
let pool = small_pool(2, 1024);
let a = pool.lease().expect("first lease");
let b = pool.lease().expect("second lease");
let third = pool.lease();
assert!(matches!(third, Err(Error::ResourceExhausted(_))));
drop((a, b));
}
#[test]
fn arena_alloc_caps_at_size_limit() {
let pool = small_pool(1, 64);
let arena = pool.lease().unwrap();
let _: &mut [u8] = arena.alloc::<u8>(32).unwrap();
let _: &mut [u8] = arena.alloc::<u8>(32).unwrap();
let third = arena.alloc::<u8>(1);
assert!(matches!(third, Err(Error::ResourceExhausted(_))));
}
#[test]
fn arena_alloc_count_cap_fires() {
let pool = ArenaPool::with_alloc_count_cap(1, 1024, 3);
let arena = pool.lease().unwrap();
let _: &mut [u8] = arena.alloc::<u8>(1).unwrap();
let _: &mut [u8] = arena.alloc::<u8>(1).unwrap();
let _: &mut [u8] = arena.alloc::<u8>(1).unwrap();
assert!(arena.alloc_count_exceeded());
let fourth = arena.alloc::<u8>(1);
assert!(matches!(fourth, Err(Error::ResourceExhausted(_))));
assert_eq!(arena.alloc_count(), 3);
}
#[test]
fn arena_returns_to_pool_on_drop() {
let pool = small_pool(1, 256);
{
let arena = pool.lease().expect("first lease");
assert!(matches!(pool.lease(), Err(Error::ResourceExhausted(_))));
drop(arena);
}
let _again = pool.lease().expect("re-lease after drop");
}
#[test]
fn arena_alignment_is_respected() {
let pool = small_pool(1, 64);
let arena = pool.lease().unwrap();
let _: &mut [u8] = arena.alloc::<u8>(1).unwrap();
let s: &mut [u32] = arena.alloc::<u32>(4).unwrap();
let addr = s.as_ptr() as usize;
assert_eq!(addr % align_of::<u32>(), 0);
assert_eq!(s.len(), 4);
}
fn build_simple_frame(pool: &Arc<ArenaPool>) -> Frame {
let arena = pool.lease().unwrap();
let plane0: &mut [u8] = arena.alloc::<u8>(16).unwrap();
for (i, b) in plane0.iter_mut().enumerate() {
*b = i as u8;
}
let header = FrameHeader::new(4, 4, PixelFormat::Gray8, Some(42));
FrameInner::new(arena, &[(0, 16)], header).unwrap()
}
#[test]
fn frame_refcount_keeps_arena_alive() {
let pool = small_pool(1, 256);
let frame = build_simple_frame(&pool);
let clone = Arc::clone(&frame);
drop(frame);
let plane = clone.plane(0).expect("plane 0");
assert_eq!(plane.len(), 16);
for (i, b) in plane.iter().enumerate() {
assert_eq!(*b, i as u8);
}
assert_eq!(clone.header().width, 4);
assert_eq!(clone.header().height, 4);
assert_eq!(clone.header().presentation_timestamp, Some(42));
assert!(matches!(pool.lease(), Err(Error::ResourceExhausted(_))));
}
#[test]
fn last_drop_returns_arena_to_pool() {
let pool = small_pool(1, 256);
let frame = build_simple_frame(&pool);
let clone = Arc::clone(&frame);
drop(frame);
drop(clone);
let _again = pool.lease().expect("lease after last drop");
}
#[test]
fn frame_rejects_too_many_planes() {
let pool = small_pool(1, 256);
let arena = pool.lease().unwrap();
let header = FrameHeader::new(1, 1, PixelFormat::Gray8, None);
let too_many = vec![(0usize, 0usize); MAX_PLANES + 1];
let r = FrameInner::new(arena, &too_many, header);
assert!(matches!(r, Err(Error::InvalidData(_))));
}
#[test]
fn frame_rejects_plane_outside_arena() {
let pool = small_pool(1, 64);
let arena = pool.lease().unwrap();
let header = FrameHeader::new(1, 1, PixelFormat::Gray8, None);
let r = FrameInner::new(arena, &[(0, 16)], header);
assert!(matches!(r, Err(Error::InvalidData(_))));
}
#[test]
fn pool_outlives_buffer_drop_when_pool_dropped_first() {
let pool = small_pool(1, 64);
let arena = pool.lease().unwrap();
drop(pool);
drop(arena);
}
#[test]
fn arena_reset_clears_allocations() {
let pool = small_pool(1, 32);
let mut arena = pool.lease().unwrap();
let _: &mut [u8] = arena.alloc::<u8>(32).unwrap();
assert!(matches!(
arena.alloc::<u8>(1),
Err(Error::ResourceExhausted(_))
));
arena.reset();
let _: &mut [u8] = arena.alloc::<u8>(32).unwrap();
}
#[test]
fn frame_can_be_sent_across_thread_boundary() {
let pool = small_pool(1, 256);
let frame = build_simple_frame(&pool);
let frame_for_worker = Arc::clone(&frame);
let handle = std::thread::spawn(move || {
let plane = frame_for_worker.plane(0).expect("plane 0 on worker");
let mut sum: u32 = 0;
for b in plane {
sum += *b as u32;
}
sum
});
let sum = handle.join().expect("worker joined");
assert_eq!(sum, (0..16u32).sum::<u32>());
assert_eq!(frame.plane(0).unwrap().len(), 16);
}
#[test]
fn concurrent_alloc_produces_disjoint_slices() {
let pool = small_pool(1, 256);
let arena = Arc::new(pool.lease().unwrap());
let a = Arc::clone(&arena);
let b = Arc::clone(&arena);
let h1 = std::thread::spawn(move || {
let s: &mut [u8] = a.alloc::<u8>(64).unwrap();
for x in s.iter_mut() {
*x = 0xAA;
}
(s.as_ptr() as usize, s.len())
});
let h2 = std::thread::spawn(move || {
let s: &mut [u8] = b.alloc::<u8>(64).unwrap();
for x in s.iter_mut() {
*x = 0xBB;
}
(s.as_ptr() as usize, s.len())
});
let (p1, l1) = h1.join().unwrap();
let (p2, l2) = h2.join().unwrap();
let no_overlap = p1 + l1 <= p2 || p2 + l2 <= p1;
assert!(no_overlap, "concurrent alloc returned overlapping slices");
}
}