use crate::VeloxxError;
use std::{
alloc::{alloc_zeroed, dealloc, Layout},
collections::HashMap,
ptr::NonNull,
sync::{Arc, Mutex},
};
pub const CACHE_LINE_SIZE: usize = 64;
pub const SIMD_ALIGNMENT: usize = 64;
pub struct MemoryPool {
pools: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
#[allow(clippy::arc_with_non_send_sync)]
aligned_pools: Arc<Mutex<HashMap<usize, Vec<*mut u8>>>>,
max_pool_size: usize,
total_allocated: std::sync::atomic::AtomicUsize,
allocation_count: std::sync::atomic::AtomicUsize,
}
unsafe impl Send for MemoryPool {}
unsafe impl Sync for MemoryPool {}
impl MemoryPool {
#[allow(clippy::arc_with_non_send_sync)]
pub fn new(max_pool_size: usize) -> Self {
Self {
pools: Arc::new(Mutex::new(HashMap::new())),
aligned_pools: Arc::new(Mutex::new(HashMap::new())),
max_pool_size,
total_allocated: std::sync::atomic::AtomicUsize::new(0),
allocation_count: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn allocate(&self, size: usize) -> Result<Vec<u8>, VeloxxError> {
if size > self.max_pool_size {
return Ok(vec![0; size]);
}
let mut pools = self.pools.lock().map_err(|_| {
VeloxxError::MemoryError("Failed to acquire memory pool lock".to_string())
})?;
if let Some(blocks) = pools.get_mut(&size) {
if let Some(block) = blocks.pop() {
return Ok(block);
}
}
Ok(vec![0; size])
}
pub fn deallocate(&self, block: Vec<u8>) {
let size = block.len();
if size > self.max_pool_size {
return;
}
let mut pools = match self.pools.lock() {
Ok(pools) => pools,
Err(_) => return, };
pools.entry(size).or_insert_with(Vec::new).push(block);
}
pub fn allocate_aligned<T>(&self, count: usize) -> Result<NonNull<T>, VeloxxError> {
let size = count * std::mem::size_of::<T>();
let layout = Layout::from_size_align(size, SIMD_ALIGNMENT)
.map_err(|e| VeloxxError::MemoryError(format!("Invalid layout: {}", e)))?;
{
let mut pools = self.aligned_pools.lock().map_err(|_| {
VeloxxError::MemoryError("Failed to acquire aligned pool lock".to_string())
})?;
if let Some(pool) = pools.get_mut(&size) {
if let Some(ptr) = pool.pop() {
return Ok(unsafe { NonNull::new_unchecked(ptr) }.cast::<T>());
}
}
}
let ptr = unsafe { alloc_zeroed(layout) };
if ptr.is_null() {
return Err(VeloxxError::MemoryError(
"Failed to allocate aligned memory".to_string(),
));
}
self.total_allocated
.fetch_add(size, std::sync::atomic::Ordering::Relaxed);
self.allocation_count
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(unsafe { NonNull::new_unchecked(ptr) }.cast::<T>())
}
pub fn deallocate_aligned<T>(&self, ptr: NonNull<T>, count: usize) -> Result<(), VeloxxError> {
let size = count * std::mem::size_of::<T>();
{
let mut pools = self.aligned_pools.lock().map_err(|_| {
VeloxxError::MemoryError("Failed to acquire aligned pool lock".to_string())
})?;
let pool = pools.entry(size).or_insert_with(Vec::new);
if pool.len() < 100 {
pool.push(ptr.as_ptr() as *mut u8);
return Ok(());
}
}
let layout = Layout::from_size_align(size, SIMD_ALIGNMENT).map_err(|e| {
VeloxxError::MemoryError(format!("Invalid layout for deallocation: {}", e))
})?;
unsafe {
dealloc(ptr.as_ptr() as *mut u8, layout);
}
self.total_allocated
.fetch_sub(size, std::sync::atomic::Ordering::Relaxed);
Ok(())
}
pub fn stats(&self) -> MemoryPoolStats {
let aligned_pool_count = self
.aligned_pools
.lock()
.map(|pools| pools.len())
.unwrap_or(0);
MemoryPoolStats {
total_allocated: self
.total_allocated
.load(std::sync::atomic::Ordering::Relaxed),
allocation_count: self
.allocation_count
.load(std::sync::atomic::Ordering::Relaxed),
pool_count: aligned_pool_count,
}
}
pub fn memory_usage(&self) -> usize {
let pools = match self.pools.lock() {
Ok(pools) => pools,
Err(_) => return 0,
};
pools
.values()
.map(|blocks| blocks.len() * blocks.first().map_or(0, |b| b.len()))
.sum()
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub total_allocated: usize,
pub allocation_count: usize,
pub pool_count: usize,
}
static GLOBAL_MEMORY_POOL: std::sync::OnceLock<Arc<MemoryPool>> = std::sync::OnceLock::new();
pub fn global_memory_pool() -> &'static Arc<MemoryPool> {
GLOBAL_MEMORY_POOL.get_or_init(|| Arc::new(MemoryPool::default()))
}
pub struct AlignedBuffer<T> {
ptr: NonNull<T>,
len: usize,
pool: Arc<MemoryPool>,
}
impl<T> AlignedBuffer<T> {
pub fn new(len: usize) -> Result<Self, VeloxxError> {
let pool = global_memory_pool().clone();
let ptr = pool.allocate_aligned::<T>(len)?;
Ok(Self { ptr, len, pool })
}
pub fn with_pool(len: usize, pool: Arc<MemoryPool>) -> Result<Self, VeloxxError> {
let ptr = pool.allocate_aligned::<T>(len)?;
Ok(Self { ptr, len, pool })
}
pub fn as_ptr(&self) -> *const T {
self.ptr.as_ptr()
}
pub fn as_mut_ptr(&mut self) -> *mut T {
self.ptr.as_ptr()
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.len) }
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len) }
}
}
impl<T> Drop for AlignedBuffer<T> {
fn drop(&mut self) {
let _ = self.pool.deallocate_aligned(self.ptr, self.len);
}
}
unsafe impl<T: Send> Send for AlignedBuffer<T> {}
unsafe impl<T: Sync> Sync for AlignedBuffer<T> {}
pub trait NumaAware {
fn allocate_numa_local<T>(count: usize) -> Result<AlignedBuffer<T>, VeloxxError> {
AlignedBuffer::new(count)
}
fn current_numa_node() -> usize {
0
}
}
impl NumaAware for MemoryPool {}
impl Default for MemoryPool {
fn default() -> Self {
Self::new(1024 * 1024) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_allocation() {
let pool = MemoryPool::new(1024);
let block = pool.allocate(100).unwrap();
assert_eq!(block.len(), 100);
pool.deallocate(block);
let block2 = pool.allocate(100).unwrap();
assert_eq!(block2.len(), 100);
}
#[test]
fn test_aligned_buffer() {
let mut buffer = AlignedBuffer::<f64>::new(100).unwrap();
assert_eq!(buffer.len(), 100);
assert_eq!(buffer.as_ptr() as usize % SIMD_ALIGNMENT, 0);
let slice = buffer.as_mut_slice();
slice[0] = 42.0;
assert_eq!(slice[0], 42.0);
}
#[test]
fn test_aligned_memory_pool_reuse() {
let pool = MemoryPool::new(1024 * 1024);
let ptr1 = pool.allocate_aligned::<f64>(100).unwrap();
pool.deallocate_aligned(ptr1, 100).unwrap();
let ptr2 = pool.allocate_aligned::<f64>(100).unwrap();
pool.deallocate_aligned(ptr2, 100).unwrap();
}
#[test]
fn test_large_allocation() {
let pool = MemoryPool::new(1024);
let block = pool.allocate(2048).unwrap();
assert_eq!(block.len(), 2048);
pool.deallocate(block);
}
#[test]
fn test_memory_usage() {
let pool = MemoryPool::new(1024);
assert_eq!(pool.memory_usage(), 0);
let block = pool.allocate(100).unwrap();
assert_eq!(pool.memory_usage(), 0);
pool.deallocate(block);
assert_eq!(pool.memory_usage(), 100); }
}