use std::cell::UnsafeCell;
use std::sync::Arc;
use super::arena::FrameArena;
use super::global::GlobalState;
use super::slab::SlabAllocator;
pub struct ThreadLocalState {
arena: FrameArena,
slab: SlabAllocator,
global: Arc<GlobalState>,
frame: u64,
alloc_count: u64,
total_bytes: usize,
}
impl ThreadLocalState {
pub fn new(arena_size: usize, global: Arc<GlobalState>) -> Self {
Self {
arena: FrameArena::new(arena_size).expect("Failed to create frame arena"),
slab: SlabAllocator::new(),
global,
frame: 0,
alloc_count: 0,
total_bytes: 0,
}
}
pub fn arena(&self) -> &FrameArena {
&self.arena
}
pub fn slab(&self) -> &SlabAllocator {
&self.slab
}
pub fn slab_mut(&mut self) -> &mut SlabAllocator {
&mut self.slab
}
#[inline(always)]
pub fn begin_frame(&mut self) {
self.frame += 1;
}
#[inline(always)]
pub fn end_frame(&mut self) {
self.arena.reset();
}
pub fn frame(&self) -> u64 {
self.frame
}
pub fn frame_head(&self) -> usize {
self.arena.head()
}
pub fn reset_frame_to(&self, pos: usize) {
self.arena.reset_to(pos);
}
fn record_alloc(&mut self, bytes: usize) {
self.alloc_count += 1;
self.total_bytes += bytes;
self.global.record_alloc(bytes);
}
#[inline(always)]
pub fn frame_alloc<T>(&mut self) -> *mut T {
let layout = std::alloc::Layout::new::<T>();
let ptr = self.arena.alloc(layout) as *mut T;
if !ptr.is_null() {
self.record_alloc(layout.size());
}
ptr
}
#[inline(always)]
pub fn frame_alloc_value<T>(&mut self, value: T) -> *mut T {
let ptr = self.arena.alloc_value(value);
if !ptr.is_null() {
self.record_alloc(std::mem::size_of::<T>());
}
ptr
}
#[inline(always)]
pub fn frame_alloc_slice<T>(&mut self, len: usize) -> *mut T {
let ptr: *mut T = self.arena.alloc_slice(len);
if !ptr.is_null() {
self.record_alloc(std::mem::size_of::<T>() * len);
}
ptr
}
pub fn pool_alloc<T>(&mut self) -> *mut T {
let ptr: *mut T = self.slab.alloc::<T>();
if !ptr.is_null() {
self.record_alloc(std::mem::size_of::<T>());
}
ptr
}
pub fn pool_free<T>(&mut self, ptr: *mut T) {
self.slab.free(ptr);
}
pub fn stats(&self) -> ThreadLocalStats {
ThreadLocalStats {
allocation_count: self.alloc_count,
total_allocated: self.total_bytes,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ThreadLocalStats {
pub allocation_count: u64,
pub total_allocated: usize,
}
thread_local! {
static TLS: UnsafeCell<Option<ThreadLocalState>> = const { UnsafeCell::new(None) };
}
pub fn init_tls(arena_size: usize, global: Arc<GlobalState>) {
TLS.with(|tls| {
let tls = unsafe { &mut *tls.get() };
if tls.is_none() {
*tls = Some(ThreadLocalState::new(arena_size, global));
}
});
}
pub fn is_tls_initialized() -> bool {
TLS.with(|tls| {
let tls = unsafe { &*tls.get() };
tls.is_some()
})
}
#[inline(always)]
pub fn with_tls<F, R>(f: F) -> R
where
F: FnOnce(&ThreadLocalState) -> R,
{
TLS.with(|tls| {
let tls = unsafe { &*tls.get() };
let state = tls.as_ref().expect("Thread-local state not initialized. Call init_tls first.");
f(state)
})
}
#[inline(always)]
pub fn with_tls_mut<F, R>(f: F) -> R
where
F: FnOnce(&mut ThreadLocalState) -> R,
{
TLS.with(|tls| {
let tls = unsafe { &mut *tls.get() };
let state = tls.as_mut().expect("Thread-local state not initialized. Call init_tls first.");
f(state)
})
}
pub fn try_with_tls<F, R>(f: F) -> Option<R>
where
F: FnOnce(&ThreadLocalState) -> R,
{
TLS.with(|tls| {
let tls = unsafe { &*tls.get() };
tls.as_ref().map(f)
})
}