use super::GpuBuffer;
use anyhow::Result;
use std::sync::{Arc, Mutex};
pub struct GpuMemoryPool {
device_id: i32,
pool_size: usize,
available_buffers: Arc<Mutex<Vec<GpuBuffer>>>,
total_allocated: Arc<Mutex<usize>>,
peak_usage: Arc<Mutex<usize>>,
current_usage: usize,
allocation_failures: u64,
}
impl GpuMemoryPool {
pub fn new(pool_size: usize, device_id: i32) -> Result<Self> {
Ok(Self {
device_id,
pool_size,
available_buffers: Arc::new(Mutex::new(Vec::new())),
total_allocated: Arc::new(Mutex::new(0)),
peak_usage: Arc::new(Mutex::new(0)),
current_usage: 0,
allocation_failures: 0,
})
}
pub fn get_buffer(&mut self, size: usize) -> Result<GpuBuffer> {
{
let mut buffers = self.available_buffers.lock().expect("mutex should not be poisoned");
if let Some(buffer) = buffers.pop() {
if buffer.size() >= size {
return Ok(buffer);
}
buffers.push(buffer);
}
}
let buffer = GpuBuffer::new(size, self.device_id)?;
{
let mut total = self.total_allocated.lock().expect("mutex should not be poisoned");
*total += size;
let mut peak = self.peak_usage.lock().expect("mutex should not be poisoned");
if *total > *peak {
*peak = *total;
}
}
Ok(buffer)
}
pub fn return_buffer(&mut self, buffer: GpuBuffer) {
let mut buffers = self.available_buffers.lock().expect("mutex should not be poisoned");
buffers.push(buffer);
}
pub fn statistics(&self) -> PoolStatistics {
let total = *self.total_allocated.lock().expect("mutex should not be poisoned");
let peak = *self.peak_usage.lock().expect("mutex should not be poisoned");
let available_count = self.available_buffers.lock().expect("mutex should not be poisoned").len();
PoolStatistics {
total_allocated: total,
peak_usage: peak,
current_usage: self.current_usage,
available_buffers: available_count,
allocation_failures: self.allocation_failures,
}
}
pub fn clear(&mut self) {
let mut buffers = self.available_buffers.lock().expect("mutex should not be poisoned");
buffers.clear();
let mut total = self.total_allocated.lock().expect("mutex should not be poisoned");
*total = 0;
self.current_usage = 0;
}
}
#[derive(Debug)]
pub struct PoolStatistics {
pub total_allocated: usize,
pub peak_usage: usize,
pub current_usage: usize,
pub available_buffers: usize,
pub allocation_failures: u64,
}