memkit 0.2.0-beta.1

Deterministic, intent-driven memory allocation for systems requiring predictable performance
Documentation
//! Thread-local storage for per-thread allocator state.

use std::cell::UnsafeCell;
use std::sync::Arc;

use super::arena::FrameArena;
use super::global::GlobalState;
use super::slab::SlabAllocator;

/// Per-thread allocator state.
pub struct ThreadLocalState {
    /// Frame arena for this thread.
    arena: FrameArena,
    /// Slab allocator for small objects.
    slab: SlabAllocator,
    /// Reference to global state.
    global: Arc<GlobalState>,
    /// Current frame number.
    frame: u64,
    /// Total allocations in this thread.
    alloc_count: u64,
    /// Total bytes allocated in this thread.
    total_bytes: usize,
}

impl ThreadLocalState {
    /// Create new thread-local state.
    pub fn new(arena_size: usize, global: Arc<GlobalState>) -> Self {
        Self {
            arena: FrameArena::new(arena_size).expect("Failed to create frame arena"),
            slab: SlabAllocator::new(),
            global,
            frame: 0,
            alloc_count: 0,
            total_bytes: 0,
        }
    }

    /// Get the frame arena.
    pub fn arena(&self) -> &FrameArena {
        &self.arena
    }

    /// Get the slab allocator.
    pub fn slab(&self) -> &SlabAllocator {
        &self.slab
    }

    /// Get mutable slab allocator.
    pub fn slab_mut(&mut self) -> &mut SlabAllocator {
        &mut self.slab
    }

    /// Begin a new frame.
    #[inline(always)]
    pub fn begin_frame(&mut self) {
        self.frame += 1;
    }

    /// End the current frame, resetting the arena.
    #[inline(always)]
    pub fn end_frame(&mut self) {
        self.arena.reset();
    }

    /// Get the current frame number.
    pub fn frame(&self) -> u64 {
        self.frame
    }

    /// Get the frame head position (for checkpointing).
    pub fn frame_head(&self) -> usize {
        self.arena.head()
    }

    /// Reset frame to a checkpoint.
    pub fn reset_frame_to(&self, pos: usize) {
        self.arena.reset_to(pos);
    }

    fn record_alloc(&mut self, bytes: usize) {
        self.alloc_count += 1;
        self.total_bytes += bytes;
        self.global.record_alloc(bytes);
    }

    /// Allocate from the frame arena.
    #[inline(always)]
    pub fn frame_alloc<T>(&mut self) -> *mut T {
        let layout = std::alloc::Layout::new::<T>();
        let ptr = self.arena.alloc(layout) as *mut T;
        if !ptr.is_null() {
            self.record_alloc(layout.size());
        }
        ptr
    }

    /// Allocate and initialize a value in the frame arena.
    #[inline(always)]
    pub fn frame_alloc_value<T>(&mut self, value: T) -> *mut T {
        let ptr = self.arena.alloc_value(value);
        if !ptr.is_null() {
            self.record_alloc(std::mem::size_of::<T>());
        }
        ptr
    }

    /// Allocate a slice in the frame arena.
    #[inline(always)]
    pub fn frame_alloc_slice<T>(&mut self, len: usize) -> *mut T {
        let ptr: *mut T = self.arena.alloc_slice(len);
        if !ptr.is_null() {
            self.record_alloc(std::mem::size_of::<T>() * len);
        }
        ptr
    }

    /// Allocate from the slab (pool).
    pub fn pool_alloc<T>(&mut self) -> *mut T {
        let ptr: *mut T = self.slab.alloc::<T>();
        if !ptr.is_null() {
            self.record_alloc(std::mem::size_of::<T>());
        }
        ptr
    }

    /// Free to the slab (pool).
    pub fn pool_free<T>(&mut self, ptr: *mut T) {
        self.slab.free(ptr);
        // We don't decrement alloc_count here as it's an "allocation event" counter
    }

    /// Get thread-local statistics.
    pub fn stats(&self) -> ThreadLocalStats {
        ThreadLocalStats {
            allocation_count: self.alloc_count,
            total_allocated: self.total_bytes,
        }
    }
}

/// Thread-local statistics.
#[derive(Debug, Clone, Copy)]
pub struct ThreadLocalStats {
    pub allocation_count: u64,
    pub total_allocated: usize,
}

thread_local! {
    /// Thread-local allocator state.
    /// Using UnsafeCell for faster access than RefCell.
    static TLS: UnsafeCell<Option<ThreadLocalState>> = const { UnsafeCell::new(None) };
}

/// Initialize thread-local state for the current thread.
pub fn init_tls(arena_size: usize, global: Arc<GlobalState>) {
    TLS.with(|tls| {
        let tls = unsafe { &mut *tls.get() };
        if tls.is_none() {
            *tls = Some(ThreadLocalState::new(arena_size, global));
        }
    });
}

/// Check if TLS is initialized for the current thread.
pub fn is_tls_initialized() -> bool {
    TLS.with(|tls| {
        let tls = unsafe { &*tls.get() };
        tls.is_some()
    })
}

/// Execute a closure with access to thread-local state.
///
/// Panics if TLS is not initialized.
#[inline(always)]
pub fn with_tls<F, R>(f: F) -> R
where
    F: FnOnce(&ThreadLocalState) -> R,
{
    TLS.with(|tls| {
        let tls = unsafe { &*tls.get() };
        let state = tls.as_ref().expect("Thread-local state not initialized. Call init_tls first.");
        f(state)
    })
}

/// Execute a closure with mutable access to thread-local state.
///
/// Panics if TLS is not initialized.
#[inline(always)]
pub fn with_tls_mut<F, R>(f: F) -> R
where
    F: FnOnce(&mut ThreadLocalState) -> R,
{
    TLS.with(|tls| {
        let tls = unsafe { &mut *tls.get() };
        let state = tls.as_mut().expect("Thread-local state not initialized. Call init_tls first.");
        f(state)
    })
}

/// Try to execute a closure with access to thread-local state.
///
/// Returns None if TLS is not initialized.
pub fn try_with_tls<F, R>(f: F) -> Option<R>
where
    F: FnOnce(&ThreadLocalState) -> R,
{
    TLS.with(|tls| {
        let tls = unsafe { &*tls.get() };
        tls.as_ref().map(f)
    })
}