1use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum CheckpointStrategy {
15 EveryN(usize),
17 Selective,
19 All,
21 None,
23}
24
25#[derive(Debug, Clone)]
27pub struct CheckpointConfig {
28 pub strategy: CheckpointStrategy,
30 pub checkpoint_layers: Vec<usize>,
32 pub cpu_offload: bool,
34}
35
36impl Default for CheckpointConfig {
37 fn default() -> Self {
38 CheckpointConfig {
39 strategy: CheckpointStrategy::EveryN(2),
40 checkpoint_layers: Vec::new(),
41 cpu_offload: false,
42 }
43 }
44}
45
46impl CheckpointConfig {
47 pub fn every_n(n: usize) -> Self {
49 CheckpointConfig {
50 strategy: CheckpointStrategy::EveryN(n),
51 ..Default::default()
52 }
53 }
54
55 pub fn selective(layers: Vec<usize>) -> Self {
57 CheckpointConfig {
58 strategy: CheckpointStrategy::Selective,
59 checkpoint_layers: layers,
60 ..Default::default()
61 }
62 }
63
64 pub fn all() -> Self {
66 CheckpointConfig {
67 strategy: CheckpointStrategy::All,
68 ..Default::default()
69 }
70 }
71}
72
73pub struct CheckpointManager {
75 config: CheckpointConfig,
76 checkpoints: HashMap<usize, Tensor>,
77 recompute_count: usize,
78 memory_saved: usize,
79}
80
81impl CheckpointManager {
82 pub fn new(config: CheckpointConfig) -> Self {
84 CheckpointManager {
85 config,
86 checkpoints: HashMap::new(),
87 recompute_count: 0,
88 memory_saved: 0,
89 }
90 }
91
92 pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
94 match self.config.strategy {
95 CheckpointStrategy::None => false,
96 CheckpointStrategy::All => true,
97 CheckpointStrategy::EveryN(n) => layer_idx % n == 0,
98 CheckpointStrategy::Selective => self.config.checkpoint_layers.contains(&layer_idx),
99 }
100 }
101
102 pub fn save_checkpoint(&mut self, layer_idx: usize, activation: Tensor) {
104 if self.should_checkpoint(layer_idx) {
105 let memory_size = activation.data_f32().len() * 4; self.memory_saved += memory_size;
107 self.checkpoints.insert(layer_idx, activation);
108 }
109 }
110
111 pub fn get_checkpoint(&mut self, layer_idx: usize) -> Option<&Tensor> {
113 self.checkpoints.get(&layer_idx)
114 }
115
116 pub fn recompute<F>(&mut self, layer_idx: usize, recompute_fn: F) -> Tensor
118 where
119 F: FnOnce() -> Tensor,
120 {
121 self.recompute_count += 1;
122 recompute_fn()
123 }
124
125 pub fn clear(&mut self) {
127 self.checkpoints.clear();
128 }
129
130 pub fn get_stats(&self) -> CheckpointStats {
132 CheckpointStats {
133 num_checkpoints: self.checkpoints.len(),
134 recompute_count: self.recompute_count,
135 memory_saved_bytes: self.memory_saved,
136 }
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct CheckpointStats {
143 pub num_checkpoints: usize,
145 pub recompute_count: usize,
147 pub memory_saved_bytes: usize,
149}
150
151pub struct CheckpointedLayer<F>
153where
154 F: Fn(&Tensor) -> Tensor,
155{
156 forward_fn: F,
157 layer_idx: usize,
158 manager: CheckpointManager,
159}
160
161impl<F> CheckpointedLayer<F>
162where
163 F: Fn(&Tensor) -> Tensor,
164{
165 pub fn new(forward_fn: F, layer_idx: usize, config: CheckpointConfig) -> Self {
167 CheckpointedLayer {
168 forward_fn,
169 layer_idx,
170 manager: CheckpointManager::new(config),
171 }
172 }
173
174 pub fn forward(&mut self, input: &Tensor) -> Tensor {
176 let output = (self.forward_fn)(input);
177
178 if self.manager.should_checkpoint(self.layer_idx) {
180 self.manager.save_checkpoint(self.layer_idx, input.clone());
181 }
182
183 output
184 }
185
186 pub fn backward(&mut self, grad_output: &Tensor) -> Tensor {
188 if let Some(checkpoint) = self.manager.get_checkpoint(self.layer_idx) {
190 let _recomputed = (self.forward_fn)(checkpoint);
192 grad_output.clone()
194 } else {
195 grad_output.clone()
197 }
198 }
199
200 pub fn get_stats(&self) -> CheckpointStats {
202 self.manager.get_stats()
203 }
204}
205
206pub struct CheckpointedSequential {
208 layers: Vec<Box<dyn Fn(&Tensor) -> Tensor>>,
209 manager: CheckpointManager,
210}
211
212impl CheckpointedSequential {
213 pub fn new(config: CheckpointConfig) -> Self {
215 CheckpointedSequential {
216 layers: Vec::new(),
217 manager: CheckpointManager::new(config),
218 }
219 }
220
221 pub fn add_layer<F>(&mut self, layer: F)
223 where
224 F: Fn(&Tensor) -> Tensor + 'static,
225 {
226 self.layers.push(Box::new(layer));
227 }
228
229 pub fn forward(&mut self, input: &Tensor) -> Tensor {
231 let mut x = input.clone();
232
233 for (idx, layer) in self.layers.iter().enumerate() {
234 if self.manager.should_checkpoint(idx) {
236 self.manager.save_checkpoint(idx, x.clone());
237 }
238
239 x = layer(&x);
241 }
242
243 x
244 }
245
246 pub fn get_stats(&self) -> CheckpointStats {
248 self.manager.get_stats()
249 }
250
251 pub fn clear_checkpoints(&mut self) {
253 self.manager.clear();
254 }
255}
256
257pub fn estimate_memory_savings(
259 num_layers: usize,
260 activation_size_mb: f32,
261 strategy: CheckpointStrategy,
262) -> f32 {
263 let checkpointed_layers = match strategy {
264 CheckpointStrategy::None => 0,
265 CheckpointStrategy::All => num_layers,
266 CheckpointStrategy::EveryN(n) => num_layers / n,
267 CheckpointStrategy::Selective => 0, };
269
270 let saved_memory = (num_layers - checkpointed_layers) as f32 * activation_size_mb;
271 saved_memory
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_checkpoint_strategy() {
280 let config = CheckpointConfig::every_n(2);
281 let manager = CheckpointManager::new(config);
282
283 assert!(manager.should_checkpoint(0));
284 assert!(!manager.should_checkpoint(1));
285 assert!(manager.should_checkpoint(2));
286 assert!(!manager.should_checkpoint(3));
287 }
288
289 #[test]
290 fn test_selective_checkpointing() {
291 let config = CheckpointConfig::selective(vec![1, 3, 5]);
292 let manager = CheckpointManager::new(config);
293
294 assert!(!manager.should_checkpoint(0));
295 assert!(manager.should_checkpoint(1));
296 assert!(!manager.should_checkpoint(2));
297 assert!(manager.should_checkpoint(3));
298 }
299
300 #[test]
301 fn test_checkpoint_save_and_get() {
302 let config = CheckpointConfig::all();
303 let mut manager = CheckpointManager::new(config);
304
305 let tensor = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
306 manager.save_checkpoint(0, tensor.clone());
307
308 let retrieved = manager.get_checkpoint(0).unwrap();
309 assert_eq!(retrieved.data_f32(), tensor.data_f32());
310 }
311
312 #[test]
313 fn test_checkpointed_layer() {
314 let forward_fn = |x: &Tensor| {
315 x.mul_scalar(2.0)
316 };
317
318 let config = CheckpointConfig::all();
319 let mut layer = CheckpointedLayer::new(forward_fn, 0, config);
320
321 let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0], &[3]).unwrap();
322 let output = layer.forward(&input);
323
324 let output_data = output.data_f32();
325 assert_eq!(output_data[0], 2.0);
326 assert_eq!(output_data[1], 4.0);
327 assert_eq!(output_data[2], 6.0);
328
329 let stats = layer.get_stats();
330 assert_eq!(stats.num_checkpoints, 1);
331 }
332
333 #[test]
334 fn test_checkpointed_sequential() {
335 let config = CheckpointConfig::every_n(1);
336 let mut model = CheckpointedSequential::new(config);
337
338 model.add_layer(|x: &Tensor| x.mul_scalar(2.0));
340 model.add_layer(|x: &Tensor| x.add_scalar(1.0));
341 model.add_layer(|x: &Tensor| x.mul_scalar(0.5));
342
343 let input = Tensor::from_slice(&[1.0f32, 2.0], &[2]).unwrap();
344 let output = model.forward(&input);
345
346 let output_data = output.data_f32();
348 assert!((output_data[0] - 1.5).abs() < 1e-5);
349 assert!((output_data[1] - 2.5).abs() < 1e-5);
350
351 let stats = model.get_stats();
352 assert!(stats.num_checkpoints > 0);
353 }
354
355 #[test]
356 fn test_memory_savings_estimation() {
357 let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::EveryN(2));
358 assert_eq!(savings, 500.0); let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::All);
361 assert_eq!(savings, 0.0); let savings = estimate_memory_savings(10, 100.0, CheckpointStrategy::None);
364 assert_eq!(savings, 1000.0); }
366
367 #[test]
368 fn test_checkpoint_clear() {
369 let config = CheckpointConfig::all();
370 let mut manager = CheckpointManager::new(config);
371
372 let tensor = Tensor::from_slice(&[1.0f32], &[1]).unwrap();
373 manager.save_checkpoint(0, tensor);
374
375 assert_eq!(manager.checkpoints.len(), 1);
376
377 manager.clear();
378 assert_eq!(manager.checkpoints.len(), 0);
379 }
380
381 #[test]
382 fn test_recompute_tracking() {
383 let config = CheckpointConfig::all();
384 let mut manager = CheckpointManager::new(config);
385
386 let initial_count = manager.recompute_count;
387
388 manager.recompute(0, || Tensor::from_slice(&[1.0f32], &[1]).unwrap());
389
390 assert_eq!(manager.recompute_count, initial_count + 1);
391 }
392}