ghostflow_nn/
gradient_checkpointing.rs

1//! Gradient Checkpointing
2//!
3//! Implements gradient checkpointing to reduce memory usage during training:
4//! - Selective activation storage
5//! - Recomputation during backward pass
6//! - Memory-efficient training for large models
7//! - Configurable checkpoint intervals
8
9use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12/// Checkpoint strategy
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum CheckpointStrategy {
15    /// Checkpoint every N layers
16    EveryN(usize),
17    /// Checkpoint at specific layer indices
18    Selective,
19    /// Checkpoint all layers (maximum memory savings)
20    All,
21    /// No checkpointing (maximum speed)
22    None,
23}
24
25/// Gradient checkpointing configuration
26#[derive(Debug, Clone)]
27pub struct CheckpointConfig {
28    /// Checkpoint strategy
29    pub strategy: CheckpointStrategy,
30    /// Specific layers to checkpoint (for Selective strategy)
31    pub checkpoint_layers: Vec<usize>,
32    /// Enable CPU offloading for checkpoints
33    pub cpu_offload: bool,
34}
35
36impl Default for CheckpointConfig {
37    fn default() -> Self {
38        CheckpointConfig {
39            strategy: CheckpointStrategy::EveryN(2),
40            checkpoint_layers: Vec::new(),
41            cpu_offload: false,
42        }
43    }
44}
45
46impl CheckpointConfig {
47    /// Checkpoint every N layers
48    pub fn every_n(n: usize) -> Self {
49        CheckpointConfig {
50            strategy: CheckpointStrategy::EveryN(n),
51            ..Default::default()
52        }
53    }
54    
55    /// Checkpoint specific layers
56    pub fn selective(layers: Vec<usize>) -> Self {
57        CheckpointConfig {
58            strategy: CheckpointStrategy::Selective,
59            checkpoint_layers: layers,
60            ..Default::default()
61        }
62    }
63    
64    /// Checkpoint all layers
65    pub fn all() -> Self {
66        CheckpointConfig {
67            strategy: CheckpointStrategy::All,
68            ..Default::default()
69        }
70    }
71}
72
73/// Checkpoint manager
74pub struct CheckpointManager {
75    config: CheckpointConfig,
76    checkpoints: HashMap<usize, Tensor>,
77    recompute_count: usize,
78    memory_saved: usize,
79}
80
81impl CheckpointManager {
82    /// Create new checkpoint manager
83    pub fn new(config: CheckpointConfig) -> Self {
84        CheckpointManager {
85            config,
86            checkpoints: HashMap::new(),
87            recompute_count: 0,
88            memory_saved: 0,
89        }
90    }
91    
92    /// Check if layer should be checkpointed
93    pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
94        match self.config.strategy {
95            CheckpointStrategy::None => false,
96            CheckpointStrategy::All => true,
97            CheckpointStrategy::EveryN(n) => layer_idx % n == 0,
98            CheckpointStrategy::Selective => self.config.checkpoint_layers.contains(&layer_idx),
99        }
100    }
101    
102    /// Save checkpoint
103    pub fn save_checkpoint(&mut self, layer_idx: usize, activation: Tensor) {
104        if self.should_checkpoint(layer_idx) {
105            let memory_size = activation.data_f32().len() * 4; // 4 bytes per f32
106            self.memory_saved += memory_size;
107            self.checkpoints.insert(layer_idx, activation);
108        }
109    }
110    
111    /// Get checkpoint
112    pub fn get_checkpoint(&mut self, layer_idx: usize) -> Option<&Tensor> {
113        self.checkpoints.get(&layer_idx)
114    }
115    
116    /// Recompute activation (called during backward pass)
117    pub fn recompute<F>(&mut self, layer_idx: usize, recompute_fn: F) -> Tensor
118    where
119        F: FnOnce() -> Tensor,
120    {
121        self.recompute_count += 1;
122        recompute_fn()
123    }
124    
125    /// Clear checkpoints
126    pub fn clear(&mut self) {
127        self.checkpoints.clear();
128    }
129    
130    /// Get statistics
131    pub fn get_stats(&self) -> CheckpointStats {
132        CheckpointStats {
133            num_checkpoints: self.checkpoints.len(),
134            recompute_count: self.recompute_count,
135            memory_saved_bytes: self.memory_saved,
136        }
137    }
138}
139
140/// Checkpoint statistics
141#[derive(Debug, Clone)]
142pub struct CheckpointStats {
143    /// Number of active checkpoints
144    pub num_checkpoints: usize,
145    /// Number of recomputations
146    pub recompute_count: usize,
147    /// Estimated memory saved (bytes)
148    pub memory_saved_bytes: usize,
149}
150
151/// Checkpointed layer wrapper
152pub struct CheckpointedLayer<F>
153where
154    F: Fn(&Tensor) -> Tensor,
155{
156    forward_fn: F,
157    layer_idx: usize,
158    manager: CheckpointManager,
159}
160
161impl<F> CheckpointedLayer<F>
162where
163    F: Fn(&Tensor) -> Tensor,
164{
165    /// Create new checkpointed layer
166    pub fn new(forward_fn: F, layer_idx: usize, config: CheckpointConfig) -> Self {
167        CheckpointedLayer {
168            forward_fn,
169            layer_idx,
170            manager: CheckpointManager::new(config),
171        }
172    }
173    
174    /// Forward pass with checkpointing
175    pub fn forward(&mut self, input: &Tensor) -> Tensor {
176        let output = (self.forward_fn)(input);
177        
178        // Save checkpoint if needed
179        if self.manager.should_checkpoint(self.layer_idx) {
180            self.manager.save_checkpoint(self.layer_idx, input.clone());
181        }
182        
183        output
184    }
185    
186    /// Backward pass with recomputation
187    pub fn backward(&mut self, grad_output: &Tensor) -> Tensor {
188        // Check if we have a checkpoint
189        if let Some(checkpoint) = self.manager.get_checkpoint(self.layer_idx) {
190            // Recompute forward pass
191            let _recomputed = (self.forward_fn)(checkpoint);
192            // In a real implementation, we'd compute gradients here
193            grad_output.clone()
194        } else {
195            // No checkpoint, assume we have the activation
196            grad_output.clone()
197        }
198    }
199    
200    /// Get statistics
201    pub fn get_stats(&self) -> CheckpointStats {
202        self.manager.get_stats()
203    }
204}
205
206/// Sequential model with gradient checkpointing
207pub struct CheckpointedSequential {
208    layers: Vec<Box<dyn Fn(&Tensor) -> Tensor>>,
209    manager: CheckpointManager,
210}
211
212impl CheckpointedSequential {
213    /// Create new checkpointed sequential model
214    pub fn new(config: CheckpointConfig) -> Self {
215        CheckpointedSequential {
216            layers: Vec::new(),
217            manager: CheckpointManager::new(config),
218        }
219    }
220    
221    /// Add layer
222    pub fn add_layer<F>(&mut self, layer: F)
223    where
224        F: Fn(&Tensor) -> Tensor + 'static,
225    {
226        self.layers.push(Box::new(layer));
227    }
228    
229    /// Forward pass with checkpointing
230    pub fn forward(&mut self, input: &Tensor) -> Tensor {
231        let mut x = input.clone();
232        
233        for (idx, layer) in self.layers.iter().enumerate() {
234            // Save checkpoint if needed
235            if self.manager.should_checkpoint(idx) {
236                self.manager.save_checkpoint(idx, x.clone());
237            }
238            
239            // Forward through layer
240            x = layer(&x);
241        }
242        
243        x
244    }
245    
246    /// Get statistics
247    pub fn get_stats(&self) -> CheckpointStats {
248        self.manager.get_stats()
249    }
250    
251    /// Clear checkpoints
252    pub fn clear_checkpoints(&mut self) {
253        self.manager.clear();
254    }
255}
256
257/// Utility function to estimate memory savings
258pub fn estimate_memory_savings(
259    num_layers: usize,
260    activation_size_mb: f32,
261    strategy: CheckpointStrategy,
262) -> f32 {
263    let checkpointed_layers = match strategy {
264        CheckpointStrategy::None => 0,
265        CheckpointStrategy::All => num_layers,
266        CheckpointStrategy::EveryN(n) => num_layers / n,
267        CheckpointStrategy::Selective => 0, // Can't estimate without knowing which layers
268    };
269    
270    let saved_memory = (num_layers - checkpointed_layers) as f32 * activation_size_mb;
271    saved_memory
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    
278    #[test]
279    fn test_checkpoint_strategy() {
280        let config = CheckpointConfig::every_n(2);
281        let manager = CheckpointManager::new(config);
282        
283        assert!(manager.should_checkpoint(0));
284        assert!(!manager.should_checkpoint(1));
285        assert!(manager.should_checkpoint(2));
286        assert!(!manager.should_checkpoint(3));
287    }
288    
289    #[test]
290    fn test_selective_checkpointing() {
291        let config = CheckpointConfig::selective(vec![1, 3, 5]);
292        let manager = CheckpointManager::new(config);
293        
294        assert!(!manager.should_checkpoint(0));
295        assert!(manager.should_checkpoint(1));
296        assert!(!manager.should_checkpoint(2));
297        assert!(manager.should_checkpoint(3));
298    }
299    
300    #[test]
301    fn test_checkpoint_save_and_get() {
302        let config = CheckpointConfig::all();
303        let mut manager = CheckpointManager::new(config);
304        
305        let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
306        manager.save_checkpoint(0, tensor.clone());
307        
308        let retrieved = manager.get_checkpoint(0).unwrap();
309        assert_eq!(retrieved.data_f32(), tensor.data_f32());
310    }
311    
312    #[test]
313    fn test_checkpointed_layer() {
314        let forward_fn = |x: &Tensor| {
315            x.mul_scalar(2.0)
316        };
317        
318        let config = CheckpointConfig::all();
319        let mut layer = CheckpointedLayer::new(forward_fn, 0, config);
320        
321        let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
322        let output = layer.forward(&input);
323        
324        let output_data = output.data_f32();
325        assert_eq!(output_data[0], 2.0);
326        assert_eq!(output_data[1], 4.0);
327        assert_eq!(output_data[2], 6.0);
328        
329        let stats = layer.get_stats();
330        assert_eq!(stats.num_checkpoints, 1);
331    }
332    
333    #[test]
334    fn test_checkpointed_sequential() {
335        let config = CheckpointConfig::every_n(1);
336        let mut model = CheckpointedSequential::new(config);
337        
338        // Add layers
339        model.add_layer(|x: &Tensor| x.mul_scalar(2.0));
340        model.add_layer(|x: &Tensor| x.add_scalar(1.0));
341        model.add_layer(|x: &Tensor| x.mul_scalar(0.5));
342        
343        let input = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
344        let output = model.forward(&input);
345        
346        // (x * 2 + 1) * 0.5 = (1*2+1)*0.5 = 1.5, (2*2+1)*0.5 = 2.5
347        let output_data = output.data_f32();
348        assert!((output_data[0] - 1.5).abs() < 1e-5);
349        assert!((output_data[1] - 2.5).abs() < 1e-5);
350        
351        let stats = model.get_stats();
352        assert!(stats.num_checkpoints > 0);
353    }
354    
355    #[test]
356    fn test_memory_savings_estimation() {
357        let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::EveryN(2));
358        assert_eq!(savings, 500.0); // 5 layers saved * 100 MB
359        
360        let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::All);
361        assert_eq!(savings, 0.0); // All checkpointed, no savings
362        
363        let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::None);
364        assert_eq!(savings, 1000.0); // All saved
365    }
366    
367    #[test]
368    fn test_checkpoint_clear() {
369        let config = CheckpointConfig::all();
370        let mut manager = CheckpointManager::new(config);
371        
372        let tensor = Tensor::from_slice(&[1.0f32], &[1]).unwrap();
373        manager.save_checkpoint(0, tensor);
374        
375        assert_eq!(manager.checkpoints.len(), 1);
376        
377        manager.clear();
378        assert_eq!(manager.checkpoints.len(), 0);
379    }
380    
381    #[test]
382    fn test_recompute_tracking() {
383        let config = CheckpointConfig::all();
384        let mut manager = CheckpointManager::new(config);
385        
386        let initial_count = manager.recompute_count;
387        
388        manager.recompute(0, || Tensor::from_slice(&[1.0f32], &[1]).unwrap());
389        
390        assert_eq!(manager.recompute_count, initial_count + 1);
391    }
392}