use crate::{error::AutogradError, Float, Result};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
pub struct GpuMemoryPool<T: Float> {
available: Arc<Mutex<HashMap<usize, Vec<Vec<T>>>>>,
total_allocated: Arc<Mutex<usize>>,
max_memory: usize,
stats: Arc<Mutex<PoolStatistics>>,
}
#[derive(Debug, Clone, Default)]
pub struct PoolStatistics {
pub allocations: usize,
pub deallocations: usize,
pub pool_hits: usize,
pub pool_misses: usize,
pub peak_memory: usize,
}
impl<T: Float> GpuMemoryPool<T> {
pub fn new(max_memory: usize) -> Self {
Self {
available: Arc::new(Mutex::new(HashMap::new())),
total_allocated: Arc::new(Mutex::new(0)),
max_memory,
stats: Arc::new(Mutex::new(PoolStatistics::default())),
}
}
pub fn allocate(&self, size: usize) -> Result<Vec<T>> {
let mut stats = self
.stats
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock statistics"))?;
stats.allocations += 1;
let mut available = self
.available
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock available buffers"))?;
if let Some(buffers) = available.get_mut(&size) {
if let Some(buffer) = buffers.pop() {
stats.pool_hits += 1;
return Ok(buffer);
}
}
stats.pool_misses += 1;
let mut total = self
.total_allocated
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock total allocated"))?;
let bytes = size * std::mem::size_of::<T>();
if *total + bytes > self.max_memory {
return Err(AutogradError::memory_error(format!(
"GPU memory limit exceeded: {} + {} > {}",
*total, bytes, self.max_memory
)));
}
*total += bytes;
if *total > stats.peak_memory {
stats.peak_memory = *total;
}
Ok(vec![T::zero(); size])
}
pub fn deallocate(&self, mut buffer: Vec<T>) -> Result<()> {
let mut stats = self
.stats
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock statistics"))?;
stats.deallocations += 1;
let size = buffer.len();
buffer.fill(T::zero());
let mut available = self
.available
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock available buffers"))?;
available.entry(size).or_insert_with(Vec::new).push(buffer);
Ok(())
}
pub fn statistics(&self) -> Result<PoolStatistics> {
self.stats
.lock()
.map(|s| s.clone())
.map_err(|_| AutogradError::internal_error("Failed to lock statistics"))
}
pub fn memory_usage(&self) -> Result<usize> {
self.total_allocated
.lock()
.map(|t| *t)
.map_err(|_| AutogradError::internal_error("Failed to lock total allocated"))
}
pub fn clear(&self) -> Result<()> {
let mut available = self
.available
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock available buffers"))?;
available.clear();
let mut total = self
.total_allocated
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock total allocated"))?;
*total = 0;
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AllocationStrategy {
Eager,
Lazy,
Pooled,
}
pub struct GpuGradientMemory<T: Float> {
pool: Arc<GpuMemoryPool<T>>,
strategy: AllocationStrategy,
active: Arc<Mutex<HashMap<usize, Vec<T>>>>,
}
impl<T: Float> GpuGradientMemory<T> {
pub fn new(max_memory: usize, strategy: AllocationStrategy) -> Self {
Self {
pool: Arc::new(GpuMemoryPool::new(max_memory)),
strategy,
active: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn allocate_gradient(&self, id: usize, size: usize) -> Result<Vec<T>> {
let buffer = match self.strategy {
AllocationStrategy::Eager | AllocationStrategy::Lazy => {
vec![T::zero(); size]
}
AllocationStrategy::Pooled => self.pool.allocate(size)?,
};
let mut active = self
.active
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock active allocations"))?;
active.insert(id, buffer.clone());
Ok(buffer)
}
pub fn free_gradient(&self, id: usize) -> Result<()> {
let mut active = self
.active
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock active allocations"))?;
if let Some(buffer) = active.remove(&id) {
if matches!(self.strategy, AllocationStrategy::Pooled) {
self.pool.deallocate(buffer)?;
}
}
Ok(())
}
pub fn statistics(&self) -> Result<PoolStatistics> {
self.pool.statistics()
}
pub fn active_memory(&self) -> Result<usize> {
let active = self
.active
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock active allocations"))?;
let bytes: usize = active
.values()
.map(|v| v.len() * std::mem::size_of::<T>())
.sum();
Ok(bytes)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CheckpointStrategy {
None,
Periodic(usize),
Dynamic,
Optimal,
}
pub struct GpuCheckpointing<T: Float> {
strategy: CheckpointStrategy,
checkpoints: Arc<Mutex<HashMap<usize, Vec<T>>>>,
memory_budget: usize,
}
impl<T: Float> GpuCheckpointing<T> {
pub fn new(strategy: CheckpointStrategy, memory_budget: usize) -> Self {
Self {
strategy,
checkpoints: Arc::new(Mutex::new(HashMap::new())),
memory_budget,
}
}
pub fn should_checkpoint(&self, operation_id: usize, memory_used: usize) -> bool {
match self.strategy {
CheckpointStrategy::None => false,
CheckpointStrategy::Periodic(n) => operation_id.is_multiple_of(n),
CheckpointStrategy::Dynamic => {
memory_used > (self.memory_budget * 4) / 5
}
CheckpointStrategy::Optimal => {
let checkpoint_interval = (operation_id as f64).sqrt() as usize;
operation_id.is_multiple_of(checkpoint_interval.max(1))
}
}
}
pub fn save_checkpoint(&self, id: usize, data: Vec<T>) -> Result<()> {
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock checkpoints"))?;
checkpoints.insert(id, data);
Ok(())
}
pub fn load_checkpoint(&self, id: usize) -> Result<Option<Vec<T>>> {
let checkpoints = self
.checkpoints
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock checkpoints"))?;
Ok(checkpoints.get(&id).cloned())
}
pub fn clear(&self) -> Result<()> {
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock checkpoints"))?;
checkpoints.clear();
Ok(())
}
pub fn checkpoint_count(&self) -> Result<usize> {
let checkpoints = self
.checkpoints
.lock()
.map_err(|_| AutogradError::internal_error("Failed to lock checkpoints"))?;
Ok(checkpoints.len())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool() {
let pool: GpuMemoryPool<f32> = GpuMemoryPool::new(1024 * 1024);
let buffer = pool.allocate(256).expect("Should allocate");
assert_eq!(buffer.len(), 256);
pool.deallocate(buffer).expect("Should deallocate");
let buffer2 = pool.allocate(256).expect("Should allocate from pool");
assert_eq!(buffer2.len(), 256);
let stats = pool.statistics().expect("Should get stats");
assert_eq!(stats.allocations, 2);
assert_eq!(stats.pool_hits, 1);
assert_eq!(stats.pool_misses, 1);
}
#[test]
fn test_checkpointing_strategy() {
let checkpointing: GpuCheckpointing<f32> =
GpuCheckpointing::new(CheckpointStrategy::Periodic(5), 1024);
assert!(!checkpointing.should_checkpoint(1, 0));
assert!(!checkpointing.should_checkpoint(4, 0));
assert!(checkpointing.should_checkpoint(5, 0));
assert!(checkpointing.should_checkpoint(10, 0));
}
#[test]
fn test_allocation_strategies() {
let memory = GpuGradientMemory::<f64>::new(1024 * 1024, AllocationStrategy::Pooled);
let grad = memory.allocate_gradient(0, 100).expect("Should allocate");
assert_eq!(grad.len(), 100);
memory.free_gradient(0).expect("Should free");
let active = memory.active_memory().expect("Should get active memory");
assert_eq!(active, 0);
}
}