use crate::error::{NeuralError, Result};
#[cfg(feature = "gpu")]
use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType};
use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecomputationPolicy {
CheckpointAll,
CheckpointNone,
Selective {
cost_threshold: u32,
},
EveryN {
n: usize,
},
}
impl Default for RecomputationPolicy {
fn default() -> Self {
Self::Selective {
cost_threshold: 100,
}
}
}
#[derive(Debug, Clone)]
pub struct ActivationCheckpoint {
pub layer_id: usize,
pub timestamp: u64,
pub memory_size: usize,
pub recomputation_cost: u32,
pub in_memory: bool,
}
#[cfg(feature = "gpu")]
pub struct GradientCheckpointManager<T: GpuDataType> {
checkpoints: Arc<Mutex<HashMap<usize, GpuBuffer<T>>>>,
metadata: Arc<Mutex<HashMap<usize, ActivationCheckpoint>>>,
memory_usage: Arc<AtomicU64>,
memory_budget: u64,
policy: RecomputationPolicy,
checkpoint_counter: Arc<AtomicU64>,
gpu_context: Arc<GpuContext>,
}
#[cfg(feature = "gpu")]
impl<T: GpuDataType> GradientCheckpointManager<T> {
pub fn new(
gpu_context: Arc<GpuContext>,
memory_budget: u64,
policy: RecomputationPolicy,
) -> Self {
Self {
checkpoints: Arc::new(Mutex::new(HashMap::new())),
metadata: Arc::new(Mutex::new(HashMap::new())),
memory_usage: Arc::new(AtomicU64::new(0)),
memory_budget,
policy,
checkpoint_counter: Arc::new(AtomicU64::new(0)),
gpu_context,
}
}
pub fn checkpoint_activation(
&self,
layer_id: usize,
activation: &GpuBuffer<T>,
recomputation_cost: u32,
) -> Result<()> {
let should_checkpoint = match self.policy {
RecomputationPolicy::CheckpointAll => true,
RecomputationPolicy::CheckpointNone => false,
RecomputationPolicy::Selective { cost_threshold } => {
recomputation_cost >= cost_threshold
}
RecomputationPolicy::EveryN { n } => layer_id.is_multiple_of(n),
};
if !should_checkpoint {
return Ok(());
}
let activation_size = activation.len() * std::mem::size_of::<T>();
let current_usage = self.memory_usage.load(Ordering::Relaxed);
if current_usage + activation_size as u64 > self.memory_budget {
self.evict_oldest_checkpoint()?;
}
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
let mut metadata = self
.metadata
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
let checkpoint_meta = ActivationCheckpoint {
layer_id,
timestamp: self.checkpoint_counter.fetch_add(1, Ordering::Relaxed),
memory_size: activation_size,
recomputation_cost,
in_memory: true,
};
let checkpoint_buffer = self.gpu_context.create_buffer::<T>(activation.len());
checkpoints.insert(layer_id, checkpoint_buffer);
metadata.insert(layer_id, checkpoint_meta);
self.memory_usage
.fetch_add(activation_size as u64, Ordering::Relaxed);
Ok(())
}
pub fn get_checkpoint(&self, layer_id: usize) -> Result<Option<GpuBuffer<T>>> {
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
Ok(checkpoints.remove(&layer_id))
}
pub fn has_checkpoint(&self, layer_id: usize) -> bool {
self.checkpoints
.lock()
.map(|cp| cp.contains_key(&layer_id))
.unwrap_or(false)
}
fn evict_oldest_checkpoint(&self) -> Result<()> {
let mut metadata = self
.metadata
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
let oldest = metadata
.iter()
.filter(|(_, meta)| meta.in_memory)
.min_by_key(|(_, meta)| meta.timestamp)
.map(|(id, _)| *id);
if let Some(layer_id) = oldest {
self.remove_checkpoint(layer_id)?;
}
Ok(())
}
pub fn remove_checkpoint(&self, layer_id: usize) -> Result<()> {
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
let mut metadata = self
.metadata
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
if let Some(checkpoint) = checkpoints.remove(&layer_id) {
let size = checkpoint.len() * std::mem::size_of::<T>();
self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
}
if let Some(meta) = metadata.get_mut(&layer_id) {
meta.in_memory = false;
}
Ok(())
}
pub fn clear(&self) -> Result<()> {
let mut checkpoints = self
.checkpoints
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
let mut metadata = self
.metadata
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
checkpoints.clear();
metadata.clear();
self.memory_usage.store(0, Ordering::Relaxed);
Ok(())
}
pub fn memory_usage(&self) -> u64 {
self.memory_usage.load(Ordering::Relaxed)
}
pub fn memory_budget(&self) -> u64 {
self.memory_budget
}
pub fn num_checkpoints(&self) -> usize {
self.checkpoints.lock().map(|cp| cp.len()).unwrap_or(0)
}
pub fn get_statistics(&self) -> CheckpointStatistics {
let metadata = self.metadata.lock().expect("Failed to lock metadata");
let total_checkpoints = metadata.len();
let in_memory_checkpoints = metadata.values().filter(|meta| meta.in_memory).count();
let total_memory = metadata
.values()
.filter(|meta| meta.in_memory)
.map(|meta| meta.memory_size as u64)
.sum();
CheckpointStatistics {
total_checkpoints,
in_memory_checkpoints,
total_memory,
memory_budget: self.memory_budget,
memory_utilization: total_memory as f64 / self.memory_budget as f64,
}
}
}
#[derive(Debug, Clone)]
pub struct CheckpointStatistics {
pub total_checkpoints: usize,
pub in_memory_checkpoints: usize,
pub total_memory: u64,
pub memory_budget: u64,
pub memory_utilization: f64,
}
#[cfg(feature = "gpu")]
pub struct EfficientBackprop<T: GpuDataType> {
checkpoint_manager: Arc<GradientCheckpointManager<T>>,
gpu_context: Arc<GpuContext>,
enabled: bool,
}
#[cfg(feature = "gpu")]
impl<T: GpuDataType> EfficientBackprop<T> {
pub fn new(
gpu_context: Arc<GpuContext>,
memory_budget: u64,
policy: RecomputationPolicy,
enabled: bool,
) -> Self {
let checkpoint_manager = Arc::new(GradientCheckpointManager::new(
gpu_context.clone(),
memory_budget,
policy,
));
Self {
checkpoint_manager,
gpu_context,
enabled,
}
}
pub fn forward_with_checkpoint(
&self,
layer_id: usize,
input: &GpuBuffer<T>,
forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
recomputation_cost: u32,
) -> Result<GpuBuffer<T>> {
if self.enabled {
self.checkpoint_manager
.checkpoint_activation(layer_id, input, recomputation_cost)?;
}
forward_fn(input)
}
pub fn backward_with_recomputation(
&self,
layer_id: usize,
grad_output: &GpuBuffer<T>,
forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
backward_fn: impl FnOnce(&GpuBuffer<T>, &GpuBuffer<T>) -> Result<GpuBuffer<T>>,
) -> Result<GpuBuffer<T>> {
let activation =
if let Some(checkpoint) = self.checkpoint_manager.get_checkpoint(layer_id)? {
checkpoint
} else {
self.gpu_context.create_buffer::<T>(grad_output.len())
};
backward_fn(&activation, grad_output)
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn checkpoint_manager(&self) -> &Arc<GradientCheckpointManager<T>> {
&self.checkpoint_manager
}
pub fn get_statistics(&self) -> CheckpointStatistics {
self.checkpoint_manager.get_statistics()
}
pub fn clear_checkpoints(&self) -> Result<()> {
self.checkpoint_manager.clear()
}
}
#[derive(Debug)]
pub struct CpuActivationStore<F> {
activations: Arc<Mutex<HashMap<usize, ArrayD<F>>>>,
memory_usage: Arc<AtomicU64>,
}
impl<F> CpuActivationStore<F>
where
F: Clone + Default,
{
pub fn new() -> Self {
Self {
activations: Arc::new(Mutex::new(HashMap::new())),
memory_usage: Arc::new(AtomicU64::new(0)),
}
}
pub fn store(&self, layer_id: usize, activation: ArrayD<F>) -> Result<()> {
let size = activation.len() * std::mem::size_of::<F>();
let mut activations = self
.activations
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
activations.insert(layer_id, activation);
self.memory_usage.fetch_add(size as u64, Ordering::Relaxed);
Ok(())
}
pub fn retrieve(&self, layer_id: usize) -> Result<Option<ArrayD<F>>> {
let activations = self
.activations
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
Ok(activations.get(&layer_id).cloned())
}
pub fn remove(&self, layer_id: usize) -> Result<()> {
let mut activations = self
.activations
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
if let Some(activation) = activations.remove(&layer_id) {
let size = activation.len() * std::mem::size_of::<F>();
self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
}
Ok(())
}
pub fn clear(&self) -> Result<()> {
let mut activations = self
.activations
.lock()
.map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
activations.clear();
self.memory_usage.store(0, Ordering::Relaxed);
Ok(())
}
pub fn memory_usage(&self) -> u64 {
self.memory_usage.load(Ordering::Relaxed)
}
}
impl<F> Default for CpuActivationStore<F>
where
F: Clone + Default,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(all(test, feature = "gpu"))]
mod tests {
use super::*;
use scirs2_core::gpu::GpuBackend;
#[test]
fn test_recomputation_policy() {
let policy = RecomputationPolicy::default();
assert!(matches!(policy, RecomputationPolicy::Selective { .. }));
let checkpoint_all = RecomputationPolicy::CheckpointAll;
assert_eq!(checkpoint_all, RecomputationPolicy::CheckpointAll);
}
#[test]
fn test_checkpoint_manager_creation() {
let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
let manager = GradientCheckpointManager::<f32>::new(
Arc::new(context),
1024 * 1024 * 1024, RecomputationPolicy::CheckpointAll,
);
assert_eq!(manager.memory_usage(), 0);
assert_eq!(manager.num_checkpoints(), 0);
}
#[test]
fn test_checkpoint_statistics() {
let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
let manager = GradientCheckpointManager::<f32>::new(
Arc::new(context),
1024 * 1024 * 1024,
RecomputationPolicy::CheckpointAll,
);
let stats = manager.get_statistics();
assert_eq!(stats.total_checkpoints, 0);
assert_eq!(stats.in_memory_checkpoints, 0);
assert_eq!(stats.total_memory, 0);
}
#[test]
fn test_efficient_backprop_creation() {
let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
let backprop = EfficientBackprop::<f32>::new(
Arc::new(context),
1024 * 1024 * 1024,
RecomputationPolicy::CheckpointAll,
true,
);
assert!(backprop.is_enabled());
assert_eq!(backprop.checkpoint_manager().num_checkpoints(), 0);
}
#[test]
fn test_cpu_activation_store() {
let store = CpuActivationStore::<f32>::new();
let activation = Array::zeros(IxDyn(&[2, 3, 4]));
store.store(0, activation.clone()).expect("Failed to store");
let retrieved = store.retrieve(0).expect("Failed to retrieve");
assert!(retrieved.is_some());
assert!(store.memory_usage() > 0);
store.clear().expect("Failed to clear");
assert_eq!(store.memory_usage(), 0);
}
#[test]
fn test_enable_disable_checkpointing() {
let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
let mut backprop = EfficientBackprop::<f32>::new(
Arc::new(context),
1024 * 1024 * 1024,
RecomputationPolicy::CheckpointAll,
true,
);
assert!(backprop.is_enabled());
backprop.set_enabled(false);
assert!(!backprop.is_enabled());
backprop.set_enabled(true);
assert!(backprop.is_enabled());
}
}