use crate::error::{Result, ZiporaError};
use crate::memory::{SecureMemoryPool, SecurePoolConfig};
use std::alloc::{Layout, alloc, dealloc};
use std::cell::RefCell;
use std::collections::HashMap;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, Weak};
use std::thread::{self, ThreadId};
const DEFAULT_ARENA_SIZE: usize = 2 * 1024 * 1024;
const SYNC_THRESHOLD: isize = 256 * 1024;
const MAX_CACHED_CHUNKS: usize = 64;
const TLS_SIZE_CLASSES: &[usize] = &[
16, 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, 1024, 1536, 2048, 3072, 4096
];
#[derive(Debug, Clone)]
pub struct ThreadLocalPoolConfig {
pub arena_size: usize,
pub max_threads: usize,
pub enable_stats: bool,
pub sync_threshold: isize,
pub max_cached_chunks: usize,
pub use_secure_memory: bool,
}
impl Default for ThreadLocalPoolConfig {
fn default() -> Self {
Self {
arena_size: DEFAULT_ARENA_SIZE,
max_threads: 256,
enable_stats: true,
sync_threshold: SYNC_THRESHOLD,
max_cached_chunks: MAX_CACHED_CHUNKS,
use_secure_memory: true,
}
}
}
impl ThreadLocalPoolConfig {
pub fn high_performance() -> Self {
Self {
arena_size: 8 * 1024 * 1024, max_threads: 1024,
enable_stats: false,
sync_threshold: 1024 * 1024, max_cached_chunks: 128,
use_secure_memory: false, }
}
pub fn compact() -> Self {
Self {
arena_size: 512 * 1024, max_threads: 64,
enable_stats: true,
sync_threshold: 64 * 1024, max_cached_chunks: 32,
use_secure_memory: true,
}
}
}
#[derive(Debug, Default)]
pub struct ThreadLocalPoolStats {
pub cache_hits: AtomicU64,
pub cache_misses: AtomicU64,
pub arena_allocations: AtomicU64,
pub hot_allocations: AtomicU64,
pub batch_syncs: AtomicU64,
pub cached_memory: AtomicU64,
}
impl ThreadLocalPoolStats {
pub fn hit_ratio(&self) -> f64 {
let hits = self.cache_hits.load(Ordering::Relaxed);
let misses = self.cache_misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 { 0.0 } else { hits as f64 / total as f64 }
}
pub fn locality_score(&self) -> f64 {
let hot = self.hot_allocations.load(Ordering::Relaxed);
let arena = self.arena_allocations.load(Ordering::Relaxed);
let total = hot + arena;
if total == 0 { 0.0 } else { hot as f64 / total as f64 }
}
}
struct ThreadLocalCache {
thread_id: ThreadId,
hot_area: Option<HotArea>,
free_lists: Vec<Vec<NonNull<u8>>>,
frag_inc: isize,
global_pool: Weak<ThreadLocalMemoryPool>,
stats: Option<Arc<ThreadLocalPoolStats>>,
}
struct HotArea {
start: NonNull<u8>,
pos: usize,
end: usize,
layout: Layout,
}
impl HotArea {
fn new(size: usize) -> Result<Self> {
let layout = Layout::from_size_align(size, 8)
.map_err(|e| ZiporaError::invalid_data(&format!("Invalid layout: {}", e)))?;
let start = NonNull::new(unsafe { alloc(layout) })
.ok_or_else(|| ZiporaError::out_of_memory(size))?;
Ok(Self {
start,
pos: 0,
end: size,
layout,
})
}
fn try_allocate(&mut self, size: usize) -> Option<NonNull<u8>> {
let aligned_size = (size + 7) & !7;
if self.pos + aligned_size <= self.end {
let ptr = unsafe {
NonNull::new_unchecked(self.start.as_ptr().add(self.pos))
};
self.pos += aligned_size;
Some(ptr)
} else {
None
}
}
fn remaining(&self) -> usize {
self.end - self.pos
}
}
impl Drop for HotArea {
fn drop(&mut self) {
unsafe {
dealloc(self.start.as_ptr(), self.layout);
}
}
}
impl ThreadLocalCache {
fn new(global_pool: Weak<ThreadLocalMemoryPool>, stats: Option<Arc<ThreadLocalPoolStats>>) -> Self {
Self {
thread_id: thread::current().id(),
hot_area: None,
free_lists: vec![Vec::new(); TLS_SIZE_CLASSES.len()],
frag_inc: 0,
global_pool,
stats,
}
}
fn allocate(&mut self, size: usize, config: &ThreadLocalPoolConfig) -> Result<NonNull<u8>> {
if let Some(list_index) = self.size_to_list_index(size) {
if let Some(ptr) = self.free_lists[list_index].pop() {
if let Some(stats) = &self.stats {
stats.cache_hits.fetch_add(1, Ordering::Relaxed);
}
return Ok(ptr);
}
}
if let Some(ref mut hot_area) = self.hot_area {
if let Some(ptr) = hot_area.try_allocate(size) {
if let Some(stats) = &self.stats {
stats.hot_allocations.fetch_add(1, Ordering::Relaxed);
}
return Ok(ptr);
}
}
self.allocate_new_area_or_fallback(size, config)
}
fn deallocate(&mut self, ptr: NonNull<u8>, size: usize, config: &ThreadLocalPoolConfig) -> Result<()> {
if let Some(list_index) = self.size_to_list_index(size) {
let free_list = &mut self.free_lists[list_index];
if free_list.len() < config.max_cached_chunks {
free_list.push(ptr);
self.frag_inc -= size as isize;
if self.frag_inc < -config.sync_threshold {
self.sync_with_global()?;
}
return Ok(());
}
}
self.deallocate_to_global(ptr, size)
}
fn allocate_new_area_or_fallback(&mut self, size: usize, config: &ThreadLocalPoolConfig) -> Result<NonNull<u8>> {
if size > config.arena_size / 4 {
return self.allocate_from_global(size);
}
match HotArea::new(config.arena_size) {
Ok(mut hot_area) => {
if let Some(ptr) = hot_area.try_allocate(size) {
self.hot_area = Some(hot_area);
if let Some(stats) = &self.stats {
stats.arena_allocations.fetch_add(1, Ordering::Relaxed);
stats.hot_allocations.fetch_add(1, Ordering::Relaxed);
}
return Ok(ptr);
}
self.allocate_from_global(size)
}
Err(_) => {
self.allocate_from_global(size)
}
}
}
fn allocate_from_global(&self, size: usize) -> Result<NonNull<u8>> {
if let Some(stats) = &self.stats {
stats.cache_misses.fetch_add(1, Ordering::Relaxed);
}
if let Some(global_pool) = self.global_pool.upgrade() {
global_pool.allocate(size).and_then(|alloc| {
NonNull::new(alloc.as_ptr())
.ok_or_else(|| ZiporaError::out_of_memory(size))
})
} else {
Err(ZiporaError::invalid_data("Global pool unavailable"))
}
}
fn deallocate_to_global(&self, _ptr: NonNull<u8>, _size: usize) -> Result<()> {
if let Some(_global_pool) = self.global_pool.upgrade() {
log::warn!("Bypassing secure pool deallocation - potential leak");
Ok(())
} else {
log::warn!("Global pool unavailable during deallocation");
Ok(())
}
}
fn sync_with_global(&mut self) -> Result<()> {
if let Some(stats) = &self.stats {
stats.batch_syncs.fetch_add(1, Ordering::Relaxed);
}
self.frag_inc = 0;
Ok(())
}
fn size_to_list_index(&self, size: usize) -> Option<usize> {
TLS_SIZE_CLASSES.iter().position(|&class_size| size <= class_size)
}
}
pub struct ThreadLocalMemoryPool {
config: ThreadLocalPoolConfig,
global_pool: Option<Arc<SecureMemoryPool>>,
thread_caches: Mutex<HashMap<ThreadId, RefCell<ThreadLocalCache>>>,
stats: Option<Arc<ThreadLocalPoolStats>>,
}
thread_local! {
static CURRENT_CACHE: RefCell<Option<ThreadLocalCache>> = RefCell::new(None);
}
impl ThreadLocalMemoryPool {
pub fn new(config: ThreadLocalPoolConfig) -> Result<Arc<Self>> {
let global_pool = if config.use_secure_memory {
let secure_config = SecurePoolConfig::medium_secure();
Some(SecureMemoryPool::new(secure_config)?)
} else {
None
};
let stats = if config.enable_stats {
Some(Arc::new(ThreadLocalPoolStats::default()))
} else {
None
};
Ok(Arc::new(Self {
config,
global_pool,
thread_caches: Mutex::new(HashMap::new()),
stats,
}))
}
pub fn allocate(self: &Arc<Self>, size: usize) -> Result<ThreadLocalAllocation> {
if size == 0 {
return Err(ZiporaError::invalid_data("Cannot allocate zero bytes"));
}
let ptr = CURRENT_CACHE.with(|cache_cell| {
let mut cache_opt = cache_cell.borrow_mut();
if cache_opt.is_none() {
let weak_self = Arc::downgrade(self);
*cache_opt = Some(ThreadLocalCache::new(weak_self, self.stats.clone()));
}
if let Some(ref mut cache) = *cache_opt {
cache.allocate(size, &self.config)
} else {
Err(ZiporaError::invalid_data("Failed to initialize thread cache"))
}
})?;
Ok(ThreadLocalAllocation::new(ptr, size, Arc::clone(self)))
}
fn deallocate(&self, ptr: NonNull<u8>, size: usize) -> Result<()> {
CURRENT_CACHE.with(|cache_cell| {
let mut cache_opt = cache_cell.borrow_mut();
if let Some(ref mut cache) = *cache_opt {
cache.deallocate(ptr, size, &self.config)
} else {
self.deallocate_bypass_cache(ptr, size)
}
})
}
fn allocate_bypass_cache(&self, size: usize) -> Result<NonNull<u8>> {
if let Some(ref global_pool) = self.global_pool {
let secure_ptr = global_pool.allocate()?;
NonNull::new(secure_ptr.as_ptr())
.ok_or_else(|| ZiporaError::out_of_memory(size))
} else {
let layout = Layout::from_size_align(size, 8)
.map_err(|e| ZiporaError::invalid_data(&format!("Invalid layout: {}", e)))?;
NonNull::new(unsafe { alloc(layout) })
.ok_or_else(|| ZiporaError::out_of_memory(size))
}
}
fn deallocate_bypass_cache(&self, ptr: NonNull<u8>, size: usize) -> Result<()> {
if self.global_pool.is_some() {
log::warn!("Bypassing cache deallocation - potential leak");
} else {
let layout = Layout::from_size_align(size, 8)
.map_err(|e| ZiporaError::invalid_data(&format!("Invalid layout: {}", e)))?;
unsafe {
dealloc(ptr.as_ptr(), layout);
}
}
Ok(())
}
pub fn stats(&self) -> Option<Arc<ThreadLocalPoolStats>> {
self.stats.clone()
}
#[inline]
pub fn memory_usage(&self) -> usize {
if let Some(stats) = &self.stats {
stats.cached_memory.load(Ordering::Relaxed) as usize
} else {
0
}
}
pub fn clear_caches(&self) {
CURRENT_CACHE.with(|cache_cell| {
*cache_cell.borrow_mut() = None;
});
}
}
pub struct ThreadLocalAllocation {
ptr: NonNull<u8>,
size: usize,
pool: Arc<ThreadLocalMemoryPool>,
}
impl ThreadLocalAllocation {
fn new(ptr: NonNull<u8>, size: usize, pool: Arc<ThreadLocalMemoryPool>) -> Self {
Self { ptr, size, pool }
}
#[inline]
pub fn as_ptr(&self) -> *mut u8 {
self.ptr.as_ptr()
}
#[inline]
pub fn size(&self) -> usize {
self.size
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.size) }
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
}
}
impl Drop for ThreadLocalAllocation {
fn drop(&mut self) {
if let Err(e) = self.pool.deallocate(self.ptr, self.size) {
log::error!("Failed to deallocate thread-local memory: {}", e);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_threadlocal_pool_creation() {
let config = ThreadLocalPoolConfig::default();
let pool = ThreadLocalMemoryPool::new(config).unwrap();
assert!(pool.stats.is_some());
}
#[test]
fn test_basic_allocation_deallocation() {
let config = ThreadLocalPoolConfig::default();
let pool = ThreadLocalMemoryPool::new(config).unwrap();
let alloc = pool.allocate(64).unwrap();
assert_eq!(alloc.size(), 64);
assert!(!alloc.as_ptr().is_null());
}
#[test]
fn test_thread_local_caching() {
let config = ThreadLocalPoolConfig::default();
let pool = ThreadLocalMemoryPool::new(config).unwrap();
{
let _alloc1 = pool.allocate(64).unwrap();
let _alloc2 = pool.allocate(128).unwrap();
let _alloc3 = pool.allocate(64).unwrap(); }
if let Some(stats) = pool.stats() {
let hits = stats.cache_hits.load(Ordering::Relaxed);
println!("Cache hits: {}", hits);
}
}
#[test]
fn test_hot_area_allocation() {
let config = ThreadLocalPoolConfig {
arena_size: 4096, ..ThreadLocalPoolConfig::default()
};
let pool = ThreadLocalMemoryPool::new(config).unwrap();
let mut allocations = Vec::new();
for i in 0..10 {
let alloc = pool.allocate(32 + i).unwrap();
allocations.push(alloc);
}
if let Some(stats) = pool.stats() {
let hot_allocs = stats.hot_allocations.load(Ordering::Relaxed);
let arena_allocs = stats.arena_allocations.load(Ordering::Relaxed);
println!("Hot allocations: {}, Arena allocations: {}", hot_allocs, arena_allocs);
assert!(hot_allocs > 0);
}
}
#[test]
fn test_concurrent_thread_local_allocation() {
let config = ThreadLocalPoolConfig::high_performance();
let pool = ThreadLocalMemoryPool::new(config).unwrap();
let mut allocations = Vec::new();
for i in 0..10 {
let alloc = pool.allocate(64 + i).unwrap();
allocations.push(alloc);
}
if let Some(stats) = pool.stats() {
let hit_ratio = stats.hit_ratio();
let locality = stats.locality_score();
println!("Hit ratio: {:.2}, Locality score: {:.2}", hit_ratio, locality);
}
}
#[test]
fn test_size_class_mapping() {
let pool_weak = Weak::new();
let mut cache = ThreadLocalCache::new(pool_weak, None);
assert_eq!(cache.size_to_list_index(8), Some(0)); assert_eq!(cache.size_to_list_index(16), Some(0)); assert_eq!(cache.size_to_list_index(17), Some(1)); assert_eq!(cache.size_to_list_index(64), Some(3)); assert_eq!(cache.size_to_list_index(5000), None); }
#[test]
fn test_cache_overflow() {
let config = ThreadLocalPoolConfig {
max_cached_chunks: 2, ..ThreadLocalPoolConfig::default()
};
let pool = ThreadLocalMemoryPool::new(config).unwrap();
for _ in 0..5 {
let alloc = pool.allocate(64).unwrap();
drop(alloc); }
}
#[test]
fn test_different_configurations() {
let hp_config = ThreadLocalPoolConfig::high_performance();
let hp_pool = ThreadLocalMemoryPool::new(hp_config).unwrap();
assert!(hp_pool.stats.is_none());
let compact_config = ThreadLocalPoolConfig::compact();
let compact_pool = ThreadLocalMemoryPool::new(compact_config).unwrap();
assert!(compact_pool.stats.is_some()); assert_eq!(compact_pool.config.arena_size, 512 * 1024);
}
}