use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_pool_size: usize,
pub min_pooled_size: usize,
pub max_pooled_size: usize,
pub prealloc_count: usize,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_pool_size: 16 * 1024 * 1024, min_pooled_size: 1024, max_pooled_size: 4 * 1024 * 1024, prealloc_count: 8, }
}
}
#[derive(Debug)]
pub struct MemoryPool {
pools: Arc<Mutex<HashMap<usize, Vec<Vec<u8>>>>>,
config: PoolConfig,
stats: Arc<Mutex<PoolStats>>,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub total_allocations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub total_bytes_allocated: u64,
pub pooled_bytes_served: u64,
pub peak_memory_usage: usize,
pub current_memory_usage: usize,
}
impl MemoryPool {
pub fn new() -> Self {
Self::with_config(PoolConfig::default())
}
pub fn with_config(config: PoolConfig) -> Self {
let pool = Self {
pools: Arc::new(Mutex::new(HashMap::new())),
config,
stats: Arc::new(Mutex::new(PoolStats::default())),
};
pool.preallocate_common_sizes();
pool
}
fn preallocate_common_sizes(&self) {
let common_sizes = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072];
for &size in &common_sizes {
if size >= self.config.min_pooled_size && size <= self.config.max_pooled_size {
let pool_size = self.round_to_power_of_2(size);
let mut pools = self.pools.lock().unwrap();
let pool = pools.entry(pool_size).or_default();
for _ in 0..self.config.prealloc_count {
pool.push(vec![0; pool_size]);
}
}
}
}
pub fn allocate(&self, size: usize) -> Vec<u8> {
let mut stats = self.stats.lock().unwrap();
stats.total_allocations += 1;
stats.total_bytes_allocated += size as u64;
if size < self.config.min_pooled_size || size > self.config.max_pooled_size {
stats.cache_misses += 1;
stats.current_memory_usage += size;
if stats.current_memory_usage > stats.peak_memory_usage {
stats.peak_memory_usage = stats.current_memory_usage;
}
drop(stats);
return vec![0; size];
}
let pool_size = self.round_to_power_of_2(size);
let mut pools = self.pools.lock().unwrap();
if let Some(pool) = pools.get_mut(&pool_size) {
if let Some(mut buffer) = pool.pop() {
stats.cache_hits += 1;
stats.pooled_bytes_served += pool_size as u64;
drop(stats);
drop(pools);
buffer.resize(size, 0);
return buffer;
}
}
stats.cache_misses += 1;
stats.current_memory_usage += pool_size;
if stats.current_memory_usage > stats.peak_memory_usage {
stats.peak_memory_usage = stats.current_memory_usage;
}
drop(stats);
drop(pools);
vec![0; size]
}
pub fn deallocate(&self, mut buffer: Vec<u8>) {
let original_size = buffer.len();
if original_size < self.config.min_pooled_size || original_size > self.config.max_pooled_size {
let mut stats = self.stats.lock().unwrap();
stats.current_memory_usage = stats.current_memory_usage.saturating_sub(original_size);
return;
}
let pool_size = self.round_to_power_of_2(original_size);
buffer.resize(pool_size, 0);
buffer.clear(); buffer.resize(pool_size, 0);
let mut pools = self.pools.lock().unwrap();
let pool = pools.entry(pool_size).or_default();
if pool.len() < self.config.max_pool_size / pool_size {
pool.push(buffer);
} else {
let mut stats = self.stats.lock().unwrap();
stats.current_memory_usage = stats.current_memory_usage.saturating_sub(pool_size);
}
}
fn round_to_power_of_2(&self, size: usize) -> usize {
if size <= 1 {
return 1;
}
let mut power = 1;
while power < size {
power <<= 1;
}
power
}
pub fn stats(&self) -> PoolStats {
self.stats.lock().unwrap().clone()
}
pub fn hit_ratio(&self) -> f64 {
let stats = self.stats.lock().unwrap();
if stats.total_allocations == 0 {
return 0.0;
}
(stats.cache_hits as f64 / stats.total_allocations as f64) * 100.0
}
pub fn clear(&self) {
self.pools.lock().unwrap().clear();
let mut stats = self.stats.lock().unwrap();
*stats = PoolStats::default();
}
pub fn total_pooled_memory(&self) -> usize {
let pools = self.pools.lock().unwrap();
pools.iter()
.map(|(&size, pool)| size * pool.len())
.sum()
}
pub fn shrink_to_fit(&self) {
let mut pools = self.pools.lock().unwrap();
for pool in pools.values_mut() {
pool.shrink_to_fit();
}
}
}
impl Default for MemoryPool {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_POOL: std::sync::OnceLock<MemoryPool> = std::sync::OnceLock::new();
pub fn global_pool() -> &'static MemoryPool {
GLOBAL_POOL.get_or_init(MemoryPool::new)
}
pub fn allocate(size: usize) -> Vec<u8> {
global_pool().allocate(size)
}
pub fn deallocate(buffer: Vec<u8>) {
global_pool().deallocate(buffer);
}
pub fn global_stats() -> PoolStats {
global_pool().stats()
}
pub struct KernelMemoryManager {
pool: Arc<MemoryPool>,
allocations: Mutex<HashMap<*const u8, usize>>,
}
impl KernelMemoryManager {
pub fn new() -> Self {
Self {
pool: Arc::new(MemoryPool::new()),
allocations: Mutex::new(HashMap::new()),
}
}
pub fn allocate_kernel_memory(&self, size: usize, alignment: usize) -> Result<*mut u8> {
let buffer = self.pool.allocate(size + alignment - 1);
let ptr = buffer.as_ptr() as *mut u8;
{
let mut allocations = self.allocations.lock().unwrap();
allocations.insert(ptr, size);
}
std::mem::forget(buffer);
Ok(ptr)
}
pub unsafe fn deallocate_kernel_memory(&self, ptr: *mut u8) -> Result<()> {
let size = {
let mut allocations = self.allocations.lock().unwrap();
allocations.remove(&(ptr as *const u8))
.ok_or_else(|| crate::error::CudaRustError::MemoryError("Invalid pointer for deallocation".to_string()))?
};
let buffer = Vec::from_raw_parts(ptr, size, size);
self.pool.deallocate(buffer);
Ok(())
}
pub fn total_kernel_memory(&self) -> usize {
let allocations = self.allocations.lock().unwrap();
allocations.values().sum()
}
}
impl Default for KernelMemoryManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_basic() {
let pool = MemoryPool::new();
let buffer1 = pool.allocate(1024);
assert_eq!(buffer1.len(), 1024);
pool.deallocate(buffer1);
let buffer2 = pool.allocate(1024);
assert_eq!(buffer2.len(), 1024);
assert!(pool.hit_ratio() > 0.0);
}
#[test]
fn test_power_of_2_rounding() {
let pool = MemoryPool::new();
assert_eq!(pool.round_to_power_of_2(1000), 1024);
assert_eq!(pool.round_to_power_of_2(1024), 1024);
assert_eq!(pool.round_to_power_of_2(1500), 2048);
}
#[test]
fn test_global_pool() {
let buffer = allocate(2048);
assert_eq!(buffer.len(), 2048);
deallocate(buffer);
let stats = global_stats();
assert!(stats.total_allocations > 0);
}
#[test]
fn test_kernel_memory_manager() {
let manager = KernelMemoryManager::new();
unsafe {
let ptr = manager.allocate_kernel_memory(4096, 16).unwrap();
assert!(!ptr.is_null());
manager.deallocate_kernel_memory(ptr).unwrap();
}
}
}