1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10use std::ops::{AddAssign, MulAssign, SubAssign};
11
12pub trait InPlaceOptimizer<A: Float + ScalarOperand + Debug, D: Dimension> {
14 fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()>;
19
20 fn step_list_inplace(
22 &mut self,
23 params_list: &mut [&mut Array<A, D>],
24 gradients_list: &[&Array<A, D>],
25 ) -> Result<()> {
26 if params_list.len() != gradients_list.len() {
27 return Err(OptimError::InvalidConfig(format!(
28 "Number of parameter arrays ({}) does not match number of gradient arrays ({})",
29 params_list.len(),
30 gradients_list.len()
31 )));
32 }
33
34 for (params, grads) in params_list.iter_mut().zip(gradients_list.iter()) {
35 self.step_inplace(params, grads)?;
36 }
37 Ok(())
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct InPlaceSGD<A: Float> {
44 _learningrate: A,
45 momentum: A,
46 weight_decay: A,
47}
48
49impl<A: Float + ScalarOperand + Debug + Send + Sync> InPlaceSGD<A> {
50 pub fn new(_learningrate: A) -> Self {
52 Self {
53 _learningrate,
54 momentum: A::zero(),
55 weight_decay: A::zero(),
56 }
57 }
58
59 pub fn with_momentum(mut self, momentum: A) -> Self {
61 self.momentum = momentum;
62 self
63 }
64
65 pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
67 self.weight_decay = weightdecay;
68 self
69 }
70}
71
72impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
73 for InPlaceSGD<A>
74{
75 fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
76 if self.weight_decay > A::zero() {
78 params.zip_mut_with(gradients, |p, &g| {
79 *p = *p - self._learningrate * (g + *p * self.weight_decay);
80 });
81 } else {
82 params.zip_mut_with(gradients, |p, &g| {
84 *p = *p - self._learningrate * g;
85 });
86 }
87 Ok(())
88 }
89}
90
91#[derive(Debug)]
93pub struct InPlaceAdam<A: Float, D: Dimension> {
94 _learningrate: A,
95 beta1: A,
96 beta2: A,
97 epsilon: A,
98 weight_decay: A,
99 t: i32,
100 m: Option<Array<A, D>>,
102 v: Option<Array<A, D>>,
104}
105
106impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceAdam<A, D> {
107 pub fn new(_learningrate: A) -> Self {
109 Self {
110 _learningrate,
111 beta1: A::from(0.9).unwrap(),
112 beta2: A::from(0.999).unwrap(),
113 epsilon: A::from(1e-8).unwrap(),
114 weight_decay: A::zero(),
115 t: 0,
116 m: None,
117 v: None,
118 }
119 }
120
121 pub fn with_beta1(mut self, beta1: A) -> Self {
123 self.beta1 = beta1;
124 self
125 }
126
127 pub fn with_beta2(mut self, beta2: A) -> Self {
129 self.beta2 = beta2;
130 self
131 }
132
133 pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
135 self.weight_decay = weightdecay;
136 self
137 }
138
139 pub fn with_epsilon(mut self, epsilon: A) -> Self {
141 self.epsilon = epsilon;
142 self
143 }
144
145 pub fn reset(&mut self) {
147 self.t = 0;
148 self.m = None;
149 self.v = None;
150 }
151}
152
153impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
154 for InPlaceAdam<A, D>
155{
156 fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
157 self.t += 1;
158 let _t = A::from(self.t).unwrap();
159
160 if self.m.is_none() {
162 self.m = Some(Array::zeros(params.raw_dim()));
163 }
164 if self.v.is_none() {
165 self.v = Some(Array::zeros(params.raw_dim()));
166 }
167
168 let m = self.m.as_mut().unwrap();
169 let v = self.v.as_mut().unwrap();
170
171 let grad_with_decay = if self.weight_decay > A::zero() {
173 let mut temp = gradients.clone();
175 temp.zip_mut_with(params, |g, &p| {
176 *g = *g + p * self.weight_decay;
177 });
178 temp
179 } else {
180 gradients.clone()
181 };
182
183 m.zip_mut_with(&grad_with_decay, |m_i, &g| {
185 *m_i = self.beta1 * *m_i + (A::one() - self.beta1) * g;
186 });
187
188 v.zip_mut_with(&grad_with_decay, |v_i, &g| {
190 *v_i = self.beta2 * *v_i + (A::one() - self.beta2) * g * g;
191 });
192
193 let bias1 = A::one() - self.beta1.powi(self.t);
195 let bias2 = A::one() - self.beta2.powi(self.t);
196
197 let m_iter = m.iter();
199 let v_iter = v.iter();
200 let params_iter = params.iter_mut();
201
202 for ((p, &m_i), &v_i) in params_iter.zip(m_iter).zip(v_iter) {
203 let m_hat = m_i / bias1;
204 let v_hat = v_i / bias2;
205 *p = *p - self._learningrate * m_hat / (v_hat.sqrt() + self.epsilon);
206 }
207
208 Ok(())
209 }
210}
211
212pub mod utils {
214 use super::*;
215
216 pub fn scale_inplace<A, D>(array: &mut Array<A, D>, scalar: A)
218 where
219 A: Float + ScalarOperand + MulAssign,
220 D: Dimension,
221 {
222 array.map_inplace(|x| *x *= scalar);
223 }
224
225 pub fn add_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
227 where
228 A: Float + ScalarOperand + AddAssign,
229 D: Dimension,
230 {
231 a.zip_mut_with(b, |x, &y| *x += y);
232 }
233
234 pub fn subtract_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
236 where
237 A: Float + ScalarOperand + SubAssign,
238 D: Dimension,
239 {
240 a.zip_mut_with(b, |x, &y| *x -= y);
241 }
242
243 pub fn apply_inplace<A, D, F>(array: &mut Array<A, D>, f: F)
245 where
246 A: Float + ScalarOperand,
247 D: Dimension,
248 F: Fn(&mut A),
249 {
250 array.map_inplace(f);
251 }
252
253 pub fn clip_inplace<A, D>(array: &mut Array<A, D>, min: A, max: A)
255 where
256 A: Float + ScalarOperand,
257 D: Dimension,
258 {
259 array.map_inplace(|x| {
260 if *x < min {
261 *x = min;
262 } else if *x > max {
263 *x = max;
264 }
265 });
266 }
267
268 pub fn normalize_inplace<A, D>(array: &mut Array<A, D>)
270 where
271 A: Float + ScalarOperand + MulAssign,
272 D: Dimension,
273 {
274 let norm = array.mapv(|x| x * x).sum().sqrt();
275 if norm > A::zero() {
276 array.map_inplace(|x| *x *= A::one() / norm);
277 }
278 }
279}
280
281pub mod fused {
283 use super::*;
284
285 #[derive(Debug, Clone, Copy)]
287 pub struct AdamConfig<A> {
288 pub lr: A,
289 pub beta1: A,
290 pub beta2: A,
291 pub epsilon: A,
292 pub bias1: A,
293 pub bias2: A,
294 pub weight_decay: Option<A>,
295 }
296
297 pub fn fused_adam_update<A, D>(
302 params: &mut Array<A, D>,
303 gradients: &Array<A, D>,
304 m: &mut Array<A, D>,
305 v: &mut Array<A, D>,
306 config: AdamConfig<A>,
307 ) where
308 A: Float + ScalarOperand,
309 D: Dimension,
310 {
311 let one = A::one();
312 let one_minus_beta1 = one - config.beta1;
313 let one_minus_beta2 = one - config.beta2;
314
315 if let Some(wd) = config.weight_decay {
316 for ((((p, &g), m_val), v_val), bias_corrected) in params
318 .iter_mut()
319 .zip(gradients.iter())
320 .zip(m.iter_mut())
321 .zip(v.iter_mut())
322 .zip(std::iter::repeat((config.bias1, config.bias2)))
323 {
324 let g_with_decay = g + *p * wd;
326
327 *m_val = config.beta1 * *m_val + one_minus_beta1 * g_with_decay;
329
330 *v_val = config.beta2 * *v_val + one_minus_beta2 * g_with_decay * g_with_decay;
332
333 let m_hat = *m_val / bias_corrected.0;
335 let v_hat = *v_val / bias_corrected.1;
336 *p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
337 }
338 } else {
339 for ((((p, &g), m_val), v_val), bias_corrected) in params
341 .iter_mut()
342 .zip(gradients.iter())
343 .zip(m.iter_mut())
344 .zip(v.iter_mut())
345 .zip(std::iter::repeat((config.bias1, config.bias2)))
346 {
347 *m_val = config.beta1 * *m_val + one_minus_beta1 * g;
349
350 *v_val = config.beta2 * *v_val + one_minus_beta2 * g * g;
352
353 let m_hat = *m_val / bias_corrected.0;
355 let v_hat = *v_val / bias_corrected.1;
356 *p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
357 }
358 }
359 }
360
361 pub fn fused_sgd_update<A, D>(
363 params: &mut Array<A, D>,
364 gradients: &Array<A, D>,
365 momentum_buf: Option<&mut Array<A, D>>,
366 lr: A,
367 momentum: A,
368 weight_decay: Option<A>,
369 dampening: A,
370 ) where
371 A: Float + ScalarOperand,
372 D: Dimension,
373 {
374 if let Some(_buf) = momentum_buf {
375 if let Some(wd) = weight_decay {
376 for ((p, g), buf_val) in
378 params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
379 {
380 let g_with_decay = *g + *p * wd;
381 *buf_val = momentum * *buf_val + (A::one() - dampening) * g_with_decay;
382 *p = *p - lr * *buf_val;
383 }
384 } else {
385 for ((p, g), buf_val) in
387 params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
388 {
389 *buf_val = momentum * *buf_val + (A::one() - dampening) * *g;
390 *p = *p - lr * *buf_val;
391 }
392 }
393 } else if let Some(wd) = weight_decay {
394 for (p, g) in params.iter_mut().zip(gradients.iter()) {
396 *p = *p - lr * (*g + *p * wd);
397 }
398 } else {
399 for (p, g) in params.iter_mut().zip(gradients.iter()) {
401 *p = *p - lr * *g;
402 }
403 }
404 }
405
406 pub fn fused_gradient_clip_normalize<A, D>(
408 gradients: &mut Array<A, D>,
409 max_norm: Option<A>,
410 clip_value: Option<A>,
411 ) where
412 A: Float + ScalarOperand,
413 D: Dimension,
414 {
415 if let Some(clip_val) = clip_value {
416 for g in gradients.iter_mut() {
418 if *g > clip_val {
419 *g = clip_val;
420 } else if *g < -clip_val {
421 *g = -clip_val;
422 }
423 }
424 }
425
426 if let Some(max_norm_val) = max_norm {
427 let norm_sq = gradients
429 .iter()
430 .map(|&x| x * x)
431 .fold(A::zero(), |acc, x| acc + x);
432 let _norm = norm_sq.sqrt();
433
434 if _norm > max_norm_val {
435 let scale = max_norm_val / _norm;
436 for g in gradients.iter_mut() {
437 *g = *g * scale;
438 }
439 }
440 }
441 }
442
443 pub fn fused_apply_constraints<A, D>(
445 params: &mut Array<A, D>,
446 l2_constraint: Option<A>,
447 value_bounds: Option<(A, A)>,
448 ) where
449 A: Float + ScalarOperand,
450 D: Dimension,
451 {
452 if let Some((min_val, max_val)) = value_bounds {
454 for p in params.iter_mut() {
455 if *p < min_val {
456 *p = min_val;
457 } else if *p > max_val {
458 *p = max_val;
459 }
460 }
461 }
462
463 if let Some(max_norm) = l2_constraint {
465 let norm_sq = params
466 .iter()
467 .map(|&x| x * x)
468 .fold(A::zero(), |acc, x| acc + x);
469 let norm = norm_sq.sqrt();
470
471 if norm > max_norm {
472 let scale = max_norm / norm;
473 for p in params.iter_mut() {
474 *p = *p * scale;
475 }
476 }
477 }
478 }
479}
480
481pub mod mixed_precision {
483 use super::*;
484
485 #[derive(Debug, Clone)]
487 pub struct LossScaler {
488 scale: f32,
489 growth_factor: f32,
490 backoff_factor: f32,
491 growth_interval: usize,
492 steps_since_update: usize,
493 }
494
495 impl LossScaler {
496 pub fn new(_initialscale: f32) -> Self {
498 Self {
499 scale: _initialscale,
500 growth_factor: 2.0,
501 backoff_factor: 0.5,
502 growth_interval: 2000,
503 steps_since_update: 0,
504 }
505 }
506
507 pub fn get_scale(&self) -> f32 {
509 self.scale
510 }
511
512 pub fn scale_loss(&self, loss: f32) -> f32 {
514 loss * self.scale
515 }
516
517 pub fn unscale_gradients<A, D>(&self, gradients: &mut Array<A, D>)
519 where
520 A: Float + ScalarOperand,
521 D: Dimension,
522 {
523 let inv_scale = A::one() / A::from(self.scale).unwrap();
524 for g in gradients.iter_mut() {
525 *g = *g * inv_scale;
526 }
527 }
528
529 pub fn update(&mut self, foundinf: bool) {
531 self.steps_since_update += 1;
532
533 if foundinf {
534 self.scale *= self.backoff_factor;
536 self.steps_since_update = 0;
537 } else if self.steps_since_update >= self.growth_interval {
538 self.scale *= self.growth_factor;
540 self.steps_since_update = 0;
541 }
542 }
543
544 pub fn check_gradients<A, D>(&self, gradients: &Array<A, D>) -> bool
546 where
547 A: Float + ScalarOperand,
548 D: Dimension,
549 {
550 gradients.iter().any(|&x| !x.is_finite())
551 }
552 }
553}
554
555pub mod gradient_checkpointing {
557 use super::*;
558 use std::collections::VecDeque;
559
560 #[derive(Debug, Clone, PartialEq)]
562 pub enum CheckpointStrategy {
563 None,
565 Uniform {
567 interval: usize,
569 },
570 Logarithmic {
572 base: f64,
574 },
575 MemoryAware {
577 memory_threshold: f64,
579 },
580 Custom {
582 pattern: Vec<bool>,
584 },
585 }
586
587 #[derive(Debug)]
589 pub struct GradientCheckpointer<A: Float, D: Dimension> {
590 strategy: CheckpointStrategy,
592 checkpoints: std::collections::HashMap<usize, Array<A, D>>,
594 memory_tracker: MemoryTracker,
596 current_depth: usize,
598 max_depth: usize,
600 enabled: bool,
602 }
603
604 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientCheckpointer<A, D> {
605 pub fn new(strategy: CheckpointStrategy) -> Self {
607 Self {
608 strategy,
609 checkpoints: std::collections::HashMap::new(),
610 memory_tracker: MemoryTracker::new(),
611 current_depth: 0,
612 max_depth: 0,
613 enabled: true,
614 }
615 }
616
617 pub fn set_max_depth(&mut self, depth: usize) {
619 self.max_depth = depth;
620 }
621
622 pub fn set_enabled(&mut self, enabled: bool) {
624 self.enabled = enabled;
625 }
626
627 pub fn should_checkpoint(&self, depth: usize) -> bool {
629 if !self.enabled || self.max_depth == 0 {
630 return false;
631 }
632
633 match self.strategy {
634 CheckpointStrategy::None => false,
635 CheckpointStrategy::Uniform { interval } => depth.is_multiple_of(interval),
636 CheckpointStrategy::Logarithmic { base } => {
637 let log_depth = (depth as f64).log(base).floor() as usize;
638 depth == base.powi(log_depth as i32) as usize
639 }
640 CheckpointStrategy::MemoryAware { memory_threshold } => {
641 self.memory_tracker.usage_ratio() > memory_threshold
642 }
643 CheckpointStrategy::Custom { ref pattern } => {
644 if depth < pattern.len() {
645 pattern[depth]
646 } else {
647 false
648 }
649 }
650 }
651 }
652
653 pub fn store_checkpoint(&mut self, depth: usize, activation: Array<A, D>) {
655 if self.should_checkpoint(depth) {
656 let memory_size = activation.len() * std::mem::size_of::<A>();
657 self.memory_tracker.add_allocation(memory_size);
658 self.checkpoints.insert(depth, activation);
659 }
660 }
661
662 pub fn get_checkpoint(&self, depth: usize) -> Option<&Array<A, D>> {
664 self.checkpoints.get(&depth)
665 }
666
667 pub fn remove_checkpoint(&mut self, depth: usize) -> Option<Array<A, D>> {
669 if let Some(checkpoint) = self.checkpoints.remove(&depth) {
670 let memory_size = checkpoint.len() * std::mem::size_of::<A>();
671 self.memory_tracker.remove_allocation(memory_size);
672 Some(checkpoint)
673 } else {
674 None
675 }
676 }
677
678 pub fn clear_checkpoints(&mut self) {
680 self.checkpoints.clear();
681 self.memory_tracker.reset();
682 }
683
684 pub fn memory_usage(&self) -> MemoryUsage {
686 self.memory_tracker.usage()
687 }
688
689 pub fn optimize_strategy(&mut self, target_memoryusage: f64) {
691 let current_usage = self.memory_tracker.usage_ratio();
692
693 if current_usage > target_memoryusage {
694 self.strategy = match &self.strategy {
696 CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
697 interval: (interval / 2).max(1),
698 },
699 CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
700 memory_threshold: target_memoryusage * 0.8,
701 },
702 other => other.clone(),
703 };
704 } else if current_usage < target_memoryusage * 0.5 {
705 self.strategy = match &self.strategy {
707 CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
708 interval: interval * 2,
709 },
710 CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
711 memory_threshold: target_memoryusage * 1.2,
712 },
713 other => other.clone(),
714 };
715 }
716 }
717
718 pub fn checkpointed_forward<F, Output>(
720 &mut self,
721 depth: usize,
722 input: &Array<A, D>,
723 forward_fn: F,
724 ) -> Result<(Output, Option<Array<A, D>>)>
725 where
726 F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
727 {
728 self.current_depth = depth;
729
730 let (output, activation) = forward_fn(input)?;
732
733 let checkpoint = if self.should_checkpoint(depth) {
735 self.store_checkpoint(depth, activation.clone());
736 Some(activation)
737 } else {
738 None
739 };
740
741 Ok((output, checkpoint))
742 }
743
744 pub fn recompute_from_checkpoint<F>(
746 &self,
747 start_depth: usize,
748 target_depth: usize,
749 recompute_fn: F,
750 ) -> Result<Array<A, D>>
751 where
752 F: Fn(usize, &Array<A, D>) -> Result<Array<A, D>>,
753 {
754 let checkpoint_depth = (0..=start_depth)
756 .rev()
757 .find(|&d| self.checkpoints.contains_key(&d))
758 .ok_or_else(|| {
759 OptimError::InvalidConfig("No checkpoint found for recomputation".to_string())
760 })?;
761
762 let mut current_activation = self.checkpoints[&checkpoint_depth].clone();
763
764 for _depth in (checkpoint_depth + 1)..=target_depth {
766 current_activation = recompute_fn(_depth, ¤t_activation)?;
767 }
768
769 Ok(current_activation)
770 }
771 }
772
773 #[derive(Debug, Clone)]
775 pub struct MemoryTracker {
776 allocated_bytes: usize,
777 peak_bytes: usize,
778 total_system_memory: usize,
779 }
780
781 impl Default for MemoryTracker {
782 fn default() -> Self {
783 Self::new()
784 }
785 }
786
787 impl MemoryTracker {
788 pub fn new() -> Self {
790 Self {
791 allocated_bytes: 0,
792 peak_bytes: 0,
793 total_system_memory: Self::estimate_system_memory(),
794 }
795 }
796
797 pub fn add_allocation(&mut self, bytes: usize) {
799 self.allocated_bytes += bytes;
800 self.peak_bytes = self.peak_bytes.max(self.allocated_bytes);
801 }
802
803 pub fn remove_allocation(&mut self, bytes: usize) {
805 self.allocated_bytes = self.allocated_bytes.saturating_sub(bytes);
806 }
807
808 pub fn usage(&self) -> MemoryUsage {
810 MemoryUsage {
811 current_bytes: self.allocated_bytes,
812 peak_bytes: self.peak_bytes,
813 total_system_bytes: self.total_system_memory,
814 }
815 }
816
817 pub fn usage_ratio(&self) -> f64 {
819 if self.total_system_memory == 0 {
820 0.0
821 } else {
822 self.allocated_bytes as f64 / self.total_system_memory as f64
823 }
824 }
825
826 pub fn reset(&mut self) {
828 self.allocated_bytes = 0;
829 self.peak_bytes = 0;
830 }
831
832 fn estimate_system_memory() -> usize {
834 8 * 1024 * 1024 * 1024 }
838 }
839
840 #[derive(Debug, Clone, Copy)]
842 pub struct MemoryUsage {
843 pub current_bytes: usize,
845 pub peak_bytes: usize,
847 pub total_system_bytes: usize,
849 }
850
851 impl MemoryUsage {
852 pub fn current_ratio(&self) -> f64 {
854 if self.total_system_bytes == 0 {
855 0.0
856 } else {
857 self.current_bytes as f64 / self.total_system_bytes as f64
858 }
859 }
860
861 pub fn peak_ratio(&self) -> f64 {
863 if self.total_system_bytes == 0 {
864 0.0
865 } else {
866 self.peak_bytes as f64 / self.total_system_bytes as f64
867 }
868 }
869
870 pub fn format(&self) -> String {
872 format!(
873 "Current: {:.1} MB ({:.1}%), Peak: {:.1} MB ({:.1}%), Total: {:.1} MB",
874 self.current_bytes as f64 / (1024.0 * 1024.0),
875 self.current_ratio() * 100.0,
876 self.peak_bytes as f64 / (1024.0 * 1024.0),
877 self.peak_ratio() * 100.0,
878 self.total_system_bytes as f64 / (1024.0 * 1024.0)
879 )
880 }
881 }
882
883 #[derive(Debug)]
885 pub struct AutoCheckpointer<A: Float, D: Dimension> {
886 checkpointer: GradientCheckpointer<A, D>,
887 memory_history: VecDeque<f64>,
889 target_memoryratio: f64,
891 adaptation_frequency: usize,
893 step_count: usize,
895 }
896
897 impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AutoCheckpointer<A, D> {
898 pub fn new(_initial_strategy: CheckpointStrategy, target_memoryratio: f64) -> Self {
900 Self {
901 checkpointer: GradientCheckpointer::new(_initial_strategy),
902 memory_history: VecDeque::with_capacity(100),
903 target_memoryratio: target_memoryratio.clamp(0.1, 0.9),
904 adaptation_frequency: 10,
905 step_count: 0,
906 }
907 }
908
909 pub fn with_adaptation_frequency(mut self, frequency: usize) -> Self {
911 self.adaptation_frequency = frequency.max(1);
912 self
913 }
914
915 pub fn auto_step<F, Output>(
917 &mut self,
918 depth: usize,
919 input: &Array<A, D>,
920 forward_fn: F,
921 ) -> Result<(Output, Option<Array<A, D>>)>
922 where
923 F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
924 {
925 self.step_count += 1;
926
927 let result = self
929 .checkpointer
930 .checkpointed_forward(depth, input, forward_fn)?;
931
932 let current_usage = self.checkpointer.memory_usage().current_ratio();
934 self.memory_history.push_back(current_usage);
935 if self.memory_history.len() > 100 {
936 self.memory_history.pop_front();
937 }
938
939 if self.step_count.is_multiple_of(self.adaptation_frequency) {
941 self.adapt_strategy();
942 }
943
944 Ok(result)
945 }
946
947 fn adapt_strategy(&mut self) {
949 if self.memory_history.len() < 5 {
950 return;
951 }
952
953 let recent_avg = self.memory_history.iter().rev().take(10).sum::<f64>()
955 / 10.0.min(self.memory_history.len() as f64);
956
957 let deviation = (recent_avg - self.target_memoryratio).abs();
959 if deviation > 0.1 {
960 self.checkpointer.optimize_strategy(self.target_memoryratio);
961 }
962 }
963
964 pub fn checkpointer(&self) -> &GradientCheckpointer<A, D> {
966 &self.checkpointer
967 }
968
969 pub fn checkpointer_mut(&mut self) -> &mut GradientCheckpointer<A, D> {
971 &mut self.checkpointer
972 }
973
974 pub fn get_memory_stats(&self) -> MemoryStats {
976 let usage = self.checkpointer.memory_usage();
977 let avg_usage = if self.memory_history.is_empty() {
978 0.0
979 } else {
980 self.memory_history.iter().sum::<f64>() / self.memory_history.len() as f64
981 };
982
983 MemoryStats {
984 current_usage: usage.current_ratio(),
985 peak_usage: usage.peak_ratio(),
986 average_usage: avg_usage,
987 target_usage: self.target_memoryratio,
988 checkpoints_stored: self.checkpointer.checkpoints.len(),
989 }
990 }
991 }
992
993 #[derive(Debug, Clone, Copy)]
995 pub struct MemoryStats {
996 pub current_usage: f64,
998 pub peak_usage: f64,
1000 pub average_usage: f64,
1002 pub target_usage: f64,
1004 pub checkpoints_stored: usize,
1006 }
1007
1008 impl MemoryStats {
1009 pub fn is_within_target(&self, tolerance: f64) -> bool {
1011 (self.current_usage - self.target_usage).abs() <= tolerance
1012 }
1013
1014 pub fn efficiency_score(&self) -> f64 {
1016 if self.current_usage <= self.target_usage {
1017 self.current_usage / self.target_usage
1018 } else {
1019 self.target_usage / self.current_usage
1020 }
1021 }
1022 }
1023}
1024
1025pub mod adaptive {
1027 use super::*;
1028
1029 #[derive(Debug, Clone)]
1031 pub struct MemoryAwareBatchSizer {
1032 _initial_batchsize: usize,
1033 max_batch_size: usize,
1034 min_batch_size: usize,
1035 current_batch_size: usize,
1036 memory_threshold: f64, adaptation_factor: f64,
1038 }
1039
1040 impl MemoryAwareBatchSizer {
1041 pub fn new(_initial_batchsize: usize) -> Self {
1043 Self {
1044 _initial_batchsize,
1045 max_batch_size: _initial_batchsize * 4,
1046 min_batch_size: _initial_batchsize.max(1) / 4,
1047 current_batch_size: _initial_batchsize,
1048 memory_threshold: 0.8,
1049 adaptation_factor: 1.2,
1050 }
1051 }
1052
1053 pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
1055 self.memory_threshold = threshold.clamp(0.1, 0.95);
1056 self
1057 }
1058
1059 pub fn with_adaptation_factor(mut self, factor: f64) -> Self {
1061 self.adaptation_factor = factor.max(1.0);
1062 self
1063 }
1064
1065 pub fn current_batch_size(&self) -> usize {
1067 self.current_batch_size
1068 }
1069
1070 pub fn adapt(&mut self, memory_usageratio: f64) {
1072 if memory_usageratio > self.memory_threshold {
1073 let new_size = (self.current_batch_size as f64 / self.adaptation_factor) as usize;
1075 self.current_batch_size = new_size.max(self.min_batch_size);
1076 } else if memory_usageratio < self.memory_threshold * 0.7 {
1077 let new_size = (self.current_batch_size as f64 * self.adaptation_factor) as usize;
1079 self.current_batch_size = new_size.min(self.max_batch_size);
1080 }
1081 }
1082
1083 pub fn reset(&mut self) {
1085 self.current_batch_size = self._initial_batchsize;
1086 }
1087 }
1088
1089 pub fn estimate_memory_usage<A, D>(arrays: &[&Array<A, D>]) -> usize
1091 where
1092 A: Sized,
1093 D: Dimension,
1094 {
1095 arrays
1096 .iter()
1097 .map(|arr| arr.len() * std::mem::size_of::<A>())
1098 .sum()
1099 }
1100
1101 pub fn get_memory_usage_ratio() -> f64 {
1103 0.5 }
1108}
1109
1110pub use utils::{
1112 add_inplace, apply_inplace, clip_inplace, normalize_inplace, scale_inplace, subtract_inplace,
1113};
1114
1115pub use adaptive::*;
1117pub use fused::*;
1118pub use gradient_checkpointing::*;
1119pub use mixed_precision::*;
1120
1121#[cfg(test)]
1122mod tests {
1123 use super::*;
1124 use approx::assert_relative_eq;
1125 use scirs2_core::ndarray::Array1;
1126
1127 #[test]
1128 fn test_inplace_sgd() {
1129 let mut optimizer = InPlaceSGD::new(0.1);
1130 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1131 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1132
1133 optimizer.step_inplace(&mut params, &gradients).unwrap();
1134
1135 assert_relative_eq!(params[0], 0.99, epsilon = 1e-6);
1136 assert_relative_eq!(params[1], 1.98, epsilon = 1e-6);
1137 assert_relative_eq!(params[2], 2.97, epsilon = 1e-6);
1138 }
1139
1140 #[test]
1141 fn test_inplace_adam() {
1142 let mut optimizer = InPlaceAdam::new(0.001);
1143 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1144 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1145
1146 for _ in 0..5 {
1148 optimizer.step_inplace(&mut params, &gradients).unwrap();
1149 }
1150
1151 assert!(params[0] < 1.0);
1153 assert!(params[1] < 2.0);
1154 assert!(params[2] < 3.0);
1155 }
1156
1157 #[test]
1158 fn test_utils_scale_inplace() {
1159 let mut array = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1160 utils::scale_inplace(&mut array, 2.0);
1161
1162 assert_eq!(array.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
1163 }
1164
1165 #[test]
1166 fn test_utils_clip_inplace() {
1167 let mut array = Array1::from_vec(vec![0.5, 1.5, 2.5]);
1168 utils::clip_inplace(&mut array, 1.0, 2.0);
1169
1170 assert_eq!(array.as_slice().unwrap(), &[1.0, 1.5, 2.0]);
1171 }
1172
1173 #[test]
1174 fn test_memory_efficiency() {
1175 let mut params = Array1::from_vec(vec![1.0; 1000]);
1177 let gradients = Array1::from_vec(vec![0.01; 1000]);
1178 let params_ptr = params.as_ptr();
1179
1180 let mut optimizer = InPlaceSGD::new(0.1);
1181 optimizer.step_inplace(&mut params, &gradients).unwrap();
1182
1183 assert_eq!(params_ptr, params.as_ptr());
1185 }
1186
1187 #[test]
1188 fn test_fused_adam_update() {
1189 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1190 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1191 let mut m = Array1::zeros(3);
1192 let mut v = Array1::zeros(3);
1193
1194 let config = fused::AdamConfig {
1195 lr: 0.01,
1196 beta1: 0.9,
1197 beta2: 0.999,
1198 epsilon: 1e-8,
1199 bias1: 0.1,
1200 bias2: 0.001,
1201 weight_decay: None,
1202 };
1203
1204 fused::fused_adam_update(&mut params, &gradients, &mut m, &mut v, config);
1205
1206 assert!(params[0] < 1.0);
1208 assert!(params[1] < 2.0);
1209 assert!(params[2] < 3.0);
1210
1211 assert!(m[0] > 0.0);
1213 assert!(v[0] > 0.0);
1214 }
1215
1216 #[test]
1217 fn test_fused_sgd_update() {
1218 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1219 let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1220 let mut momentum_buf = Array1::zeros(3);
1221
1222 fused::fused_sgd_update(
1223 &mut params,
1224 &gradients,
1225 Some(&mut momentum_buf),
1226 0.1, 0.9, Some(0.01), 0.0, );
1231
1232 assert!(params[0] < 1.0);
1234 assert!(params[1] < 2.0);
1235 assert!(params[2] < 3.0);
1236 }
1237
1238 #[test]
1239 fn test_fused_gradient_clip_normalize() {
1240 let mut gradients = Array1::from_vec(vec![5.0, -3.0, 2.0]);
1241
1242 fused::fused_gradient_clip_normalize(
1243 &mut gradients,
1244 Some(2.0), Some(1.0), );
1247
1248 assert!(gradients.iter().all(|&x| x.abs() <= 1.0));
1250
1251 let norm = gradients.iter().map(|&x| x * x).sum::<f64>().sqrt();
1253 assert!(norm <= 2.0 + 1e-6);
1254 }
1255
1256 #[test]
1257 fn test_mixed_precision_loss_scaler() {
1258 let scaler = mixed_precision::LossScaler::new(65536.0);
1259
1260 let loss = 0.5;
1262 let scaled_loss = scaler.scale_loss(loss);
1263 assert_eq!(scaled_loss, 0.5 * 65536.0);
1264
1265 let mut gradients = Array1::from_vec(vec![65536.0, 131072.0]);
1267 scaler.unscale_gradients(&mut gradients);
1268 assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
1269 assert_relative_eq!(gradients[1], 2.0, epsilon = 1e-6);
1270
1271 let inf_gradients = Array1::from_vec(vec![f64::INFINITY, 1.0]);
1273 assert!(scaler.check_gradients(&inf_gradients));
1274
1275 let finite_gradients = Array1::from_vec(vec![1.0, 2.0]);
1276 assert!(!scaler.check_gradients(&finite_gradients));
1277 }
1278
1279 #[test]
1280 fn test_memory_aware_batch_sizer() {
1281 let mut sizer = adaptive::MemoryAwareBatchSizer::new(32)
1282 .with_memory_threshold(0.8)
1283 .with_adaptation_factor(1.3); assert_eq!(sizer.current_batch_size(), 32);
1286
1287 sizer.adapt(0.9);
1289 let reduced_size = sizer.current_batch_size();
1290 assert!(reduced_size < 32);
1291
1292 sizer.adapt(0.3);
1294 sizer.adapt(0.3); assert!(sizer.current_batch_size() >= 32);
1296
1297 sizer.reset();
1299 assert_eq!(sizer.current_batch_size(), 32);
1300 }
1301
1302 #[test]
1303 fn test_memory_estimation() {
1304 let array1 = Array1::from_vec(vec![1.0; 100]);
1305 let array2 = Array1::from_vec(vec![2.0; 200]);
1306
1307 let arrays = vec![&array1, &array2];
1308 let estimated_size = adaptive::estimate_memory_usage(&arrays);
1309
1310 let expected_size = 300 * std::mem::size_of::<f64>();
1312 assert_eq!(estimated_size, expected_size);
1313 }
1314
1315 #[test]
1316 fn test_gradient_checkpointing_uniform() {
1317 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1318 f64,
1319 scirs2_core::ndarray::Ix1,
1320 > = gradient_checkpointing::GradientCheckpointer::new(
1321 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1322 );
1323 checkpointer.set_max_depth(10);
1324
1325 assert!(checkpointer.should_checkpoint(0));
1327 assert!(!checkpointer.should_checkpoint(1));
1328 assert!(checkpointer.should_checkpoint(2));
1329 assert!(!checkpointer.should_checkpoint(3));
1330 assert!(checkpointer.should_checkpoint(4));
1331
1332 let activation = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1334 checkpointer.store_checkpoint(2, activation.clone());
1335
1336 let retrieved = checkpointer.get_checkpoint(2).unwrap();
1338 assert_eq!(
1339 retrieved.as_slice().unwrap(),
1340 activation.as_slice().unwrap()
1341 );
1342
1343 assert!(checkpointer.get_checkpoint(1).is_none());
1345 }
1346
1347 #[test]
1348 fn test_gradient_checkpointing_logarithmic() {
1349 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1350 f64,
1351 scirs2_core::ndarray::Ix1,
1352 > = gradient_checkpointing::GradientCheckpointer::new(
1353 gradient_checkpointing::CheckpointStrategy::Logarithmic { base: 2.0 },
1354 );
1355
1356 checkpointer.set_max_depth(10);
1358
1359 assert!(checkpointer.should_checkpoint(1));
1361 assert!(checkpointer.should_checkpoint(2));
1362 assert!(!checkpointer.should_checkpoint(3));
1363 assert!(checkpointer.should_checkpoint(4));
1364 assert!(!checkpointer.should_checkpoint(5));
1365 assert!(!checkpointer.should_checkpoint(6));
1366 assert!(!checkpointer.should_checkpoint(7));
1367 assert!(checkpointer.should_checkpoint(8));
1368 }
1369
1370 #[test]
1371 fn test_gradient_checkpointing_custom() {
1372 let pattern = vec![true, false, false, true, false];
1373 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1374 f64,
1375 scirs2_core::ndarray::Ix1,
1376 > = gradient_checkpointing::GradientCheckpointer::new(
1377 gradient_checkpointing::CheckpointStrategy::Custom { pattern },
1378 );
1379
1380 checkpointer.set_max_depth(10);
1382
1383 assert!(checkpointer.should_checkpoint(0));
1385 assert!(!checkpointer.should_checkpoint(1));
1386 assert!(!checkpointer.should_checkpoint(2));
1387 assert!(checkpointer.should_checkpoint(3));
1388 assert!(!checkpointer.should_checkpoint(4));
1389 assert!(!checkpointer.should_checkpoint(5)); }
1391
1392 #[test]
1393 fn test_gradient_checkpointing_memory_tracking() {
1394 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1395 f64,
1396 scirs2_core::ndarray::Ix1,
1397 > = gradient_checkpointing::GradientCheckpointer::new(
1398 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1399 );
1400 checkpointer.set_max_depth(5);
1401
1402 let activation1 = Array1::from_vec(vec![1.0; 100]);
1403 let activation2 = Array1::from_vec(vec![2.0; 200]);
1404
1405 checkpointer.store_checkpoint(0, activation1);
1406 let usage_after_first = checkpointer.memory_usage();
1407 assert!(usage_after_first.current_bytes > 0);
1408
1409 checkpointer.store_checkpoint(1, activation2);
1410 let usage_after_second = checkpointer.memory_usage();
1411 assert!(usage_after_second.current_bytes > usage_after_first.current_bytes);
1412
1413 checkpointer.remove_checkpoint(0);
1415 let usage_after_removal = checkpointer.memory_usage();
1416 assert!(usage_after_removal.current_bytes < usage_after_second.current_bytes);
1417 }
1418
1419 #[test]
1420 fn test_checkpointed_forward() {
1421 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1422 f64,
1423 scirs2_core::ndarray::Ix1,
1424 > = gradient_checkpointing::GradientCheckpointer::new(
1425 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1426 );
1427 checkpointer.set_max_depth(5);
1428
1429 let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1430
1431 let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
1433 let output = x.sum();
1434 let activation = x.mapv(|val| val * 2.0);
1435 Ok((output, activation))
1436 };
1437
1438 let (output, checkpoint) = checkpointer
1439 .checkpointed_forward(0, &input, forward_fn)
1440 .unwrap();
1441
1442 assert_eq!(output, 6.0); assert!(checkpoint.is_some());
1444 let checkpoint = checkpoint.unwrap();
1445 assert_eq!(checkpoint.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
1446 }
1447
1448 #[test]
1449 fn test_recompute_from_checkpoint() {
1450 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1451 f64,
1452 scirs2_core::ndarray::Ix1,
1453 > = gradient_checkpointing::GradientCheckpointer::new(
1454 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1455 );
1456 checkpointer.set_max_depth(10);
1457
1458 let checkpoint0 = Array1::from_vec(vec![1.0, 2.0]);
1460 let checkpoint2 = Array1::from_vec(vec![3.0, 4.0]);
1461
1462 checkpointer.store_checkpoint(0, checkpoint0);
1463 checkpointer.store_checkpoint(2, checkpoint2);
1464
1465 let recompute_fn =
1467 |_depth: usize, x: &Array1<f64>| -> Result<Array1<f64>> { Ok(x.mapv(|val| val + 1.0)) };
1468
1469 let result = checkpointer
1471 .recompute_from_checkpoint(2, 4, recompute_fn)
1472 .unwrap();
1473
1474 assert_eq!(result.as_slice().unwrap(), &[5.0, 6.0]);
1476 }
1477
1478 #[test]
1479 fn test_auto_checkpointer() {
1480 let mut auto_checkpointer: AutoCheckpointer<f64, scirs2_core::ndarray::Ix1> =
1481 gradient_checkpointing::AutoCheckpointer::new(
1482 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1483 0.6, );
1485
1486 let input = Array1::from_vec(vec![1.0, 2.0]);
1487
1488 let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
1490 let output = x.sum();
1491 let activation = x.clone();
1492 Ok((output, activation))
1493 };
1494
1495 for depth in 0..5 {
1497 let (output_checkpoint, _) = auto_checkpointer
1498 .auto_step(depth, &input, forward_fn)
1499 .unwrap();
1500 assert_eq!(output_checkpoint, 3.0); }
1502
1503 let stats = auto_checkpointer.get_memory_stats();
1504 assert!(stats.target_usage > 0.0);
1505 }
1506
1507 #[test]
1508 fn test_memory_stats() {
1509 let stats = gradient_checkpointing::MemoryStats {
1510 current_usage: 0.5,
1511 peak_usage: 0.7,
1512 average_usage: 0.6,
1513 target_usage: 0.6,
1514 checkpoints_stored: 3,
1515 };
1516
1517 assert!(stats.is_within_target(0.1));
1518 assert!(!stats.is_within_target(0.01));
1519
1520 let efficiency = stats.efficiency_score();
1521 assert!(efficiency > 0.8 && efficiency <= 1.0);
1522 }
1523
1524 #[test]
1525 fn test_memory_usage_formatting() {
1526 let usage = gradient_checkpointing::MemoryUsage {
1527 current_bytes: 1024 * 1024, peak_bytes: 2 * 1024 * 1024, total_system_bytes: 8 * 1024 * 1024 * 1024, };
1531
1532 let formatted = usage.format();
1533 assert!(formatted.contains("1.0 MB"));
1534 assert!(formatted.contains("2.0 MB"));
1535 assert!(formatted.contains("8192.0 MB"));
1536
1537 assert_relative_eq!(usage.current_ratio(), 1.0 / 8192.0, epsilon = 1e-6);
1538 assert_relative_eq!(usage.peak_ratio(), 2.0 / 8192.0, epsilon = 1e-6);
1539 }
1540
1541 #[test]
1542 fn test_checkpointing_strategy_optimization() {
1543 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1544 f64,
1545 scirs2_core::ndarray::Ix1,
1546 > = gradient_checkpointing::GradientCheckpointer::new(
1547 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 4 },
1548 );
1549
1550 checkpointer.set_max_depth(10);
1552
1553 let checkpoint = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1555 checkpointer.store_checkpoint(0, checkpoint);
1556
1557 checkpointer.optimize_strategy(0.3); assert!(
1563 checkpointer.should_checkpoint(0)
1564 || checkpointer.should_checkpoint(1)
1565 || checkpointer.should_checkpoint(2)
1566 );
1567 }
1568
1569 #[test]
1570 fn test_checkpointing_disabled() {
1571 let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1572 f64,
1573 scirs2_core::ndarray::Ix1,
1574 > = gradient_checkpointing::GradientCheckpointer::new(
1575 gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1576 );
1577 checkpointer.set_enabled(false);
1578
1579 assert!(!checkpointer.should_checkpoint(0));
1581 assert!(!checkpointer.should_checkpoint(1));
1582 assert!(!checkpointer.should_checkpoint(2));
1583 }
1584}