use crate::error::{Result, RuvLLMError};
use parking_lot::{Mutex, RwLock};
use std::alloc::{alloc_zeroed, dealloc, Layout};
use std::cell::UnsafeCell;
#[cfg(not(target_arch = "wasm32"))]
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
#[cfg(not(target_arch = "wasm32"))]
use std::thread::ThreadId;
pub const CACHE_LINE_SIZE: usize = 64;
pub const NEON_ALIGNMENT: usize = 16;
pub const DEFAULT_ALIGNMENT: usize = 64;
#[derive(Debug)]
pub struct InferenceArena {
memory: *mut u8,
offset: AtomicUsize,
capacity: usize,
layout: Layout,
high_water_mark: AtomicUsize,
allocation_count: AtomicUsize,
}
unsafe impl Send for InferenceArena {}
unsafe impl Sync for InferenceArena {}
impl InferenceArena {
pub fn new(capacity: usize) -> Result<Self> {
let aligned_capacity = (capacity + DEFAULT_ALIGNMENT - 1) & !(DEFAULT_ALIGNMENT - 1);
let layout =
Layout::from_size_align(aligned_capacity, DEFAULT_ALIGNMENT).map_err(|_| {
RuvLLMError::OutOfMemory(format!(
"Invalid arena layout: size={}, align={}",
aligned_capacity, DEFAULT_ALIGNMENT
))
})?;
let memory = unsafe { alloc_zeroed(layout) };
if memory.is_null() {
return Err(RuvLLMError::OutOfMemory(format!(
"Failed to allocate arena of {} bytes",
aligned_capacity
)));
}
Ok(Self {
memory,
offset: AtomicUsize::new(0),
capacity: aligned_capacity,
layout,
high_water_mark: AtomicUsize::new(0),
allocation_count: AtomicUsize::new(0),
})
}
pub fn for_model(hidden_dim: usize, vocab_size: usize, batch_size: usize) -> Result<Self> {
let activations = hidden_dim * batch_size * std::mem::size_of::<f32>();
let logits = vocab_size * batch_size * std::mem::size_of::<f32>();
let scratch = hidden_dim * 4 * std::mem::size_of::<f32>();
let total = (activations + logits + scratch) * 2; Self::new(total)
}
#[inline]
pub fn alloc<T: Copy + Default>(&self, count: usize) -> Option<&mut [T]> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>().max(DEFAULT_ALIGNMENT);
let current = self.offset.load(Ordering::Acquire);
let aligned_offset = (current + align - 1) & !(align - 1);
let new_offset = aligned_offset + size;
if new_offset > self.capacity {
return None;
}
match self
.offset
.compare_exchange(current, new_offset, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
self.allocation_count.fetch_add(1, Ordering::Relaxed);
let _ = self
.high_water_mark
.fetch_max(new_offset, Ordering::Relaxed);
unsafe {
let ptr = self.memory.add(aligned_offset) as *mut T;
std::ptr::write_bytes(ptr, 0, count);
Some(std::slice::from_raw_parts_mut(ptr, count))
}
}
Err(actual) => {
None
}
}
}
#[inline]
pub unsafe fn alloc_uninit<T>(&self, count: usize) -> Option<&mut [T]> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>().max(DEFAULT_ALIGNMENT);
let current = self.offset.load(Ordering::Acquire);
let aligned_offset = (current + align - 1) & !(align - 1);
let new_offset = aligned_offset + size;
if new_offset > self.capacity {
return None;
}
match self
.offset
.compare_exchange(current, new_offset, Ordering::AcqRel, Ordering::Acquire)
{
Ok(_) => {
self.allocation_count.fetch_add(1, Ordering::Relaxed);
let _ = self
.high_water_mark
.fetch_max(new_offset, Ordering::Relaxed);
let ptr = self.memory.add(aligned_offset) as *mut T;
Some(std::slice::from_raw_parts_mut(ptr, count))
}
Err(_) => None,
}
}
#[inline]
pub fn reset(&self) {
self.offset.store(0, Ordering::Release);
self.allocation_count.store(0, Ordering::Relaxed);
}
#[inline]
pub fn used(&self) -> usize {
self.offset.load(Ordering::Acquire)
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn remaining(&self) -> usize {
self.capacity - self.used()
}
#[inline]
pub fn high_water_mark(&self) -> usize {
self.high_water_mark.load(Ordering::Relaxed)
}
#[inline]
pub fn allocation_count(&self) -> usize {
self.allocation_count.load(Ordering::Relaxed)
}
pub fn stats(&self) -> ArenaStats {
ArenaStats {
capacity: self.capacity,
used: self.used(),
remaining: self.remaining(),
high_water_mark: self.high_water_mark(),
allocation_count: self.allocation_count(),
utilization: self.used() as f64 / self.capacity as f64,
}
}
#[inline]
pub unsafe fn as_ptr(&self) -> *const u8 {
self.memory
}
#[inline]
pub unsafe fn as_mut_ptr(&self) -> *mut u8 {
self.memory
}
}
impl Drop for InferenceArena {
fn drop(&mut self) {
unsafe {
dealloc(self.memory, self.layout);
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ArenaStats {
pub capacity: usize,
pub used: usize,
pub remaining: usize,
pub high_water_mark: usize,
pub allocation_count: usize,
pub utilization: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(usize)]
pub enum BufferSize {
KB1 = 0,
KB4 = 1,
KB16 = 2,
KB64 = 3,
KB256 = 4,
}
impl BufferSize {
#[inline]
pub const fn bytes(self) -> usize {
match self {
Self::KB1 => 1024,
Self::KB4 => 4096,
Self::KB16 => 16384,
Self::KB64 => 65536,
Self::KB256 => 262144,
}
}
#[inline]
pub const fn index(self) -> usize {
self as usize
}
pub fn for_size(bytes: usize) -> Option<Self> {
if bytes <= 1024 {
Some(Self::KB1)
} else if bytes <= 4096 {
Some(Self::KB4)
} else if bytes <= 16384 {
Some(Self::KB16)
} else if bytes <= 65536 {
Some(Self::KB64)
} else if bytes <= 262144 {
Some(Self::KB256)
} else {
None
}
}
pub const fn all() -> [BufferSize; 5] {
[Self::KB1, Self::KB4, Self::KB16, Self::KB64, Self::KB256]
}
}
pub struct PooledBuffer {
data: Box<[u8]>,
size_class: BufferSize,
pool: Arc<BufferPoolInner>,
}
impl PooledBuffer {
#[inline]
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
#[inline]
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.data
}
#[inline]
pub fn as_slice<T: Copy>(&self) -> &[T] {
let size = std::mem::size_of::<T>();
assert!(
self.data.len() % size == 0,
"Buffer size not aligned to type"
);
unsafe {
std::slice::from_raw_parts(self.data.as_ptr() as *const T, self.data.len() / size)
}
}
#[inline]
pub fn as_slice_mut<T: Copy>(&mut self) -> &mut [T] {
let size = std::mem::size_of::<T>();
assert!(
self.data.len() % size == 0,
"Buffer size not aligned to type"
);
unsafe {
std::slice::from_raw_parts_mut(self.data.as_mut_ptr() as *mut T, self.data.len() / size)
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.data.len()
}
#[inline]
pub fn size_class(&self) -> BufferSize {
self.size_class
}
#[inline]
pub fn as_ptr(&self) -> *const u8 {
self.data.as_ptr()
}
#[inline]
pub fn as_mut_ptr(&mut self) -> *mut u8 {
self.data.as_mut_ptr()
}
#[inline]
pub fn clear(&mut self) {
self.data.fill(0);
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
let data = std::mem::replace(&mut self.data, Box::new([]));
self.pool.return_buffer(self.size_class, data);
}
}
impl std::fmt::Debug for PooledBuffer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PooledBuffer")
.field("size_class", &self.size_class)
.field("capacity", &self.data.len())
.finish()
}
}
struct SizeClassPool {
free_list: Vec<Box<[u8]>>,
max_buffers: usize,
}
struct BufferPoolInner {
pools: [Mutex<SizeClassPool>; 5],
stats: PoolStatistics,
}
impl BufferPoolInner {
fn new(max_buffers_per_class: usize) -> Self {
Self {
pools: [
Mutex::new(SizeClassPool {
free_list: Vec::with_capacity(max_buffers_per_class),
max_buffers: max_buffers_per_class,
}),
Mutex::new(SizeClassPool {
free_list: Vec::with_capacity(max_buffers_per_class),
max_buffers: max_buffers_per_class,
}),
Mutex::new(SizeClassPool {
free_list: Vec::with_capacity(max_buffers_per_class),
max_buffers: max_buffers_per_class,
}),
Mutex::new(SizeClassPool {
free_list: Vec::with_capacity(max_buffers_per_class),
max_buffers: max_buffers_per_class,
}),
Mutex::new(SizeClassPool {
free_list: Vec::with_capacity(max_buffers_per_class),
max_buffers: max_buffers_per_class,
}),
],
stats: PoolStatistics::new(),
}
}
fn acquire(&self, size_class: BufferSize) -> Result<Box<[u8]>> {
let mut pool = self.pools[size_class.index()].lock();
if let Some(buf) = pool.free_list.pop() {
self.stats.hits.fetch_add(1, Ordering::Relaxed);
Ok(buf)
} else {
self.stats.misses.fetch_add(1, Ordering::Relaxed);
self.stats.allocations.fetch_add(1, Ordering::Relaxed);
Self::allocate_buffer(size_class)
}
}
fn return_buffer(&self, size_class: BufferSize, buf: Box<[u8]>) {
if buf.is_empty() {
return;
}
let mut pool = self.pools[size_class.index()].lock();
if pool.free_list.len() < pool.max_buffers {
self.stats.returns.fetch_add(1, Ordering::Relaxed);
pool.free_list.push(buf);
} else {
self.stats.drops.fetch_add(1, Ordering::Relaxed);
}
}
fn allocate_buffer(size_class: BufferSize) -> Result<Box<[u8]>> {
let size = size_class.bytes();
let layout = Layout::from_size_align(size, DEFAULT_ALIGNMENT).map_err(|_| {
RuvLLMError::OutOfMemory(format!(
"Invalid buffer layout: size={}, align={}",
size, DEFAULT_ALIGNMENT
))
})?;
unsafe {
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
return Err(RuvLLMError::OutOfMemory(format!(
"Failed to allocate buffer of {} bytes",
size
)));
}
Ok(Box::from_raw(std::slice::from_raw_parts_mut(ptr, size)))
}
}
}
struct PoolStatistics {
hits: AtomicU64,
misses: AtomicU64,
allocations: AtomicU64,
returns: AtomicU64,
drops: AtomicU64,
}
impl PoolStatistics {
fn new() -> Self {
Self {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
allocations: AtomicU64::new(0),
returns: AtomicU64::new(0),
drops: AtomicU64::new(0),
}
}
}
#[derive(Clone)]
pub struct BufferPool {
inner: Arc<BufferPoolInner>,
}
impl BufferPool {
pub fn new() -> Self {
Self::with_capacity(32)
}
pub fn with_capacity(max_buffers_per_class: usize) -> Self {
Self {
inner: Arc::new(BufferPoolInner::new(max_buffers_per_class)),
}
}
pub fn acquire(&self, size_class: BufferSize) -> Result<PooledBuffer> {
let data = self.inner.acquire(size_class)?;
Ok(PooledBuffer {
data,
size_class,
pool: Arc::clone(&self.inner),
})
}
pub fn acquire_for_size(&self, bytes: usize) -> Result<Option<PooledBuffer>> {
match BufferSize::for_size(bytes) {
Some(size_class) => Ok(Some(self.acquire(size_class)?)),
None => Ok(None),
}
}
pub fn prewarm(&self, size_class: BufferSize, count: usize) -> Result<()> {
for _ in 0..count {
let buf = BufferPoolInner::allocate_buffer(size_class)?;
self.inner.return_buffer(size_class, buf);
}
Ok(())
}
pub fn prewarm_all(&self, count_per_class: usize) -> Result<()> {
for size_class in BufferSize::all() {
self.prewarm(size_class, count_per_class)?;
}
Ok(())
}
pub fn stats(&self) -> BufferPoolStats {
let mut free_counts = [0usize; 5];
for (i, pool) in self.inner.pools.iter().enumerate() {
free_counts[i] = pool.lock().free_list.len();
}
BufferPoolStats {
hits: self.inner.stats.hits.load(Ordering::Relaxed),
misses: self.inner.stats.misses.load(Ordering::Relaxed),
allocations: self.inner.stats.allocations.load(Ordering::Relaxed),
returns: self.inner.stats.returns.load(Ordering::Relaxed),
drops: self.inner.stats.drops.load(Ordering::Relaxed),
free_buffers: free_counts,
hit_rate: {
let hits = self.inner.stats.hits.load(Ordering::Relaxed);
let total = hits + self.inner.stats.misses.load(Ordering::Relaxed);
if total > 0 {
hits as f64 / total as f64
} else {
0.0
}
},
}
}
pub fn clear(&self) {
for pool in &self.inner.pools {
pool.lock().free_list.clear();
}
}
}
impl Default for BufferPool {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for BufferPool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BufferPool")
.field("stats", &self.stats())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct BufferPoolStats {
pub hits: u64,
pub misses: u64,
pub allocations: u64,
pub returns: u64,
pub drops: u64,
pub free_buffers: [usize; 5],
pub hit_rate: f64,
}
#[cfg(not(target_arch = "wasm32"))]
struct ThreadScratch {
data: Box<[u8]>,
used: usize,
}
#[cfg(not(target_arch = "wasm32"))]
impl ThreadScratch {
fn new(size: usize) -> Result<Self> {
let layout = Layout::from_size_align(size, DEFAULT_ALIGNMENT).map_err(|_| {
RuvLLMError::OutOfMemory(format!(
"Invalid scratch layout: size={}, align={}",
size, DEFAULT_ALIGNMENT
))
})?;
let data = unsafe {
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
return Err(RuvLLMError::OutOfMemory(format!(
"Failed to allocate scratch buffer of {} bytes",
size
)));
}
Box::from_raw(std::slice::from_raw_parts_mut(ptr, size))
};
Ok(Self { data, used: 0 })
}
fn reset(&mut self) {
self.used = 0;
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct ScratchSpaceManager {
scratches: RwLock<HashMap<ThreadId, UnsafeCell<ThreadScratch>>>,
scratch_size: usize,
max_threads: usize,
}
#[cfg(not(target_arch = "wasm32"))]
unsafe impl Send for ScratchSpaceManager {}
#[cfg(not(target_arch = "wasm32"))]
unsafe impl Sync for ScratchSpaceManager {}
#[cfg(not(target_arch = "wasm32"))]
impl ScratchSpaceManager {
pub fn new(scratch_size: usize, max_threads: usize) -> Result<Self> {
Ok(Self {
scratches: RwLock::new(HashMap::with_capacity(max_threads)),
scratch_size,
max_threads,
})
}
pub fn for_model(hidden_dim: usize, max_threads: usize) -> Result<Self> {
let scratch_size = hidden_dim * 4 * std::mem::size_of::<f32>();
Self::new(scratch_size, max_threads)
}
pub fn get_scratch(&self) -> Result<ScratchSpace<'_>> {
let thread_id = std::thread::current().id();
{
let scratches = self.scratches.read();
if let Some(scratch_cell) = scratches.get(&thread_id) {
return Ok(ScratchSpace {
scratch: unsafe { &mut *scratch_cell.get() },
});
}
}
{
let mut scratches = self.scratches.write();
if !scratches.contains_key(&thread_id) {
if scratches.len() >= self.max_threads {
return Err(RuvLLMError::OutOfMemory(format!(
"Exceeded maximum thread count ({}) for scratch space",
self.max_threads
)));
}
scratches.insert(
thread_id,
UnsafeCell::new(ThreadScratch::new(self.scratch_size)?),
);
}
let scratch_cell = scratches.get(&thread_id).unwrap();
Ok(ScratchSpace {
scratch: unsafe { &mut *scratch_cell.get() },
})
}
}
pub fn reset_all(&self) {
let scratches = self.scratches.read();
for scratch_cell in scratches.values() {
unsafe {
(*scratch_cell.get()).reset();
}
}
}
pub fn scratch_size(&self) -> usize {
self.scratch_size
}
pub fn active_threads(&self) -> usize {
self.scratches.read().len()
}
pub fn stats(&self) -> ScratchStats {
let scratches = self.scratches.read();
let mut total_used = 0;
let mut max_used = 0;
for scratch_cell in scratches.values() {
let used = unsafe { (*scratch_cell.get()).used };
total_used += used;
max_used = max_used.max(used);
}
ScratchStats {
scratch_size: self.scratch_size,
active_threads: scratches.len(),
max_threads: self.max_threads,
total_allocated: scratches.len() * self.scratch_size,
total_used,
max_thread_usage: max_used,
}
}
}
#[cfg(not(target_arch = "wasm32"))]
impl std::fmt::Debug for ScratchSpaceManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScratchSpaceManager")
.field("scratch_size", &self.scratch_size)
.field("max_threads", &self.max_threads)
.field("active_threads", &self.scratches.read().len())
.finish()
}
}
#[cfg(target_arch = "wasm32")]
struct WasmScratch {
data: Box<[u8]>,
used: usize,
}
#[cfg(target_arch = "wasm32")]
impl WasmScratch {
fn new(size: usize) -> Result<Self> {
let layout = Layout::from_size_align(size, DEFAULT_ALIGNMENT).map_err(|_| {
RuvLLMError::OutOfMemory(format!(
"Invalid scratch layout: size={}, align={}",
size, DEFAULT_ALIGNMENT
))
})?;
let data = unsafe {
let ptr = alloc_zeroed(layout);
if ptr.is_null() {
return Err(RuvLLMError::OutOfMemory(format!(
"Failed to allocate scratch buffer of {} bytes",
size
)));
}
Box::from_raw(std::slice::from_raw_parts_mut(ptr, size))
};
Ok(Self { data, used: 0 })
}
fn reset(&mut self) {
self.used = 0;
}
}
#[cfg(target_arch = "wasm32")]
pub struct ScratchSpaceManager {
scratch: UnsafeCell<WasmScratch>,
scratch_size: usize,
max_threads: usize,
}
#[cfg(target_arch = "wasm32")]
unsafe impl Send for ScratchSpaceManager {}
#[cfg(target_arch = "wasm32")]
unsafe impl Sync for ScratchSpaceManager {}
#[cfg(target_arch = "wasm32")]
impl ScratchSpaceManager {
pub fn new(scratch_size: usize, _max_threads: usize) -> Result<Self> {
Ok(Self {
scratch: UnsafeCell::new(WasmScratch::new(scratch_size)?),
scratch_size,
max_threads: 1, })
}
pub fn for_model(hidden_dim: usize, _max_threads: usize) -> Result<Self> {
let scratch_size = hidden_dim * 4 * std::mem::size_of::<f32>();
Self::new(scratch_size, 1)
}
pub fn get_scratch(&self) -> Result<ScratchSpace<'_>> {
Ok(ScratchSpace {
scratch: unsafe { &mut *self.scratch.get() },
})
}
pub fn reset_all(&self) {
unsafe {
(*self.scratch.get()).reset();
}
}
pub fn scratch_size(&self) -> usize {
self.scratch_size
}
pub fn active_threads(&self) -> usize {
1
}
pub fn stats(&self) -> ScratchStats {
let used = unsafe { (*self.scratch.get()).used };
ScratchStats {
scratch_size: self.scratch_size,
active_threads: 1,
max_threads: 1,
total_allocated: self.scratch_size,
total_used: used,
max_thread_usage: used,
}
}
}
#[cfg(target_arch = "wasm32")]
impl std::fmt::Debug for ScratchSpaceManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScratchSpaceManager")
.field("scratch_size", &self.scratch_size)
.field("max_threads", &self.max_threads)
.field("active_threads", &1)
.finish()
}
}
#[cfg(not(target_arch = "wasm32"))]
pub struct ScratchSpace<'a> {
scratch: &'a mut ThreadScratch,
}
#[cfg(not(target_arch = "wasm32"))]
impl<'a> ScratchSpace<'a> {
pub fn get<T: Copy + Default>(&mut self, count: usize) -> Option<&mut [T]> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>().max(DEFAULT_ALIGNMENT);
let aligned_used = (self.scratch.used + align - 1) & !(align - 1);
let new_used = aligned_used + size;
if new_used > self.scratch.data.len() {
return None;
}
self.scratch.used = new_used;
unsafe {
let ptr = self.scratch.data.as_mut_ptr().add(aligned_used) as *mut T;
std::ptr::write_bytes(ptr, 0, count);
Some(std::slice::from_raw_parts_mut(ptr, count))
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.scratch.data
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.scratch.data
}
pub fn reset(&mut self) {
self.scratch.reset();
}
pub fn used(&self) -> usize {
self.scratch.used
}
pub fn remaining(&self) -> usize {
self.scratch.data.len() - self.scratch.used
}
pub fn capacity(&self) -> usize {
self.scratch.data.len()
}
}
#[cfg(target_arch = "wasm32")]
pub struct ScratchSpace<'a> {
scratch: &'a mut WasmScratch,
}
#[cfg(target_arch = "wasm32")]
impl<'a> ScratchSpace<'a> {
pub fn get<T: Copy + Default>(&mut self, count: usize) -> Option<&mut [T]> {
let size = count * std::mem::size_of::<T>();
let align = std::mem::align_of::<T>().max(DEFAULT_ALIGNMENT);
let aligned_used = (self.scratch.used + align - 1) & !(align - 1);
let new_used = aligned_used + size;
if new_used > self.scratch.data.len() {
return None;
}
self.scratch.used = new_used;
unsafe {
let ptr = self.scratch.data.as_mut_ptr().add(aligned_used) as *mut T;
std::ptr::write_bytes(ptr, 0, count);
Some(std::slice::from_raw_parts_mut(ptr, count))
}
}
pub fn as_bytes(&self) -> &[u8] {
&self.scratch.data
}
pub fn as_bytes_mut(&mut self) -> &mut [u8] {
&mut self.scratch.data
}
pub fn reset(&mut self) {
self.scratch.reset();
}
pub fn used(&self) -> usize {
self.scratch.used
}
pub fn remaining(&self) -> usize {
self.scratch.data.len() - self.scratch.used
}
pub fn capacity(&self) -> usize {
self.scratch.data.len()
}
}
#[derive(Debug, Clone, Default)]
pub struct ScratchStats {
pub scratch_size: usize,
pub active_threads: usize,
pub max_threads: usize,
pub total_allocated: usize,
pub total_used: usize,
pub max_thread_usage: usize,
}
#[derive(Debug, Clone)]
pub struct MemoryManagerConfig {
pub arena_capacity: usize,
pub pool_buffers_per_class: usize,
pub scratch_size: usize,
pub max_threads: usize,
}
impl Default for MemoryManagerConfig {
fn default() -> Self {
Self {
arena_capacity: 16 * 1024 * 1024, pool_buffers_per_class: 32,
scratch_size: 64 * 1024, max_threads: 16,
}
}
}
impl MemoryManagerConfig {
pub fn for_model(hidden_dim: usize, vocab_size: usize, batch_size: usize) -> Self {
let arena_capacity = {
let activations = hidden_dim * batch_size * 4; let logits = vocab_size * batch_size * 4;
(activations + logits) * 4 };
let scratch_size = hidden_dim * 4 * 4;
Self {
arena_capacity,
pool_buffers_per_class: 32,
scratch_size,
max_threads: 16,
}
}
}
pub struct MemoryManager {
pub arena: InferenceArena,
pub pool: BufferPool,
pub scratch: ScratchSpaceManager,
config: MemoryManagerConfig,
}
impl MemoryManager {
pub fn new() -> Result<Self> {
Self::with_config(MemoryManagerConfig::default())
}
pub fn with_config(config: MemoryManagerConfig) -> Result<Self> {
let arena = InferenceArena::new(config.arena_capacity)?;
let pool = BufferPool::with_capacity(config.pool_buffers_per_class);
let scratch = ScratchSpaceManager::new(config.scratch_size, config.max_threads)?;
Ok(Self {
arena,
pool,
scratch,
config,
})
}
pub fn for_model(hidden_dim: usize, vocab_size: usize, batch_size: usize) -> Result<Self> {
let config = MemoryManagerConfig::for_model(hidden_dim, vocab_size, batch_size);
Self::with_config(config)
}
#[inline]
pub fn reset_step(&self) {
self.arena.reset();
self.scratch.reset_all();
}
pub fn prewarm_pool(&self, count_per_class: usize) -> Result<()> {
self.pool.prewarm_all(count_per_class)
}
pub fn stats(&self) -> MemoryManagerStats {
MemoryManagerStats {
arena: self.arena.stats(),
pool: self.pool.stats(),
scratch: self.scratch.stats(),
}
}
pub fn config(&self) -> &MemoryManagerConfig {
&self.config
}
}
impl std::fmt::Debug for MemoryManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryManager")
.field("config", &self.config)
.field("arena_stats", &self.arena.stats())
.field("pool_stats", &self.pool.stats())
.field("scratch_stats", &self.scratch.stats())
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryManagerStats {
pub arena: ArenaStats,
pub pool: BufferPoolStats,
pub scratch: ScratchStats,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_arena_basic() {
let arena = InferenceArena::new(4096).expect("arena creation failed");
let buf1: &mut [f32] = arena.alloc(100).expect("alloc failed");
assert_eq!(buf1.len(), 100);
let buf2: &mut [f32] = arena.alloc(200).expect("alloc failed");
assert_eq!(buf2.len(), 200);
let stats = arena.stats();
assert_eq!(stats.allocation_count, 2);
assert!(stats.used > 0);
arena.reset();
assert_eq!(arena.used(), 0);
assert_eq!(arena.allocation_count(), 0);
}
#[test]
fn test_arena_alignment() {
let arena = InferenceArena::new(4096).expect("arena creation failed");
let _: &mut [u8] = arena.alloc(1).unwrap();
let buf: &mut [f32] = arena.alloc(10).unwrap();
assert!(buf.as_ptr() as usize % DEFAULT_ALIGNMENT == 0);
}
#[test]
fn test_arena_out_of_memory() {
let arena = InferenceArena::new(1024).expect("arena creation failed");
let result: Option<&mut [f32]> = arena.alloc(1000);
assert!(result.is_none());
}
#[test]
fn test_buffer_pool_basic() {
let pool = BufferPool::new();
let buf1 = pool.acquire(BufferSize::KB4).expect("acquire failed");
assert_eq!(buf1.capacity(), 4096);
drop(buf1);
let buf2 = pool.acquire(BufferSize::KB4).expect("acquire failed");
assert_eq!(buf2.capacity(), 4096);
let stats = pool.stats();
assert!(stats.hits > 0 || stats.misses > 0);
}
#[test]
fn test_buffer_pool_size_classes() {
let pool = BufferPool::new();
for size in BufferSize::all() {
let buf = pool.acquire(size).expect("acquire failed");
assert_eq!(buf.capacity(), size.bytes());
}
}
#[test]
fn test_buffer_pool_typed_access() {
let pool = BufferPool::new();
let mut buf = pool.acquire(BufferSize::KB1).expect("acquire failed");
let floats = buf.as_slice_mut::<f32>();
assert_eq!(floats.len(), 256);
floats[0] = 1.0;
floats[1] = 2.0;
assert_eq!(buf.as_slice::<f32>()[0], 1.0);
}
#[test]
fn test_buffer_pool_prewarm() {
let pool = BufferPool::new();
pool.prewarm(BufferSize::KB4, 5).expect("prewarm failed");
let stats = pool.stats();
assert_eq!(stats.free_buffers[BufferSize::KB4.index()], 5);
}
#[test]
fn test_scratch_space_basic() {
let manager = ScratchSpaceManager::new(4096, 4).expect("manager creation failed");
let mut scratch = manager.get_scratch().expect("get_scratch failed");
let buf1: &mut [f32] = scratch.get(100).expect("alloc failed");
assert_eq!(buf1.len(), 100);
let buf2: &mut [f32] = scratch.get(50).expect("alloc failed");
assert_eq!(buf2.len(), 50);
assert!(scratch.used() > 0);
scratch.reset();
assert_eq!(scratch.used(), 0);
}
#[test]
fn test_scratch_space_per_thread() {
use std::sync::Arc;
use std::thread;
let manager = Arc::new(ScratchSpaceManager::new(4096, 4).expect("manager creation failed"));
let handles: Vec<_> = (0..4)
.map(|_| {
let manager = Arc::clone(&manager);
thread::spawn(move || {
let mut scratch = manager.get_scratch().expect("get_scratch failed");
let _: &mut [f32] = scratch.get(100).unwrap();
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(manager.active_threads(), 4);
}
#[test]
fn test_memory_manager_basic() {
let manager = MemoryManager::new().expect("manager creation failed");
let arena_buf: &mut [f32] = manager.arena.alloc(100).unwrap();
assert_eq!(arena_buf.len(), 100);
let pool_buf = manager
.pool
.acquire(BufferSize::KB4)
.expect("acquire failed");
assert_eq!(pool_buf.capacity(), 4096);
let mut scratch = manager.scratch.get_scratch().expect("get_scratch failed");
let scratch_buf: &mut [f32] = scratch.get(50).unwrap();
assert_eq!(scratch_buf.len(), 50);
manager.reset_step();
assert_eq!(manager.arena.used(), 0);
}
#[test]
fn test_memory_manager_for_model() {
let manager = MemoryManager::for_model(4096, 32000, 1).expect("manager creation failed");
let stats = manager.stats();
assert!(stats.arena.capacity > 0);
}
#[test]
fn test_buffer_size_for_size() {
assert_eq!(BufferSize::for_size(512), Some(BufferSize::KB1));
assert_eq!(BufferSize::for_size(1024), Some(BufferSize::KB1));
assert_eq!(BufferSize::for_size(2000), Some(BufferSize::KB4));
assert_eq!(BufferSize::for_size(4096), Some(BufferSize::KB4));
assert_eq!(BufferSize::for_size(10000), Some(BufferSize::KB16));
assert_eq!(BufferSize::for_size(50000), Some(BufferSize::KB64));
assert_eq!(BufferSize::for_size(200000), Some(BufferSize::KB256));
assert_eq!(BufferSize::for_size(300000), None);
}
}