use crate::error::{ModelError, ModelResult};
use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub checkpoint_every_n_layers: usize,
pub max_checkpoints: usize,
pub use_mixed_precision: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
checkpoint_every_n_layers: 4,
max_checkpoints: 64,
use_mixed_precision: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ActivationCheckpointer {
config: CheckpointConfig,
checkpoints: Vec<Option<Array1<f32>>>,
bytes_saved: usize,
bytes_stored: usize,
}
impl ActivationCheckpointer {
pub fn new(config: CheckpointConfig) -> Self {
Self {
config,
checkpoints: Vec::new(),
bytes_saved: 0,
bytes_stored: 0,
}
}
pub fn save(&mut self, layer_idx: usize, activation: Array1<f32>) -> ModelResult<()> {
if layer_idx >= self.checkpoints.len() {
self.checkpoints.resize(layer_idx + 1, None);
}
let current_count = self.num_checkpoints();
if current_count >= self.config.max_checkpoints && self.checkpoints[layer_idx].is_none() {
return Err(ModelError::invalid_config(format!(
"Maximum checkpoint count ({}) exceeded when saving layer {}",
self.config.max_checkpoints, layer_idx
)));
}
let byte_size = activation.len() * std::mem::size_of::<f32>();
let stored = if self.config.use_mixed_precision {
activation.mapv(|x| (x * 1000.0).round() / 1000.0)
} else {
activation
};
self.bytes_stored += byte_size;
self.checkpoints[layer_idx] = Some(stored);
Ok(())
}
pub fn get(&self, layer_idx: usize) -> ModelResult<&Array1<f32>> {
if layer_idx >= self.checkpoints.len() {
return Err(ModelError::IndexOutOfBounds {
index: layer_idx,
limit: self.checkpoints.len(),
context: "ActivationCheckpointer::get".to_string(),
});
}
self.checkpoints[layer_idx].as_ref().ok_or_else(|| {
ModelError::not_initialized(format!("No checkpoint stored for layer {}", layer_idx))
})
}
pub fn clear(&mut self) {
self.checkpoints.clear();
self.bytes_saved = 0;
self.bytes_stored = 0;
}
pub fn memory_saved_bytes(&self) -> usize {
self.bytes_saved
}
pub fn memory_stored_bytes(&self) -> usize {
self.bytes_stored
}
pub fn num_checkpoints(&self) -> usize {
self.checkpoints.iter().filter(|c| c.is_some()).count()
}
pub fn is_checkpoint_layer(&self, layer_idx: usize) -> bool {
if self.config.checkpoint_every_n_layers == 0 {
return false;
}
layer_idx.is_multiple_of(self.config.checkpoint_every_n_layers)
}
pub fn checkpointed_forward<F>(
&mut self,
input: &Array1<f32>,
layers: &[usize],
forward_fn: F,
) -> ModelResult<Array1<f32>>
where
F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
{
let mut current = input.clone();
for &layer_idx in layers {
current = forward_fn(¤t, layer_idx)?;
let byte_size = current.len() * std::mem::size_of::<f32>();
if self.is_checkpoint_layer(layer_idx) {
if self.num_checkpoints() < self.config.max_checkpoints {
self.save(layer_idx, current.clone())?;
} else {
self.bytes_saved += byte_size;
}
} else {
self.bytes_saved += byte_size;
}
}
Ok(current)
}
pub fn recompute_from_checkpoint<F>(
&self,
target_layer: usize,
layers: &[usize],
forward_fn: F,
) -> ModelResult<Array1<f32>>
where
F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
{
let mut nearest_checkpoint_layer = None;
let mut nearest_activation = None;
for &l in layers.iter().rev() {
if l > target_layer {
continue;
}
if l < self.checkpoints.len() {
if let Some(ref act) = self.checkpoints[l] {
nearest_checkpoint_layer = Some(l);
nearest_activation = Some(act.clone());
break;
}
}
}
let (start_layer, mut current) = match (nearest_checkpoint_layer, nearest_activation) {
(Some(l), Some(act)) => (l, act),
_ => {
return Err(ModelError::not_initialized(format!(
"No checkpoint found before layer {} for recomputation",
target_layer
)));
}
};
let mut started = false;
for &l in layers {
if l == start_layer {
started = true;
continue; }
if !started {
continue;
}
current = forward_fn(¤t, l)?;
if l == target_layer {
break;
}
}
Ok(current)
}
pub fn config(&self) -> &CheckpointConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
fn simple_forward(activation: &Array1<f32>, layer_idx: usize) -> ModelResult<Array1<f32>> {
Ok(activation.mapv(|x| x * 1.1 + layer_idx as f32 * 0.01))
}
#[test]
fn test_gradient_checkpoint_save_get() {
let config = CheckpointConfig {
checkpoint_every_n_layers: 2,
max_checkpoints: 10,
use_mixed_precision: false,
};
let mut cp = ActivationCheckpointer::new(config);
let activation = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
cp.save(2, activation.clone()).expect("save should succeed");
let retrieved = cp.get(2).expect("get should succeed");
assert_eq!(retrieved.len(), 4);
assert!((retrieved[0] - 1.0).abs() < 1e-6);
assert!((retrieved[3] - 4.0).abs() < 1e-6);
assert!(cp.get(5).is_err());
}
#[test]
fn test_gradient_checkpoint_memory_accounting() {
let config = CheckpointConfig {
checkpoint_every_n_layers: 3,
max_checkpoints: 10,
use_mixed_precision: false,
};
let mut cp = ActivationCheckpointer::new(config);
let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let layers: Vec<usize> = (0..6).collect();
let _output = cp
.checkpointed_forward(&input, &layers, simple_forward)
.expect("forward should succeed");
assert!(
cp.memory_saved_bytes() > 0,
"should have saved some memory, got 0"
);
assert!(
cp.memory_stored_bytes() > 0,
"should have stored some activations"
);
assert_eq!(
cp.num_checkpoints(),
2,
"should have 2 checkpoints (layers 0, 3)"
);
}
#[test]
fn test_gradient_checkpoint_clear() {
let config = CheckpointConfig {
checkpoint_every_n_layers: 2,
max_checkpoints: 10,
use_mixed_precision: false,
};
let mut cp = ActivationCheckpointer::new(config);
cp.save(0, Array1::from_vec(vec![1.0, 2.0]))
.expect("save should succeed");
cp.save(2, Array1::from_vec(vec![3.0, 4.0]))
.expect("save should succeed");
assert_eq!(cp.num_checkpoints(), 2);
cp.clear();
assert_eq!(cp.num_checkpoints(), 0);
assert_eq!(cp.memory_saved_bytes(), 0);
assert_eq!(cp.memory_stored_bytes(), 0);
assert!(cp.get(0).is_err());
}
#[test]
fn test_gradient_checkpoint_forward() {
let config = CheckpointConfig {
checkpoint_every_n_layers: 2,
max_checkpoints: 20,
use_mixed_precision: false,
};
let mut cp = ActivationCheckpointer::new(config);
let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let layers: Vec<usize> = (0..8).collect();
let checkpointed_output = cp
.checkpointed_forward(&input, &layers, simple_forward)
.expect("checkpointed forward should succeed");
let mut direct = input.clone();
for &l in &layers {
direct = simple_forward(&direct, l).expect("forward should succeed");
}
assert_eq!(checkpointed_output.len(), direct.len());
for (a, b) in checkpointed_output.iter().zip(direct.iter()) {
assert!(
(a - b).abs() < 1e-4,
"mismatch: checkpointed={}, direct={}",
a,
b
);
}
assert!(cp.num_checkpoints() > 0, "should have saved checkpoints");
}
}