use crate::autograd::graph_opt::OpType;
use crate::Tensor;
use std::cell::RefCell;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub enabled: bool,
pub num_segments: usize,
pub selective: bool,
}
impl CheckpointConfig {
pub fn enabled(num_segments: usize) -> Self {
Self { enabled: true, num_segments, selective: false }
}
pub fn disabled() -> Self {
Self { enabled: false, num_segments: 1, selective: false }
}
pub fn with_selective(mut self) -> Self {
self.selective = true;
self
}
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self::disabled()
}
}
pub struct CheckpointedSegment {
input: Tensor,
output: RefCell<Option<Tensor>>,
is_checkpointed: bool,
}
impl CheckpointedSegment {
pub fn new(input: Tensor, is_checkpointed: bool) -> Self {
Self { input, output: RefCell::new(None), is_checkpointed }
}
pub fn input(&self) -> &Tensor {
&self.input
}
pub fn is_checkpointed(&self) -> bool {
self.is_checkpointed
}
pub fn set_output(&self, output: Tensor) {
*self.output.borrow_mut() = Some(output);
}
pub fn output(&self) -> Option<Tensor> {
self.output.borrow().clone()
}
pub fn clear_output(&self) {
*self.output.borrow_mut() = None;
}
}
pub struct CheckpointManager {
config: CheckpointConfig,
segments: Vec<Rc<CheckpointedSegment>>,
current_segment: RefCell<usize>,
memory_saved: RefCell<usize>,
}
impl CheckpointManager {
pub fn new(config: CheckpointConfig) -> Self {
Self {
config,
segments: Vec::new(),
current_segment: RefCell::new(0),
memory_saved: RefCell::new(0),
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn num_segments(&self) -> usize {
self.config.num_segments
}
pub fn register_segment(&mut self, input: Tensor) -> Rc<CheckpointedSegment> {
let idx = self.segments.len();
let should_checkpoint = self.config.enabled && self.should_checkpoint_segment(idx);
let segment = Rc::new(CheckpointedSegment::new(input, should_checkpoint));
self.segments.push(segment.clone());
if should_checkpoint {
*self.memory_saved.borrow_mut() += 1;
}
segment
}
fn should_checkpoint_segment(&self, segment_idx: usize) -> bool {
if !self.config.enabled {
return false;
}
let checkpoint_interval = self.segments.len().max(1) / self.config.num_segments.max(1);
if checkpoint_interval == 0 {
return true; }
segment_idx.is_multiple_of(checkpoint_interval)
}
pub fn memory_saved_segments(&self) -> usize {
*self.memory_saved.borrow()
}
pub fn clear(&mut self) {
for segment in &self.segments {
segment.clear_output();
}
self.segments.clear();
*self.current_segment.borrow_mut() = 0;
}
pub fn total_segments(&self) -> usize {
self.segments.len()
}
}
pub fn checkpoint<F>(f: F, input: &Tensor) -> Tensor
where
F: Fn(&Tensor) -> Tensor,
{
f(input)
}
pub fn checkpoint_if<F>(f: F, input: &Tensor, should_checkpoint: bool) -> Tensor
where
F: Fn(&Tensor) -> Tensor,
{
if should_checkpoint {
f(input)
} else {
f(input)
}
}
pub fn estimate_memory_savings(
num_layers: usize,
hidden_size: usize,
seq_len: usize,
batch_size: usize,
num_checkpoints: usize,
) -> (usize, usize) {
let activation_size = batch_size * seq_len * hidden_size * 4;
let memory_without = num_layers * activation_size;
let sqrt_layers = (num_layers as f64).sqrt().ceil() as usize;
let memory_with = sqrt_layers.max(num_checkpoints) * activation_size;
(memory_without, memory_with)
}
pub fn optimal_checkpoints(num_layers: usize) -> usize {
((num_layers as f64).sqrt().ceil() as usize).max(1)
}
#[derive(Debug, Clone)]
pub struct OperationInfo {
pub op_type: OpType,
pub output_bytes: usize,
pub has_batch_dim: bool,
pub layer_index: usize,
}
impl OperationInfo {
pub fn new(op_type: OpType, output_bytes: usize) -> Self {
Self { op_type, output_bytes, has_batch_dim: false, layer_index: 0 }
}
pub fn with_batch_dim(mut self, has_batch: bool) -> Self {
self.has_batch_dim = has_batch;
self
}
pub fn with_layer_index(mut self, index: usize) -> Self {
self.layer_index = index;
self
}
}
pub trait CheckpointPolicy {
fn should_save(&self, op: &OperationInfo) -> bool;
fn recompute_cost(&self, _op: &OperationInfo) -> f64 {
1.0
}
}
pub struct SaveAll;
impl CheckpointPolicy for SaveAll {
fn should_save(&self, _op: &OperationInfo) -> bool {
true
}
}
pub struct SaveNothing;
impl CheckpointPolicy for SaveNothing {
fn should_save(&self, _op: &OperationInfo) -> bool {
false
}
}
pub struct SaveMatmuls;
impl CheckpointPolicy for SaveMatmuls {
fn should_save(&self, op: &OperationInfo) -> bool {
matches!(op.op_type, OpType::Matmul | OpType::Attention)
}
fn recompute_cost(&self, op: &OperationInfo) -> f64 {
match op.op_type {
OpType::Matmul => 100.0,
OpType::Attention => 150.0,
OpType::Add
| OpType::Mul
| OpType::Scale
| OpType::Sum
| OpType::Relu
| OpType::Gelu
| OpType::Softmax
| OpType::LayerNorm
| OpType::Constant => 1.0,
}
}
}
pub struct SaveUnbatchedMatmuls;
impl CheckpointPolicy for SaveUnbatchedMatmuls {
fn should_save(&self, op: &OperationInfo) -> bool {
matches!(op.op_type, OpType::Matmul | OpType::Attention) && !op.has_batch_dim
}
}
pub struct BinomialCheckpointing {
pub num_layers: usize,
}
impl BinomialCheckpointing {
pub fn checkpoint_indices(&self) -> Vec<usize> {
let num_checkpoints = optimal_checkpoints(self.num_layers);
let interval = self.num_layers / num_checkpoints.max(1);
(0..self.num_layers).step_by(interval.max(1)).collect()
}
}
impl CheckpointPolicy for BinomialCheckpointing {
fn should_save(&self, op: &OperationInfo) -> bool {
let indices = self.checkpoint_indices();
indices.contains(&op.layer_index)
}
}
pub struct MemoryBudget {
pub max_bytes: usize,
used_bytes: RefCell<usize>,
}
impl MemoryBudget {
pub fn new(max_bytes: usize) -> Self {
Self { max_bytes, used_bytes: RefCell::new(0) }
}
pub fn used_bytes(&self) -> usize {
*self.used_bytes.borrow()
}
pub fn reset(&self) {
*self.used_bytes.borrow_mut() = 0;
}
}
impl CheckpointPolicy for MemoryBudget {
fn should_save(&self, op: &OperationInfo) -> bool {
let current = *self.used_bytes.borrow();
if current + op.output_bytes <= self.max_bytes {
*self.used_bytes.borrow_mut() += op.output_bytes;
true
} else {
false
}
}
}
pub struct CustomPolicy<F: Fn(&OperationInfo) -> bool> {
predicate: F,
}
impl<F: Fn(&OperationInfo) -> bool> CustomPolicy<F> {
pub fn new(predicate: F) -> Self {
Self { predicate }
}
}
impl<F: Fn(&OperationInfo) -> bool> CheckpointPolicy for CustomPolicy<F> {
fn should_save(&self, op: &OperationInfo) -> bool {
(self.predicate)(op)
}
}
pub struct PolicyCheckpointManager {
saved: Vec<Option<Tensor>>,
total_bytes_saved: usize,
num_layers: usize,
}
impl PolicyCheckpointManager {
pub fn new(num_layers: usize) -> Self {
Self { saved: vec![None; num_layers], total_bytes_saved: 0, num_layers }
}
pub fn record<P: CheckpointPolicy>(
&mut self,
layer_index: usize,
activation: &Tensor,
op_info: &OperationInfo,
policy: &P,
) {
if policy.should_save(op_info) && layer_index < self.num_layers {
self.saved[layer_index] = Some(activation.clone());
self.total_bytes_saved += op_info.output_bytes;
}
}
pub fn get(&self, layer_index: usize) -> Option<&Tensor> {
self.saved.get(layer_index).and_then(|s| s.as_ref())
}
pub fn is_saved(&self, layer_index: usize) -> bool {
self.saved.get(layer_index).is_some_and(Option::is_some)
}
pub fn total_bytes(&self) -> usize {
self.total_bytes_saved
}
pub fn num_saved(&self) -> usize {
self.saved.iter().filter(|s| s.is_some()).count()
}
pub fn clear(&mut self) {
self.saved.iter_mut().for_each(|s| *s = None);
self.total_bytes_saved = 0;
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
}
pub fn estimate_policy_tradeoff<P: CheckpointPolicy>(
policy: &P,
layer_infos: &[OperationInfo],
) -> (usize, usize, f64) {
let mut bytes_saved = 0usize;
let mut bytes_used = 0usize;
let mut recompute_overhead = 0.0f64;
for info in layer_infos {
if policy.should_save(info) {
bytes_used += info.output_bytes;
} else {
bytes_saved += info.output_bytes;
recompute_overhead += policy.recompute_cost(info);
}
}
(bytes_saved, bytes_used, recompute_overhead)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::scale;
#[test]
fn test_checkpoint_config_enabled() {
let config = CheckpointConfig::enabled(4);
assert!(config.enabled);
assert_eq!(config.num_segments, 4);
assert!(!config.selective);
}
#[test]
fn test_checkpoint_config_disabled() {
let config = CheckpointConfig::disabled();
assert!(!config.enabled);
}
#[test]
fn test_checkpoint_config_default() {
let config = CheckpointConfig::default();
assert!(!config.enabled);
}
#[test]
fn test_checkpoint_config_selective() {
let config = CheckpointConfig::enabled(4).with_selective();
assert!(config.selective);
}
#[test]
fn test_checkpointed_segment_new() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let segment = CheckpointedSegment::new(input, true);
assert!(segment.is_checkpointed());
assert!(segment.output().is_none());
}
#[test]
fn test_checkpointed_segment_output() {
let input = Tensor::from_vec(vec![1.0, 2.0], true);
let segment = CheckpointedSegment::new(input, true);
let output = Tensor::from_vec(vec![2.0, 4.0], true);
segment.set_output(output.clone());
assert!(segment.output().is_some());
assert_eq!(segment.output().expect("operation should succeed").len(), 2);
}
#[test]
fn test_checkpointed_segment_clear() {
let input = Tensor::from_vec(vec![1.0], true);
let segment = CheckpointedSegment::new(input, true);
segment.set_output(Tensor::from_vec(vec![2.0], true));
segment.clear_output();
assert!(segment.output().is_none());
}
#[test]
fn test_checkpoint_manager_new() {
let config = CheckpointConfig::enabled(4);
let manager = CheckpointManager::new(config);
assert!(manager.is_enabled());
assert_eq!(manager.num_segments(), 4);
}
#[test]
fn test_checkpoint_manager_disabled() {
let config = CheckpointConfig::disabled();
let manager = CheckpointManager::new(config);
assert!(!manager.is_enabled());
}
#[test]
fn test_checkpoint_manager_register() {
let config = CheckpointConfig::enabled(2);
let mut manager = CheckpointManager::new(config);
let input1 = Tensor::from_vec(vec![1.0], true);
let input2 = Tensor::from_vec(vec![2.0], true);
let seg1 = manager.register_segment(input1);
let seg2 = manager.register_segment(input2);
assert_eq!(manager.total_segments(), 2);
assert_eq!(seg1.input().len(), 1);
assert_eq!(seg2.input().len(), 1);
}
#[test]
fn test_checkpoint_manager_clear() {
let config = CheckpointConfig::enabled(2);
let mut manager = CheckpointManager::new(config);
manager.register_segment(Tensor::from_vec(vec![1.0], true));
manager.register_segment(Tensor::from_vec(vec![2.0], true));
manager.clear();
assert_eq!(manager.total_segments(), 0);
}
#[test]
fn test_checkpoint_function() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let output = checkpoint(|x| scale(x, 2.0), &input);
assert_eq!(output.len(), 3);
assert_eq!(output.data()[0], 2.0);
}
#[test]
fn test_checkpoint_if_enabled() {
let input = Tensor::from_vec(vec![1.0, 2.0], true);
let output = checkpoint_if(|x| scale(x, 3.0), &input, true);
assert_eq!(output.data()[0], 3.0);
}
#[test]
fn test_checkpoint_if_disabled() {
let input = Tensor::from_vec(vec![1.0, 2.0], true);
let output = checkpoint_if(|x| scale(x, 3.0), &input, false);
assert_eq!(output.data()[0], 3.0);
}
#[test]
fn test_estimate_memory_savings() {
let (without, with) = estimate_memory_savings(32, 4096, 512, 1, 6);
assert!(with < without);
assert_eq!(without, 32 * 512 * 4096 * 4);
}
#[test]
fn test_optimal_checkpoints() {
assert_eq!(optimal_checkpoints(1), 1);
assert_eq!(optimal_checkpoints(4), 2);
assert_eq!(optimal_checkpoints(16), 4);
assert_eq!(optimal_checkpoints(32), 6);
assert_eq!(optimal_checkpoints(64), 8);
}
#[test]
fn test_memory_savings_formula() {
let num_layers = 32;
let checkpoints = optimal_checkpoints(num_layers);
let (without, with) = estimate_memory_savings(num_layers, 1024, 128, 1, checkpoints);
let ratio = without as f64 / with as f64;
assert!(ratio > 4.0); }
#[test]
fn test_checkpoint_preserves_computation() {
let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
let direct = scale(&input, 2.5);
let checkpointed = checkpoint(|x| scale(x, 2.5), &input);
for i in 0..4 {
assert_eq!(direct.data()[i], checkpointed.data()[i]);
}
}
#[test]
fn test_nested_checkpoints() {
let input = Tensor::from_vec(vec![1.0, 2.0], true);
let output = checkpoint(
|x| {
let h1 = scale(x, 2.0);
checkpoint(|y| scale(y, 3.0), &h1)
},
&input,
);
assert_eq!(output.data()[0], 6.0);
}
#[test]
fn test_checkpoint_manager_memory_tracking() {
let config = CheckpointConfig::enabled(2);
let mut manager = CheckpointManager::new(config);
for i in 0..4 {
manager.register_segment(Tensor::from_vec(vec![i as f32], true));
}
assert!(manager.memory_saved_segments() > 0);
}
fn make_op(op_type: OpType, bytes: usize) -> OperationInfo {
OperationInfo::new(op_type, bytes)
}
#[test]
fn test_operation_info_builder() {
let info =
OperationInfo::new(OpType::Matmul, 1024).with_batch_dim(true).with_layer_index(5);
assert_eq!(info.op_type, OpType::Matmul);
assert_eq!(info.output_bytes, 1024);
assert!(info.has_batch_dim);
assert_eq!(info.layer_index, 5);
}
#[test]
fn test_save_all_policy() {
let policy = SaveAll;
assert!(policy.should_save(&make_op(OpType::Add, 100)));
assert!(policy.should_save(&make_op(OpType::Matmul, 10000)));
assert!(policy.should_save(&make_op(OpType::Relu, 50)));
}
#[test]
fn test_save_nothing_policy() {
let policy = SaveNothing;
assert!(!policy.should_save(&make_op(OpType::Add, 100)));
assert!(!policy.should_save(&make_op(OpType::Matmul, 10000)));
assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
}
#[test]
fn test_save_matmuls_policy() {
let policy = SaveMatmuls;
assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
assert!(policy.should_save(&make_op(OpType::Attention, 2000)));
assert!(!policy.should_save(&make_op(OpType::Add, 100)));
assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
assert!(!policy.should_save(&make_op(OpType::Softmax, 100)));
}
#[test]
fn test_save_matmuls_recompute_cost() {
let policy = SaveMatmuls;
assert!((policy.recompute_cost(&make_op(OpType::Matmul, 0)) - 100.0).abs() < f64::EPSILON);
assert!(
(policy.recompute_cost(&make_op(OpType::Attention, 0)) - 150.0).abs() < f64::EPSILON
);
assert!((policy.recompute_cost(&make_op(OpType::Add, 0)) - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_save_unbatched_matmuls_policy() {
let policy = SaveUnbatchedMatmuls;
let unbatched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(false);
assert!(policy.should_save(&unbatched));
let batched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(true);
assert!(!policy.should_save(&batched));
let add = OperationInfo::new(OpType::Add, 100).with_batch_dim(false);
assert!(!policy.should_save(&add));
}
#[test]
fn test_binomial_checkpointing_indices() {
let policy = BinomialCheckpointing { num_layers: 16 };
let indices = policy.checkpoint_indices();
assert_eq!(indices, vec![0, 4, 8, 12]);
}
#[test]
fn test_binomial_checkpointing_policy() {
let policy = BinomialCheckpointing { num_layers: 16 };
let at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(0);
assert!(policy.should_save(&at_checkpoint));
let not_at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(1);
assert!(!policy.should_save(¬_at_checkpoint));
let at_checkpoint_4 = OperationInfo::new(OpType::Add, 100).with_layer_index(4);
assert!(policy.should_save(&at_checkpoint_4));
}
#[test]
fn test_memory_budget_policy() {
let policy = MemoryBudget::new(500);
let op1 = make_op(OpType::Matmul, 200);
assert!(policy.should_save(&op1));
assert_eq!(policy.used_bytes(), 200);
let op2 = make_op(OpType::Add, 200);
assert!(policy.should_save(&op2));
assert_eq!(policy.used_bytes(), 400);
let op3 = make_op(OpType::Relu, 200);
assert!(!policy.should_save(&op3));
assert_eq!(policy.used_bytes(), 400);
policy.reset();
assert_eq!(policy.used_bytes(), 0);
assert!(policy.should_save(&op3));
}
#[test]
fn test_custom_policy() {
let policy = CustomPolicy::new(|op: &OperationInfo| op.output_bytes > 500);
assert!(!policy.should_save(&make_op(OpType::Add, 100)));
assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
assert!(!policy.should_save(&make_op(OpType::Relu, 500)));
assert!(policy.should_save(&make_op(OpType::Softmax, 501)));
}
#[test]
fn test_policy_checkpoint_manager_basic() {
let mut manager = PolicyCheckpointManager::new(4);
let policy = SaveAll;
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
let info = make_op(OpType::Matmul, 12);
manager.record(0, &tensor, &info, &policy);
assert!(manager.is_saved(0));
assert!(!manager.is_saved(1));
assert_eq!(manager.num_saved(), 1);
assert_eq!(manager.total_bytes(), 12);
let saved = manager.get(0).expect("key should exist");
assert_eq!(saved.len(), 3);
}
#[test]
fn test_policy_checkpoint_manager_selective() {
let mut manager = PolicyCheckpointManager::new(4);
let policy = SaveMatmuls;
let t1 = Tensor::from_vec(vec![1.0], true);
let t2 = Tensor::from_vec(vec![2.0], true);
manager.record(0, &t1, &make_op(OpType::Matmul, 4), &policy);
manager.record(1, &t2, &make_op(OpType::Add, 4), &policy);
assert!(manager.is_saved(0));
assert!(!manager.is_saved(1));
assert_eq!(manager.num_saved(), 1);
}
#[test]
fn test_policy_checkpoint_manager_clear() {
let mut manager = PolicyCheckpointManager::new(2);
let policy = SaveAll;
let t = Tensor::from_vec(vec![1.0], true);
manager.record(0, &t, &make_op(OpType::Add, 4), &policy);
manager.clear();
assert_eq!(manager.num_saved(), 0);
assert_eq!(manager.total_bytes(), 0);
assert!(!manager.is_saved(0));
}
#[test]
fn test_policy_checkpoint_manager_out_of_bounds() {
let mut manager = PolicyCheckpointManager::new(2);
let policy = SaveAll;
let t = Tensor::from_vec(vec![1.0], true);
manager.record(5, &t, &make_op(OpType::Add, 4), &policy);
assert_eq!(manager.num_saved(), 0);
}
#[test]
fn test_estimate_policy_tradeoff_save_all() {
let policy = SaveAll;
let infos = vec![
make_op(OpType::Matmul, 1000),
make_op(OpType::Add, 200),
make_op(OpType::Relu, 200),
];
let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
assert_eq!(saved, 0); assert_eq!(used, 1400);
assert!((overhead - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_estimate_policy_tradeoff_save_nothing() {
let policy = SaveNothing;
let infos = vec![make_op(OpType::Matmul, 1000), make_op(OpType::Add, 200)];
let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
assert_eq!(saved, 1200); assert_eq!(used, 0);
assert!(overhead > 0.0); }
#[test]
fn test_estimate_policy_tradeoff_save_matmuls() {
let policy = SaveMatmuls;
let infos = vec![
make_op(OpType::Matmul, 1000),
make_op(OpType::Add, 200),
make_op(OpType::Relu, 200),
];
let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
assert_eq!(used, 1000); assert_eq!(saved, 400); assert!(overhead > 0.0); }
#[test]
fn test_policy_checkpoint_manager_num_layers() {
let manager = PolicyCheckpointManager::new(8);
assert_eq!(manager.num_layers(), 8);
}
#[test]
fn test_binomial_single_layer() {
let policy = BinomialCheckpointing { num_layers: 1 };
let indices = policy.checkpoint_indices();
assert_eq!(indices, vec![0]);
}
#[test]
fn test_default_recompute_cost() {
let policy = SaveAll;
let info = make_op(OpType::Add, 100);
assert!((policy.recompute_cost(&info) - 1.0).abs() < f64::EPSILON);
}
}