use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, Instant};
use crate::error::{Error, Result};
use crate::gpu::{GpuError, GpuManager};
use crate::{lock_safe, read_lock_safe, write_lock_safe};
#[cfg(cuda_available)]
use cudarc::driver::{CudaContext as CudarcContext, CudaSlice, CudaStream};
#[derive(Debug, Clone)]
pub struct MemoryPoolConfig {
pub initial_size: usize,
pub max_size: usize,
pub min_allocation_size: usize,
pub enable_compaction: bool,
pub cleanup_interval: u64,
pub max_allocation_age: u64,
pub growth_factor: f64,
}
impl Default for MemoryPoolConfig {
fn default() -> Self {
Self {
initial_size: 256 * 1024 * 1024, max_size: 2 * 1024 * 1024 * 1024, min_allocation_size: 4096, enable_compaction: true,
cleanup_interval: 30, max_allocation_age: 300, growth_factor: 1.5,
}
}
}
#[derive(Debug, Clone)]
struct AllocationInfo {
size: usize,
allocated_at: Instant,
last_accessed: Instant,
in_use: bool,
ref_count: usize,
}
#[derive(Debug)]
struct MemoryBlock {
#[cfg(cuda_available)]
ptr: Box<CudaSlice<u8>>,
#[cfg(not(cuda_available))]
ptr: usize, size: usize,
is_free: bool,
allocation_info: Option<AllocationInfo>,
}
pub struct GpuMemoryPool {
config: MemoryPoolConfig,
device_id: i32,
free_blocks: BTreeMap<usize, VecDeque<Arc<Mutex<MemoryBlock>>>>,
allocated_blocks: HashMap<usize, Arc<Mutex<MemoryBlock>>>, current_size: usize,
peak_usage: usize,
stats: MemoryPoolStats,
last_cleanup: Instant,
#[cfg(cuda_available)]
context: Option<Arc<CudarcContext>>,
}
#[derive(Debug, Clone, Default)]
pub struct MemoryPoolStats {
pub total_allocations: u64,
pub total_deallocations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub total_bytes_allocated: u64,
pub total_bytes_deallocated: u64,
pub compaction_count: u64,
pub avg_allocation_size: f64,
}
impl GpuMemoryPool {
pub fn new(device_id: i32, config: MemoryPoolConfig) -> Result<Self> {
let context = Self::get_cuda_context(device_id)?;
let mut pool = Self {
config,
device_id,
free_blocks: BTreeMap::new(),
allocated_blocks: HashMap::new(),
current_size: 0,
peak_usage: 0,
stats: MemoryPoolStats::default(),
last_cleanup: Instant::now(),
#[cfg(cuda_available)]
context,
#[cfg(not(cuda_available))]
context: None,
};
pool.expand_pool(pool.config.initial_size)?;
Ok(pool)
}
#[cfg(cuda_available)]
fn get_cuda_context(device_id: i32) -> Result<Option<Arc<CudarcContext>>> {
Ok(None)
}
#[cfg(not(cuda_available))]
fn get_cuda_context(_device_id: i32) -> Result<Option<()>> {
Ok(None)
}
pub fn allocate(&mut self, size: usize) -> Result<GpuAllocation> {
let aligned_size = self.align_size(size);
self.maybe_cleanup();
if let Some(block) = self.find_free_block(aligned_size) {
self.stats.cache_hits += 1;
self.use_block(block, aligned_size)
} else {
self.stats.cache_misses += 1;
self.allocate_new_block(aligned_size)
}
}
pub fn deallocate(&mut self, allocation: GpuAllocation) -> Result<()> {
let ptr_addr = allocation.ptr_address();
if let Some(block) = self.allocated_blocks.remove(&ptr_addr) {
let mut block_guard = lock_safe!(block, "memory block lock for deallocation")?;
block_guard.is_free = true;
block_guard.allocation_info = None;
let size = block_guard.size;
self.free_blocks
.entry(size)
.or_insert_with(VecDeque::new)
.push_back(block.clone());
self.stats.total_deallocations += 1;
self.stats.total_bytes_deallocated += size as u64;
Ok(())
} else {
Err(Error::from(GpuError::DeviceError(
"Invalid allocation pointer".to_string(),
)))
}
}
pub fn get_stats(&self) -> MemoryPoolStats {
let mut stats = self.stats.clone();
if stats.total_allocations > 0 {
stats.avg_allocation_size =
stats.total_bytes_allocated as f64 / stats.total_allocations as f64;
}
stats
}
pub fn current_size(&self) -> usize {
self.current_size
}
pub fn peak_usage(&self) -> usize {
self.peak_usage
}
pub fn cleanup(&mut self) -> Result<()> {
let now = Instant::now();
let max_age = Duration::from_secs(self.config.max_allocation_age);
let mut removed_size = 0;
let mut blocks_to_remove = Vec::new();
for (size, blocks) in &mut self.free_blocks {
blocks.retain(|block| {
let block_guard = match lock_safe!(block, "memory block lock for cleanup") {
Ok(guard) => guard,
Err(_) => return true, };
if let Some(ref info) = block_guard.allocation_info {
let age = now.duration_since(info.last_accessed);
if age > max_age {
removed_size += *size;
false
} else {
true
}
} else {
true
}
});
if blocks.is_empty() {
blocks_to_remove.push(*size);
}
}
for size in blocks_to_remove {
self.free_blocks.remove(&size);
}
self.current_size -= removed_size;
self.last_cleanup = now;
if self.config.enable_compaction {
self.compact_memory()?;
}
Ok(())
}
fn compact_memory(&mut self) -> Result<()> {
self.stats.compaction_count += 1;
log::info!("Memory compaction completed for device {}", self.device_id);
Ok(())
}
fn find_free_block(&mut self, size: usize) -> Option<Arc<Mutex<MemoryBlock>>> {
if let Some(blocks) = self.free_blocks.get_mut(&size) {
if let Some(block) = blocks.pop_front() {
return Some(block);
}
}
for (&block_size, blocks) in self.free_blocks.range_mut(size..) {
if let Some(block) = blocks.pop_front() {
if block_size > size * 2 && block_size - size >= self.config.min_allocation_size {
}
return Some(block);
}
}
None
}
fn use_block(&mut self, block: Arc<Mutex<MemoryBlock>>, size: usize) -> Result<GpuAllocation> {
let mut block_guard = lock_safe!(block, "memory block lock for use_block")?;
block_guard.is_free = false;
block_guard.allocation_info = Some(AllocationInfo {
size,
allocated_at: Instant::now(),
last_accessed: Instant::now(),
in_use: true,
ref_count: 1,
});
let ptr_addr = self.get_ptr_address(&block_guard);
drop(block_guard);
self.allocated_blocks.insert(ptr_addr, block);
self.stats.total_allocations += 1;
self.stats.total_bytes_allocated += size as u64;
Ok(GpuAllocation::new(ptr_addr, size))
}
fn allocate_new_block(&mut self, size: usize) -> Result<GpuAllocation> {
if self.current_size + size > self.config.max_size {
return Err(Error::from(GpuError::DeviceError(
"Memory pool size limit exceeded".to_string(),
)));
}
let block = self.allocate_gpu_memory(size)?;
let ptr_addr = self.get_ptr_address(&block);
let block = Arc::new(Mutex::new(block));
self.allocated_blocks.insert(ptr_addr, block.clone());
self.current_size += size;
self.peak_usage = self.peak_usage.max(self.current_size);
self.stats.total_allocations += 1;
self.stats.total_bytes_allocated += size as u64;
Ok(GpuAllocation::new(ptr_addr, size))
}
fn allocate_gpu_memory(&self, size: usize) -> Result<MemoryBlock> {
#[cfg(cuda_available)]
{
if let Some(ref context) = self.context {
let stream = context.default_stream();
match stream.alloc_zeros::<u8>(size) {
Ok(ptr) => Ok(MemoryBlock {
ptr: Box::new(ptr),
size,
is_free: false,
allocation_info: Some(AllocationInfo {
size,
allocated_at: Instant::now(),
last_accessed: Instant::now(),
in_use: true,
ref_count: 1,
}),
}),
Err(e) => Err(Error::from(GpuError::DeviceError(format!(
"GPU memory allocation failed: {}",
e
)))),
}
} else {
Err(Error::from(GpuError::DeviceError(
"CUDA context not available for memory allocation".to_string(),
)))
}
}
#[cfg(not(cuda_available))]
{
Ok(MemoryBlock {
ptr: size, size,
is_free: false,
allocation_info: Some(AllocationInfo {
size,
allocated_at: Instant::now(),
last_accessed: Instant::now(),
in_use: true,
ref_count: 1,
}),
})
}
}
fn get_ptr_address(&self, block: &MemoryBlock) -> usize {
#[cfg(cuda_available)]
{
block as *const _ as usize
}
#[cfg(not(cuda_available))]
{
block.ptr
}
}
fn expand_pool(&mut self, additional_size: usize) -> Result<()> {
if self.current_size + additional_size > self.config.max_size {
return Err(Error::from(GpuError::DeviceError(
"Cannot expand pool beyond maximum size".to_string(),
)));
}
let chunk_size = self.config.min_allocation_size * 16; let num_chunks = additional_size / chunk_size;
for _ in 0..num_chunks {
let block = self.allocate_gpu_memory(chunk_size)?;
let block = Arc::new(Mutex::new(MemoryBlock {
is_free: true,
allocation_info: None,
..block
}));
self.free_blocks
.entry(chunk_size)
.or_insert_with(VecDeque::new)
.push_back(block);
}
self.current_size += num_chunks * chunk_size;
Ok(())
}
fn align_size(&self, size: usize) -> usize {
let min_size = self.config.min_allocation_size;
(size + min_size - 1) / min_size * min_size
}
fn maybe_cleanup(&mut self) {
let now = Instant::now();
if now.duration_since(self.last_cleanup).as_secs() >= self.config.cleanup_interval {
let _ = self.cleanup();
}
}
}
pub struct GpuAllocation {
ptr_address: usize,
size: usize,
}
impl GpuAllocation {
fn new(ptr_address: usize, size: usize) -> Self {
Self { ptr_address, size }
}
pub fn ptr_address(&self) -> usize {
self.ptr_address
}
pub fn size(&self) -> usize {
self.size
}
#[cfg(cuda_available)]
pub fn as_device_ptr<T>(&self) -> *mut T {
self.ptr_address as *mut T
}
}
pub struct GlobalMemoryPoolManager {
pools: RwLock<HashMap<i32, Arc<Mutex<GpuMemoryPool>>>>,
default_config: MemoryPoolConfig,
}
impl GlobalMemoryPoolManager {
pub fn new(config: MemoryPoolConfig) -> Self {
Self {
pools: RwLock::new(HashMap::new()),
default_config: config,
}
}
pub fn get_pool(&self, device_id: i32) -> Result<Arc<Mutex<GpuMemoryPool>>> {
{
let pools = read_lock_safe!(self.pools, "memory pool manager pools read")?;
if let Some(pool) = pools.get(&device_id) {
return Ok(pool.clone());
}
}
let pool = GpuMemoryPool::new(device_id, self.default_config.clone())?;
let pool = Arc::new(Mutex::new(pool));
{
let mut pools = write_lock_safe!(self.pools, "memory pool manager pools write")?;
pools.insert(device_id, pool.clone());
}
Ok(pool)
}
pub fn get_all_stats(&self) -> Result<HashMap<i32, MemoryPoolStats>> {
let pools = read_lock_safe!(self.pools, "memory pool manager pools read for stats")?;
let mut stats = HashMap::new();
for (&device_id, pool) in pools.iter() {
let pool_guard = lock_safe!(pool, "memory pool lock for stats")?;
stats.insert(device_id, pool_guard.get_stats());
}
Ok(stats)
}
pub fn cleanup_all(&self) -> Result<()> {
let pools = read_lock_safe!(self.pools, "memory pool manager pools read for cleanup")?;
for pool in pools.values() {
let mut pool_guard = lock_safe!(pool, "memory pool lock for cleanup")?;
pool_guard.cleanup()?;
}
Ok(())
}
}
lazy_static::lazy_static! {
static ref GLOBAL_MEMORY_POOL: GlobalMemoryPoolManager =
GlobalMemoryPoolManager::new(MemoryPoolConfig::default());
}
pub fn get_memory_pool_manager() -> &'static GlobalMemoryPoolManager {
&GLOBAL_MEMORY_POOL
}
pub fn gpu_alloc(device_id: i32, size: usize) -> Result<GpuAllocation> {
let pool = get_memory_pool_manager().get_pool(device_id)?;
let mut pool_guard = lock_safe!(pool, "global memory pool lock for allocation")?;
pool_guard.allocate(size)
}
pub fn gpu_dealloc(device_id: i32, allocation: GpuAllocation) -> Result<()> {
let pool = get_memory_pool_manager().get_pool(device_id)?;
let mut pool_guard = lock_safe!(pool, "global memory pool lock for deallocation")?;
pool_guard.deallocate(allocation)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpu::init_gpu;
fn is_gpu_available_for_testing() -> bool {
if let Ok(status) = init_gpu() {
status.available
} else {
false
}
}
#[test]
fn test_memory_pool_creation() {
let config = MemoryPoolConfig::default();
let pool = GpuMemoryPool::new(0, config);
if pool.is_err() {
println!(
"Pool creation failed (expected if no GPU memory available): {:?}",
pool.err()
);
return;
}
assert!(pool.is_ok());
}
#[test]
fn test_memory_allocation() {
if !is_gpu_available_for_testing() {
println!("Skipping test_memory_allocation - no GPU available");
return;
}
let config = MemoryPoolConfig {
initial_size: 1024 * 1024, ..MemoryPoolConfig::default()
};
let pool = GpuMemoryPool::new(0, config);
if pool.is_err() {
println!("Skipping test_memory_allocation - pool creation failed");
return;
}
let mut pool = pool.expect("operation should succeed");
let alloc1 = pool.allocate(1024).expect("operation should succeed");
assert_eq!(alloc1.size(), 4096);
let alloc2 = pool.allocate(2048).expect("operation should succeed");
assert_eq!(alloc2.size(), 4096);
pool.deallocate(alloc1).expect("operation should succeed");
pool.deallocate(alloc2).expect("operation should succeed");
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 2);
assert_eq!(stats.total_deallocations, 2);
}
#[test]
fn test_memory_pool_stats() {
if !is_gpu_available_for_testing() {
println!("Skipping test_memory_pool_stats - no GPU available");
return;
}
let config = MemoryPoolConfig::default();
let pool = GpuMemoryPool::new(0, config);
if pool.is_err() {
println!("Skipping test_memory_pool_stats - pool creation failed");
return;
}
let mut pool = pool.expect("operation should succeed");
let alloc = pool.allocate(1024).expect("operation should succeed");
let stats = pool.get_stats();
assert_eq!(stats.total_allocations, 1);
assert!(stats.avg_allocation_size > 0.0);
pool.deallocate(alloc).expect("operation should succeed");
let stats = pool.get_stats();
assert_eq!(stats.total_deallocations, 1);
}
#[test]
fn test_global_memory_pool() {
if !is_gpu_available_for_testing() {
println!("Skipping test_global_memory_pool - no GPU available");
return;
}
let manager = get_memory_pool_manager();
let pool = manager.get_pool(0);
if pool.is_err() {
println!("Skipping test_global_memory_pool - pool retrieval failed");
return;
}
let pool = pool.expect("operation should succeed");
let pool2 = manager.get_pool(0).expect("operation should succeed");
assert_eq!(Arc::as_ptr(&pool), Arc::as_ptr(&pool2));
}
}