Skip to main content

entrenar/autograd/
checkpoint.rs

1//! Gradient checkpointing for memory-efficient training
2//!
3//! Gradient checkpointing trades compute for memory by recomputing intermediate
4//! activations during the backward pass instead of storing them.
5//!
6//! ## How It Works
7//!
8//! 1. During forward pass, only inputs to checkpointed segments are saved
9//! 2. During backward pass, the forward pass is recomputed to get activations
10//! 3. Memory usage scales with O(sqrt(N)) instead of O(N) for N layers
11//!
12//! ## Example
13//!
14//! ```ignore
15//! use entrenar::autograd::checkpoint::{checkpoint, CheckpointConfig};
16//!
17//! // Wrap a computation in a checkpoint
18//! let output = checkpoint(|| {
19//!     let h1 = layer1.forward(&input);
20//!     let h2 = layer2.forward(&h1);
21//!     layer3.forward(&h2)
22//! }, &input);
23//! ```
24
25use crate::autograd::graph_opt::OpType;
26use crate::Tensor;
27use std::cell::RefCell;
28use std::rc::Rc;
29
30/// Configuration for gradient checkpointing
31#[derive(Debug, Clone)]
32pub struct CheckpointConfig {
33    /// Whether checkpointing is enabled
34    pub enabled: bool,
35    /// Number of segments to divide the model into
36    pub num_segments: usize,
37    /// Whether to use selective checkpointing (only checkpoint attention)
38    pub selective: bool,
39}
40
41impl CheckpointConfig {
42    /// Create new config with checkpointing enabled
43    pub fn enabled(num_segments: usize) -> Self {
44        Self { enabled: true, num_segments, selective: false }
45    }
46
47    /// Create config with checkpointing disabled
48    pub fn disabled() -> Self {
49        Self { enabled: false, num_segments: 1, selective: false }
50    }
51
52    /// Enable selective checkpointing (only attention layers)
53    pub fn with_selective(mut self) -> Self {
54        self.selective = true;
55        self
56    }
57}
58
59impl Default for CheckpointConfig {
60    fn default() -> Self {
61        Self::disabled()
62    }
63}
64
65/// A checkpointed computation segment
66///
67/// Stores the input tensor and a function to recompute the forward pass.
68/// During backward, the forward pass is recomputed to recover activations.
69pub struct CheckpointedSegment {
70    /// Input tensor (saved for recomputation)
71    input: Tensor,
72    /// Output tensor (computed lazily or cached)
73    output: RefCell<Option<Tensor>>,
74    /// Whether this segment has been checkpointed
75    is_checkpointed: bool,
76}
77
78impl CheckpointedSegment {
79    /// Create a new checkpointed segment
80    pub fn new(input: Tensor, is_checkpointed: bool) -> Self {
81        Self { input, output: RefCell::new(None), is_checkpointed }
82    }
83
84    /// Get the input tensor
85    pub fn input(&self) -> &Tensor {
86        &self.input
87    }
88
89    /// Check if this segment is checkpointed
90    pub fn is_checkpointed(&self) -> bool {
91        self.is_checkpointed
92    }
93
94    /// Set the output (used during forward pass)
95    pub fn set_output(&self, output: Tensor) {
96        *self.output.borrow_mut() = Some(output);
97    }
98
99    /// Get the output (returns None if not computed yet)
100    pub fn output(&self) -> Option<Tensor> {
101        contract_pre_output!();
102        self.output.borrow().clone()
103    }
104
105    /// Clear the output to free memory
106    pub fn clear_output(&self) {
107        *self.output.borrow_mut() = None;
108    }
109}
110
111/// Checkpoint manager for coordinating checkpointed segments
112pub struct CheckpointManager {
113    /// Configuration
114    config: CheckpointConfig,
115    /// Segments in order
116    segments: Vec<Rc<CheckpointedSegment>>,
117    /// Current segment index during forward pass
118    current_segment: RefCell<usize>,
119    /// Memory saved (estimated bytes)
120    memory_saved: RefCell<usize>,
121}
122
123impl CheckpointManager {
124    /// Create a new checkpoint manager
125    pub fn new(config: CheckpointConfig) -> Self {
126        Self {
127            config,
128            segments: Vec::new(),
129            current_segment: RefCell::new(0),
130            memory_saved: RefCell::new(0),
131        }
132    }
133
134    /// Check if checkpointing is enabled
135    pub fn is_enabled(&self) -> bool {
136        self.config.enabled
137    }
138
139    /// Get the number of segments
140    pub fn num_segments(&self) -> usize {
141        self.config.num_segments
142    }
143
144    /// Register a new segment
145    pub fn register_segment(&mut self, input: Tensor) -> Rc<CheckpointedSegment> {
146        let idx = self.segments.len();
147        let should_checkpoint = self.config.enabled && self.should_checkpoint_segment(idx);
148
149        let segment = Rc::new(CheckpointedSegment::new(input, should_checkpoint));
150        self.segments.push(segment.clone());
151
152        // Track memory savings
153        if should_checkpoint {
154            // Estimate: we save the intermediate activations
155            // For now, just track a placeholder value
156            *self.memory_saved.borrow_mut() += 1;
157        }
158
159        segment
160    }
161
162    /// Determine if a segment should be checkpointed
163    fn should_checkpoint_segment(&self, segment_idx: usize) -> bool {
164        if !self.config.enabled {
165            return false;
166        }
167
168        // Checkpoint every N segments based on config
169        let checkpoint_interval = self.segments.len().max(1) / self.config.num_segments.max(1);
170        if checkpoint_interval == 0 {
171            return true; // Checkpoint all if interval is 0
172        }
173
174        segment_idx.is_multiple_of(checkpoint_interval)
175    }
176
177    /// Get estimated memory saved (number of checkpointed segments)
178    pub fn memory_saved_segments(&self) -> usize {
179        *self.memory_saved.borrow()
180    }
181
182    /// Clear all segments (call after backward pass)
183    pub fn clear(&mut self) {
184        for segment in &self.segments {
185            segment.clear_output();
186        }
187        self.segments.clear();
188        *self.current_segment.borrow_mut() = 0;
189    }
190
191    /// Get total number of registered segments
192    pub fn total_segments(&self) -> usize {
193        self.segments.len()
194    }
195}
196
197/// Run a computation with gradient checkpointing
198///
199/// The function `f` is executed during forward pass. During backward pass,
200/// if checkpointing is enabled, `f` will be re-executed to recompute activations.
201///
202/// # Arguments
203///
204/// * `f` - Function that computes the forward pass
205/// * `input` - Input tensor (saved for recomputation)
206///
207/// # Returns
208///
209/// The output tensor from the computation
210pub fn checkpoint<F>(f: F, input: &Tensor) -> Tensor
211where
212    F: Fn(&Tensor) -> Tensor,
213{
214    // Simply run the function - actual checkpointing happens in training loop
215    f(input)
216}
217
218/// Run a computation with explicit checkpointing control
219///
220/// # Arguments
221///
222/// * `f` - Function that computes the forward pass
223/// * `input` - Input tensor
224/// * `should_checkpoint` - Whether to enable checkpointing for this segment
225pub fn checkpoint_if<F>(f: F, input: &Tensor, should_checkpoint: bool) -> Tensor
226where
227    F: Fn(&Tensor) -> Tensor,
228{
229    if should_checkpoint {
230        // In a full implementation, we would save `input` and `f` for recomputation
231        // For now, just run the function
232        f(input)
233    } else {
234        f(input)
235    }
236}
237
238/// Estimate memory savings from checkpointing
239///
240/// # Arguments
241///
242/// * `num_layers` - Number of transformer layers
243/// * `hidden_size` - Hidden dimension
244/// * `seq_len` - Sequence length
245/// * `batch_size` - Batch size
246/// * `num_checkpoints` - Number of checkpoint segments
247///
248/// # Returns
249///
250/// Tuple of (memory_without_checkpoint, memory_with_checkpoint) in bytes
251pub fn estimate_memory_savings(
252    num_layers: usize,
253    hidden_size: usize,
254    seq_len: usize,
255    batch_size: usize,
256    num_checkpoints: usize,
257) -> (usize, usize) {
258    // Each activation: batch_size * seq_len * hidden_size * sizeof(f32)
259    let activation_size = batch_size * seq_len * hidden_size * 4;
260
261    // Without checkpointing: store all layer activations
262    let memory_without = num_layers * activation_size;
263
264    // With checkpointing: store only checkpoint boundaries + recompute cost
265    // Memory scales as O(sqrt(N)) with optimal checkpointing
266    let sqrt_layers = (num_layers as f64).sqrt().ceil() as usize;
267    let memory_with = sqrt_layers.max(num_checkpoints) * activation_size;
268
269    (memory_without, memory_with)
270}
271
272/// Calculate optimal number of checkpoints for given memory budget
273///
274/// Uses the formula: optimal_checkpoints = sqrt(num_layers)
275pub fn optimal_checkpoints(num_layers: usize) -> usize {
276    ((num_layers as f64).sqrt().ceil() as usize).max(1)
277}
278
279// ---------------------------------------------------------------------------
280// Policy-based selective gradient checkpointing (GH-83)
281// ---------------------------------------------------------------------------
282
283/// Metadata about an operation, used by checkpoint policies to decide
284/// whether to save or recompute its output activation.
285#[derive(Debug, Clone)]
286pub struct OperationInfo {
287    /// The type of operation
288    pub op_type: OpType,
289    /// Output size in bytes (batch_size * elements * sizeof(f32))
290    pub output_bytes: usize,
291    /// Whether any input has a batch dimension (ndim > 2)
292    pub has_batch_dim: bool,
293    /// Layer index in the sequential model
294    pub layer_index: usize,
295}
296
297impl OperationInfo {
298    /// Create operation info for a given op type and output size
299    pub fn new(op_type: OpType, output_bytes: usize) -> Self {
300        Self { op_type, output_bytes, has_batch_dim: false, layer_index: 0 }
301    }
302
303    /// Set whether this operation has batch dimensions
304    pub fn with_batch_dim(mut self, has_batch: bool) -> Self {
305        self.has_batch_dim = has_batch;
306        self
307    }
308
309    /// Set the layer index
310    pub fn with_layer_index(mut self, index: usize) -> Self {
311        self.layer_index = index;
312        self
313    }
314}
315
316/// Policy for deciding which activations to save vs recompute during
317/// gradient checkpointing.
318///
319/// Implementations control the memory/compute tradeoff by returning `true`
320/// from `should_save` for operations whose outputs should be cached.
321pub trait CheckpointPolicy {
322    /// Returns true if this operation's output should be saved (not recomputed)
323    fn should_save(&self, op: &OperationInfo) -> bool;
324
325    /// Estimated relative cost of recomputing this operation (default: 1.0)
326    fn recompute_cost(&self, _op: &OperationInfo) -> f64 {
327        1.0
328    }
329}
330
331/// Save everything — maximum memory usage, no recomputation overhead.
332pub struct SaveAll;
333
334impl CheckpointPolicy for SaveAll {
335    fn should_save(&self, _op: &OperationInfo) -> bool {
336        true
337    }
338}
339
340/// Save nothing — minimum memory usage, full recomputation during backward.
341pub struct SaveNothing;
342
343impl CheckpointPolicy for SaveNothing {
344    fn should_save(&self, _op: &OperationInfo) -> bool {
345        false
346    }
347}
348
349/// Save only matrix multiplication results (most expensive to recompute).
350pub struct SaveMatmuls;
351
352impl CheckpointPolicy for SaveMatmuls {
353    fn should_save(&self, op: &OperationInfo) -> bool {
354        matches!(op.op_type, OpType::Matmul | OpType::Attention)
355    }
356
357    fn recompute_cost(&self, op: &OperationInfo) -> f64 {
358        match op.op_type {
359            OpType::Matmul => 100.0,
360            OpType::Attention => 150.0,
361            OpType::Add
362            | OpType::Mul
363            | OpType::Scale
364            | OpType::Sum
365            | OpType::Relu
366            | OpType::Gelu
367            | OpType::Softmax
368            | OpType::LayerNorm
369            | OpType::Constant => 1.0,
370        }
371    }
372}
373
374/// Save matmuls that do NOT have batch dimensions (common in transformers).
375/// These are typically the most expensive weight-projection operations.
376pub struct SaveUnbatchedMatmuls;
377
378impl CheckpointPolicy for SaveUnbatchedMatmuls {
379    fn should_save(&self, op: &OperationInfo) -> bool {
380        matches!(op.op_type, OpType::Matmul | OpType::Attention) && !op.has_batch_dim
381    }
382}
383
384/// Save activations at regular intervals (every N layers).
385/// Uses the binomial checkpointing strategy: checkpoint sqrt(N) layers
386/// for O(sqrt(N)) memory with O(1) extra forward passes.
387pub struct BinomialCheckpointing {
388    /// Total number of layers in the model
389    pub num_layers: usize,
390}
391
392impl BinomialCheckpointing {
393    /// Compute the indices that should be checkpointed
394    pub fn checkpoint_indices(&self) -> Vec<usize> {
395        let num_checkpoints = optimal_checkpoints(self.num_layers);
396        let interval = self.num_layers / num_checkpoints.max(1);
397        (0..self.num_layers).step_by(interval.max(1)).collect()
398    }
399}
400
401impl CheckpointPolicy for BinomialCheckpointing {
402    fn should_save(&self, op: &OperationInfo) -> bool {
403        let indices = self.checkpoint_indices();
404        indices.contains(&op.layer_index)
405    }
406}
407
408/// Save activations up to a memory budget (in bytes).
409pub struct MemoryBudget {
410    /// Maximum total bytes for saved activations
411    pub max_bytes: usize,
412    /// Current bytes used (interior mutability for stateful tracking)
413    used_bytes: RefCell<usize>,
414}
415
416impl MemoryBudget {
417    /// Create a new memory budget policy
418    pub fn new(max_bytes: usize) -> Self {
419        Self { max_bytes, used_bytes: RefCell::new(0) }
420    }
421
422    /// Get the current bytes used
423    pub fn used_bytes(&self) -> usize {
424        *self.used_bytes.borrow()
425    }
426
427    /// Reset the used bytes counter
428    pub fn reset(&self) {
429        *self.used_bytes.borrow_mut() = 0;
430    }
431}
432
433impl CheckpointPolicy for MemoryBudget {
434    fn should_save(&self, op: &OperationInfo) -> bool {
435        let current = *self.used_bytes.borrow();
436        if current + op.output_bytes <= self.max_bytes {
437            *self.used_bytes.borrow_mut() += op.output_bytes;
438            true
439        } else {
440            false
441        }
442    }
443}
444
445/// Custom policy using a predicate function.
446pub struct CustomPolicy<F: Fn(&OperationInfo) -> bool> {
447    predicate: F,
448}
449
450impl<F: Fn(&OperationInfo) -> bool> CustomPolicy<F> {
451    /// Create a custom policy from a predicate
452    pub fn new(predicate: F) -> Self {
453        Self { predicate }
454    }
455}
456
457impl<F: Fn(&OperationInfo) -> bool> CheckpointPolicy for CustomPolicy<F> {
458    fn should_save(&self, op: &OperationInfo) -> bool {
459        (self.predicate)(op)
460    }
461}
462
463/// Policy-based checkpoint manager that uses a `CheckpointPolicy` to
464/// decide which activations to save vs recompute.
465pub struct PolicyCheckpointManager {
466    /// Activation storage (layer_index -> saved tensor)
467    saved: Vec<Option<Tensor>>,
468    /// Total bytes saved
469    total_bytes_saved: usize,
470    /// Number of layers
471    num_layers: usize,
472}
473
474impl PolicyCheckpointManager {
475    /// Create a new policy checkpoint manager
476    pub fn new(num_layers: usize) -> Self {
477        Self { saved: vec![None; num_layers], total_bytes_saved: 0, num_layers }
478    }
479
480    /// Record a forward activation, saving it if the policy says so
481    pub fn record<P: CheckpointPolicy>(
482        &mut self,
483        layer_index: usize,
484        activation: &Tensor,
485        op_info: &OperationInfo,
486        policy: &P,
487    ) {
488        if policy.should_save(op_info) && layer_index < self.num_layers {
489            self.saved[layer_index] = Some(activation.clone());
490            self.total_bytes_saved += op_info.output_bytes;
491        }
492    }
493
494    /// Get a saved activation (returns None if it was not saved / needs recompute)
495    pub fn get(&self, layer_index: usize) -> Option<&Tensor> {
496        self.saved.get(layer_index).and_then(|s| s.as_ref())
497    }
498
499    /// Check if an activation is saved for a given layer
500    pub fn is_saved(&self, layer_index: usize) -> bool {
501        self.saved.get(layer_index).is_some_and(Option::is_some)
502    }
503
504    /// Get total bytes used by saved activations
505    pub fn total_bytes(&self) -> usize {
506        contract_pre_total_bytes!();
507        self.total_bytes_saved
508    }
509
510    /// Get the number of saved activations
511    pub fn num_saved(&self) -> usize {
512        self.saved.iter().filter(|s| s.is_some()).count()
513    }
514
515    /// Clear all saved activations
516    pub fn clear(&mut self) {
517        self.saved.iter_mut().for_each(|s| *s = None);
518        self.total_bytes_saved = 0;
519    }
520
521    /// Get the total number of layers
522    pub fn num_layers(&self) -> usize {
523        self.num_layers
524    }
525}
526
527/// Estimate the memory/compute tradeoff for a given policy on a model.
528///
529/// Returns `(bytes_saved, bytes_used, recompute_overhead)` where:
530/// - `bytes_saved` is the memory freed by not saving some activations
531/// - `bytes_used` is the memory used by saved activations
532/// - `recompute_overhead` is the estimated relative compute cost of recomputation
533pub fn estimate_policy_tradeoff<P: CheckpointPolicy>(
534    policy: &P,
535    layer_infos: &[OperationInfo],
536) -> (usize, usize, f64) {
537    let mut bytes_saved = 0usize;
538    let mut bytes_used = 0usize;
539    let mut recompute_overhead = 0.0f64;
540
541    for info in layer_infos {
542        if policy.should_save(info) {
543            bytes_used += info.output_bytes;
544        } else {
545            bytes_saved += info.output_bytes;
546            recompute_overhead += policy.recompute_cost(info);
547        }
548    }
549
550    (bytes_saved, bytes_used, recompute_overhead)
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556    use crate::autograd::scale;
557
558    #[test]
559    fn test_checkpoint_config_enabled() {
560        let config = CheckpointConfig::enabled(4);
561        assert!(config.enabled);
562        assert_eq!(config.num_segments, 4);
563        assert!(!config.selective);
564    }
565
566    #[test]
567    fn test_checkpoint_config_disabled() {
568        let config = CheckpointConfig::disabled();
569        assert!(!config.enabled);
570    }
571
572    #[test]
573    fn test_checkpoint_config_default() {
574        let config = CheckpointConfig::default();
575        assert!(!config.enabled);
576    }
577
578    #[test]
579    fn test_checkpoint_config_selective() {
580        let config = CheckpointConfig::enabled(4).with_selective();
581        assert!(config.selective);
582    }
583
584    #[test]
585    fn test_checkpointed_segment_new() {
586        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
587        let segment = CheckpointedSegment::new(input, true);
588        assert!(segment.is_checkpointed());
589        assert!(segment.output().is_none());
590    }
591
592    #[test]
593    fn test_checkpointed_segment_output() {
594        let input = Tensor::from_vec(vec![1.0, 2.0], true);
595        let segment = CheckpointedSegment::new(input, true);
596
597        let output = Tensor::from_vec(vec![2.0, 4.0], true);
598        segment.set_output(output.clone());
599
600        assert!(segment.output().is_some());
601        assert_eq!(segment.output().expect("operation should succeed").len(), 2);
602    }
603
604    #[test]
605    fn test_checkpointed_segment_clear() {
606        let input = Tensor::from_vec(vec![1.0], true);
607        let segment = CheckpointedSegment::new(input, true);
608        segment.set_output(Tensor::from_vec(vec![2.0], true));
609
610        segment.clear_output();
611        assert!(segment.output().is_none());
612    }
613
614    #[test]
615    fn test_checkpoint_manager_new() {
616        let config = CheckpointConfig::enabled(4);
617        let manager = CheckpointManager::new(config);
618        assert!(manager.is_enabled());
619        assert_eq!(manager.num_segments(), 4);
620    }
621
622    #[test]
623    fn test_checkpoint_manager_disabled() {
624        let config = CheckpointConfig::disabled();
625        let manager = CheckpointManager::new(config);
626        assert!(!manager.is_enabled());
627    }
628
629    #[test]
630    fn test_checkpoint_manager_register() {
631        let config = CheckpointConfig::enabled(2);
632        let mut manager = CheckpointManager::new(config);
633
634        let input1 = Tensor::from_vec(vec![1.0], true);
635        let input2 = Tensor::from_vec(vec![2.0], true);
636
637        let seg1 = manager.register_segment(input1);
638        let seg2 = manager.register_segment(input2);
639
640        assert_eq!(manager.total_segments(), 2);
641        assert_eq!(seg1.input().len(), 1);
642        assert_eq!(seg2.input().len(), 1);
643    }
644
645    #[test]
646    fn test_checkpoint_manager_clear() {
647        let config = CheckpointConfig::enabled(2);
648        let mut manager = CheckpointManager::new(config);
649
650        manager.register_segment(Tensor::from_vec(vec![1.0], true));
651        manager.register_segment(Tensor::from_vec(vec![2.0], true));
652
653        manager.clear();
654        assert_eq!(manager.total_segments(), 0);
655    }
656
657    #[test]
658    fn test_checkpoint_function() {
659        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
660        let output = checkpoint(|x| scale(x, 2.0), &input);
661        assert_eq!(output.len(), 3);
662        assert_eq!(output.data()[0], 2.0);
663    }
664
665    #[test]
666    fn test_checkpoint_if_enabled() {
667        let input = Tensor::from_vec(vec![1.0, 2.0], true);
668        let output = checkpoint_if(|x| scale(x, 3.0), &input, true);
669        assert_eq!(output.data()[0], 3.0);
670    }
671
672    #[test]
673    fn test_checkpoint_if_disabled() {
674        let input = Tensor::from_vec(vec![1.0, 2.0], true);
675        let output = checkpoint_if(|x| scale(x, 3.0), &input, false);
676        assert_eq!(output.data()[0], 3.0);
677    }
678
679    #[test]
680    fn test_estimate_memory_savings() {
681        let (without, with) = estimate_memory_savings(32, 4096, 512, 1, 6);
682
683        // With checkpointing should use less memory
684        assert!(with < without);
685
686        // Sanity check: without checkpointing stores all layers
687        // 32 layers * 512 seq * 4096 hidden * 4 bytes = 268,435,456 bytes
688        assert_eq!(without, 32 * 512 * 4096 * 4);
689    }
690
691    #[test]
692    fn test_optimal_checkpoints() {
693        assert_eq!(optimal_checkpoints(1), 1);
694        assert_eq!(optimal_checkpoints(4), 2);
695        assert_eq!(optimal_checkpoints(16), 4);
696        assert_eq!(optimal_checkpoints(32), 6);
697        assert_eq!(optimal_checkpoints(64), 8);
698    }
699
700    #[test]
701    fn test_memory_savings_formula() {
702        // For 32 layers with optimal sqrt(32) ≈ 6 checkpoints
703        let num_layers = 32;
704        let checkpoints = optimal_checkpoints(num_layers);
705
706        let (without, with) = estimate_memory_savings(num_layers, 1024, 128, 1, checkpoints);
707
708        // Memory reduction factor should be approximately sqrt(N)
709        let ratio = without as f64 / with as f64;
710        assert!(ratio > 4.0); // Should be significant savings
711    }
712
713    #[test]
714    fn test_checkpoint_preserves_computation() {
715        let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
716
717        // Without checkpoint
718        let direct = scale(&input, 2.5);
719
720        // With checkpoint
721        let checkpointed = checkpoint(|x| scale(x, 2.5), &input);
722
723        // Results should be identical
724        for i in 0..4 {
725            assert_eq!(direct.data()[i], checkpointed.data()[i]);
726        }
727    }
728
729    #[test]
730    fn test_nested_checkpoints() {
731        let input = Tensor::from_vec(vec![1.0, 2.0], true);
732
733        let output = checkpoint(
734            |x| {
735                let h1 = scale(x, 2.0);
736                checkpoint(|y| scale(y, 3.0), &h1)
737            },
738            &input,
739        );
740
741        // 1.0 * 2.0 * 3.0 = 6.0
742        assert_eq!(output.data()[0], 6.0);
743    }
744
745    #[test]
746    fn test_checkpoint_manager_memory_tracking() {
747        let config = CheckpointConfig::enabled(2);
748        let mut manager = CheckpointManager::new(config);
749
750        for i in 0..4 {
751            manager.register_segment(Tensor::from_vec(vec![i as f32], true));
752        }
753
754        // Should have tracked some memory savings
755        assert!(manager.memory_saved_segments() > 0);
756    }
757
758    // --- Policy tests (GH-83) ---
759
760    fn make_op(op_type: OpType, bytes: usize) -> OperationInfo {
761        OperationInfo::new(op_type, bytes)
762    }
763
764    #[test]
765    fn test_operation_info_builder() {
766        let info =
767            OperationInfo::new(OpType::Matmul, 1024).with_batch_dim(true).with_layer_index(5);
768        assert_eq!(info.op_type, OpType::Matmul);
769        assert_eq!(info.output_bytes, 1024);
770        assert!(info.has_batch_dim);
771        assert_eq!(info.layer_index, 5);
772    }
773
774    #[test]
775    fn test_save_all_policy() {
776        let policy = SaveAll;
777        assert!(policy.should_save(&make_op(OpType::Add, 100)));
778        assert!(policy.should_save(&make_op(OpType::Matmul, 10000)));
779        assert!(policy.should_save(&make_op(OpType::Relu, 50)));
780    }
781
782    #[test]
783    fn test_save_nothing_policy() {
784        let policy = SaveNothing;
785        assert!(!policy.should_save(&make_op(OpType::Add, 100)));
786        assert!(!policy.should_save(&make_op(OpType::Matmul, 10000)));
787        assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
788    }
789
790    #[test]
791    fn test_save_matmuls_policy() {
792        let policy = SaveMatmuls;
793        assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
794        assert!(policy.should_save(&make_op(OpType::Attention, 2000)));
795        assert!(!policy.should_save(&make_op(OpType::Add, 100)));
796        assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
797        assert!(!policy.should_save(&make_op(OpType::Softmax, 100)));
798    }
799
800    #[test]
801    fn test_save_matmuls_recompute_cost() {
802        let policy = SaveMatmuls;
803        assert!((policy.recompute_cost(&make_op(OpType::Matmul, 0)) - 100.0).abs() < f64::EPSILON);
804        assert!(
805            (policy.recompute_cost(&make_op(OpType::Attention, 0)) - 150.0).abs() < f64::EPSILON
806        );
807        assert!((policy.recompute_cost(&make_op(OpType::Add, 0)) - 1.0).abs() < f64::EPSILON);
808    }
809
810    #[test]
811    fn test_save_unbatched_matmuls_policy() {
812        let policy = SaveUnbatchedMatmuls;
813
814        // No batch dim -> should save
815        let unbatched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(false);
816        assert!(policy.should_save(&unbatched));
817
818        // With batch dim -> should not save
819        let batched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(true);
820        assert!(!policy.should_save(&batched));
821
822        // Non-matmul -> should not save
823        let add = OperationInfo::new(OpType::Add, 100).with_batch_dim(false);
824        assert!(!policy.should_save(&add));
825    }
826
827    #[test]
828    fn test_binomial_checkpointing_indices() {
829        let policy = BinomialCheckpointing { num_layers: 16 };
830        let indices = policy.checkpoint_indices();
831
832        // sqrt(16) = 4 checkpoints, interval = 16/4 = 4
833        assert_eq!(indices, vec![0, 4, 8, 12]);
834    }
835
836    #[test]
837    fn test_binomial_checkpointing_policy() {
838        let policy = BinomialCheckpointing { num_layers: 16 };
839
840        let at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(0);
841        assert!(policy.should_save(&at_checkpoint));
842
843        let not_at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(1);
844        assert!(!policy.should_save(&not_at_checkpoint));
845
846        let at_checkpoint_4 = OperationInfo::new(OpType::Add, 100).with_layer_index(4);
847        assert!(policy.should_save(&at_checkpoint_4));
848    }
849
850    #[test]
851    fn test_memory_budget_policy() {
852        let policy = MemoryBudget::new(500);
853
854        // First op fits
855        let op1 = make_op(OpType::Matmul, 200);
856        assert!(policy.should_save(&op1));
857        assert_eq!(policy.used_bytes(), 200);
858
859        // Second op fits
860        let op2 = make_op(OpType::Add, 200);
861        assert!(policy.should_save(&op2));
862        assert_eq!(policy.used_bytes(), 400);
863
864        // Third op doesn't fit
865        let op3 = make_op(OpType::Relu, 200);
866        assert!(!policy.should_save(&op3));
867        assert_eq!(policy.used_bytes(), 400);
868
869        // Reset and try again
870        policy.reset();
871        assert_eq!(policy.used_bytes(), 0);
872        assert!(policy.should_save(&op3));
873    }
874
875    #[test]
876    fn test_custom_policy() {
877        // Only save ops with output > 500 bytes
878        let policy = CustomPolicy::new(|op: &OperationInfo| op.output_bytes > 500);
879
880        assert!(!policy.should_save(&make_op(OpType::Add, 100)));
881        assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
882        assert!(!policy.should_save(&make_op(OpType::Relu, 500)));
883        assert!(policy.should_save(&make_op(OpType::Softmax, 501)));
884    }
885
886    #[test]
887    fn test_policy_checkpoint_manager_basic() {
888        let mut manager = PolicyCheckpointManager::new(4);
889        let policy = SaveAll;
890
891        let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
892        let info = make_op(OpType::Matmul, 12);
893
894        manager.record(0, &tensor, &info, &policy);
895        assert!(manager.is_saved(0));
896        assert!(!manager.is_saved(1));
897        assert_eq!(manager.num_saved(), 1);
898        assert_eq!(manager.total_bytes(), 12);
899
900        // Retrieve saved activation
901        let saved = manager.get(0).expect("key should exist");
902        assert_eq!(saved.len(), 3);
903    }
904
905    #[test]
906    fn test_policy_checkpoint_manager_selective() {
907        let mut manager = PolicyCheckpointManager::new(4);
908        let policy = SaveMatmuls;
909
910        let t1 = Tensor::from_vec(vec![1.0], true);
911        let t2 = Tensor::from_vec(vec![2.0], true);
912
913        // Matmul -> saved
914        manager.record(0, &t1, &make_op(OpType::Matmul, 4), &policy);
915        // Add -> not saved
916        manager.record(1, &t2, &make_op(OpType::Add, 4), &policy);
917
918        assert!(manager.is_saved(0));
919        assert!(!manager.is_saved(1));
920        assert_eq!(manager.num_saved(), 1);
921    }
922
923    #[test]
924    fn test_policy_checkpoint_manager_clear() {
925        let mut manager = PolicyCheckpointManager::new(2);
926        let policy = SaveAll;
927
928        let t = Tensor::from_vec(vec![1.0], true);
929        manager.record(0, &t, &make_op(OpType::Add, 4), &policy);
930
931        manager.clear();
932        assert_eq!(manager.num_saved(), 0);
933        assert_eq!(manager.total_bytes(), 0);
934        assert!(!manager.is_saved(0));
935    }
936
937    #[test]
938    fn test_policy_checkpoint_manager_out_of_bounds() {
939        let mut manager = PolicyCheckpointManager::new(2);
940        let policy = SaveAll;
941
942        let t = Tensor::from_vec(vec![1.0], true);
943        // Layer index beyond capacity — should be a no-op
944        manager.record(5, &t, &make_op(OpType::Add, 4), &policy);
945        assert_eq!(manager.num_saved(), 0);
946    }
947
948    #[test]
949    fn test_estimate_policy_tradeoff_save_all() {
950        let policy = SaveAll;
951        let infos = vec![
952            make_op(OpType::Matmul, 1000),
953            make_op(OpType::Add, 200),
954            make_op(OpType::Relu, 200),
955        ];
956
957        let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
958        assert_eq!(saved, 0); // Everything saved
959        assert_eq!(used, 1400);
960        assert!((overhead - 0.0).abs() < f64::EPSILON);
961    }
962
963    #[test]
964    fn test_estimate_policy_tradeoff_save_nothing() {
965        let policy = SaveNothing;
966        let infos = vec![make_op(OpType::Matmul, 1000), make_op(OpType::Add, 200)];
967
968        let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
969        assert_eq!(saved, 1200); // Nothing saved
970        assert_eq!(used, 0);
971        assert!(overhead > 0.0); // Must recompute everything
972    }
973
974    #[test]
975    fn test_estimate_policy_tradeoff_save_matmuls() {
976        let policy = SaveMatmuls;
977        let infos = vec![
978            make_op(OpType::Matmul, 1000),
979            make_op(OpType::Add, 200),
980            make_op(OpType::Relu, 200),
981        ];
982
983        let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
984        assert_eq!(used, 1000); // Only matmul saved
985        assert_eq!(saved, 400); // Add + Relu not saved
986        assert!(overhead > 0.0); // Recompute cost for add + relu
987    }
988
989    #[test]
990    fn test_policy_checkpoint_manager_num_layers() {
991        let manager = PolicyCheckpointManager::new(8);
992        assert_eq!(manager.num_layers(), 8);
993    }
994
995    #[test]
996    fn test_binomial_single_layer() {
997        let policy = BinomialCheckpointing { num_layers: 1 };
998        let indices = policy.checkpoint_indices();
999        assert_eq!(indices, vec![0]);
1000    }
1001
1002    #[test]
1003    fn test_default_recompute_cost() {
1004        let policy = SaveAll;
1005        let info = make_op(OpType::Add, 100);
1006        assert!((policy.recompute_cost(&info) - 1.0).abs() < f64::EPSILON);
1007    }
1008}