use super::{GpuBuffer, GpuConfig};
use anyhow::{anyhow, Result};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct GpuMemoryPool {
device_id: i32,
available_buffers: Arc<Mutex<VecDeque<GpuBuffer>>>,
allocated_buffers: Arc<Mutex<Vec<GpuBuffer>>>,
total_memory: usize,
used_memory: usize,
buffer_size: usize,
max_buffers: usize,
allocation_times: Arc<Mutex<Vec<(usize, Instant)>>>,
operation_timings: Arc<Mutex<HashMap<String, Vec<Duration>>>>,
allocation_count: usize,
deallocation_count: usize,
peak_memory_usage: usize,
}
impl GpuMemoryPool {
pub fn new(config: &GpuConfig, buffer_size: usize) -> Result<Self> {
let max_buffers = config.memory_pool_size / (buffer_size * std::mem::size_of::<f32>());
Ok(Self {
device_id: config.device_id,
available_buffers: Arc::new(Mutex::new(VecDeque::new())),
allocated_buffers: Arc::new(Mutex::new(Vec::new())),
total_memory: config.memory_pool_size,
used_memory: 0,
buffer_size,
max_buffers,
allocation_times: Arc::new(Mutex::new(Vec::new())),
operation_timings: Arc::new(Mutex::new(HashMap::new())),
allocation_count: 0,
deallocation_count: 0,
peak_memory_usage: 0,
})
}
pub fn get_buffer(&mut self) -> Result<GpuBuffer> {
let start_time = Instant::now();
{
let mut available = self
.available_buffers
.lock()
.map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?;
if let Some(buffer) = available.pop_front() {
let elapsed = start_time.elapsed();
self.record_operation_time("buffer_acquire_reuse", elapsed);
let ptr_value = buffer.ptr() as usize;
self.allocation_times
.lock()
.expect("lock poisoned")
.push((ptr_value, Instant::now()));
return Ok(buffer);
}
}
if self.allocated_buffers.lock().expect("lock poisoned").len() >= self.max_buffers {
let elapsed = start_time.elapsed();
self.record_operation_time("buffer_acquire_failed", elapsed);
return Err(anyhow!("Memory pool exhausted"));
}
let alloc_start = Instant::now();
let buffer = GpuBuffer::new(self.buffer_size, self.device_id)?;
let alloc_elapsed = alloc_start.elapsed();
self.record_operation_time("buffer_alloc", alloc_elapsed);
self.used_memory += self.buffer_size * std::mem::size_of::<f32>();
self.allocation_count += 1;
if self.used_memory > self.peak_memory_usage {
self.peak_memory_usage = self.used_memory;
}
let ptr_value = buffer.ptr() as usize;
self.allocation_times
.lock()
.expect("lock poisoned")
.push((ptr_value, Instant::now()));
let total_elapsed = start_time.elapsed();
self.record_operation_time("buffer_acquire_new", total_elapsed);
Ok(buffer)
}
fn record_operation_time(&self, operation: &str, duration: Duration) {
if let Ok(mut timings) = self.operation_timings.lock() {
timings
.entry(operation.to_string())
.or_insert_with(Vec::new)
.push(duration);
}
}
pub fn return_buffer(&mut self, buffer: GpuBuffer) -> Result<()> {
let start_time = Instant::now();
let ptr_value = buffer.ptr() as usize;
{
let mut allocated = self
.allocated_buffers
.lock()
.map_err(|e| anyhow!("Failed to lock allocated buffers: {}", e))?;
allocated.retain(|b| b.ptr() != buffer.ptr());
}
{
let mut alloc_times = self.allocation_times.lock().expect("lock poisoned");
alloc_times.retain(|(ptr, _)| *ptr != ptr_value);
}
self.deallocation_count += 1;
self.available_buffers
.lock()
.map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?
.push_back(buffer);
let elapsed = start_time.elapsed();
self.record_operation_time("buffer_return", elapsed);
Ok(())
}
pub fn stats(&self) -> MemoryPoolStats {
let allocated_count = self.allocated_buffers.lock().expect("lock poisoned").len();
let available_count = self.available_buffers.lock().expect("lock poisoned").len();
MemoryPoolStats {
total_buffers: allocated_count + available_count,
allocated_buffers: allocated_count,
available_buffers: available_count,
total_memory: self.total_memory,
used_memory: self.used_memory,
buffer_size: self.buffer_size,
utilization: if self.total_memory > 0 {
self.used_memory as f64 / self.total_memory as f64
} else {
0.0
},
}
}
pub fn preallocate(&mut self, count: usize) -> Result<()> {
let effective_count = count.min(self.max_buffers);
for _ in 0..effective_count {
let buffer = GpuBuffer::new(self.buffer_size, self.device_id)?;
self.used_memory += self.buffer_size * std::mem::size_of::<f32>();
self.available_buffers
.lock()
.map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?
.push_back(buffer);
}
Ok(())
}
pub fn clear(&mut self) {
self.available_buffers
.lock()
.expect("lock poisoned")
.clear();
self.allocated_buffers
.lock()
.expect("lock poisoned")
.clear();
self.used_memory = 0;
}
pub fn has_capacity(&self) -> bool {
let total_buffers = self.available_buffers.lock().expect("lock poisoned").len()
+ self.allocated_buffers.lock().expect("lock poisoned").len();
total_buffers < self.max_buffers
}
pub fn memory_usage(&self) -> usize {
self.used_memory
}
pub fn utilization(&self) -> f64 {
if self.total_memory > 0 {
self.used_memory as f64 / self.total_memory as f64
} else {
0.0
}
}
pub fn defragment(&mut self) -> Result<()> {
let start_time = Instant::now();
let mut available = self
.available_buffers
.lock()
.map_err(|e| anyhow!("Failed to lock available buffers: {}", e))?;
let mut buffers: Vec<GpuBuffer> = available.drain(..).collect();
buffers.sort_by_key(|b| b.ptr() as usize);
for buffer in buffers {
available.push_back(buffer);
}
let elapsed = start_time.elapsed();
self.record_operation_time("pool_defrag", elapsed);
Ok(())
}
pub fn detect_leaks(&self, threshold_secs: u64) -> Vec<MemoryLeak> {
let mut leaks = Vec::new();
let now = Instant::now();
let alloc_times = self.allocation_times.lock().expect("lock poisoned");
for (ptr, alloc_time) in alloc_times.iter() {
let duration = now.duration_since(*alloc_time);
if duration.as_secs() > threshold_secs {
leaks.push(MemoryLeak {
ptr_address: *ptr,
allocated_for_secs: duration.as_secs(),
buffer_size: self.buffer_size,
});
}
}
leaks
}
pub fn profiling_report(&self) -> String {
let timings = self.operation_timings.lock().expect("lock poisoned");
let mut report = String::from("GPU Memory Pool Performance Report:\n");
for (operation, durations) in timings.iter() {
if !durations.is_empty() {
let total: Duration = durations.iter().sum();
let avg = total / durations.len() as u32;
let min = durations.iter().min().expect("non-empty durations");
let max = durations.iter().max().expect("non-empty durations");
report.push_str(&format!(
" {}: {} calls, avg={:.2}µs, min={:.2}µs, max={:.2}µs\n",
operation,
durations.len(),
avg.as_micros(),
min.as_micros(),
max.as_micros()
));
}
}
report
}
pub fn get_metrics(&self) -> PoolMetrics {
PoolMetrics {
allocation_count: self.allocation_count,
deallocation_count: self.deallocation_count,
peak_memory_usage: self.peak_memory_usage,
current_memory_usage: self.used_memory,
memory_efficiency: if self.allocation_count > 0 {
self.deallocation_count as f64 / self.allocation_count as f64
} else {
0.0
},
active_allocations: self.allocation_times.lock().expect("lock poisoned").len(),
}
}
pub fn suggest_optimal_buffer_size(&self) -> usize {
let metrics = self.get_metrics();
if metrics.memory_efficiency > 0.95 && self.utilization() < 0.5 {
self.buffer_size / 2
}
else if metrics.memory_efficiency < 0.7 && self.utilization() > 0.8 {
self.buffer_size * 2
} else {
self.buffer_size
}
}
pub fn reset_profiling(&mut self) {
if let Ok(mut timings) = self.operation_timings.lock() {
timings.clear();
}
}
pub fn get_avg_operation_time(&self, operation: &str) -> Option<f64> {
let timings = self.operation_timings.lock().ok()?;
let durations = timings.get(operation)?;
if durations.is_empty() {
return None;
}
let total: Duration = durations.iter().sum();
let avg = total / durations.len() as u32;
Some(avg.as_micros() as f64)
}
}
#[derive(Debug, Clone)]
pub struct MemoryLeak {
pub ptr_address: usize,
pub allocated_for_secs: u64,
pub buffer_size: usize,
}
impl MemoryLeak {
pub fn description(&self) -> String {
format!(
"Memory leak at 0x{:x}: {} bytes held for {} seconds",
self.ptr_address, self.buffer_size, self.allocated_for_secs
)
}
}
#[derive(Debug, Clone)]
pub struct PoolMetrics {
pub allocation_count: usize,
pub deallocation_count: usize,
pub peak_memory_usage: usize,
pub current_memory_usage: usize,
pub memory_efficiency: f64,
pub active_allocations: usize,
}
impl PoolMetrics {
pub fn has_potential_leak(&self) -> bool {
self.memory_efficiency < 0.5 && self.active_allocations > 100
}
pub fn report(&self) -> String {
format!(
"Pool Metrics:\n\
- Allocations: {}\n\
- Deallocations: {}\n\
- Active: {}\n\
- Peak memory: {:.2} MB\n\
- Current memory: {:.2} MB\n\
- Efficiency: {:.1}%",
self.allocation_count,
self.deallocation_count,
self.active_allocations,
self.peak_memory_usage as f64 / 1024.0 / 1024.0,
self.current_memory_usage as f64 / 1024.0 / 1024.0,
self.memory_efficiency * 100.0
)
}
}
#[derive(Debug, Clone)]
pub struct MemoryPoolStats {
pub total_buffers: usize,
pub allocated_buffers: usize,
pub available_buffers: usize,
pub total_memory: usize,
pub used_memory: usize,
pub buffer_size: usize,
pub utilization: f64,
}
impl MemoryPoolStats {
pub fn is_under_pressure(&self) -> bool {
self.utilization > 0.8 || self.available_buffers < 2
}
pub fn remaining_capacity(&self) -> usize {
if self.total_memory > self.used_memory {
let remaining_memory = self.total_memory - self.used_memory;
remaining_memory / (self.buffer_size * std::mem::size_of::<f32>())
} else {
0
}
}
pub fn print(&self) {
println!("GPU Memory Pool Statistics:");
println!(" Total buffers: {}", self.total_buffers);
println!(
" Allocated: {}, Available: {}",
self.allocated_buffers, self.available_buffers
);
println!(
" Memory usage: {:.2} MB / {:.2} MB ({:.1}%)",
self.used_memory as f64 / 1024.0 / 1024.0,
self.total_memory as f64 / 1024.0 / 1024.0,
self.utilization * 100.0
);
println!(
" Buffer size: {:.2} KB",
self.buffer_size as f64 * 4.0 / 1024.0
);
println!(
" Remaining capacity: {} buffers",
self.remaining_capacity()
);
if self.is_under_pressure() {
println!(" ⚠️ Memory pool is under pressure!");
}
}
}
#[derive(Debug)]
pub struct AdvancedGpuMemoryPool {
pools: Vec<GpuMemoryPool>,
buffer_sizes: Vec<usize>,
device_id: i32,
}
impl AdvancedGpuMemoryPool {
pub fn new(config: &GpuConfig, buffer_sizes: Vec<usize>) -> Result<Self> {
let mut pools = Vec::new();
for &size in &buffer_sizes {
let pool = GpuMemoryPool::new(config, size)?;
pools.push(pool);
}
Ok(Self {
pools,
buffer_sizes: buffer_sizes.clone(),
device_id: config.device_id,
})
}
pub fn get_buffer(&mut self, required_size: usize) -> Result<GpuBuffer> {
let pool_index = self
.buffer_sizes
.iter()
.position(|&size| size >= required_size)
.ok_or_else(|| anyhow!("No buffer size large enough for request"))?;
self.pools[pool_index].get_buffer()
}
pub fn return_buffer(&mut self, buffer: GpuBuffer) -> Result<()> {
let buffer_size = buffer.size();
let pool_index = self
.buffer_sizes
.iter()
.position(|&size| size == buffer_size)
.ok_or_else(|| anyhow!("Buffer size does not match any pool"))?;
self.pools[pool_index].return_buffer(buffer)
}
pub fn combined_stats(&self) -> AdvancedMemoryPoolStats {
let mut total_buffers = 0;
let mut total_allocated = 0;
let mut total_available = 0;
let mut total_memory = 0;
let mut total_used = 0;
let mut pool_stats = Vec::new();
for pool in &self.pools {
let stats = pool.stats();
total_buffers += stats.total_buffers;
total_allocated += stats.allocated_buffers;
total_available += stats.available_buffers;
total_memory += stats.total_memory;
total_used += stats.used_memory;
pool_stats.push(stats);
}
AdvancedMemoryPoolStats {
pool_stats,
total_buffers,
total_allocated,
total_available,
total_memory,
total_used,
utilization: if total_memory > 0 {
total_used as f64 / total_memory as f64
} else {
0.0
},
}
}
pub fn preallocate_all(&mut self, buffers_per_pool: usize) -> Result<()> {
for pool in &mut self.pools {
pool.preallocate(buffers_per_pool)?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct AdvancedMemoryPoolStats {
pub pool_stats: Vec<MemoryPoolStats>,
pub total_buffers: usize,
pub total_allocated: usize,
pub total_available: usize,
pub total_memory: usize,
pub total_used: usize,
pub utilization: f64,
}
impl AdvancedMemoryPoolStats {
pub fn print_detailed(&self) {
println!("Advanced GPU Memory Pool Statistics:");
println!(
" Overall: {} buffers, {:.1}% utilization",
self.total_buffers,
self.utilization * 100.0
);
println!(
" Total memory: {:.2} MB",
self.total_memory as f64 / 1024.0 / 1024.0
);
for (i, stats) in self.pool_stats.iter().enumerate() {
println!(
" Pool {}: {:.2} KB buffers, {} total, {:.1}% util",
i,
stats.buffer_size as f64 * 4.0 / 1024.0,
stats.total_buffers,
stats.utilization * 100.0
);
}
}
}