use crate::{Result, Tensor, TensorError};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq)]
pub enum CheckpointPolicy {
EveryNLayers(usize),
SpecificLayers(Vec<usize>),
Automatic {
memory_budget: usize,
avg_activation_size: usize,
},
Custom,
None,
}
impl Default for CheckpointPolicy {
fn default() -> Self {
CheckpointPolicy::EveryNLayers(1)
}
}
#[derive(Debug, Clone)]
pub struct CheckpointingConfig {
pub policy: CheckpointPolicy,
pub recompute_on_backward: bool,
pub save_rng_state: bool,
pub enable_statistics: bool,
pub max_checkpoints: Option<usize>,
}
impl Default for CheckpointingConfig {
fn default() -> Self {
Self {
policy: CheckpointPolicy::default(),
recompute_on_backward: true,
save_rng_state: true,
enable_statistics: false,
max_checkpoints: None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointStatistics {
pub forward_passes: usize,
pub backward_passes: usize,
pub recompute_count: usize,
pub memory_saved_bytes: usize,
pub additional_compute_time_us: u64,
pub active_checkpoints: usize,
}
impl CheckpointStatistics {
pub fn avg_recomputations(&self) -> f64 {
if self.backward_passes == 0 {
0.0
} else {
self.recompute_count as f64 / self.backward_passes as f64
}
}
pub fn memory_saved_mb(&self) -> f64 {
self.memory_saved_bytes as f64 / (1024.0 * 1024.0)
}
pub fn compute_overhead_percent(&self) -> f64 {
if self.forward_passes == 0 {
0.0
} else {
(self.recompute_count as f64 / self.forward_passes as f64) * 100.0
}
}
}
#[derive(Debug, Clone)]
pub struct Checkpoint<T> {
pub layer_index: usize,
pub activations: Vec<Tensor<T>>,
pub rng_state: Option<Vec<u8>>,
pub timestamp: std::time::Instant,
pub memory_bytes: usize,
}
impl<T> Checkpoint<T>
where
T: Clone + Default,
{
pub fn new(layer_index: usize, activations: Vec<Tensor<T>>) -> Self {
let memory_bytes = activations
.iter()
.map(|t| t.shape().size() * std::mem::size_of::<T>())
.sum();
Self {
layer_index,
activations,
rng_state: None,
timestamp: std::time::Instant::now(),
memory_bytes,
}
}
pub fn with_rng_state(mut self, rng_state: Vec<u8>) -> Self {
self.rng_state = Some(rng_state);
self
}
pub fn age(&self) -> std::time::Duration {
self.timestamp.elapsed()
}
}
pub struct CheckpointManager<T> {
config: CheckpointingConfig,
checkpoints: Arc<Mutex<HashMap<usize, Checkpoint<T>>>>,
statistics: Arc<Mutex<CheckpointStatistics>>,
}
impl<T> CheckpointManager<T>
where
T: Clone + Default + Send + Sync,
{
pub fn new(config: CheckpointingConfig) -> Self {
Self {
config,
checkpoints: Arc::new(Mutex::new(HashMap::new())),
statistics: Arc::new(Mutex::new(CheckpointStatistics::default())),
}
}
pub fn with_memory_budget(memory_budget_mb: usize, avg_activation_mb: usize) -> Self {
Self::new(CheckpointingConfig {
policy: CheckpointPolicy::Automatic {
memory_budget: memory_budget_mb * 1024 * 1024,
avg_activation_size: avg_activation_mb * 1024 * 1024,
},
..Default::default()
})
}
pub fn should_checkpoint(&self, layer_index: usize, total_layers: usize) -> bool {
match &self.config.policy {
CheckpointPolicy::EveryNLayers(n) => layer_index % n == 0,
CheckpointPolicy::SpecificLayers(indices) => indices.contains(&layer_index),
CheckpointPolicy::Automatic {
memory_budget,
avg_activation_size,
} => {
let max_checkpoints = memory_budget / avg_activation_size;
if max_checkpoints == 0 {
return false;
}
let checkpoint_every = total_layers / max_checkpoints.max(1);
layer_index % checkpoint_every.max(1) == 0
}
CheckpointPolicy::Custom => {
false
}
CheckpointPolicy::None => false,
}
}
pub fn save_checkpoint(&self, layer_index: usize, activations: Vec<Tensor<T>>) -> Result<()> {
let mut checkpoint = Checkpoint::new(layer_index, activations);
if self.config.save_rng_state {
checkpoint = checkpoint.with_rng_state(self.capture_rng_state());
}
let memory_bytes = checkpoint.memory_bytes;
let mut checkpoints = self.checkpoints.lock().map_err(|_| {
TensorError::invalid_operation_simple("Failed to acquire checkpoint lock".to_string())
})?;
if let Some(max_cp) = self.config.max_checkpoints {
if checkpoints.len() >= max_cp {
if let Some(oldest_key) = checkpoints
.iter()
.min_by_key(|(_, cp)| cp.timestamp)
.map(|(k, _)| *k)
{
checkpoints.remove(&oldest_key);
}
}
}
checkpoints.insert(layer_index, checkpoint);
if self.config.enable_statistics {
let mut stats = self.statistics.lock().map_err(|_| {
TensorError::invalid_operation_simple(
"Failed to acquire statistics lock".to_string(),
)
})?;
stats.forward_passes += 1;
stats.memory_saved_bytes += memory_bytes;
stats.active_checkpoints = checkpoints.len();
}
Ok(())
}
pub fn get_checkpoint(&self, layer_index: usize) -> Result<Option<Checkpoint<T>>> {
let checkpoints = self.checkpoints.lock().map_err(|_| {
TensorError::invalid_operation_simple("Failed to acquire checkpoint lock".to_string())
})?;
Ok(checkpoints.get(&layer_index).cloned())
}
pub fn restore_rng_state(&self, rng_state: &[u8]) -> Result<()> {
if rng_state.is_empty() {
return Err(TensorError::invalid_argument("Empty RNG state".to_string()));
}
Ok(())
}
fn capture_rng_state(&self) -> Vec<u8> {
vec![0u8; 32]
}
pub fn record_recomputation(&self, compute_time_us: u64) {
if !self.config.enable_statistics {
return;
}
if let Ok(mut stats) = self.statistics.lock() {
stats.recompute_count += 1;
stats.additional_compute_time_us += compute_time_us;
}
}
pub fn record_backward_pass(&self) {
if !self.config.enable_statistics {
return;
}
if let Ok(mut stats) = self.statistics.lock() {
stats.backward_passes += 1;
}
}
pub fn get_statistics(&self) -> Result<CheckpointStatistics> {
let stats = self.statistics.lock().map_err(|_| {
TensorError::invalid_operation_simple("Failed to acquire statistics lock".to_string())
})?;
Ok(stats.clone())
}
pub fn clear(&self) -> Result<()> {
let mut checkpoints = self.checkpoints.lock().map_err(|_| {
TensorError::invalid_operation_simple("Failed to acquire checkpoint lock".to_string())
})?;
checkpoints.clear();
if self.config.enable_statistics {
if let Ok(mut stats) = self.statistics.lock() {
stats.active_checkpoints = 0;
}
}
Ok(())
}
pub fn checkpoint_count(&self) -> usize {
self.checkpoints.lock().map(|cp| cp.len()).unwrap_or(0)
}
pub fn total_memory_bytes(&self) -> usize {
self.checkpoints
.lock()
.map(|cp| cp.values().map(|c| c.memory_bytes).sum())
.unwrap_or(0)
}
}
impl<T> Clone for CheckpointManager<T>
where
T: Clone + Default + Send + Sync,
{
fn clone(&self) -> Self {
Self {
config: self.config.clone(),
checkpoints: Arc::clone(&self.checkpoints),
statistics: Arc::clone(&self.statistics),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_checkpoint_policy_every_n() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig {
policy: CheckpointPolicy::EveryNLayers(2),
..Default::default()
});
assert!(manager.should_checkpoint(0, 10));
assert!(!manager.should_checkpoint(1, 10));
assert!(manager.should_checkpoint(2, 10));
assert!(!manager.should_checkpoint(3, 10));
assert!(manager.should_checkpoint(4, 10));
}
#[test]
fn test_checkpoint_policy_specific() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig {
policy: CheckpointPolicy::SpecificLayers(vec![1, 3, 7]),
..Default::default()
});
assert!(!manager.should_checkpoint(0, 10));
assert!(manager.should_checkpoint(1, 10));
assert!(!manager.should_checkpoint(2, 10));
assert!(manager.should_checkpoint(3, 10));
assert!(manager.should_checkpoint(7, 10));
}
#[test]
fn test_checkpoint_save_and_retrieve() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig::default());
let tensor = Tensor::from_array(array![[1.0, 2.0], [3.0, 4.0]].into_dyn());
manager
.save_checkpoint(5, vec![tensor.clone()])
.expect("test: operation should succeed");
let checkpoint = manager
.get_checkpoint(5)
.expect("test: get_checkpoint should succeed");
assert!(checkpoint.is_some());
let cp = checkpoint.expect("test: operation should succeed");
assert_eq!(cp.layer_index, 5);
assert_eq!(cp.activations.len(), 1);
}
#[test]
fn test_max_checkpoints_limit() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig {
max_checkpoints: Some(2),
..Default::default()
});
let tensor = Tensor::from_array(array![1.0, 2.0].into_dyn());
manager
.save_checkpoint(0, vec![tensor.clone()])
.expect("test: operation should succeed");
manager
.save_checkpoint(1, vec![tensor.clone()])
.expect("test: operation should succeed");
manager
.save_checkpoint(2, vec![tensor.clone()])
.expect("test: operation should succeed");
assert_eq!(manager.checkpoint_count(), 2);
}
#[test]
fn test_statistics_tracking() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig {
enable_statistics: true,
..Default::default()
});
let tensor = Tensor::from_array(array![1.0, 2.0, 3.0].into_dyn());
manager
.save_checkpoint(0, vec![tensor.clone()])
.expect("test: operation should succeed");
manager.record_backward_pass();
manager.record_recomputation(1000);
manager.record_recomputation(2000);
let stats = manager
.get_statistics()
.expect("test: get_statistics should succeed");
assert_eq!(stats.forward_passes, 1);
assert_eq!(stats.backward_passes, 1);
assert_eq!(stats.recompute_count, 2);
assert_eq!(stats.additional_compute_time_us, 3000);
}
#[test]
fn test_checkpoint_clear() {
let manager = CheckpointManager::<f32>::new(CheckpointingConfig::default());
let tensor = Tensor::from_array(array![1.0, 2.0].into_dyn());
manager
.save_checkpoint(0, vec![tensor.clone()])
.expect("test: operation should succeed");
manager
.save_checkpoint(1, vec![tensor.clone()])
.expect("test: operation should succeed");
assert_eq!(manager.checkpoint_count(), 2);
manager.clear().expect("test: clear should succeed");
assert_eq!(manager.checkpoint_count(), 0);
}
#[test]
fn test_automatic_policy() {
let manager = CheckpointManager::<f32>::with_memory_budget(100, 10);
assert!(manager.should_checkpoint(0, 50));
assert!(manager.should_checkpoint(5, 50));
assert!(manager.should_checkpoint(10, 50));
assert!(!manager.should_checkpoint(1, 50));
assert!(!manager.should_checkpoint(7, 50));
}
#[test]
fn test_checkpoint_statistics_calculations() {
let mut stats = CheckpointStatistics {
forward_passes: 100,
backward_passes: 100,
recompute_count: 300,
memory_saved_bytes: 1024 * 1024 * 500, additional_compute_time_us: 1_000_000,
active_checkpoints: 10,
};
assert_eq!(stats.avg_recomputations(), 3.0);
assert_eq!(stats.memory_saved_mb(), 500.0);
assert_eq!(stats.compute_overhead_percent(), 300.0);
}
}