Skip to main content

kizzasi_model/
gradient_checkpoint.rs

1//! Gradient Checkpointing for Memory-Efficient Training
2//!
3//! Implements activation checkpointing (also known as gradient checkpointing)
4//! to trade compute for memory during training. Instead of storing all
5//! intermediate activations for the backward pass, only activations at
6//! designated checkpoint boundaries are retained. During backpropagation,
7//! discarded activations are recomputed from the nearest checkpoint.
8//!
9//! This technique, introduced by Chen et al. (2016), can reduce memory usage
10//! from O(N) to O(sqrt(N)) for N-layer networks at the cost of one
11//! additional forward pass.
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use kizzasi_model::gradient_checkpoint::{ActivationCheckpointer, CheckpointConfig};
17//! use scirs2_core::ndarray::Array1;
18//!
19//! let config = CheckpointConfig {
20//!     checkpoint_every_n_layers: 3,
21//!     max_checkpoints: 10,
22//!     use_mixed_precision: false,
23//! };
24//!
25//! let mut checkpointer = ActivationCheckpointer::new(config);
26//!
27//! let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
28//! let layers = vec![0, 1, 2, 3, 4, 5];
29//!
30//! let output = checkpointer.checkpointed_forward(
31//!     &input,
32//!     &layers,
33//!     |activation, layer_idx| {
34//!         // Your layer forward function
35//!         Ok(activation.mapv(|x| x * 1.1 + layer_idx as f32 * 0.01))
36//!     },
37//! )?;
38//! ```
39
40use crate::error::{ModelError, ModelResult};
41use scirs2_core::ndarray::Array1;
42
43// ---------------------------------------------------------------------------
44// CheckpointConfig
45// ---------------------------------------------------------------------------
46
47/// Configuration for activation checkpointing.
48#[derive(Debug, Clone)]
49pub struct CheckpointConfig {
50    /// Save an activation checkpoint every N layers.
51    ///
52    /// For example, `3` means layers 0, 3, 6, 9, ... will have their
53    /// activations stored.
54    pub checkpoint_every_n_layers: usize,
55
56    /// Maximum number of checkpoint slots.
57    ///
58    /// Once this limit is reached, older checkpoints may be evicted.
59    pub max_checkpoints: usize,
60
61    /// Whether to use reduced precision (f16-like truncation) for stored
62    /// activations to further save memory.
63    ///
64    /// When enabled, stored activations are quantised to half-precision
65    /// (simulated by rounding to 3 decimal places) and restored on retrieval.
66    pub use_mixed_precision: bool,
67}
68
69impl Default for CheckpointConfig {
70    fn default() -> Self {
71        Self {
72            checkpoint_every_n_layers: 4,
73            max_checkpoints: 64,
74            use_mixed_precision: false,
75        }
76    }
77}
78
79// ---------------------------------------------------------------------------
80// ActivationCheckpointer
81// ---------------------------------------------------------------------------
82
83/// Manages activation checkpoints for a multi-layer forward pass.
84///
85/// Stores activations at designated layer boundaries so they can be
86/// used during backpropagation without keeping every intermediate result
87/// in memory.
88#[derive(Debug, Clone)]
89pub struct ActivationCheckpointer {
90    config: CheckpointConfig,
91    /// Stored activations indexed by layer. `None` means the activation
92    /// for that layer was not checkpointed.
93    checkpoints: Vec<Option<Array1<f32>>>,
94    /// Tracks total bytes of activations that were *not* stored due to
95    /// the checkpointing policy — i.e., the memory that was saved.
96    bytes_saved: usize,
97    /// Tracks total bytes of activations that *are* stored.
98    bytes_stored: usize,
99}
100
101impl ActivationCheckpointer {
102    /// Create a new checkpointer with the given configuration.
103    pub fn new(config: CheckpointConfig) -> Self {
104        Self {
105            config,
106            checkpoints: Vec::new(),
107            bytes_saved: 0,
108            bytes_stored: 0,
109        }
110    }
111
112    /// Save an activation at the given layer index.
113    ///
114    /// If `use_mixed_precision` is enabled, the activation is quantised
115    /// before storage.
116    ///
117    /// # Errors
118    ///
119    /// Returns an error if `max_checkpoints` would be exceeded and the
120    /// layer is not a checkpoint boundary.
121    pub fn save(&mut self, layer_idx: usize, activation: Array1<f32>) -> ModelResult<()> {
122        // Ensure the checkpoints vector is large enough.
123        if layer_idx >= self.checkpoints.len() {
124            self.checkpoints.resize(layer_idx + 1, None);
125        }
126
127        // Check checkpoint capacity.
128        let current_count = self.num_checkpoints();
129        if current_count >= self.config.max_checkpoints && self.checkpoints[layer_idx].is_none() {
130            return Err(ModelError::invalid_config(format!(
131                "Maximum checkpoint count ({}) exceeded when saving layer {}",
132                self.config.max_checkpoints, layer_idx
133            )));
134        }
135
136        let byte_size = activation.len() * std::mem::size_of::<f32>();
137
138        let stored = if self.config.use_mixed_precision {
139            // Simulate half-precision by rounding to 3 decimal places.
140            // This halves effective precision while keeping f32 storage format.
141            activation.mapv(|x| (x * 1000.0).round() / 1000.0)
142        } else {
143            activation
144        };
145
146        self.bytes_stored += byte_size;
147        self.checkpoints[layer_idx] = Some(stored);
148
149        Ok(())
150    }
151
152    /// Retrieve the checkpointed activation at the given layer.
153    ///
154    /// # Errors
155    ///
156    /// Returns an error if no checkpoint exists for `layer_idx`.
157    pub fn get(&self, layer_idx: usize) -> ModelResult<&Array1<f32>> {
158        if layer_idx >= self.checkpoints.len() {
159            return Err(ModelError::IndexOutOfBounds {
160                index: layer_idx,
161                limit: self.checkpoints.len(),
162                context: "ActivationCheckpointer::get".to_string(),
163            });
164        }
165
166        self.checkpoints[layer_idx].as_ref().ok_or_else(|| {
167            ModelError::not_initialized(format!("No checkpoint stored for layer {}", layer_idx))
168        })
169    }
170
171    /// Clear all stored checkpoints and reset memory accounting.
172    pub fn clear(&mut self) {
173        self.checkpoints.clear();
174        self.bytes_saved = 0;
175        self.bytes_stored = 0;
176    }
177
178    /// Estimated bytes of memory saved by not storing non-checkpointed
179    /// activations.
180    ///
181    /// This value is updated during `checkpointed_forward` calls.
182    pub fn memory_saved_bytes(&self) -> usize {
183        self.bytes_saved
184    }
185
186    /// Bytes currently stored in checkpoints.
187    pub fn memory_stored_bytes(&self) -> usize {
188        self.bytes_stored
189    }
190
191    /// Number of non-`None` checkpoints currently held.
192    pub fn num_checkpoints(&self) -> usize {
193        self.checkpoints.iter().filter(|c| c.is_some()).count()
194    }
195
196    /// Whether a given layer index is a checkpoint boundary according to
197    /// the current configuration.
198    pub fn is_checkpoint_layer(&self, layer_idx: usize) -> bool {
199        if self.config.checkpoint_every_n_layers == 0 {
200            return false;
201        }
202        layer_idx.is_multiple_of(self.config.checkpoint_every_n_layers)
203    }
204
205    /// Run a checkpointed forward pass through the given layers.
206    ///
207    /// The `forward_fn` is called sequentially for each layer in `layers`,
208    /// receiving the current activation and the layer index. Activations
209    /// at checkpoint boundaries are saved; others are discarded (their
210    /// memory cost is recorded in `bytes_saved`).
211    ///
212    /// # Parameters
213    ///
214    /// - `input`: the initial activation fed into the first layer.
215    /// - `layers`: ordered list of layer indices to process.
216    /// - `forward_fn`: `Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>`;
217    ///   applies one layer's computation.
218    ///
219    /// # Returns
220    ///
221    /// The activation after all layers have been applied.
222    pub fn checkpointed_forward<F>(
223        &mut self,
224        input: &Array1<f32>,
225        layers: &[usize],
226        forward_fn: F,
227    ) -> ModelResult<Array1<f32>>
228    where
229        F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
230    {
231        let mut current = input.clone();
232
233        for &layer_idx in layers {
234            current = forward_fn(&current, layer_idx)?;
235
236            let byte_size = current.len() * std::mem::size_of::<f32>();
237
238            if self.is_checkpoint_layer(layer_idx) {
239                // Save checkpoint (respects max_checkpoints internally).
240                if self.num_checkpoints() < self.config.max_checkpoints {
241                    self.save(layer_idx, current.clone())?;
242                } else {
243                    // Cannot save more — count as saved memory.
244                    self.bytes_saved += byte_size;
245                }
246            } else {
247                // Not a checkpoint layer — activation is discarded.
248                self.bytes_saved += byte_size;
249            }
250        }
251
252        Ok(current)
253    }
254
255    /// Recompute activations from the nearest checkpoint up to `target_layer`.
256    ///
257    /// This is used during the backward pass: find the closest checkpoint
258    /// before `target_layer`, then replay the forward function from there.
259    ///
260    /// # Parameters
261    ///
262    /// - `target_layer`: the layer whose activation is needed.
263    /// - `layers`: the full ordered list of layer indices.
264    /// - `forward_fn`: the same forward function used during the forward pass.
265    ///
266    /// # Returns
267    ///
268    /// The recomputed activation at `target_layer`.
269    pub fn recompute_from_checkpoint<F>(
270        &self,
271        target_layer: usize,
272        layers: &[usize],
273        forward_fn: F,
274    ) -> ModelResult<Array1<f32>>
275    where
276        F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
277    {
278        // Find the nearest checkpoint at or before target_layer.
279        let mut nearest_checkpoint_layer = None;
280        let mut nearest_activation = None;
281
282        for &l in layers.iter().rev() {
283            if l > target_layer {
284                continue;
285            }
286            if l < self.checkpoints.len() {
287                if let Some(ref act) = self.checkpoints[l] {
288                    nearest_checkpoint_layer = Some(l);
289                    nearest_activation = Some(act.clone());
290                    break;
291                }
292            }
293        }
294
295        let (start_layer, mut current) = match (nearest_checkpoint_layer, nearest_activation) {
296            (Some(l), Some(act)) => (l, act),
297            _ => {
298                return Err(ModelError::not_initialized(format!(
299                    "No checkpoint found before layer {} for recomputation",
300                    target_layer
301                )));
302            }
303        };
304
305        // Replay forward from the checkpoint layer to the target.
306        let mut started = false;
307        for &l in layers {
308            if l == start_layer {
309                started = true;
310                continue; // Skip the checkpoint layer itself — we already have its activation.
311            }
312            if !started {
313                continue;
314            }
315            current = forward_fn(&current, l)?;
316            if l == target_layer {
317                break;
318            }
319        }
320
321        Ok(current)
322    }
323
324    /// Return the configuration.
325    pub fn config(&self) -> &CheckpointConfig {
326        &self.config
327    }
328}
329
330// ---------------------------------------------------------------------------
331// Tests
332// ---------------------------------------------------------------------------
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337    use scirs2_core::ndarray::Array1;
338
339    /// Simple forward function: multiplies each element by 1.1 and adds
340    /// a layer-dependent offset.
341    fn simple_forward(activation: &Array1<f32>, layer_idx: usize) -> ModelResult<Array1<f32>> {
342        Ok(activation.mapv(|x| x * 1.1 + layer_idx as f32 * 0.01))
343    }
344
345    // 6. Save and get
346    #[test]
347    fn test_gradient_checkpoint_save_get() {
348        let config = CheckpointConfig {
349            checkpoint_every_n_layers: 2,
350            max_checkpoints: 10,
351            use_mixed_precision: false,
352        };
353        let mut cp = ActivationCheckpointer::new(config);
354
355        let activation = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
356
357        cp.save(2, activation.clone()).expect("save should succeed");
358
359        let retrieved = cp.get(2).expect("get should succeed");
360        assert_eq!(retrieved.len(), 4);
361        assert!((retrieved[0] - 1.0).abs() < 1e-6);
362        assert!((retrieved[3] - 4.0).abs() < 1e-6);
363
364        // Getting a non-existent layer should fail.
365        assert!(cp.get(5).is_err());
366    }
367
368    // 7. Memory accounting
369    #[test]
370    fn test_gradient_checkpoint_memory_accounting() {
371        let config = CheckpointConfig {
372            checkpoint_every_n_layers: 3,
373            max_checkpoints: 10,
374            use_mixed_precision: false,
375        };
376        let mut cp = ActivationCheckpointer::new(config);
377
378        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
379        let layers: Vec<usize> = (0..6).collect();
380
381        let _output = cp
382            .checkpointed_forward(&input, &layers, simple_forward)
383            .expect("forward should succeed");
384
385        // Checkpoints at layers 0, 3 (divisible by 3)
386        // Non-checkpointed: layers 1, 2, 4, 5
387        assert!(
388            cp.memory_saved_bytes() > 0,
389            "should have saved some memory, got 0"
390        );
391        assert!(
392            cp.memory_stored_bytes() > 0,
393            "should have stored some activations"
394        );
395        assert_eq!(
396            cp.num_checkpoints(),
397            2,
398            "should have 2 checkpoints (layers 0, 3)"
399        );
400    }
401
402    // 8. Clear resets everything
403    #[test]
404    fn test_gradient_checkpoint_clear() {
405        let config = CheckpointConfig {
406            checkpoint_every_n_layers: 2,
407            max_checkpoints: 10,
408            use_mixed_precision: false,
409        };
410        let mut cp = ActivationCheckpointer::new(config);
411
412        cp.save(0, Array1::from_vec(vec![1.0, 2.0]))
413            .expect("save should succeed");
414        cp.save(2, Array1::from_vec(vec![3.0, 4.0]))
415            .expect("save should succeed");
416
417        assert_eq!(cp.num_checkpoints(), 2);
418
419        cp.clear();
420
421        assert_eq!(cp.num_checkpoints(), 0);
422        assert_eq!(cp.memory_saved_bytes(), 0);
423        assert_eq!(cp.memory_stored_bytes(), 0);
424        assert!(cp.get(0).is_err());
425    }
426
427    // 9. Checkpointed forward produces same output as direct sequential forward
428    #[test]
429    fn test_gradient_checkpoint_forward() {
430        let config = CheckpointConfig {
431            checkpoint_every_n_layers: 2,
432            max_checkpoints: 20,
433            use_mixed_precision: false,
434        };
435        let mut cp = ActivationCheckpointer::new(config);
436
437        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
438        let layers: Vec<usize> = (0..8).collect();
439
440        // Checkpointed forward
441        let checkpointed_output = cp
442            .checkpointed_forward(&input, &layers, simple_forward)
443            .expect("checkpointed forward should succeed");
444
445        // Direct sequential forward (no checkpointing)
446        let mut direct = input.clone();
447        for &l in &layers {
448            direct = simple_forward(&direct, l).expect("forward should succeed");
449        }
450
451        // Both should produce the same result
452        assert_eq!(checkpointed_output.len(), direct.len());
453        for (a, b) in checkpointed_output.iter().zip(direct.iter()) {
454            assert!(
455                (a - b).abs() < 1e-4,
456                "mismatch: checkpointed={}, direct={}",
457                a,
458                b
459            );
460        }
461
462        // Verify some checkpoints were actually saved
463        assert!(cp.num_checkpoints() > 0, "should have saved checkpoints");
464    }
465}