use std::sync::Arc;
use anyhow::Result;
use ronn_core::{DataType, MemoryInfo, TensorAllocator, TensorBuffer};
use super::simd::{SimdCapabilities, detect_simd_capabilities, get_optimal_vector_width};
use crate::allocator::{AlignedMemoryAllocator, PoolConfig, PooledMemoryAllocator};
#[derive(Debug)]
pub struct CpuMemoryAllocator {
pooled_allocator: PooledMemoryAllocator,
aligned_allocator: AlignedMemoryAllocator,
simd_capabilities: SimdCapabilities,
numa_node: i32,
}
impl Default for CpuMemoryAllocator {
fn default() -> Self {
Self::new()
}
}
impl CpuMemoryAllocator {
pub fn new() -> Self {
let simd_capabilities = detect_simd_capabilities();
let vector_width = get_optimal_vector_width(&simd_capabilities);
let pool_config = PoolConfig {
max_buffers_per_bucket: 128, max_pool_size: 512 * 1024 * 1024, bucket_granularity: vector_width, };
Self {
pooled_allocator: PooledMemoryAllocator::new(pool_config),
aligned_allocator: AlignedMemoryAllocator::new(),
simd_capabilities,
numa_node: -1, }
}
pub fn with_config(config: PoolConfig) -> Self {
let simd_capabilities = detect_simd_capabilities();
Self {
pooled_allocator: PooledMemoryAllocator::new(config),
aligned_allocator: AlignedMemoryAllocator::new(),
simd_capabilities,
numa_node: -1,
}
}
pub fn with_numa_node(numa_node: i32) -> Self {
let mut allocator = Self::new();
allocator.numa_node = numa_node;
allocator
}
pub fn get_simd_capabilities(&self) -> &SimdCapabilities {
&self.simd_capabilities
}
pub fn get_numa_node(&self) -> i32 {
self.numa_node
}
pub fn set_numa_node(&mut self, numa_node: i32) {
self.numa_node = numa_node;
}
fn should_use_pooling(&self, size: usize) -> bool {
const MIN_POOL_SIZE: usize = 1024; const MAX_POOL_SIZE: usize = 16 * 1024 * 1024;
size >= MIN_POOL_SIZE && size <= MAX_POOL_SIZE
}
#[cfg(target_os = "linux")]
fn allocate_numa(&self, buffer: TensorBuffer) -> Result<TensorBuffer> {
if self.numa_node < 0 {
return Ok(buffer); }
Ok(buffer)
}
#[cfg(not(target_os = "linux"))]
fn allocate_numa(&self, buffer: TensorBuffer) -> Result<TensorBuffer> {
Ok(buffer)
}
pub fn get_detailed_stats(&self) -> CpuAllocatorStats {
let memory_info = self.get_memory_info();
CpuAllocatorStats {
memory_info: memory_info.clone(),
pool_hit_rate: self.pooled_allocator.get_hit_rate(),
pool_size: memory_info.allocated_bytes, numa_node: self.numa_node,
simd_alignment: get_optimal_vector_width(&self.simd_capabilities),
}
}
pub fn clear_pools(&self) {
self.pooled_allocator.clear_pools();
}
}
impl TensorAllocator for CpuMemoryAllocator {
fn allocate(&self, shape: &[usize], dtype: DataType) -> Result<TensorBuffer> {
let size = crate::allocator::calculate_tensor_size(shape, dtype);
let buffer = if self.should_use_pooling(size) {
self.pooled_allocator.allocate(shape, dtype)?
} else {
self.aligned_allocator.allocate(shape, dtype)?
};
let numa_buffer = self.allocate_numa(buffer)?;
Ok(numa_buffer)
}
fn deallocate(&self, buffer: TensorBuffer) -> Result<()> {
let size = buffer.size;
if self.should_use_pooling(size) {
self.pooled_allocator.deallocate(buffer)
} else {
self.aligned_allocator.deallocate(buffer)
}
}
fn get_memory_info(&self) -> MemoryInfo {
let pool_info = self.pooled_allocator.get_memory_info();
let aligned_info = self.aligned_allocator.get_memory_info();
MemoryInfo {
total_bytes: pool_info.total_bytes,
allocated_bytes: pool_info.allocated_bytes + aligned_info.allocated_bytes,
peak_bytes: pool_info.peak_bytes.max(aligned_info.peak_bytes),
}
}
}
#[derive(Debug, Clone)]
pub struct CpuAllocatorStats {
pub memory_info: MemoryInfo,
pub pool_hit_rate: f64,
pub pool_size: usize,
pub numa_node: i32,
pub simd_alignment: usize,
}
pub fn create_cpu_allocator() -> Arc<dyn TensorAllocator> {
Arc::new(CpuMemoryAllocator::new())
}
pub fn create_numa_cpu_allocator(numa_node: i32) -> Arc<dyn TensorAllocator> {
Arc::new(CpuMemoryAllocator::with_numa_node(numa_node))
}
pub fn create_cpu_allocator_with_config(config: PoolConfig) -> Arc<dyn TensorAllocator> {
Arc::new(CpuMemoryAllocator::with_config(config))
}
#[cfg(target_os = "linux")]
pub fn detect_numa_nodes() -> usize {
1
}
#[cfg(not(target_os = "linux"))]
pub fn detect_numa_nodes() -> usize {
1 }
#[cfg(test)]
mod tests {
use super::*;
use ronn_core::MemoryType;
#[test]
fn test_cpu_allocator_creation() {
let allocator = CpuMemoryAllocator::new();
let capabilities = allocator.get_simd_capabilities();
println!("SIMD capabilities: {:?}", capabilities);
assert_eq!(allocator.get_numa_node(), -1);
}
#[test]
fn test_numa_configuration() {
let mut allocator = CpuMemoryAllocator::new();
assert_eq!(allocator.get_numa_node(), -1);
allocator.set_numa_node(0);
assert_eq!(allocator.get_numa_node(), 0);
}
#[test]
fn test_pooling_decision() -> Result<()> {
let allocator = CpuMemoryAllocator::new();
let small_size = 512; assert!(!allocator.should_use_pooling(small_size));
let medium_size = 64 * 1024; assert!(allocator.should_use_pooling(medium_size));
let large_size = 32 * 1024 * 1024; assert!(!allocator.should_use_pooling(large_size));
Ok(())
}
#[test]
fn test_allocation_and_deallocation() -> Result<()> {
let allocator = CpuMemoryAllocator::new();
let buffer = allocator.allocate(&[1024], DataType::F32)?;
assert_eq!(buffer.size, 4096); assert!(buffer.alignment >= 4);
assert_eq!(buffer.memory_type, MemoryType::SystemRAM);
let stats_before_dealloc = allocator.get_detailed_stats();
assert!(stats_before_dealloc.memory_info.allocated_bytes > 0);
allocator.deallocate(buffer)?;
Ok(())
}
#[test]
fn test_pool_statistics() -> Result<()> {
let allocator = CpuMemoryAllocator::new();
let buffer1 = allocator.allocate(&[256], DataType::F32)?; allocator.deallocate(buffer1)?;
let buffer2 = allocator.allocate(&[256], DataType::F32)?; let stats = allocator.get_detailed_stats();
println!("Pool hit rate: {}", stats.pool_hit_rate);
allocator.deallocate(buffer2)?;
Ok(())
}
#[test]
fn test_shared_allocator_creation() {
let allocator = create_cpu_allocator();
let memory_info = allocator.get_memory_info();
assert_eq!(memory_info.allocated_bytes, 0);
}
#[test]
fn test_numa_detection() {
let numa_nodes = detect_numa_nodes();
assert!(numa_nodes >= 1); println!("Detected {} NUMA nodes", numa_nodes);
}
#[test]
fn test_custom_config() {
let config = PoolConfig {
max_buffers_per_bucket: 16,
max_pool_size: 1024 * 1024, bucket_granularity: 32,
};
let allocator = CpuMemoryAllocator::with_config(config);
let stats = allocator.get_detailed_stats();
assert_eq!(
stats.simd_alignment,
get_optimal_vector_width(allocator.get_simd_capabilities())
);
}
}