kizzasi_model/gradient_checkpoint.rs
1//! Gradient Checkpointing for Memory-Efficient Training
2//!
3//! Implements activation checkpointing (also known as gradient checkpointing)
4//! to trade compute for memory during training. Instead of storing all
5//! intermediate activations for the backward pass, only activations at
6//! designated checkpoint boundaries are retained. During backpropagation,
7//! discarded activations are recomputed from the nearest checkpoint.
8//!
9//! This technique, introduced by Chen et al. (2016), can reduce memory usage
10//! from O(N) to O(sqrt(N)) for N-layer networks at the cost of one
11//! additional forward pass.
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use kizzasi_model::gradient_checkpoint::{ActivationCheckpointer, CheckpointConfig};
17//! use scirs2_core::ndarray::Array1;
18//!
19//! let config = CheckpointConfig {
20//! checkpoint_every_n_layers: 3,
21//! max_checkpoints: 10,
22//! use_mixed_precision: false,
23//! };
24//!
25//! let mut checkpointer = ActivationCheckpointer::new(config);
26//!
27//! let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
28//! let layers = vec![0, 1, 2, 3, 4, 5];
29//!
30//! let output = checkpointer.checkpointed_forward(
31//! &input,
32//! &layers,
33//! |activation, layer_idx| {
34//! // Your layer forward function
35//! Ok(activation.mapv(|x| x * 1.1 + layer_idx as f32 * 0.01))
36//! },
37//! )?;
38//! ```
39
40use crate::error::{ModelError, ModelResult};
41use scirs2_core::ndarray::Array1;
42
43// ---------------------------------------------------------------------------
44// CheckpointConfig
45// ---------------------------------------------------------------------------
46
47/// Configuration for activation checkpointing.
48#[derive(Debug, Clone)]
49pub struct CheckpointConfig {
50 /// Save an activation checkpoint every N layers.
51 ///
52 /// For example, `3` means layers 0, 3, 6, 9, ... will have their
53 /// activations stored.
54 pub checkpoint_every_n_layers: usize,
55
56 /// Maximum number of checkpoint slots.
57 ///
58 /// Once this limit is reached, older checkpoints may be evicted.
59 pub max_checkpoints: usize,
60
61 /// Whether to use reduced precision (f16-like truncation) for stored
62 /// activations to further save memory.
63 ///
64 /// When enabled, stored activations are quantised to half-precision
65 /// (simulated by rounding to 3 decimal places) and restored on retrieval.
66 pub use_mixed_precision: bool,
67}
68
69impl Default for CheckpointConfig {
70 fn default() -> Self {
71 Self {
72 checkpoint_every_n_layers: 4,
73 max_checkpoints: 64,
74 use_mixed_precision: false,
75 }
76 }
77}
78
79// ---------------------------------------------------------------------------
80// ActivationCheckpointer
81// ---------------------------------------------------------------------------
82
83/// Manages activation checkpoints for a multi-layer forward pass.
84///
85/// Stores activations at designated layer boundaries so they can be
86/// used during backpropagation without keeping every intermediate result
87/// in memory.
88#[derive(Debug, Clone)]
89pub struct ActivationCheckpointer {
90 config: CheckpointConfig,
91 /// Stored activations indexed by layer. `None` means the activation
92 /// for that layer was not checkpointed.
93 checkpoints: Vec<Option<Array1<f32>>>,
94 /// Tracks total bytes of activations that were *not* stored due to
95 /// the checkpointing policy — i.e., the memory that was saved.
96 bytes_saved: usize,
97 /// Tracks total bytes of activations that *are* stored.
98 bytes_stored: usize,
99}
100
101impl ActivationCheckpointer {
102 /// Create a new checkpointer with the given configuration.
103 pub fn new(config: CheckpointConfig) -> Self {
104 Self {
105 config,
106 checkpoints: Vec::new(),
107 bytes_saved: 0,
108 bytes_stored: 0,
109 }
110 }
111
112 /// Save an activation at the given layer index.
113 ///
114 /// If `use_mixed_precision` is enabled, the activation is quantised
115 /// before storage.
116 ///
117 /// # Errors
118 ///
119 /// Returns an error if `max_checkpoints` would be exceeded and the
120 /// layer is not a checkpoint boundary.
121 pub fn save(&mut self, layer_idx: usize, activation: Array1<f32>) -> ModelResult<()> {
122 // Ensure the checkpoints vector is large enough.
123 if layer_idx >= self.checkpoints.len() {
124 self.checkpoints.resize(layer_idx + 1, None);
125 }
126
127 // Check checkpoint capacity.
128 let current_count = self.num_checkpoints();
129 if current_count >= self.config.max_checkpoints && self.checkpoints[layer_idx].is_none() {
130 return Err(ModelError::invalid_config(format!(
131 "Maximum checkpoint count ({}) exceeded when saving layer {}",
132 self.config.max_checkpoints, layer_idx
133 )));
134 }
135
136 let byte_size = activation.len() * std::mem::size_of::<f32>();
137
138 let stored = if self.config.use_mixed_precision {
139 // Simulate half-precision by rounding to 3 decimal places.
140 // This halves effective precision while keeping f32 storage format.
141 activation.mapv(|x| (x * 1000.0).round() / 1000.0)
142 } else {
143 activation
144 };
145
146 self.bytes_stored += byte_size;
147 self.checkpoints[layer_idx] = Some(stored);
148
149 Ok(())
150 }
151
152 /// Retrieve the checkpointed activation at the given layer.
153 ///
154 /// # Errors
155 ///
156 /// Returns an error if no checkpoint exists for `layer_idx`.
157 pub fn get(&self, layer_idx: usize) -> ModelResult<&Array1<f32>> {
158 if layer_idx >= self.checkpoints.len() {
159 return Err(ModelError::IndexOutOfBounds {
160 index: layer_idx,
161 limit: self.checkpoints.len(),
162 context: "ActivationCheckpointer::get".to_string(),
163 });
164 }
165
166 self.checkpoints[layer_idx].as_ref().ok_or_else(|| {
167 ModelError::not_initialized(format!("No checkpoint stored for layer {}", layer_idx))
168 })
169 }
170
171 /// Clear all stored checkpoints and reset memory accounting.
172 pub fn clear(&mut self) {
173 self.checkpoints.clear();
174 self.bytes_saved = 0;
175 self.bytes_stored = 0;
176 }
177
178 /// Estimated bytes of memory saved by not storing non-checkpointed
179 /// activations.
180 ///
181 /// This value is updated during `checkpointed_forward` calls.
182 pub fn memory_saved_bytes(&self) -> usize {
183 self.bytes_saved
184 }
185
186 /// Bytes currently stored in checkpoints.
187 pub fn memory_stored_bytes(&self) -> usize {
188 self.bytes_stored
189 }
190
191 /// Number of non-`None` checkpoints currently held.
192 pub fn num_checkpoints(&self) -> usize {
193 self.checkpoints.iter().filter(|c| c.is_some()).count()
194 }
195
196 /// Whether a given layer index is a checkpoint boundary according to
197 /// the current configuration.
198 pub fn is_checkpoint_layer(&self, layer_idx: usize) -> bool {
199 if self.config.checkpoint_every_n_layers == 0 {
200 return false;
201 }
202 layer_idx.is_multiple_of(self.config.checkpoint_every_n_layers)
203 }
204
205 /// Run a checkpointed forward pass through the given layers.
206 ///
207 /// The `forward_fn` is called sequentially for each layer in `layers`,
208 /// receiving the current activation and the layer index. Activations
209 /// at checkpoint boundaries are saved; others are discarded (their
210 /// memory cost is recorded in `bytes_saved`).
211 ///
212 /// # Parameters
213 ///
214 /// - `input`: the initial activation fed into the first layer.
215 /// - `layers`: ordered list of layer indices to process.
216 /// - `forward_fn`: `Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>`;
217 /// applies one layer's computation.
218 ///
219 /// # Returns
220 ///
221 /// The activation after all layers have been applied.
222 pub fn checkpointed_forward<F>(
223 &mut self,
224 input: &Array1<f32>,
225 layers: &[usize],
226 forward_fn: F,
227 ) -> ModelResult<Array1<f32>>
228 where
229 F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
230 {
231 let mut current = input.clone();
232
233 for &layer_idx in layers {
234 current = forward_fn(¤t, layer_idx)?;
235
236 let byte_size = current.len() * std::mem::size_of::<f32>();
237
238 if self.is_checkpoint_layer(layer_idx) {
239 // Save checkpoint (respects max_checkpoints internally).
240 if self.num_checkpoints() < self.config.max_checkpoints {
241 self.save(layer_idx, current.clone())?;
242 } else {
243 // Cannot save more — count as saved memory.
244 self.bytes_saved += byte_size;
245 }
246 } else {
247 // Not a checkpoint layer — activation is discarded.
248 self.bytes_saved += byte_size;
249 }
250 }
251
252 Ok(current)
253 }
254
255 /// Recompute activations from the nearest checkpoint up to `target_layer`.
256 ///
257 /// This is used during the backward pass: find the closest checkpoint
258 /// before `target_layer`, then replay the forward function from there.
259 ///
260 /// # Parameters
261 ///
262 /// - `target_layer`: the layer whose activation is needed.
263 /// - `layers`: the full ordered list of layer indices.
264 /// - `forward_fn`: the same forward function used during the forward pass.
265 ///
266 /// # Returns
267 ///
268 /// The recomputed activation at `target_layer`.
269 pub fn recompute_from_checkpoint<F>(
270 &self,
271 target_layer: usize,
272 layers: &[usize],
273 forward_fn: F,
274 ) -> ModelResult<Array1<f32>>
275 where
276 F: Fn(&Array1<f32>, usize) -> ModelResult<Array1<f32>>,
277 {
278 // Find the nearest checkpoint at or before target_layer.
279 let mut nearest_checkpoint_layer = None;
280 let mut nearest_activation = None;
281
282 for &l in layers.iter().rev() {
283 if l > target_layer {
284 continue;
285 }
286 if l < self.checkpoints.len() {
287 if let Some(ref act) = self.checkpoints[l] {
288 nearest_checkpoint_layer = Some(l);
289 nearest_activation = Some(act.clone());
290 break;
291 }
292 }
293 }
294
295 let (start_layer, mut current) = match (nearest_checkpoint_layer, nearest_activation) {
296 (Some(l), Some(act)) => (l, act),
297 _ => {
298 return Err(ModelError::not_initialized(format!(
299 "No checkpoint found before layer {} for recomputation",
300 target_layer
301 )));
302 }
303 };
304
305 // Replay forward from the checkpoint layer to the target.
306 let mut started = false;
307 for &l in layers {
308 if l == start_layer {
309 started = true;
310 continue; // Skip the checkpoint layer itself — we already have its activation.
311 }
312 if !started {
313 continue;
314 }
315 current = forward_fn(¤t, l)?;
316 if l == target_layer {
317 break;
318 }
319 }
320
321 Ok(current)
322 }
323
324 /// Return the configuration.
325 pub fn config(&self) -> &CheckpointConfig {
326 &self.config
327 }
328}
329
330// ---------------------------------------------------------------------------
331// Tests
332// ---------------------------------------------------------------------------
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use scirs2_core::ndarray::Array1;
338
339 /// Simple forward function: multiplies each element by 1.1 and adds
340 /// a layer-dependent offset.
341 fn simple_forward(activation: &Array1<f32>, layer_idx: usize) -> ModelResult<Array1<f32>> {
342 Ok(activation.mapv(|x| x * 1.1 + layer_idx as f32 * 0.01))
343 }
344
345 // 6. Save and get
346 #[test]
347 fn test_gradient_checkpoint_save_get() {
348 let config = CheckpointConfig {
349 checkpoint_every_n_layers: 2,
350 max_checkpoints: 10,
351 use_mixed_precision: false,
352 };
353 let mut cp = ActivationCheckpointer::new(config);
354
355 let activation = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
356
357 cp.save(2, activation.clone()).expect("save should succeed");
358
359 let retrieved = cp.get(2).expect("get should succeed");
360 assert_eq!(retrieved.len(), 4);
361 assert!((retrieved[0] - 1.0).abs() < 1e-6);
362 assert!((retrieved[3] - 4.0).abs() < 1e-6);
363
364 // Getting a non-existent layer should fail.
365 assert!(cp.get(5).is_err());
366 }
367
368 // 7. Memory accounting
369 #[test]
370 fn test_gradient_checkpoint_memory_accounting() {
371 let config = CheckpointConfig {
372 checkpoint_every_n_layers: 3,
373 max_checkpoints: 10,
374 use_mixed_precision: false,
375 };
376 let mut cp = ActivationCheckpointer::new(config);
377
378 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
379 let layers: Vec<usize> = (0..6).collect();
380
381 let _output = cp
382 .checkpointed_forward(&input, &layers, simple_forward)
383 .expect("forward should succeed");
384
385 // Checkpoints at layers 0, 3 (divisible by 3)
386 // Non-checkpointed: layers 1, 2, 4, 5
387 assert!(
388 cp.memory_saved_bytes() > 0,
389 "should have saved some memory, got 0"
390 );
391 assert!(
392 cp.memory_stored_bytes() > 0,
393 "should have stored some activations"
394 );
395 assert_eq!(
396 cp.num_checkpoints(),
397 2,
398 "should have 2 checkpoints (layers 0, 3)"
399 );
400 }
401
402 // 8. Clear resets everything
403 #[test]
404 fn test_gradient_checkpoint_clear() {
405 let config = CheckpointConfig {
406 checkpoint_every_n_layers: 2,
407 max_checkpoints: 10,
408 use_mixed_precision: false,
409 };
410 let mut cp = ActivationCheckpointer::new(config);
411
412 cp.save(0, Array1::from_vec(vec![1.0, 2.0]))
413 .expect("save should succeed");
414 cp.save(2, Array1::from_vec(vec![3.0, 4.0]))
415 .expect("save should succeed");
416
417 assert_eq!(cp.num_checkpoints(), 2);
418
419 cp.clear();
420
421 assert_eq!(cp.num_checkpoints(), 0);
422 assert_eq!(cp.memory_saved_bytes(), 0);
423 assert_eq!(cp.memory_stored_bytes(), 0);
424 assert!(cp.get(0).is_err());
425 }
426
427 // 9. Checkpointed forward produces same output as direct sequential forward
428 #[test]
429 fn test_gradient_checkpoint_forward() {
430 let config = CheckpointConfig {
431 checkpoint_every_n_layers: 2,
432 max_checkpoints: 20,
433 use_mixed_precision: false,
434 };
435 let mut cp = ActivationCheckpointer::new(config);
436
437 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
438 let layers: Vec<usize> = (0..8).collect();
439
440 // Checkpointed forward
441 let checkpointed_output = cp
442 .checkpointed_forward(&input, &layers, simple_forward)
443 .expect("checkpointed forward should succeed");
444
445 // Direct sequential forward (no checkpointing)
446 let mut direct = input.clone();
447 for &l in &layers {
448 direct = simple_forward(&direct, l).expect("forward should succeed");
449 }
450
451 // Both should produce the same result
452 assert_eq!(checkpointed_output.len(), direct.len());
453 for (a, b) in checkpointed_output.iter().zip(direct.iter()) {
454 assert!(
455 (a - b).abs() < 1e-4,
456 "mismatch: checkpointed={}, direct={}",
457 a,
458 b
459 );
460 }
461
462 // Verify some checkpoints were actually saved
463 assert!(cp.num_checkpoints() > 0, "should have saved checkpoints");
464 }
465}