1use crate::autograd::graph_opt::OpType;
26use crate::Tensor;
27use std::cell::RefCell;
28use std::rc::Rc;
29
30#[derive(Debug, Clone)]
32pub struct CheckpointConfig {
33 pub enabled: bool,
35 pub num_segments: usize,
37 pub selective: bool,
39}
40
41impl CheckpointConfig {
42 pub fn enabled(num_segments: usize) -> Self {
44 Self { enabled: true, num_segments, selective: false }
45 }
46
47 pub fn disabled() -> Self {
49 Self { enabled: false, num_segments: 1, selective: false }
50 }
51
52 pub fn with_selective(mut self) -> Self {
54 self.selective = true;
55 self
56 }
57}
58
59impl Default for CheckpointConfig {
60 fn default() -> Self {
61 Self::disabled()
62 }
63}
64
65pub struct CheckpointedSegment {
70 input: Tensor,
72 output: RefCell<Option<Tensor>>,
74 is_checkpointed: bool,
76}
77
78impl CheckpointedSegment {
79 pub fn new(input: Tensor, is_checkpointed: bool) -> Self {
81 Self { input, output: RefCell::new(None), is_checkpointed }
82 }
83
84 pub fn input(&self) -> &Tensor {
86 &self.input
87 }
88
89 pub fn is_checkpointed(&self) -> bool {
91 self.is_checkpointed
92 }
93
94 pub fn set_output(&self, output: Tensor) {
96 *self.output.borrow_mut() = Some(output);
97 }
98
99 pub fn output(&self) -> Option<Tensor> {
101 contract_pre_output!();
102 self.output.borrow().clone()
103 }
104
105 pub fn clear_output(&self) {
107 *self.output.borrow_mut() = None;
108 }
109}
110
111pub struct CheckpointManager {
113 config: CheckpointConfig,
115 segments: Vec<Rc<CheckpointedSegment>>,
117 current_segment: RefCell<usize>,
119 memory_saved: RefCell<usize>,
121}
122
123impl CheckpointManager {
124 pub fn new(config: CheckpointConfig) -> Self {
126 Self {
127 config,
128 segments: Vec::new(),
129 current_segment: RefCell::new(0),
130 memory_saved: RefCell::new(0),
131 }
132 }
133
134 pub fn is_enabled(&self) -> bool {
136 self.config.enabled
137 }
138
139 pub fn num_segments(&self) -> usize {
141 self.config.num_segments
142 }
143
144 pub fn register_segment(&mut self, input: Tensor) -> Rc<CheckpointedSegment> {
146 let idx = self.segments.len();
147 let should_checkpoint = self.config.enabled && self.should_checkpoint_segment(idx);
148
149 let segment = Rc::new(CheckpointedSegment::new(input, should_checkpoint));
150 self.segments.push(segment.clone());
151
152 if should_checkpoint {
154 *self.memory_saved.borrow_mut() += 1;
157 }
158
159 segment
160 }
161
162 fn should_checkpoint_segment(&self, segment_idx: usize) -> bool {
164 if !self.config.enabled {
165 return false;
166 }
167
168 let checkpoint_interval = self.segments.len().max(1) / self.config.num_segments.max(1);
170 if checkpoint_interval == 0 {
171 return true; }
173
174 segment_idx.is_multiple_of(checkpoint_interval)
175 }
176
177 pub fn memory_saved_segments(&self) -> usize {
179 *self.memory_saved.borrow()
180 }
181
182 pub fn clear(&mut self) {
184 for segment in &self.segments {
185 segment.clear_output();
186 }
187 self.segments.clear();
188 *self.current_segment.borrow_mut() = 0;
189 }
190
191 pub fn total_segments(&self) -> usize {
193 self.segments.len()
194 }
195}
196
197pub fn checkpoint<F>(f: F, input: &Tensor) -> Tensor
211where
212 F: Fn(&Tensor) -> Tensor,
213{
214 f(input)
216}
217
218pub fn checkpoint_if<F>(f: F, input: &Tensor, should_checkpoint: bool) -> Tensor
226where
227 F: Fn(&Tensor) -> Tensor,
228{
229 if should_checkpoint {
230 f(input)
233 } else {
234 f(input)
235 }
236}
237
238pub fn estimate_memory_savings(
252 num_layers: usize,
253 hidden_size: usize,
254 seq_len: usize,
255 batch_size: usize,
256 num_checkpoints: usize,
257) -> (usize, usize) {
258 let activation_size = batch_size * seq_len * hidden_size * 4;
260
261 let memory_without = num_layers * activation_size;
263
264 let sqrt_layers = (num_layers as f64).sqrt().ceil() as usize;
267 let memory_with = sqrt_layers.max(num_checkpoints) * activation_size;
268
269 (memory_without, memory_with)
270}
271
272pub fn optimal_checkpoints(num_layers: usize) -> usize {
276 ((num_layers as f64).sqrt().ceil() as usize).max(1)
277}
278
279#[derive(Debug, Clone)]
286pub struct OperationInfo {
287 pub op_type: OpType,
289 pub output_bytes: usize,
291 pub has_batch_dim: bool,
293 pub layer_index: usize,
295}
296
297impl OperationInfo {
298 pub fn new(op_type: OpType, output_bytes: usize) -> Self {
300 Self { op_type, output_bytes, has_batch_dim: false, layer_index: 0 }
301 }
302
303 pub fn with_batch_dim(mut self, has_batch: bool) -> Self {
305 self.has_batch_dim = has_batch;
306 self
307 }
308
309 pub fn with_layer_index(mut self, index: usize) -> Self {
311 self.layer_index = index;
312 self
313 }
314}
315
316pub trait CheckpointPolicy {
322 fn should_save(&self, op: &OperationInfo) -> bool;
324
325 fn recompute_cost(&self, _op: &OperationInfo) -> f64 {
327 1.0
328 }
329}
330
331pub struct SaveAll;
333
334impl CheckpointPolicy for SaveAll {
335 fn should_save(&self, _op: &OperationInfo) -> bool {
336 true
337 }
338}
339
340pub struct SaveNothing;
342
343impl CheckpointPolicy for SaveNothing {
344 fn should_save(&self, _op: &OperationInfo) -> bool {
345 false
346 }
347}
348
349pub struct SaveMatmuls;
351
352impl CheckpointPolicy for SaveMatmuls {
353 fn should_save(&self, op: &OperationInfo) -> bool {
354 matches!(op.op_type, OpType::Matmul | OpType::Attention)
355 }
356
357 fn recompute_cost(&self, op: &OperationInfo) -> f64 {
358 match op.op_type {
359 OpType::Matmul => 100.0,
360 OpType::Attention => 150.0,
361 OpType::Add
362 | OpType::Mul
363 | OpType::Scale
364 | OpType::Sum
365 | OpType::Relu
366 | OpType::Gelu
367 | OpType::Softmax
368 | OpType::LayerNorm
369 | OpType::Constant => 1.0,
370 }
371 }
372}
373
374pub struct SaveUnbatchedMatmuls;
377
378impl CheckpointPolicy for SaveUnbatchedMatmuls {
379 fn should_save(&self, op: &OperationInfo) -> bool {
380 matches!(op.op_type, OpType::Matmul | OpType::Attention) && !op.has_batch_dim
381 }
382}
383
384pub struct BinomialCheckpointing {
388 pub num_layers: usize,
390}
391
392impl BinomialCheckpointing {
393 pub fn checkpoint_indices(&self) -> Vec<usize> {
395 let num_checkpoints = optimal_checkpoints(self.num_layers);
396 let interval = self.num_layers / num_checkpoints.max(1);
397 (0..self.num_layers).step_by(interval.max(1)).collect()
398 }
399}
400
401impl CheckpointPolicy for BinomialCheckpointing {
402 fn should_save(&self, op: &OperationInfo) -> bool {
403 let indices = self.checkpoint_indices();
404 indices.contains(&op.layer_index)
405 }
406}
407
408pub struct MemoryBudget {
410 pub max_bytes: usize,
412 used_bytes: RefCell<usize>,
414}
415
416impl MemoryBudget {
417 pub fn new(max_bytes: usize) -> Self {
419 Self { max_bytes, used_bytes: RefCell::new(0) }
420 }
421
422 pub fn used_bytes(&self) -> usize {
424 *self.used_bytes.borrow()
425 }
426
427 pub fn reset(&self) {
429 *self.used_bytes.borrow_mut() = 0;
430 }
431}
432
433impl CheckpointPolicy for MemoryBudget {
434 fn should_save(&self, op: &OperationInfo) -> bool {
435 let current = *self.used_bytes.borrow();
436 if current + op.output_bytes <= self.max_bytes {
437 *self.used_bytes.borrow_mut() += op.output_bytes;
438 true
439 } else {
440 false
441 }
442 }
443}
444
445pub struct CustomPolicy<F: Fn(&OperationInfo) -> bool> {
447 predicate: F,
448}
449
450impl<F: Fn(&OperationInfo) -> bool> CustomPolicy<F> {
451 pub fn new(predicate: F) -> Self {
453 Self { predicate }
454 }
455}
456
457impl<F: Fn(&OperationInfo) -> bool> CheckpointPolicy for CustomPolicy<F> {
458 fn should_save(&self, op: &OperationInfo) -> bool {
459 (self.predicate)(op)
460 }
461}
462
463pub struct PolicyCheckpointManager {
466 saved: Vec<Option<Tensor>>,
468 total_bytes_saved: usize,
470 num_layers: usize,
472}
473
474impl PolicyCheckpointManager {
475 pub fn new(num_layers: usize) -> Self {
477 Self { saved: vec![None; num_layers], total_bytes_saved: 0, num_layers }
478 }
479
480 pub fn record<P: CheckpointPolicy>(
482 &mut self,
483 layer_index: usize,
484 activation: &Tensor,
485 op_info: &OperationInfo,
486 policy: &P,
487 ) {
488 if policy.should_save(op_info) && layer_index < self.num_layers {
489 self.saved[layer_index] = Some(activation.clone());
490 self.total_bytes_saved += op_info.output_bytes;
491 }
492 }
493
494 pub fn get(&self, layer_index: usize) -> Option<&Tensor> {
496 self.saved.get(layer_index).and_then(|s| s.as_ref())
497 }
498
499 pub fn is_saved(&self, layer_index: usize) -> bool {
501 self.saved.get(layer_index).is_some_and(Option::is_some)
502 }
503
504 pub fn total_bytes(&self) -> usize {
506 contract_pre_total_bytes!();
507 self.total_bytes_saved
508 }
509
510 pub fn num_saved(&self) -> usize {
512 self.saved.iter().filter(|s| s.is_some()).count()
513 }
514
515 pub fn clear(&mut self) {
517 self.saved.iter_mut().for_each(|s| *s = None);
518 self.total_bytes_saved = 0;
519 }
520
521 pub fn num_layers(&self) -> usize {
523 self.num_layers
524 }
525}
526
527pub fn estimate_policy_tradeoff<P: CheckpointPolicy>(
534 policy: &P,
535 layer_infos: &[OperationInfo],
536) -> (usize, usize, f64) {
537 let mut bytes_saved = 0usize;
538 let mut bytes_used = 0usize;
539 let mut recompute_overhead = 0.0f64;
540
541 for info in layer_infos {
542 if policy.should_save(info) {
543 bytes_used += info.output_bytes;
544 } else {
545 bytes_saved += info.output_bytes;
546 recompute_overhead += policy.recompute_cost(info);
547 }
548 }
549
550 (bytes_saved, bytes_used, recompute_overhead)
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556 use crate::autograd::scale;
557
558 #[test]
559 fn test_checkpoint_config_enabled() {
560 let config = CheckpointConfig::enabled(4);
561 assert!(config.enabled);
562 assert_eq!(config.num_segments, 4);
563 assert!(!config.selective);
564 }
565
566 #[test]
567 fn test_checkpoint_config_disabled() {
568 let config = CheckpointConfig::disabled();
569 assert!(!config.enabled);
570 }
571
572 #[test]
573 fn test_checkpoint_config_default() {
574 let config = CheckpointConfig::default();
575 assert!(!config.enabled);
576 }
577
578 #[test]
579 fn test_checkpoint_config_selective() {
580 let config = CheckpointConfig::enabled(4).with_selective();
581 assert!(config.selective);
582 }
583
584 #[test]
585 fn test_checkpointed_segment_new() {
586 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
587 let segment = CheckpointedSegment::new(input, true);
588 assert!(segment.is_checkpointed());
589 assert!(segment.output().is_none());
590 }
591
592 #[test]
593 fn test_checkpointed_segment_output() {
594 let input = Tensor::from_vec(vec![1.0, 2.0], true);
595 let segment = CheckpointedSegment::new(input, true);
596
597 let output = Tensor::from_vec(vec![2.0, 4.0], true);
598 segment.set_output(output.clone());
599
600 assert!(segment.output().is_some());
601 assert_eq!(segment.output().expect("operation should succeed").len(), 2);
602 }
603
604 #[test]
605 fn test_checkpointed_segment_clear() {
606 let input = Tensor::from_vec(vec![1.0], true);
607 let segment = CheckpointedSegment::new(input, true);
608 segment.set_output(Tensor::from_vec(vec![2.0], true));
609
610 segment.clear_output();
611 assert!(segment.output().is_none());
612 }
613
614 #[test]
615 fn test_checkpoint_manager_new() {
616 let config = CheckpointConfig::enabled(4);
617 let manager = CheckpointManager::new(config);
618 assert!(manager.is_enabled());
619 assert_eq!(manager.num_segments(), 4);
620 }
621
622 #[test]
623 fn test_checkpoint_manager_disabled() {
624 let config = CheckpointConfig::disabled();
625 let manager = CheckpointManager::new(config);
626 assert!(!manager.is_enabled());
627 }
628
629 #[test]
630 fn test_checkpoint_manager_register() {
631 let config = CheckpointConfig::enabled(2);
632 let mut manager = CheckpointManager::new(config);
633
634 let input1 = Tensor::from_vec(vec![1.0], true);
635 let input2 = Tensor::from_vec(vec![2.0], true);
636
637 let seg1 = manager.register_segment(input1);
638 let seg2 = manager.register_segment(input2);
639
640 assert_eq!(manager.total_segments(), 2);
641 assert_eq!(seg1.input().len(), 1);
642 assert_eq!(seg2.input().len(), 1);
643 }
644
645 #[test]
646 fn test_checkpoint_manager_clear() {
647 let config = CheckpointConfig::enabled(2);
648 let mut manager = CheckpointManager::new(config);
649
650 manager.register_segment(Tensor::from_vec(vec![1.0], true));
651 manager.register_segment(Tensor::from_vec(vec![2.0], true));
652
653 manager.clear();
654 assert_eq!(manager.total_segments(), 0);
655 }
656
657 #[test]
658 fn test_checkpoint_function() {
659 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
660 let output = checkpoint(|x| scale(x, 2.0), &input);
661 assert_eq!(output.len(), 3);
662 assert_eq!(output.data()[0], 2.0);
663 }
664
665 #[test]
666 fn test_checkpoint_if_enabled() {
667 let input = Tensor::from_vec(vec![1.0, 2.0], true);
668 let output = checkpoint_if(|x| scale(x, 3.0), &input, true);
669 assert_eq!(output.data()[0], 3.0);
670 }
671
672 #[test]
673 fn test_checkpoint_if_disabled() {
674 let input = Tensor::from_vec(vec![1.0, 2.0], true);
675 let output = checkpoint_if(|x| scale(x, 3.0), &input, false);
676 assert_eq!(output.data()[0], 3.0);
677 }
678
679 #[test]
680 fn test_estimate_memory_savings() {
681 let (without, with) = estimate_memory_savings(32, 4096, 512, 1, 6);
682
683 assert!(with < without);
685
686 assert_eq!(without, 32 * 512 * 4096 * 4);
689 }
690
691 #[test]
692 fn test_optimal_checkpoints() {
693 assert_eq!(optimal_checkpoints(1), 1);
694 assert_eq!(optimal_checkpoints(4), 2);
695 assert_eq!(optimal_checkpoints(16), 4);
696 assert_eq!(optimal_checkpoints(32), 6);
697 assert_eq!(optimal_checkpoints(64), 8);
698 }
699
700 #[test]
701 fn test_memory_savings_formula() {
702 let num_layers = 32;
704 let checkpoints = optimal_checkpoints(num_layers);
705
706 let (without, with) = estimate_memory_savings(num_layers, 1024, 128, 1, checkpoints);
707
708 let ratio = without as f64 / with as f64;
710 assert!(ratio > 4.0); }
712
713 #[test]
714 fn test_checkpoint_preserves_computation() {
715 let input = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], true);
716
717 let direct = scale(&input, 2.5);
719
720 let checkpointed = checkpoint(|x| scale(x, 2.5), &input);
722
723 for i in 0..4 {
725 assert_eq!(direct.data()[i], checkpointed.data()[i]);
726 }
727 }
728
729 #[test]
730 fn test_nested_checkpoints() {
731 let input = Tensor::from_vec(vec![1.0, 2.0], true);
732
733 let output = checkpoint(
734 |x| {
735 let h1 = scale(x, 2.0);
736 checkpoint(|y| scale(y, 3.0), &h1)
737 },
738 &input,
739 );
740
741 assert_eq!(output.data()[0], 6.0);
743 }
744
745 #[test]
746 fn test_checkpoint_manager_memory_tracking() {
747 let config = CheckpointConfig::enabled(2);
748 let mut manager = CheckpointManager::new(config);
749
750 for i in 0..4 {
751 manager.register_segment(Tensor::from_vec(vec![i as f32], true));
752 }
753
754 assert!(manager.memory_saved_segments() > 0);
756 }
757
758 fn make_op(op_type: OpType, bytes: usize) -> OperationInfo {
761 OperationInfo::new(op_type, bytes)
762 }
763
764 #[test]
765 fn test_operation_info_builder() {
766 let info =
767 OperationInfo::new(OpType::Matmul, 1024).with_batch_dim(true).with_layer_index(5);
768 assert_eq!(info.op_type, OpType::Matmul);
769 assert_eq!(info.output_bytes, 1024);
770 assert!(info.has_batch_dim);
771 assert_eq!(info.layer_index, 5);
772 }
773
774 #[test]
775 fn test_save_all_policy() {
776 let policy = SaveAll;
777 assert!(policy.should_save(&make_op(OpType::Add, 100)));
778 assert!(policy.should_save(&make_op(OpType::Matmul, 10000)));
779 assert!(policy.should_save(&make_op(OpType::Relu, 50)));
780 }
781
782 #[test]
783 fn test_save_nothing_policy() {
784 let policy = SaveNothing;
785 assert!(!policy.should_save(&make_op(OpType::Add, 100)));
786 assert!(!policy.should_save(&make_op(OpType::Matmul, 10000)));
787 assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
788 }
789
790 #[test]
791 fn test_save_matmuls_policy() {
792 let policy = SaveMatmuls;
793 assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
794 assert!(policy.should_save(&make_op(OpType::Attention, 2000)));
795 assert!(!policy.should_save(&make_op(OpType::Add, 100)));
796 assert!(!policy.should_save(&make_op(OpType::Relu, 50)));
797 assert!(!policy.should_save(&make_op(OpType::Softmax, 100)));
798 }
799
800 #[test]
801 fn test_save_matmuls_recompute_cost() {
802 let policy = SaveMatmuls;
803 assert!((policy.recompute_cost(&make_op(OpType::Matmul, 0)) - 100.0).abs() < f64::EPSILON);
804 assert!(
805 (policy.recompute_cost(&make_op(OpType::Attention, 0)) - 150.0).abs() < f64::EPSILON
806 );
807 assert!((policy.recompute_cost(&make_op(OpType::Add, 0)) - 1.0).abs() < f64::EPSILON);
808 }
809
810 #[test]
811 fn test_save_unbatched_matmuls_policy() {
812 let policy = SaveUnbatchedMatmuls;
813
814 let unbatched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(false);
816 assert!(policy.should_save(&unbatched));
817
818 let batched = OperationInfo::new(OpType::Matmul, 1000).with_batch_dim(true);
820 assert!(!policy.should_save(&batched));
821
822 let add = OperationInfo::new(OpType::Add, 100).with_batch_dim(false);
824 assert!(!policy.should_save(&add));
825 }
826
827 #[test]
828 fn test_binomial_checkpointing_indices() {
829 let policy = BinomialCheckpointing { num_layers: 16 };
830 let indices = policy.checkpoint_indices();
831
832 assert_eq!(indices, vec![0, 4, 8, 12]);
834 }
835
836 #[test]
837 fn test_binomial_checkpointing_policy() {
838 let policy = BinomialCheckpointing { num_layers: 16 };
839
840 let at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(0);
841 assert!(policy.should_save(&at_checkpoint));
842
843 let not_at_checkpoint = OperationInfo::new(OpType::Add, 100).with_layer_index(1);
844 assert!(!policy.should_save(¬_at_checkpoint));
845
846 let at_checkpoint_4 = OperationInfo::new(OpType::Add, 100).with_layer_index(4);
847 assert!(policy.should_save(&at_checkpoint_4));
848 }
849
850 #[test]
851 fn test_memory_budget_policy() {
852 let policy = MemoryBudget::new(500);
853
854 let op1 = make_op(OpType::Matmul, 200);
856 assert!(policy.should_save(&op1));
857 assert_eq!(policy.used_bytes(), 200);
858
859 let op2 = make_op(OpType::Add, 200);
861 assert!(policy.should_save(&op2));
862 assert_eq!(policy.used_bytes(), 400);
863
864 let op3 = make_op(OpType::Relu, 200);
866 assert!(!policy.should_save(&op3));
867 assert_eq!(policy.used_bytes(), 400);
868
869 policy.reset();
871 assert_eq!(policy.used_bytes(), 0);
872 assert!(policy.should_save(&op3));
873 }
874
875 #[test]
876 fn test_custom_policy() {
877 let policy = CustomPolicy::new(|op: &OperationInfo| op.output_bytes > 500);
879
880 assert!(!policy.should_save(&make_op(OpType::Add, 100)));
881 assert!(policy.should_save(&make_op(OpType::Matmul, 1000)));
882 assert!(!policy.should_save(&make_op(OpType::Relu, 500)));
883 assert!(policy.should_save(&make_op(OpType::Softmax, 501)));
884 }
885
886 #[test]
887 fn test_policy_checkpoint_manager_basic() {
888 let mut manager = PolicyCheckpointManager::new(4);
889 let policy = SaveAll;
890
891 let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], true);
892 let info = make_op(OpType::Matmul, 12);
893
894 manager.record(0, &tensor, &info, &policy);
895 assert!(manager.is_saved(0));
896 assert!(!manager.is_saved(1));
897 assert_eq!(manager.num_saved(), 1);
898 assert_eq!(manager.total_bytes(), 12);
899
900 let saved = manager.get(0).expect("key should exist");
902 assert_eq!(saved.len(), 3);
903 }
904
905 #[test]
906 fn test_policy_checkpoint_manager_selective() {
907 let mut manager = PolicyCheckpointManager::new(4);
908 let policy = SaveMatmuls;
909
910 let t1 = Tensor::from_vec(vec![1.0], true);
911 let t2 = Tensor::from_vec(vec![2.0], true);
912
913 manager.record(0, &t1, &make_op(OpType::Matmul, 4), &policy);
915 manager.record(1, &t2, &make_op(OpType::Add, 4), &policy);
917
918 assert!(manager.is_saved(0));
919 assert!(!manager.is_saved(1));
920 assert_eq!(manager.num_saved(), 1);
921 }
922
923 #[test]
924 fn test_policy_checkpoint_manager_clear() {
925 let mut manager = PolicyCheckpointManager::new(2);
926 let policy = SaveAll;
927
928 let t = Tensor::from_vec(vec![1.0], true);
929 manager.record(0, &t, &make_op(OpType::Add, 4), &policy);
930
931 manager.clear();
932 assert_eq!(manager.num_saved(), 0);
933 assert_eq!(manager.total_bytes(), 0);
934 assert!(!manager.is_saved(0));
935 }
936
937 #[test]
938 fn test_policy_checkpoint_manager_out_of_bounds() {
939 let mut manager = PolicyCheckpointManager::new(2);
940 let policy = SaveAll;
941
942 let t = Tensor::from_vec(vec![1.0], true);
943 manager.record(5, &t, &make_op(OpType::Add, 4), &policy);
945 assert_eq!(manager.num_saved(), 0);
946 }
947
948 #[test]
949 fn test_estimate_policy_tradeoff_save_all() {
950 let policy = SaveAll;
951 let infos = vec![
952 make_op(OpType::Matmul, 1000),
953 make_op(OpType::Add, 200),
954 make_op(OpType::Relu, 200),
955 ];
956
957 let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
958 assert_eq!(saved, 0); assert_eq!(used, 1400);
960 assert!((overhead - 0.0).abs() < f64::EPSILON);
961 }
962
963 #[test]
964 fn test_estimate_policy_tradeoff_save_nothing() {
965 let policy = SaveNothing;
966 let infos = vec![make_op(OpType::Matmul, 1000), make_op(OpType::Add, 200)];
967
968 let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
969 assert_eq!(saved, 1200); assert_eq!(used, 0);
971 assert!(overhead > 0.0); }
973
974 #[test]
975 fn test_estimate_policy_tradeoff_save_matmuls() {
976 let policy = SaveMatmuls;
977 let infos = vec![
978 make_op(OpType::Matmul, 1000),
979 make_op(OpType::Add, 200),
980 make_op(OpType::Relu, 200),
981 ];
982
983 let (saved, used, overhead) = estimate_policy_tradeoff(&policy, &infos);
984 assert_eq!(used, 1000); assert_eq!(saved, 400); assert!(overhead > 0.0); }
988
989 #[test]
990 fn test_policy_checkpoint_manager_num_layers() {
991 let manager = PolicyCheckpointManager::new(8);
992 assert_eq!(manager.num_layers(), 8);
993 }
994
995 #[test]
996 fn test_binomial_single_layer() {
997 let policy = BinomialCheckpointing { num_layers: 1 };
998 let indices = policy.checkpoint_indices();
999 assert_eq!(indices, vec![0]);
1000 }
1001
1002 #[test]
1003 fn test_default_recompute_cost() {
1004 let policy = SaveAll;
1005 let info = make_op(OpType::Add, 100);
1006 assert!((policy.recompute_cost(&info) - 1.0).abs() < f64::EPSILON);
1007 }
1008}