use std::cell::UnsafeCell;
use std::sync::atomic::{AtomicUsize, Ordering};
use fastarena::{Arena, ArenaStats};
#[allow(dead_code)]
const DEFAULT_INITIAL_BLOCK: usize = 64 * 1024;
pub(crate) struct ConcurrentArena {
shards: Vec<UnsafeCell<Arena>>,
shard_assign: AtomicUsize,
bytes_allocated: AtomicUsize,
}
const MAX_CACHED_ARENAS: usize = 8;
thread_local! {
static SHARD_CACHE: std::cell::RefCell<[(usize, usize); MAX_CACHED_ARENAS]> =
const { std::cell::RefCell::new([(0, usize::MAX); MAX_CACHED_ARENAS]) };
}
unsafe impl Send for ConcurrentArena {}
unsafe impl Sync for ConcurrentArena {}
impl ConcurrentArena {
#[allow(dead_code)]
pub fn new(num_shards: usize) -> Self {
Self::with_block_size(num_shards, DEFAULT_INITIAL_BLOCK)
}
pub fn with_block_size(num_shards: usize, block_size: usize) -> Self {
assert!(num_shards > 0, "num_shards must be > 0");
let mut shards = Vec::with_capacity(num_shards);
for _ in 0..num_shards {
shards.push(UnsafeCell::new(Arena::with_capacity(block_size)));
}
ConcurrentArena {
shards,
shard_assign: AtomicUsize::new(0),
bytes_allocated: AtomicUsize::new(0),
}
}
#[inline(always)]
pub fn record_alloc(&self, size: usize) {
self.bytes_allocated.fetch_add(size, Ordering::Relaxed);
}
#[inline(always)]
pub fn bytes_allocated_fast(&self) -> usize {
self.bytes_allocated.load(Ordering::Relaxed)
}
#[allow(clippy::mut_from_ref)]
pub fn local(&self) -> &mut Arena {
let self_ptr = self as *const Self as usize;
SHARD_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
for i in 0..MAX_CACHED_ARENAS {
if cache[i].0 == self_ptr {
return unsafe { &mut *self.shards[cache[i].1].get() };
}
if cache[i].1 == usize::MAX {
let assigned = self.shard_assign.fetch_add(1, Ordering::Relaxed);
assert!(
assigned < self.shards.len(),
"more concurrent threads ({}) than arena shards ({}); \
increase shard count via with_shards() or with_capacity_and_shards()",
assigned + 1,
self.shards.len()
);
cache[i] = (self_ptr, assigned);
return unsafe { &mut *self.shards[assigned].get() };
}
}
let assigned = self.shard_assign.fetch_add(1, Ordering::Relaxed);
assert!(
assigned < self.shards.len(),
"more concurrent threads ({}) than arena shards ({}); \
increase shard count via with_shards() or with_capacity_and_shards()",
assigned + 1,
self.shards.len()
);
cache[0] = (self_ptr, assigned);
unsafe { &mut *self.shards[assigned].get() }
})
}
pub fn stats(&self) -> ArenaStats {
let mut total = ArenaStats::default();
for shard in &self.shards {
let arena = unsafe { &*shard.get() };
let s = arena.stats();
total.bytes_allocated += s.bytes_allocated;
total.bytes_reserved += s.bytes_reserved;
total.block_count += s.block_count;
}
total
}
pub unsafe fn reset_all(&mut self) {
for shard in &mut self.shards {
shard.get_mut().reset();
}
self.bytes_allocated.store(0, Ordering::Relaxed);
let self_ptr = self as *const Self as usize;
SHARD_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
for i in 0..MAX_CACHED_ARENAS {
if cache[i].0 == self_ptr {
cache[i] = (0, usize::MAX);
break;
}
}
});
}
#[allow(dead_code)]
pub fn reset_local() {
SHARD_CACHE.with(|cache| {
*cache.borrow_mut() = [(0, usize::MAX); MAX_CACHED_ARENAS];
});
}
#[allow(dead_code)]
pub fn num_shards(&self) -> usize {
self.shards.len()
}
}
impl Drop for ConcurrentArena {
fn drop(&mut self) {
let self_ptr = self as *const Self as usize;
SHARD_CACHE.with(|cache| {
let mut cache = cache.borrow_mut();
for i in 0..MAX_CACHED_ARENAS {
if cache[i].0 == self_ptr {
cache[i] = (0, usize::MAX);
}
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_concurrent_arena_basic() {
let arena = ConcurrentArena::new(4);
let local = arena.local();
let val = local.alloc(42u64);
assert_eq!(*val, 42);
}
#[test]
fn test_stats_aggregation() {
let arena = ConcurrentArena::new(2);
{
let local = arena.local();
local.alloc(1u64);
}
let stats = arena.stats();
assert!(stats.bytes_allocated > 0);
}
#[test]
fn test_multiple_arenas_same_thread() {
let arena_a = ConcurrentArena::new(4);
let arena_b = ConcurrentArena::new(4);
let local_a = arena_a.local();
local_a.alloc(1u64);
let local_b = arena_b.local();
local_b.alloc(2u64);
assert_eq!(*arena_a.local().alloc(3u64), 3);
}
}