use std::alloc::{Layout, alloc, dealloc};
use std::cell::UnsafeCell;
use std::ptr::NonNull;
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, AtomicU64, AtomicUsize, Ordering};
const DEFAULT_BLOCK_SIZE: usize = 2 * 1024 * 1024;
const MIN_ALIGN: usize = 8;
const MAX_INLINE_KEY_SIZE: usize = 256;
#[derive(Clone, Copy)]
pub struct ArenaHandle {
ptr: NonNull<u8>,
len: u32,
epoch: u64,
}
impl ArenaHandle {
#[inline]
pub(crate) unsafe fn new(ptr: NonNull<u8>, len: usize, epoch: u64) -> Self {
Self {
ptr,
len: len as u32,
epoch,
}
}
#[inline]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[inline]
pub fn len(&self) -> usize {
self.len as usize
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub unsafe fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len as usize) }
}
#[inline]
pub unsafe fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len as usize) }
}
}
unsafe impl Send for ArenaHandle {}
unsafe impl Sync for ArenaHandle {}
struct MemoryBlock {
data: NonNull<u8>,
size: usize,
offset: AtomicUsize,
layout: Layout,
}
impl MemoryBlock {
fn new(size: usize) -> Option<Self> {
let layout = Layout::from_size_align(size, MIN_ALIGN).ok()?;
let ptr = unsafe { alloc(layout) };
let data = NonNull::new(ptr)?;
Some(Self {
data,
size,
offset: AtomicUsize::new(0),
layout,
})
}
#[inline]
fn allocate(&self, size: usize, align: usize) -> Option<NonNull<u8>> {
let base = self.data.as_ptr() as usize;
loop {
let current = self.offset.load(Ordering::Relaxed);
let current_addr = base + current;
let aligned_addr = (current_addr + align - 1) & !(align - 1);
let aligned = aligned_addr - base;
let new_offset = aligned + size;
if new_offset > self.size {
return None;
}
match self.offset.compare_exchange_weak(
current,
new_offset,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => {
let ptr = unsafe { self.data.as_ptr().add(aligned) };
return NonNull::new(ptr);
}
Err(_) => continue, }
}
}
#[inline]
#[allow(dead_code)]
fn remaining(&self) -> usize {
self.size
.saturating_sub(self.offset.load(Ordering::Relaxed))
}
#[inline]
fn used(&self) -> usize {
self.offset.load(Ordering::Relaxed)
}
fn reset(&self) {
self.offset.store(0, Ordering::Release);
}
}
impl Drop for MemoryBlock {
fn drop(&mut self) {
unsafe {
dealloc(self.data.as_ptr(), self.layout);
}
}
}
unsafe impl Send for MemoryBlock {}
unsafe impl Sync for MemoryBlock {}
pub struct EpochArena {
epoch: AtomicU64,
blocks: UnsafeCell<Vec<MemoryBlock>>,
active_block: AtomicUsize,
block_size: usize,
total_allocated: AtomicUsize,
allocation_count: AtomicUsize,
block_lock: std::sync::Mutex<()>,
}
impl EpochArena {
pub fn new(epoch: u64) -> Self {
Self::with_block_size(epoch, DEFAULT_BLOCK_SIZE)
}
pub fn with_block_size(epoch: u64, block_size: usize) -> Self {
let initial_block = MemoryBlock::new(block_size).expect("Failed to allocate initial block");
Self {
epoch: AtomicU64::new(epoch),
blocks: UnsafeCell::new(vec![initial_block]),
active_block: AtomicUsize::new(0),
block_size,
total_allocated: AtomicUsize::new(0),
allocation_count: AtomicUsize::new(0),
block_lock: std::sync::Mutex::new(()),
}
}
#[inline]
pub fn epoch(&self) -> u64 {
self.epoch.load(Ordering::Relaxed)
}
#[inline]
pub fn allocate(&self, size: usize) -> Option<ArenaHandle> {
self.allocate_aligned(size, MIN_ALIGN)
}
pub fn allocate_aligned(&self, size: usize, align: usize) -> Option<ArenaHandle> {
if size == 0 {
return None;
}
let active_idx = self.active_block.load(Ordering::Acquire);
let blocks = unsafe { &*self.blocks.get() };
if active_idx < blocks.len() {
if let Some(ptr) = blocks[active_idx].allocate(size, align) {
self.total_allocated.fetch_add(size, Ordering::Relaxed);
self.allocation_count.fetch_add(1, Ordering::Relaxed);
return Some(unsafe { ArenaHandle::new(ptr, size, self.epoch()) });
}
}
self.allocate_slow(size, align)
}
#[cold]
fn allocate_slow(&self, size: usize, align: usize) -> Option<ArenaHandle> {
let _guard = self.block_lock.lock().ok()?;
let active_idx = self.active_block.load(Ordering::Acquire);
let blocks = unsafe { &mut *self.blocks.get() };
if active_idx < blocks.len() {
if let Some(ptr) = blocks[active_idx].allocate(size, align) {
self.total_allocated.fetch_add(size, Ordering::Relaxed);
self.allocation_count.fetch_add(1, Ordering::Relaxed);
return Some(unsafe { ArenaHandle::new(ptr, size, self.epoch()) });
}
}
let new_block_size = self.block_size.max(size + align);
let new_block = MemoryBlock::new(new_block_size)?;
let ptr = new_block.allocate(size, align)?;
blocks.push(new_block);
self.active_block.store(blocks.len() - 1, Ordering::Release);
self.total_allocated.fetch_add(size, Ordering::Relaxed);
self.allocation_count.fetch_add(1, Ordering::Relaxed);
Some(unsafe { ArenaHandle::new(ptr, size, self.epoch()) })
}
#[inline]
pub fn allocate_copy(&self, data: &[u8]) -> Option<ArenaHandle> {
let handle = self.allocate(data.len())?;
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), handle.ptr.as_ptr(), data.len());
}
Some(handle)
}
#[inline]
pub fn allocate_key(&self, key: &[u8]) -> Option<ArenaHandle> {
if key.len() > MAX_INLINE_KEY_SIZE {
return None;
}
self.allocate_aligned(key.len(), 16).map(|handle| {
unsafe {
std::ptr::copy_nonoverlapping(key.as_ptr(), handle.ptr.as_ptr(), key.len());
}
handle
})
}
pub fn stats(&self) -> ArenaStats {
let blocks = unsafe { &*self.blocks.get() };
ArenaStats {
epoch: self.epoch(),
block_count: blocks.len(),
total_capacity: blocks.iter().map(|b| b.size).sum(),
total_used: blocks.iter().map(|b| b.used()).sum(),
total_allocated: self.total_allocated.load(Ordering::Relaxed),
allocation_count: self.allocation_count.load(Ordering::Relaxed),
}
}
pub fn reset(&self, new_epoch: u64) {
let _guard = self.block_lock.lock().unwrap();
let blocks = unsafe { &*self.blocks.get() };
for block in blocks {
block.reset();
}
self.epoch.store(new_epoch, Ordering::Release);
self.active_block.store(0, Ordering::Release);
self.total_allocated.store(0, Ordering::Relaxed);
self.allocation_count.store(0, Ordering::Relaxed);
}
}
unsafe impl Send for EpochArena {}
unsafe impl Sync for EpochArena {}
#[derive(Debug, Clone)]
pub struct ArenaStats {
pub epoch: u64,
pub block_count: usize,
pub total_capacity: usize,
pub total_used: usize,
pub total_allocated: usize,
pub allocation_count: usize,
}
pub struct ArenaPool {
arenas: Vec<Arc<EpochArena>>,
current_epoch: AtomicU64,
pool_size: usize,
#[allow(dead_code)]
block_size: usize,
}
impl ArenaPool {
pub fn new(pool_size: usize) -> Self {
Self::with_block_size(pool_size, DEFAULT_BLOCK_SIZE)
}
pub fn with_block_size(pool_size: usize, block_size: usize) -> Self {
let arenas = (0..pool_size)
.map(|i| Arc::new(EpochArena::with_block_size(i as u64, block_size)))
.collect();
Self {
arenas,
current_epoch: AtomicU64::new(0),
pool_size,
block_size,
}
}
#[inline]
pub fn current_epoch(&self) -> u64 {
self.current_epoch.load(Ordering::Acquire)
}
#[inline]
pub fn current_arena(&self) -> Arc<EpochArena> {
let epoch = self.current_epoch();
let idx = (epoch as usize) % self.pool_size;
self.arenas[idx].clone()
}
#[inline]
pub fn allocate(&self, size: usize) -> Option<ArenaHandle> {
self.current_arena().allocate(size)
}
#[inline]
pub fn allocate_key(&self, key: &[u8]) -> Option<ArenaHandle> {
self.current_arena().allocate_key(key)
}
pub fn advance_epoch(&self) -> u64 {
let new_epoch = self.current_epoch.fetch_add(1, Ordering::AcqRel) + 1;
let next_idx = (new_epoch as usize) % self.pool_size;
self.arenas[next_idx].reset(new_epoch);
new_epoch
}
#[inline]
pub fn is_epoch_valid(&self, epoch: u64) -> bool {
let current = self.current_epoch();
epoch + (self.pool_size as u64) > current
}
pub fn stats(&self) -> Vec<ArenaStats> {
self.arenas.iter().map(|a| a.stats()).collect()
}
}
pub struct ThreadLocalArena {
pool: Arc<ArenaPool>,
cached_arena: AtomicPtr<EpochArena>,
cached_epoch: AtomicU64,
}
impl ThreadLocalArena {
pub fn new(pool: Arc<ArenaPool>) -> Self {
let arena = pool.current_arena();
let epoch = arena.epoch();
Self {
pool,
cached_arena: AtomicPtr::new(Arc::into_raw(arena) as *mut _),
cached_epoch: AtomicU64::new(epoch),
}
}
#[inline]
pub fn allocate(&self, size: usize) -> Option<ArenaHandle> {
let current_epoch = self.pool.current_epoch();
let cached_epoch = self.cached_epoch.load(Ordering::Relaxed);
if current_epoch == cached_epoch {
let arena_ptr = self.cached_arena.load(Ordering::Acquire);
if !arena_ptr.is_null() {
let arena = unsafe { &*arena_ptr };
return arena.allocate(size);
}
}
self.allocate_slow(size, current_epoch)
}
#[cold]
fn allocate_slow(&self, size: usize, _current_epoch: u64) -> Option<ArenaHandle> {
let new_arena = self.pool.current_arena();
let new_epoch = new_arena.epoch();
let old_ptr = self
.cached_arena
.swap(Arc::into_raw(new_arena.clone()) as *mut _, Ordering::AcqRel);
self.cached_epoch.store(new_epoch, Ordering::Release);
if !old_ptr.is_null() {
unsafe { Arc::from_raw(old_ptr as *const EpochArena) };
}
new_arena.allocate(size)
}
#[inline]
pub fn allocate_key(&self, key: &[u8]) -> Option<ArenaHandle> {
if key.len() > MAX_INLINE_KEY_SIZE {
return None;
}
self.allocate(key.len()).map(|handle| {
unsafe {
std::ptr::copy_nonoverlapping(key.as_ptr(), handle.ptr.as_ptr(), key.len());
}
handle
})
}
}
impl Drop for ThreadLocalArena {
fn drop(&mut self) {
let ptr = self.cached_arena.load(Ordering::Acquire);
if !ptr.is_null() {
unsafe { Arc::from_raw(ptr as *const EpochArena) };
}
}
}
unsafe impl Send for ThreadLocalArena {}
unsafe impl Sync for ThreadLocalArena {}
#[derive(Clone, Copy)]
pub struct ArenaKey {
handle: ArenaHandle,
}
impl ArenaKey {
#[inline]
pub fn new(handle: ArenaHandle) -> Self {
Self { handle }
}
#[inline]
pub unsafe fn as_bytes(&self) -> &[u8] {
unsafe { self.handle.as_slice() }
}
#[inline]
pub fn epoch(&self) -> u64 {
self.handle.epoch()
}
#[inline]
pub fn len(&self) -> usize {
self.handle.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.handle.is_empty()
}
}
unsafe impl Send for ArenaKey {}
unsafe impl Sync for ArenaKey {}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_epoch_arena_basic() {
let arena = EpochArena::new(1);
let h1 = arena.allocate(16).unwrap();
let h2 = arena.allocate(32).unwrap();
let h3 = arena.allocate(64).unwrap();
assert_eq!(h1.len(), 16);
assert_eq!(h2.len(), 32);
assert_eq!(h3.len(), 64);
assert_eq!(h1.epoch(), 1);
let stats = arena.stats();
assert_eq!(stats.allocation_count, 3);
}
#[test]
fn test_allocate_copy() {
let arena = EpochArena::new(1);
let data = b"hello world";
let handle = arena.allocate_copy(data).unwrap();
assert_eq!(handle.len(), data.len());
let slice = unsafe { handle.as_slice() };
assert_eq!(slice, data);
}
#[test]
fn test_allocate_key() {
let arena = EpochArena::new(1);
let key = b"my_test_key";
let handle = arena.allocate_key(key).unwrap();
let slice = unsafe { handle.as_slice() };
assert_eq!(slice, key);
}
#[test]
fn test_arena_reset() {
let arena = EpochArena::new(1);
for _ in 0..1000 {
arena.allocate(64).unwrap();
}
let stats_before = arena.stats();
assert!(stats_before.total_allocated > 0);
arena.reset(2);
let stats_after = arena.stats();
assert_eq!(stats_after.epoch, 2);
assert_eq!(stats_after.allocation_count, 0);
}
#[test]
fn test_arena_pool() {
let pool = ArenaPool::new(4);
let h1 = pool.allocate(16).unwrap();
assert_eq!(h1.epoch(), 0);
pool.advance_epoch();
let h2 = pool.allocate(16).unwrap();
assert_eq!(h2.epoch(), 1);
assert!(pool.is_epoch_valid(0));
assert!(pool.is_epoch_valid(1));
}
#[test]
fn test_thread_local_arena() {
let pool = Arc::new(ArenaPool::new(4));
let tla = ThreadLocalArena::new(pool.clone());
let h1 = tla.allocate(32).unwrap();
assert_eq!(h1.len(), 32);
let h2 = tla.allocate_key(b"test").unwrap();
assert_eq!(h2.len(), 4);
}
#[test]
fn test_concurrent_allocation() {
let pool = Arc::new(ArenaPool::new(4));
let mut handles = vec![];
for _ in 0..8 {
let pool_clone = pool.clone();
handles.push(thread::spawn(move || {
for i in 0..10000 {
let size = (i % 64) + 8;
pool_clone.allocate(size).expect("allocation failed");
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = pool.stats();
let total_allocs: usize = stats.iter().map(|s| s.allocation_count).sum();
assert_eq!(total_allocs, 80000);
}
#[test]
fn test_large_allocation() {
let arena = EpochArena::new(1);
let large_size = 3 * 1024 * 1024;
let handle = arena.allocate(large_size).unwrap();
assert_eq!(handle.len(), large_size);
let stats = arena.stats();
assert!(stats.block_count >= 2); }
#[test]
fn test_alignment() {
let arena = EpochArena::new(1);
let h1 = arena.allocate_aligned(17, 16).unwrap();
assert!((h1.ptr.as_ptr() as usize) % 16 == 0);
let h2 = arena.allocate_aligned(65, 64).unwrap();
assert!((h2.ptr.as_ptr() as usize) % 64 == 0);
}
#[test]
fn test_arena_key() {
let arena = EpochArena::new(42);
let key_data = b"user:12345:profile";
let handle = arena.allocate_key(key_data).unwrap();
let key = ArenaKey::new(handle);
assert_eq!(key.len(), key_data.len());
assert_eq!(key.epoch(), 42);
let bytes = unsafe { key.as_bytes() };
assert_eq!(bytes, key_data);
}
#[test]
fn test_epoch_advancement() {
let pool = ArenaPool::new(4);
for expected_epoch in 1..=10 {
let new_epoch = pool.advance_epoch();
assert_eq!(new_epoch, expected_epoch);
}
assert!(!pool.is_epoch_valid(0)); assert!(pool.is_epoch_valid(7)); assert!(pool.is_epoch_valid(10)); }
}