use crate::error::{Result, ZiporaError};
use std::alloc::{Layout, alloc, dealloc};
use std::collections::VecDeque;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, RwLock};
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub chunk_size: usize,
pub max_chunks: usize,
pub alignment: usize,
}
impl PoolConfig {
pub fn new(chunk_size: usize, max_chunks: usize, alignment: usize) -> Self {
Self {
chunk_size,
max_chunks,
alignment,
}
}
pub fn small() -> Self {
Self::new(1024, 100, 8)
}
pub fn medium() -> Self {
Self::new(64 * 1024, 50, 16)
}
pub fn large() -> Self {
Self::new(1024 * 1024, 10, 32)
}
}
#[derive(Debug, Clone)]
pub struct PoolStats {
pub allocated: u64,
pub available: u64,
pub chunks: usize,
pub alloc_count: u64,
pub dealloc_count: u64,
pub pool_hits: u64,
pub pool_misses: u64,
}
impl Default for PoolStats {
fn default() -> Self {
Self {
allocated: 0,
available: 0,
chunks: 0,
alloc_count: 0,
dealloc_count: 0,
pool_hits: 0,
pool_misses: 0,
}
}
}
pub struct MemoryPool {
config: PoolConfig,
free_chunks: Mutex<VecDeque<*mut u8>>,
stats: RwLock<PoolStats>,
alloc_count: AtomicU64,
dealloc_count: AtomicU64,
pool_hits: AtomicU64,
pool_misses: AtomicU64,
}
unsafe impl Send for MemoryPool {}
unsafe impl Sync for MemoryPool {}
impl MemoryPool {
pub fn new(config: PoolConfig) -> Result<Self> {
if config.chunk_size == 0 {
return Err(ZiporaError::invalid_data("chunk_size cannot be zero"));
}
if config.alignment == 0 || !config.alignment.is_power_of_two() {
return Err(ZiporaError::invalid_data(
"alignment must be a power of two",
));
}
Ok(Self {
config,
free_chunks: Mutex::new(VecDeque::new()),
stats: RwLock::new(PoolStats::default()),
alloc_count: AtomicU64::new(0),
dealloc_count: AtomicU64::new(0),
pool_hits: AtomicU64::new(0),
pool_misses: AtomicU64::new(0),
})
}
pub fn allocate(&self) -> Result<NonNull<u8>> {
self.alloc_count.fetch_add(1, Ordering::Relaxed);
if let Ok(mut free_chunks) = self.free_chunks.try_lock() {
if let Some(chunk) = free_chunks.pop_front() {
self.pool_hits.fetch_add(1, Ordering::Relaxed);
self.update_stats_on_alloc(true);
return Ok(unsafe { NonNull::new_unchecked(chunk) });
}
}
self.pool_misses.fetch_add(1, Ordering::Relaxed);
self.allocate_new_chunk()
}
pub fn deallocate(&self, chunk: NonNull<u8>) -> Result<()> {
self.dealloc_count.fetch_add(1, Ordering::Relaxed);
if let Ok(mut free_chunks) = self.free_chunks.try_lock() {
if free_chunks.len() < self.config.max_chunks {
free_chunks.push_back(chunk.as_ptr());
self.update_stats_on_dealloc(true);
return Ok(());
}
}
self.deallocate_chunk(chunk);
self.update_stats_on_dealloc(false);
Ok(())
}
pub fn stats(&self) -> PoolStats {
let mut stats = self.stats.read()
.map(|s| s.clone())
.unwrap_or_default();
stats.alloc_count = self.alloc_count.load(Ordering::Relaxed);
stats.dealloc_count = self.dealloc_count.load(Ordering::Relaxed);
stats.pool_hits = self.pool_hits.load(Ordering::Relaxed);
stats.pool_misses = self.pool_misses.load(Ordering::Relaxed);
if let Ok(free_chunks) = self.free_chunks.try_lock() {
stats.chunks = free_chunks.len();
stats.available = (free_chunks.len() * self.config.chunk_size) as u64;
}
stats
}
pub fn clear(&self) -> Result<()> {
let mut free_chunks = self.free_chunks.lock()
.map_err(|e| ZiporaError::resource_busy(format!("Free chunks mutex poisoned: {}", e)))?;
while let Some(chunk_ptr) = free_chunks.pop_front() {
let chunk = unsafe { NonNull::new_unchecked(chunk_ptr) };
self.deallocate_chunk(chunk);
}
let mut stats = self.stats.write()
.map_err(|e| ZiporaError::resource_busy(format!("Stats RwLock poisoned: {}", e)))?;
stats.chunks = 0;
stats.available = 0;
Ok(())
}
pub fn config(&self) -> &PoolConfig {
&self.config
}
fn allocate_new_chunk(&self) -> Result<NonNull<u8>> {
let layout = Layout::from_size_align(self.config.chunk_size, self.config.alignment)
.map_err(|_| ZiporaError::invalid_data("invalid layout for chunk allocation"))?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
return Err(ZiporaError::out_of_memory(self.config.chunk_size));
}
self.update_stats_on_alloc(false);
Ok(unsafe { NonNull::new_unchecked(ptr) })
}
fn deallocate_chunk(&self, chunk: NonNull<u8>) {
let layout = Layout::from_size_align(self.config.chunk_size, self.config.alignment)
.expect("Layout invariant violated: config was validated during chunk allocation");
unsafe {
dealloc(chunk.as_ptr(), layout);
}
}
fn update_stats_on_alloc(&self, from_pool: bool) {
if let Ok(mut stats) = self.stats.try_write() {
if !from_pool {
stats.allocated += self.config.chunk_size as u64;
}
}
}
fn update_stats_on_dealloc(&self, to_pool: bool) {
if let Ok(mut stats) = self.stats.try_write() {
if !to_pool {
stats.allocated = stats
.allocated
.saturating_sub(self.config.chunk_size as u64);
}
}
}
}
impl Drop for MemoryPool {
fn drop(&mut self) {
let _ = self.clear();
}
}
static GLOBAL_POOLS: once_cell::sync::Lazy<GlobalPools> =
once_cell::sync::Lazy::new(|| GlobalPools::new());
struct GlobalPools {
small_pool: Arc<MemoryPool>,
medium_pool: Arc<MemoryPool>,
large_pool: Arc<MemoryPool>,
}
impl GlobalPools {
fn new() -> Self {
Self {
small_pool: Arc::new(MemoryPool::new(PoolConfig::small()).expect("small pool creation")),
medium_pool: Arc::new(MemoryPool::new(PoolConfig::medium()).expect("medium pool creation")),
large_pool: Arc::new(MemoryPool::new(PoolConfig::large()).expect("large pool creation")),
}
}
fn get_pool_for_size(&self, size: usize) -> &Arc<MemoryPool> {
if size <= 1024 {
&self.small_pool
} else if size <= 64 * 1024 {
&self.medium_pool
} else {
&self.large_pool
}
}
}
pub fn init_global_pools(chunk_size: usize, max_memory: usize) -> Result<()> {
if chunk_size == 0 {
return Err(ZiporaError::invalid_data("chunk_size cannot be zero"));
}
if max_memory == 0 {
return Err(ZiporaError::invalid_data("max_memory cannot be zero"));
}
log::debug!(
"Global pools initialized with chunk_size={}, max_memory={}",
chunk_size,
max_memory
);
Ok(())
}
pub fn get_global_pool_stats() -> PoolStats {
let small_stats = GLOBAL_POOLS.small_pool.stats();
let medium_stats = GLOBAL_POOLS.medium_pool.stats();
let large_stats = GLOBAL_POOLS.large_pool.stats();
PoolStats {
allocated: small_stats.allocated + medium_stats.allocated + large_stats.allocated,
available: small_stats.available + medium_stats.available + large_stats.available,
chunks: small_stats.chunks + medium_stats.chunks + large_stats.chunks,
alloc_count: small_stats.alloc_count + medium_stats.alloc_count + large_stats.alloc_count,
dealloc_count: small_stats.dealloc_count
+ medium_stats.dealloc_count
+ large_stats.dealloc_count,
pool_hits: small_stats.pool_hits + medium_stats.pool_hits + large_stats.pool_hits,
pool_misses: small_stats.pool_misses + medium_stats.pool_misses + large_stats.pool_misses,
}
}
pub struct PooledVec<T> {
ptr: NonNull<T>,
len: usize,
capacity: usize,
pool: Arc<MemoryPool>,
}
impl<T> PooledVec<T> {
pub fn new() -> Result<Self> {
let element_size = std::mem::size_of::<T>();
let pool = GLOBAL_POOLS.get_pool_for_size(element_size).clone();
let chunk = pool.allocate()?;
let capacity = pool.config().chunk_size / element_size;
Ok(Self {
ptr: chunk.cast(),
len: 0,
capacity,
pool,
})
}
#[inline]
pub fn push(&mut self, item: T) -> Result<()> {
if self.len >= self.capacity {
return Err(ZiporaError::invalid_data("vector capacity exceeded"));
}
unsafe {
self.ptr.as_ptr().add(self.len).write(item);
}
self.len += 1;
Ok(())
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline]
pub fn capacity(&self) -> usize {
self.capacity
}
#[inline]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
}
impl<T> Drop for PooledVec<T> {
fn drop(&mut self) {
for i in 0..self.len {
unsafe {
self.ptr.as_ptr().add(i).drop_in_place();
}
}
let _ = self.pool.deallocate(self.ptr.cast());
}
}
pub struct PooledBuffer {
ptr: NonNull<u8>,
len: usize,
#[allow(dead_code)]
capacity: usize,
pool: Arc<MemoryPool>,
}
impl PooledBuffer {
pub fn new(size: usize) -> Result<Self> {
let pool = GLOBAL_POOLS.get_pool_for_size(size).clone();
let chunk = pool.allocate()?;
Ok(Self {
ptr: chunk,
len: size,
capacity: pool.config().chunk_size,
pool,
})
}
#[inline]
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u8] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
#[inline]
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl Drop for PooledBuffer {
fn drop(&mut self) {
let _ = self.pool.deallocate(self.ptr);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config() {
let config = PoolConfig::new(1024, 100, 8);
assert_eq!(config.chunk_size, 1024);
assert_eq!(config.max_chunks, 100);
assert_eq!(config.alignment, 8);
let small_config = PoolConfig::small();
assert_eq!(small_config.chunk_size, 1024);
assert_eq!(small_config.max_chunks, 100);
}
#[test]
fn test_memory_pool_creation() {
let config = PoolConfig::new(1024, 10, 8);
let pool = MemoryPool::new(config).unwrap();
let stats = pool.stats();
assert_eq!(stats.chunks, 0);
assert_eq!(stats.allocated, 0);
}
#[test]
fn test_memory_pool_allocation() {
let config = PoolConfig::new(1024, 10, 8);
let pool = MemoryPool::new(config).unwrap();
let chunk1 = pool.allocate().unwrap();
let chunk2 = pool.allocate().unwrap();
assert_ne!(chunk1.as_ptr(), chunk2.as_ptr());
pool.deallocate(chunk1).unwrap();
pool.deallocate(chunk2).unwrap();
let stats = pool.stats();
assert_eq!(stats.alloc_count, 2);
assert_eq!(stats.dealloc_count, 2);
}
#[test]
fn test_memory_pool_reuse() {
let config = PoolConfig::new(1024, 10, 8);
let pool = MemoryPool::new(config).unwrap();
let chunk1 = pool.allocate().unwrap();
let addr1 = chunk1.as_ptr();
pool.deallocate(chunk1).unwrap();
let chunk2 = pool.allocate().unwrap();
let addr2 = chunk2.as_ptr();
assert_eq!(addr1, addr2);
pool.deallocate(chunk2).unwrap();
let stats = pool.stats();
assert!(stats.pool_hits > 0);
}
#[test]
fn test_pooled_vec() {
let mut vec = PooledVec::<i32>::new().unwrap();
assert_eq!(vec.len(), 0);
assert!(vec.is_empty());
assert!(vec.capacity() > 0);
vec.push(42).unwrap();
vec.push(84).unwrap();
assert_eq!(vec.len(), 2);
assert!(!vec.is_empty());
let slice = vec.as_slice();
assert_eq!(slice[0], 42);
assert_eq!(slice[1], 84);
}
#[test]
fn test_pooled_buffer() {
let mut buffer = PooledBuffer::new(100).unwrap();
assert_eq!(buffer.len(), 100);
assert!(!buffer.is_empty());
let slice = buffer.as_mut_slice();
slice[0] = 42;
slice[99] = 84;
let slice = buffer.as_slice();
assert_eq!(slice[0], 42);
assert_eq!(slice[99], 84);
}
#[test]
fn test_global_pool_stats() {
let _stats = get_global_pool_stats();
}
#[test]
fn test_invalid_pool_config() {
let result = MemoryPool::new(PoolConfig::new(0, 10, 8));
assert!(result.is_err());
let result = MemoryPool::new(PoolConfig::new(1024, 10, 0));
assert!(result.is_err());
let result = MemoryPool::new(PoolConfig::new(1024, 10, 3)); assert!(result.is_err());
}
}