pub mod planner;
pub mod selective;
use std::fmt;
#[derive(Debug, Clone, PartialEq)]
pub enum CheckpointError {
LayerNotFound(String),
RecomputationFailed(String),
InvalidShape {
expected: Vec<usize>,
got: Vec<usize>,
},
InvalidSegment(usize),
}
impl fmt::Display for CheckpointError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CheckpointError::LayerNotFound(name) => {
write!(f, "checkpoint layer not found: '{name}'")
},
CheckpointError::RecomputationFailed(msg) => {
write!(f, "recomputation failed: {msg}")
},
CheckpointError::InvalidShape { expected, got } => {
write!(f, "shape mismatch: expected {expected:?}, got {got:?}")
},
CheckpointError::InvalidSegment(idx) => {
write!(f, "invalid segment index: {idx}")
},
}
}
}
impl std::error::Error for CheckpointError {}
#[derive(Debug, Clone)]
pub struct ActivationSnapshot {
pub layer_name: String,
pub tensor_data: Vec<f32>,
pub shape: Vec<usize>,
pub is_checkpointed: bool,
pub computation_cost: f32,
}
impl ActivationSnapshot {
pub fn new(layer_name: &str, data: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
layer_name: layer_name.to_owned(),
tensor_data: data,
shape,
is_checkpointed: false,
computation_cost: 0.0,
}
}
pub fn checkpointed(layer_name: &str, shape: Vec<usize>, cost: f32) -> Self {
Self {
layer_name: layer_name.to_owned(),
tensor_data: Vec::new(),
shape,
is_checkpointed: true,
computation_cost: cost,
}
}
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn memory_bytes(&self) -> usize {
self.numel() * 4
}
pub fn is_stored(&self) -> bool {
!self.tensor_data.is_empty()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum CheckpointPolicy {
All,
EveryN(usize),
ByName(Vec<String>),
None,
GreedyMemory {
budget_bytes: usize,
},
}
impl CheckpointPolicy {
pub fn should_checkpoint(
&self,
layer_name: &str,
layer_idx: usize,
memory_bytes: usize,
) -> bool {
match self {
CheckpointPolicy::All => true,
CheckpointPolicy::EveryN(n) => {
if *n == 0 {
false
} else {
layer_idx.is_multiple_of(*n)
}
},
CheckpointPolicy::ByName(patterns) => {
patterns.iter().any(|p| layer_name.starts_with(p.as_str()))
},
CheckpointPolicy::None => false,
CheckpointPolicy::GreedyMemory { .. } => memory_bytes > 0,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointMemoryStats {
pub stored_bytes: usize,
pub checkpointed_bytes: usize,
pub recomputed_bytes: usize,
pub num_stored_layers: usize,
pub num_checkpointed_layers: usize,
pub memory_savings_ratio: f32,
}
impl CheckpointMemoryStats {
fn refresh_ratio(&mut self) {
let total = self.stored_bytes + self.checkpointed_bytes;
self.memory_savings_ratio =
if total == 0 { 0.0 } else { self.checkpointed_bytes as f32 / total as f32 };
}
}
pub struct CheckpointManager {
pub policy: CheckpointPolicy,
snapshots: Vec<ActivationSnapshot>,
recompute_log: Vec<String>,
stats: CheckpointMemoryStats,
}
impl CheckpointManager {
pub fn new(policy: CheckpointPolicy) -> Self {
Self {
policy,
snapshots: Vec::new(),
recompute_log: Vec::new(),
stats: CheckpointMemoryStats::default(),
}
}
pub fn record_activation(
&mut self,
layer_name: &str,
layer_idx: usize,
data: Vec<f32>,
shape: Vec<usize>,
computation_cost: f32,
) -> Result<(), CheckpointError> {
let numel: usize = shape.iter().product();
let bytes = numel * 4;
let should_ckpt = self.policy.should_checkpoint(layer_name, layer_idx, bytes);
if should_ckpt {
let snap = ActivationSnapshot::checkpointed(layer_name, shape, computation_cost);
self.stats.checkpointed_bytes += bytes;
self.stats.num_checkpointed_layers += 1;
self.snapshots.push(snap);
} else {
let snap = ActivationSnapshot::new(layer_name, data, shape);
self.stats.stored_bytes += bytes;
self.stats.num_stored_layers += 1;
self.snapshots.push(snap);
}
self.stats.refresh_ratio();
Ok(())
}
pub fn get_activation(
&mut self,
layer_name: &str,
recompute_fn: impl Fn() -> Result<Vec<f32>, CheckpointError>,
) -> Result<Vec<f32>, CheckpointError> {
let pos = self.snapshots.iter().position(|s| s.layer_name == layer_name);
match pos {
None => Err(CheckpointError::LayerNotFound(layer_name.to_owned())),
Some(idx) => {
let snap = &self.snapshots[idx];
if snap.is_stored() {
Ok(snap.tensor_data.clone())
} else {
let data = recompute_fn()
.map_err(|e| CheckpointError::RecomputationFailed(e.to_string()))?;
let bytes = data.len() * 4;
self.stats.recomputed_bytes += bytes;
self.recompute_log.push(layer_name.to_owned());
Ok(data)
}
},
}
}
pub fn snapshots(&self) -> &[ActivationSnapshot] {
&self.snapshots
}
pub fn memory_stats(&self) -> &CheckpointMemoryStats {
&self.stats
}
pub fn recompute_log(&self) -> &[String] {
&self.recompute_log
}
pub fn clear(&mut self) {
self.snapshots.clear();
self.recompute_log.clear();
self.stats = CheckpointMemoryStats::default();
}
pub fn current_memory_bytes(&self) -> usize {
self.snapshots.iter().filter(|s| s.is_stored()).map(|s| s.memory_bytes()).sum()
}
}
pub struct SegmentCheckpointer {
pub num_segments: usize,
pub segment_outputs: Vec<Option<Vec<f32>>>,
total_layers: usize,
}
impl SegmentCheckpointer {
pub fn new(num_segments: usize, total_layers: usize) -> Self {
let cap = num_segments.max(1);
Self {
num_segments: cap,
segment_outputs: vec![None; cap],
total_layers,
}
}
pub fn is_boundary(&self, layer_idx: usize) -> bool {
if self.total_layers == 0 {
return false;
}
let segment_size = self.total_layers.div_ceil(self.num_segments);
if segment_size == 0 {
return false;
}
layer_idx.is_multiple_of(segment_size) || layer_idx == self.total_layers.saturating_sub(1)
}
pub fn record_segment(
&mut self,
segment_idx: usize,
output: Vec<f32>,
) -> Result<(), CheckpointError> {
if segment_idx >= self.num_segments {
return Err(CheckpointError::InvalidSegment(segment_idx));
}
self.segment_outputs[segment_idx] = Some(output);
Ok(())
}
pub fn get_segment_output(&self, segment_idx: usize) -> Option<&[f32]> {
self.segment_outputs.get(segment_idx).and_then(|opt| opt.as_deref())
}
pub fn memory_savings_factor(&self) -> f32 {
if self.total_layers == 0 {
return 1.0;
}
self.num_segments as f32 / self.total_layers as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_activation_snapshot_new() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let snap = ActivationSnapshot::new("layer0", data.clone(), shape.clone());
assert_eq!(snap.layer_name, "layer0");
assert_eq!(snap.tensor_data, data);
assert_eq!(snap.shape, shape);
assert!(!snap.is_checkpointed);
assert!(snap.is_stored());
}
#[test]
fn test_activation_snapshot_checkpointed() {
let shape = vec![4, 8];
let snap = ActivationSnapshot::checkpointed("attn", shape.clone(), 1024.0);
assert_eq!(snap.layer_name, "attn");
assert!(snap.tensor_data.is_empty());
assert_eq!(snap.shape, shape);
assert!(snap.is_checkpointed);
assert!(!snap.is_stored());
assert_eq!(snap.computation_cost, 1024.0);
}
#[test]
fn test_activation_snapshot_memory_bytes() {
let shape = vec![3, 4]; let snap = ActivationSnapshot::new("fc", vec![0.0; 12], shape);
assert_eq!(snap.numel(), 12);
assert_eq!(snap.memory_bytes(), 48); }
#[test]
fn test_checkpoint_policy_all() {
let policy = CheckpointPolicy::All;
assert!(policy.should_checkpoint("any", 0, 100));
assert!(policy.should_checkpoint("layer99", 99, 0));
}
#[test]
fn test_checkpoint_policy_every_n() {
let policy = CheckpointPolicy::EveryN(3);
assert!(policy.should_checkpoint("layer0", 0, 100));
assert!(!policy.should_checkpoint("layer1", 1, 100));
assert!(!policy.should_checkpoint("layer2", 2, 100));
assert!(policy.should_checkpoint("layer3", 3, 100));
assert!(policy.should_checkpoint("layer6", 6, 100));
assert!(!policy.should_checkpoint("layer7", 7, 100));
}
#[test]
fn test_checkpoint_policy_by_name() {
let policy = CheckpointPolicy::ByName(vec!["attention".to_owned(), "ffn".to_owned()]);
assert!(policy.should_checkpoint("attention_0", 0, 100));
assert!(policy.should_checkpoint("ffn_layer", 1, 100));
assert!(!policy.should_checkpoint("norm", 2, 100));
assert!(!policy.should_checkpoint("embed", 3, 100));
}
#[test]
fn test_checkpoint_policy_none() {
let policy = CheckpointPolicy::None;
assert!(!policy.should_checkpoint("any", 0, 999));
assert!(!policy.should_checkpoint("layer", 5, 0));
}
#[test]
fn test_checkpoint_manager_record_stored() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::None);
let data = vec![1.0_f32; 16];
mgr.record_activation("layer0", 0, data.clone(), vec![4, 4], 0.0)
.expect("record should succeed");
let snaps = mgr.snapshots();
assert_eq!(snaps.len(), 1);
assert!(snaps[0].is_stored());
assert_eq!(snaps[0].tensor_data, data);
}
#[test]
fn test_checkpoint_manager_record_checkpointed() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::All);
mgr.record_activation("layer0", 0, vec![1.0; 8], vec![2, 4], 50.0)
.expect("record should succeed");
let snaps = mgr.snapshots();
assert_eq!(snaps.len(), 1);
assert!(!snaps[0].is_stored());
assert!(snaps[0].is_checkpointed);
}
#[test]
fn test_checkpoint_manager_get_activation_stored() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::None);
let data = vec![3.0_f32, 1.0, 4.0, 1.0];
mgr.record_activation("fc", 0, data.clone(), vec![4], 0.0)
.expect("record should succeed");
let retrieved = mgr
.get_activation("fc", || {
Err(CheckpointError::RecomputationFailed(
"should not call".into(),
))
})
.expect("get should succeed");
assert_eq!(retrieved, data);
assert_eq!(mgr.recompute_log().len(), 0);
}
#[test]
fn test_checkpoint_manager_get_activation_recompute() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::All);
mgr.record_activation("attn", 0, vec![0.0; 4], vec![2, 2], 100.0)
.expect("record should succeed");
let expected = vec![9.0_f32, 8.0, 7.0, 6.0];
let exp_clone = expected.clone();
let retrieved = mgr
.get_activation("attn", move || Ok(exp_clone.clone()))
.expect("recompute should succeed");
assert_eq!(retrieved, expected);
assert_eq!(mgr.recompute_log(), &["attn"]);
}
#[test]
fn test_checkpoint_manager_memory_stats() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::EveryN(2));
mgr.record_activation("l0", 0, vec![0.0; 4], vec![4], 10.0).expect("ok");
mgr.record_activation("l1", 1, vec![1.0; 4], vec![4], 10.0).expect("ok");
let stats = mgr.memory_stats();
assert_eq!(stats.num_checkpointed_layers, 1);
assert_eq!(stats.num_stored_layers, 1);
assert_eq!(stats.checkpointed_bytes, 16);
assert_eq!(stats.stored_bytes, 16);
assert!((stats.memory_savings_ratio - 0.5).abs() < 1e-6);
}
#[test]
fn test_checkpoint_manager_clear() {
let mut mgr = CheckpointManager::new(CheckpointPolicy::None);
mgr.record_activation("l0", 0, vec![1.0; 8], vec![8], 0.0).expect("ok");
assert_eq!(mgr.snapshots().len(), 1);
mgr.clear();
assert_eq!(mgr.snapshots().len(), 0);
assert_eq!(mgr.memory_stats().stored_bytes, 0);
}
#[test]
fn test_segment_checkpointer_boundaries() {
let sc = SegmentCheckpointer::new(3, 12);
assert!(sc.is_boundary(0));
assert!(!sc.is_boundary(1));
assert!(!sc.is_boundary(2));
assert!(!sc.is_boundary(3));
assert!(sc.is_boundary(4));
assert!(sc.is_boundary(8));
assert!(sc.is_boundary(11)); }
#[test]
fn test_segment_checkpointer_record_get() {
let mut sc = SegmentCheckpointer::new(4, 16);
let out = vec![1.0_f32, 2.0, 3.0];
sc.record_segment(1, out.clone()).expect("record ok");
let got = sc.get_segment_output(1).expect("should exist");
assert_eq!(got, out.as_slice());
assert!(sc.get_segment_output(2).is_none());
}
#[test]
fn test_segment_checkpointer_memory_savings() {
let sc = SegmentCheckpointer::new(4, 16);
let factor = sc.memory_savings_factor();
assert!((factor - 0.25).abs() < 1e-6);
}
#[test]
fn test_checkpoint_error_display() {
let e1 = CheckpointError::LayerNotFound("block3".into());
assert!(e1.to_string().contains("block3"));
let e2 = CheckpointError::RecomputationFailed("oom".into());
assert!(e2.to_string().contains("oom"));
let e3 = CheckpointError::InvalidShape {
expected: vec![4, 8],
got: vec![4, 4],
};
assert!(e3.to_string().contains("mismatch"));
let e4 = CheckpointError::InvalidSegment(7);
assert!(e4.to_string().contains('7'));
}
}