optirs_core/memory_efficient/
mod.rs

1// Memory-efficient optimizers and utilities
2//
3// This module provides in-place parameter update capabilities and
4// memory-efficient implementations of optimization algorithms.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use std::fmt::Debug;
10use std::ops::{AddAssign, MulAssign, SubAssign};
11
12/// Trait for in-place parameter updates
13pub trait InPlaceOptimizer<A: Float + ScalarOperand + Debug, D: Dimension> {
14    /// Update parameters in-place using the given gradients
15    ///
16    /// This method modifies the parameters directly rather than returning new arrays,
17    /// which can significantly reduce memory usage for large models.
18    fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()>;
19
20    /// Update multiple parameter arrays in-place
21    fn step_list_inplace(
22        &mut self,
23        params_list: &mut [&mut Array<A, D>],
24        gradients_list: &[&Array<A, D>],
25    ) -> Result<()> {
26        if params_list.len() != gradients_list.len() {
27            return Err(OptimError::InvalidConfig(format!(
28                "Number of parameter arrays ({}) does not match number of gradient arrays ({})",
29                params_list.len(),
30                gradients_list.len()
31            )));
32        }
33
34        for (params, grads) in params_list.iter_mut().zip(gradients_list.iter()) {
35            self.step_inplace(params, grads)?;
36        }
37        Ok(())
38    }
39}
40
41/// Memory-efficient SGD optimizer with in-place updates
42#[derive(Debug, Clone)]
43pub struct InPlaceSGD<A: Float> {
44    _learningrate: A,
45    momentum: A,
46    weight_decay: A,
47}
48
49impl<A: Float + ScalarOperand + Debug + Send + Sync> InPlaceSGD<A> {
50    /// Create a new in-place SGD optimizer
51    pub fn new(_learningrate: A) -> Self {
52        Self {
53            _learningrate,
54            momentum: A::zero(),
55            weight_decay: A::zero(),
56        }
57    }
58
59    /// Set momentum
60    pub fn with_momentum(mut self, momentum: A) -> Self {
61        self.momentum = momentum;
62        self
63    }
64
65    /// Set weight decay
66    pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
67        self.weight_decay = weightdecay;
68        self
69    }
70}
71
72impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
73    for InPlaceSGD<A>
74{
75    fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
76        // Apply weight decay if configured
77        if self.weight_decay > A::zero() {
78            params.zip_mut_with(gradients, |p, &g| {
79                *p = *p - self._learningrate * (g + *p * self.weight_decay);
80            });
81        } else {
82            // Simple gradient descent
83            params.zip_mut_with(gradients, |p, &g| {
84                *p = *p - self._learningrate * g;
85            });
86        }
87        Ok(())
88    }
89}
90
91/// Memory-efficient Adam optimizer with in-place updates
92#[derive(Debug)]
93pub struct InPlaceAdam<A: Float, D: Dimension> {
94    _learningrate: A,
95    beta1: A,
96    beta2: A,
97    epsilon: A,
98    weight_decay: A,
99    t: i32,
100    /// First moment estimate (momentum)
101    m: Option<Array<A, D>>,
102    /// Second moment estimate (RMSprop)
103    v: Option<Array<A, D>>,
104}
105
106impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceAdam<A, D> {
107    /// Create a new in-place Adam optimizer
108    pub fn new(_learningrate: A) -> Self {
109        Self {
110            _learningrate,
111            beta1: A::from(0.9).unwrap(),
112            beta2: A::from(0.999).unwrap(),
113            epsilon: A::from(1e-8).unwrap(),
114            weight_decay: A::zero(),
115            t: 0,
116            m: None,
117            v: None,
118        }
119    }
120
121    /// Set beta1 (momentum decay)
122    pub fn with_beta1(mut self, beta1: A) -> Self {
123        self.beta1 = beta1;
124        self
125    }
126
127    /// Set beta2 (RMSprop decay)
128    pub fn with_beta2(mut self, beta2: A) -> Self {
129        self.beta2 = beta2;
130        self
131    }
132
133    /// Set weight decay
134    pub fn with_weight_decay(mut self, weightdecay: A) -> Self {
135        self.weight_decay = weightdecay;
136        self
137    }
138
139    /// Set epsilon
140    pub fn with_epsilon(mut self, epsilon: A) -> Self {
141        self.epsilon = epsilon;
142        self
143    }
144
145    /// Reset optimizer state
146    pub fn reset(&mut self) {
147        self.t = 0;
148        self.m = None;
149        self.v = None;
150    }
151}
152
153impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> InPlaceOptimizer<A, D>
154    for InPlaceAdam<A, D>
155{
156    fn step_inplace(&mut self, params: &mut Array<A, D>, gradients: &Array<A, D>) -> Result<()> {
157        self.t += 1;
158        let _t = A::from(self.t).unwrap();
159
160        // Initialize momentum and variance if needed
161        if self.m.is_none() {
162            self.m = Some(Array::zeros(params.raw_dim()));
163        }
164        if self.v.is_none() {
165            self.v = Some(Array::zeros(params.raw_dim()));
166        }
167
168        let m = self.m.as_mut().unwrap();
169        let v = self.v.as_mut().unwrap();
170
171        // Apply weight decay if configured
172        let grad_with_decay = if self.weight_decay > A::zero() {
173            // Create temporary with weight decay
174            let mut temp = gradients.clone();
175            temp.zip_mut_with(params, |g, &p| {
176                *g = *g + p * self.weight_decay;
177            });
178            temp
179        } else {
180            gradients.clone()
181        };
182
183        // Update biased first moment estimate
184        m.zip_mut_with(&grad_with_decay, |m_i, &g| {
185            *m_i = self.beta1 * *m_i + (A::one() - self.beta1) * g;
186        });
187
188        // Update biased second raw moment estimate
189        v.zip_mut_with(&grad_with_decay, |v_i, &g| {
190            *v_i = self.beta2 * *v_i + (A::one() - self.beta2) * g * g;
191        });
192
193        // Compute bias-corrected moments
194        let bias1 = A::one() - self.beta1.powi(self.t);
195        let bias2 = A::one() - self.beta2.powi(self.t);
196
197        // Update parameters in-place
198        let m_iter = m.iter();
199        let v_iter = v.iter();
200        let params_iter = params.iter_mut();
201
202        for ((p, &m_i), &v_i) in params_iter.zip(m_iter).zip(v_iter) {
203            let m_hat = m_i / bias1;
204            let v_hat = v_i / bias2;
205            *p = *p - self._learningrate * m_hat / (v_hat.sqrt() + self.epsilon);
206        }
207
208        Ok(())
209    }
210}
211
212/// Utility functions for memory-efficient operations
213pub mod utils {
214    use super::*;
215
216    /// Apply a scalar operation in-place
217    pub fn scale_inplace<A, D>(array: &mut Array<A, D>, scalar: A)
218    where
219        A: Float + ScalarOperand + MulAssign,
220        D: Dimension,
221    {
222        array.map_inplace(|x| *x *= scalar);
223    }
224
225    /// Add arrays in-place (a += b)
226    pub fn add_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
227    where
228        A: Float + ScalarOperand + AddAssign,
229        D: Dimension,
230    {
231        a.zip_mut_with(b, |x, &y| *x += y);
232    }
233
234    /// Subtract arrays in-place (a -= b)
235    pub fn subtract_inplace<A, D>(a: &mut Array<A, D>, b: &Array<A, D>)
236    where
237        A: Float + ScalarOperand + SubAssign,
238        D: Dimension,
239    {
240        a.zip_mut_with(b, |x, &y| *x -= y);
241    }
242
243    /// Apply element-wise operation in-place
244    pub fn apply_inplace<A, D, F>(array: &mut Array<A, D>, f: F)
245    where
246        A: Float + ScalarOperand,
247        D: Dimension,
248        F: Fn(&mut A),
249    {
250        array.map_inplace(f);
251    }
252
253    /// Clip values in-place
254    pub fn clip_inplace<A, D>(array: &mut Array<A, D>, min: A, max: A)
255    where
256        A: Float + ScalarOperand,
257        D: Dimension,
258    {
259        array.map_inplace(|x| {
260            if *x < min {
261                *x = min;
262            } else if *x > max {
263                *x = max;
264            }
265        });
266    }
267
268    /// Normalize array in-place (divide by its norm)
269    pub fn normalize_inplace<A, D>(array: &mut Array<A, D>)
270    where
271        A: Float + ScalarOperand + MulAssign,
272        D: Dimension,
273    {
274        let norm = array.mapv(|x| x * x).sum().sqrt();
275        if norm > A::zero() {
276            array.map_inplace(|x| *x *= A::one() / norm);
277        }
278    }
279}
280
281/// Fused operations for maximum memory efficiency
282pub mod fused {
283    use super::*;
284
285    /// Adam optimizer configuration
286    #[derive(Debug, Clone, Copy)]
287    pub struct AdamConfig<A> {
288        pub lr: A,
289        pub beta1: A,
290        pub beta2: A,
291        pub epsilon: A,
292        pub bias1: A,
293        pub bias2: A,
294        pub weight_decay: Option<A>,
295    }
296
297    /// Fused Adam update operation: combines momentum, variance, and parameter update in one pass
298    ///
299    /// This operation fuses all Adam computations into a single loop iteration,
300    /// reducing memory allocations and improving cache efficiency.
301    pub fn fused_adam_update<A, D>(
302        params: &mut Array<A, D>,
303        gradients: &Array<A, D>,
304        m: &mut Array<A, D>,
305        v: &mut Array<A, D>,
306        config: AdamConfig<A>,
307    ) where
308        A: Float + ScalarOperand,
309        D: Dimension,
310    {
311        let one = A::one();
312        let one_minus_beta1 = one - config.beta1;
313        let one_minus_beta2 = one - config.beta2;
314
315        if let Some(wd) = config.weight_decay {
316            // Fused Adam with weight _decay
317            for ((((p, &g), m_val), v_val), bias_corrected) in params
318                .iter_mut()
319                .zip(gradients.iter())
320                .zip(m.iter_mut())
321                .zip(v.iter_mut())
322                .zip(std::iter::repeat((config.bias1, config.bias2)))
323            {
324                // Apply weight _decay to gradient
325                let g_with_decay = g + *p * wd;
326
327                // Update momentum
328                *m_val = config.beta1 * *m_val + one_minus_beta1 * g_with_decay;
329
330                // Update variance
331                *v_val = config.beta2 * *v_val + one_minus_beta2 * g_with_decay * g_with_decay;
332
333                // Bias-corrected estimates and parameter update
334                let m_hat = *m_val / bias_corrected.0;
335                let v_hat = *v_val / bias_corrected.1;
336                *p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
337            }
338        } else {
339            // Fused Adam without weight _decay
340            for ((((p, &g), m_val), v_val), bias_corrected) in params
341                .iter_mut()
342                .zip(gradients.iter())
343                .zip(m.iter_mut())
344                .zip(v.iter_mut())
345                .zip(std::iter::repeat((config.bias1, config.bias2)))
346            {
347                // Update momentum
348                *m_val = config.beta1 * *m_val + one_minus_beta1 * g;
349
350                // Update variance
351                *v_val = config.beta2 * *v_val + one_minus_beta2 * g * g;
352
353                // Bias-corrected estimates and parameter update
354                let m_hat = *m_val / bias_corrected.0;
355                let v_hat = *v_val / bias_corrected.1;
356                *p = *p - config.lr * m_hat / (v_hat.sqrt() + config.epsilon);
357            }
358        }
359    }
360
361    /// Fused SGD with momentum and weight decay
362    pub fn fused_sgd_update<A, D>(
363        params: &mut Array<A, D>,
364        gradients: &Array<A, D>,
365        momentum_buf: Option<&mut Array<A, D>>,
366        lr: A,
367        momentum: A,
368        weight_decay: Option<A>,
369        dampening: A,
370    ) where
371        A: Float + ScalarOperand,
372        D: Dimension,
373    {
374        if let Some(_buf) = momentum_buf {
375            if let Some(wd) = weight_decay {
376                // Fused SGD with momentum and weight _decay
377                for ((p, g), buf_val) in
378                    params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
379                {
380                    let g_with_decay = *g + *p * wd;
381                    *buf_val = momentum * *buf_val + (A::one() - dampening) * g_with_decay;
382                    *p = *p - lr * *buf_val;
383                }
384            } else {
385                // Fused SGD with momentum only
386                for ((p, g), buf_val) in
387                    params.iter_mut().zip(gradients.iter()).zip(_buf.iter_mut())
388                {
389                    *buf_val = momentum * *buf_val + (A::one() - dampening) * *g;
390                    *p = *p - lr * *buf_val;
391                }
392            }
393        } else if let Some(wd) = weight_decay {
394            // Fused SGD with weight _decay only
395            for (p, g) in params.iter_mut().zip(gradients.iter()) {
396                *p = *p - lr * (*g + *p * wd);
397            }
398        } else {
399            // Simple fused SGD
400            for (p, g) in params.iter_mut().zip(gradients.iter()) {
401                *p = *p - lr * *g;
402            }
403        }
404    }
405
406    /// Fused gradient clipping and normalization
407    pub fn fused_gradient_clip_normalize<A, D>(
408        gradients: &mut Array<A, D>,
409        max_norm: Option<A>,
410        clip_value: Option<A>,
411    ) where
412        A: Float + ScalarOperand,
413        D: Dimension,
414    {
415        if let Some(clip_val) = clip_value {
416            // First pass: clip values
417            for g in gradients.iter_mut() {
418                if *g > clip_val {
419                    *g = clip_val;
420                } else if *g < -clip_val {
421                    *g = -clip_val;
422                }
423            }
424        }
425
426        if let Some(max_norm_val) = max_norm {
427            // Second pass: normalize if _norm exceeds max_norm
428            let norm_sq = gradients
429                .iter()
430                .map(|&x| x * x)
431                .fold(A::zero(), |acc, x| acc + x);
432            let _norm = norm_sq.sqrt();
433
434            if _norm > max_norm_val {
435                let scale = max_norm_val / _norm;
436                for g in gradients.iter_mut() {
437                    *g = *g * scale;
438                }
439            }
440        }
441    }
442
443    /// Fused parameter constraint application
444    pub fn fused_apply_constraints<A, D>(
445        params: &mut Array<A, D>,
446        l2_constraint: Option<A>,
447        value_bounds: Option<(A, A)>,
448    ) where
449        A: Float + ScalarOperand,
450        D: Dimension,
451    {
452        // Apply value _bounds first
453        if let Some((min_val, max_val)) = value_bounds {
454            for p in params.iter_mut() {
455                if *p < min_val {
456                    *p = min_val;
457                } else if *p > max_val {
458                    *p = max_val;
459                }
460            }
461        }
462
463        // Apply L2 norm _constraint
464        if let Some(max_norm) = l2_constraint {
465            let norm_sq = params
466                .iter()
467                .map(|&x| x * x)
468                .fold(A::zero(), |acc, x| acc + x);
469            let norm = norm_sq.sqrt();
470
471            if norm > max_norm {
472                let scale = max_norm / norm;
473                for p in params.iter_mut() {
474                    *p = *p * scale;
475                }
476            }
477        }
478    }
479}
480
481/// Mixed-precision training support
482pub mod mixed_precision {
483    use super::*;
484
485    /// Loss scaler for mixed-precision training
486    #[derive(Debug, Clone)]
487    pub struct LossScaler {
488        scale: f32,
489        growth_factor: f32,
490        backoff_factor: f32,
491        growth_interval: usize,
492        steps_since_update: usize,
493    }
494
495    impl LossScaler {
496        /// Create a new loss scaler
497        pub fn new(_initialscale: f32) -> Self {
498            Self {
499                scale: _initialscale,
500                growth_factor: 2.0,
501                backoff_factor: 0.5,
502                growth_interval: 2000,
503                steps_since_update: 0,
504            }
505        }
506
507        /// Get current scale factor
508        pub fn get_scale(&self) -> f32 {
509            self.scale
510        }
511
512        /// Scale loss for backward pass
513        pub fn scale_loss(&self, loss: f32) -> f32 {
514            loss * self.scale
515        }
516
517        /// Unscale gradients after backward pass
518        pub fn unscale_gradients<A, D>(&self, gradients: &mut Array<A, D>)
519        where
520            A: Float + ScalarOperand,
521            D: Dimension,
522        {
523            let inv_scale = A::one() / A::from(self.scale).unwrap();
524            for g in gradients.iter_mut() {
525                *g = *g * inv_scale;
526            }
527        }
528
529        /// Update scale based on gradient overflow detection
530        pub fn update(&mut self, foundinf: bool) {
531            self.steps_since_update += 1;
532
533            if foundinf {
534                // Reduce scale if overflow detected
535                self.scale *= self.backoff_factor;
536                self.steps_since_update = 0;
537            } else if self.steps_since_update >= self.growth_interval {
538                // Increase scale if no overflow for growth_interval steps
539                self.scale *= self.growth_factor;
540                self.steps_since_update = 0;
541            }
542        }
543
544        /// Check if gradients contain infinite or NaN values
545        pub fn check_gradients<A, D>(&self, gradients: &Array<A, D>) -> bool
546        where
547            A: Float + ScalarOperand,
548            D: Dimension,
549        {
550            gradients.iter().any(|&x| !x.is_finite())
551        }
552    }
553}
554
555/// Gradient checkpointing for memory optimization
556pub mod gradient_checkpointing {
557    use super::*;
558    use std::collections::VecDeque;
559
560    /// Checkpointing strategy for gradient computation
561    #[derive(Debug, Clone, PartialEq)]
562    pub enum CheckpointStrategy {
563        /// No checkpointing (store all intermediate values)
564        None,
565        /// Uniform checkpointing (checkpoint every N layers)
566        Uniform {
567            /// Interval between checkpoints
568            interval: usize,
569        },
570        /// Logarithmic checkpointing (checkpoint at exponential intervals)
571        Logarithmic {
572            /// Base for exponential intervals
573            base: f64,
574        },
575        /// Memory-aware checkpointing (adaptive based on memory usage)
576        MemoryAware {
577            /// Memory threshold for triggering checkpoints
578            memory_threshold: f64,
579        },
580        /// Custom checkpointing pattern
581        Custom {
582            /// Pattern of checkpointing decisions
583            pattern: Vec<bool>,
584        },
585    }
586
587    /// Gradient checkpointing manager
588    #[derive(Debug)]
589    pub struct GradientCheckpointer<A: Float, D: Dimension> {
590        /// Checkpointing strategy
591        strategy: CheckpointStrategy,
592        /// Stored checkpoints (layer_index -> activation)
593        checkpoints: std::collections::HashMap<usize, Array<A, D>>,
594        /// Memory usage tracker
595        memory_tracker: MemoryTracker,
596        /// Current computation depth
597        current_depth: usize,
598        /// Maximum depth for this computation
599        max_depth: usize,
600        /// Whether checkpointing is enabled
601        enabled: bool,
602    }
603
604    impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientCheckpointer<A, D> {
605        /// Create a new gradient checkpointer
606        pub fn new(strategy: CheckpointStrategy) -> Self {
607            Self {
608                strategy,
609                checkpoints: std::collections::HashMap::new(),
610                memory_tracker: MemoryTracker::new(),
611                current_depth: 0,
612                max_depth: 0,
613                enabled: true,
614            }
615        }
616
617        /// Set the maximum computation depth
618        pub fn set_max_depth(&mut self, depth: usize) {
619            self.max_depth = depth;
620        }
621
622        /// Enable or disable checkpointing
623        pub fn set_enabled(&mut self, enabled: bool) {
624            self.enabled = enabled;
625        }
626
627        /// Check if we should checkpoint at the current depth
628        pub fn should_checkpoint(&self, depth: usize) -> bool {
629            if !self.enabled || self.max_depth == 0 {
630                return false;
631            }
632
633            match self.strategy {
634                CheckpointStrategy::None => false,
635                CheckpointStrategy::Uniform { interval } => depth.is_multiple_of(interval),
636                CheckpointStrategy::Logarithmic { base } => {
637                    let log_depth = (depth as f64).log(base).floor() as usize;
638                    depth == base.powi(log_depth as i32) as usize
639                }
640                CheckpointStrategy::MemoryAware { memory_threshold } => {
641                    self.memory_tracker.usage_ratio() > memory_threshold
642                }
643                CheckpointStrategy::Custom { ref pattern } => {
644                    if depth < pattern.len() {
645                        pattern[depth]
646                    } else {
647                        false
648                    }
649                }
650            }
651        }
652
653        /// Store a checkpoint
654        pub fn store_checkpoint(&mut self, depth: usize, activation: Array<A, D>) {
655            if self.should_checkpoint(depth) {
656                let memory_size = activation.len() * std::mem::size_of::<A>();
657                self.memory_tracker.add_allocation(memory_size);
658                self.checkpoints.insert(depth, activation);
659            }
660        }
661
662        /// Retrieve a checkpoint
663        pub fn get_checkpoint(&self, depth: usize) -> Option<&Array<A, D>> {
664            self.checkpoints.get(&depth)
665        }
666
667        /// Remove a checkpoint to free memory
668        pub fn remove_checkpoint(&mut self, depth: usize) -> Option<Array<A, D>> {
669            if let Some(checkpoint) = self.checkpoints.remove(&depth) {
670                let memory_size = checkpoint.len() * std::mem::size_of::<A>();
671                self.memory_tracker.remove_allocation(memory_size);
672                Some(checkpoint)
673            } else {
674                None
675            }
676        }
677
678        /// Clear all checkpoints
679        pub fn clear_checkpoints(&mut self) {
680            self.checkpoints.clear();
681            self.memory_tracker.reset();
682        }
683
684        /// Get memory usage information
685        pub fn memory_usage(&self) -> MemoryUsage {
686            self.memory_tracker.usage()
687        }
688
689        /// Optimize checkpointing strategy based on memory usage
690        pub fn optimize_strategy(&mut self, target_memoryusage: f64) {
691            let current_usage = self.memory_tracker.usage_ratio();
692
693            if current_usage > target_memoryusage {
694                // Increase checkpointing frequency to reduce memory _usage
695                self.strategy = match &self.strategy {
696                    CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
697                        interval: (interval / 2).max(1),
698                    },
699                    CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
700                        memory_threshold: target_memoryusage * 0.8,
701                    },
702                    other => other.clone(),
703                };
704            } else if current_usage < target_memoryusage * 0.5 {
705                // Decrease checkpointing frequency to improve performance
706                self.strategy = match &self.strategy {
707                    CheckpointStrategy::Uniform { interval } => CheckpointStrategy::Uniform {
708                        interval: interval * 2,
709                    },
710                    CheckpointStrategy::MemoryAware { .. } => CheckpointStrategy::MemoryAware {
711                        memory_threshold: target_memoryusage * 1.2,
712                    },
713                    other => other.clone(),
714                };
715            }
716        }
717
718        /// Execute a checkpointed computation
719        pub fn checkpointed_forward<F, Output>(
720            &mut self,
721            depth: usize,
722            input: &Array<A, D>,
723            forward_fn: F,
724        ) -> Result<(Output, Option<Array<A, D>>)>
725        where
726            F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
727        {
728            self.current_depth = depth;
729
730            // Execute forward computation
731            let (output, activation) = forward_fn(input)?;
732
733            // Decide whether to store checkpoint
734            let checkpoint = if self.should_checkpoint(depth) {
735                self.store_checkpoint(depth, activation.clone());
736                Some(activation)
737            } else {
738                None
739            };
740
741            Ok((output, checkpoint))
742        }
743
744        /// Recompute activations from checkpoint
745        pub fn recompute_from_checkpoint<F>(
746            &self,
747            start_depth: usize,
748            target_depth: usize,
749            recompute_fn: F,
750        ) -> Result<Array<A, D>>
751        where
752            F: Fn(usize, &Array<A, D>) -> Result<Array<A, D>>,
753        {
754            // Find the nearest checkpoint at or before start_depth
755            let checkpoint_depth = (0..=start_depth)
756                .rev()
757                .find(|&d| self.checkpoints.contains_key(&d))
758                .ok_or_else(|| {
759                    OptimError::InvalidConfig("No checkpoint found for recomputation".to_string())
760                })?;
761
762            let mut current_activation = self.checkpoints[&checkpoint_depth].clone();
763
764            // Recompute forward from checkpoint to target _depth
765            for _depth in (checkpoint_depth + 1)..=target_depth {
766                current_activation = recompute_fn(_depth, &current_activation)?;
767            }
768
769            Ok(current_activation)
770        }
771    }
772
773    /// Memory usage tracking
774    #[derive(Debug, Clone)]
775    pub struct MemoryTracker {
776        allocated_bytes: usize,
777        peak_bytes: usize,
778        total_system_memory: usize,
779    }
780
781    impl Default for MemoryTracker {
782        fn default() -> Self {
783            Self::new()
784        }
785    }
786
787    impl MemoryTracker {
788        /// Create a new memory tracker
789        pub fn new() -> Self {
790            Self {
791                allocated_bytes: 0,
792                peak_bytes: 0,
793                total_system_memory: Self::estimate_system_memory(),
794            }
795        }
796
797        /// Add an allocation
798        pub fn add_allocation(&mut self, bytes: usize) {
799            self.allocated_bytes += bytes;
800            self.peak_bytes = self.peak_bytes.max(self.allocated_bytes);
801        }
802
803        /// Remove an allocation
804        pub fn remove_allocation(&mut self, bytes: usize) {
805            self.allocated_bytes = self.allocated_bytes.saturating_sub(bytes);
806        }
807
808        /// Get current memory usage
809        pub fn usage(&self) -> MemoryUsage {
810            MemoryUsage {
811                current_bytes: self.allocated_bytes,
812                peak_bytes: self.peak_bytes,
813                total_system_bytes: self.total_system_memory,
814            }
815        }
816
817        /// Get memory usage ratio (0.0 to 1.0)
818        pub fn usage_ratio(&self) -> f64 {
819            if self.total_system_memory == 0 {
820                0.0
821            } else {
822                self.allocated_bytes as f64 / self.total_system_memory as f64
823            }
824        }
825
826        /// Reset memory tracking
827        pub fn reset(&mut self) {
828            self.allocated_bytes = 0;
829            self.peak_bytes = 0;
830        }
831
832        /// Estimate total system memory (simplified)
833        fn estimate_system_memory() -> usize {
834            // This is a simplified estimation
835            // In a real implementation, you would use system APIs
836            8 * 1024 * 1024 * 1024 // Assume 8GB
837        }
838    }
839
840    /// Memory usage information
841    #[derive(Debug, Clone, Copy)]
842    pub struct MemoryUsage {
843        /// Current allocated bytes
844        pub current_bytes: usize,
845        /// Peak allocated bytes
846        pub peak_bytes: usize,
847        /// Total system memory bytes
848        pub total_system_bytes: usize,
849    }
850
851    impl MemoryUsage {
852        /// Get current usage as a ratio (0.0 to 1.0)
853        pub fn current_ratio(&self) -> f64 {
854            if self.total_system_bytes == 0 {
855                0.0
856            } else {
857                self.current_bytes as f64 / self.total_system_bytes as f64
858            }
859        }
860
861        /// Get peak usage as a ratio (0.0 to 1.0)
862        pub fn peak_ratio(&self) -> f64 {
863            if self.total_system_bytes == 0 {
864                0.0
865            } else {
866                self.peak_bytes as f64 / self.total_system_bytes as f64
867            }
868        }
869
870        /// Format as human-readable string
871        pub fn format(&self) -> String {
872            format!(
873                "Current: {:.1} MB ({:.1}%), Peak: {:.1} MB ({:.1}%), Total: {:.1} MB",
874                self.current_bytes as f64 / (1024.0 * 1024.0),
875                self.current_ratio() * 100.0,
876                self.peak_bytes as f64 / (1024.0 * 1024.0),
877                self.peak_ratio() * 100.0,
878                self.total_system_bytes as f64 / (1024.0 * 1024.0)
879            )
880        }
881    }
882
883    /// Automatic checkpointing manager for optimization workflows
884    #[derive(Debug)]
885    pub struct AutoCheckpointer<A: Float, D: Dimension> {
886        checkpointer: GradientCheckpointer<A, D>,
887        /// History of memory usage for adaptive optimization
888        memory_history: VecDeque<f64>,
889        /// Target memory usage ratio
890        target_memoryratio: f64,
891        /// Adaptation frequency (steps)
892        adaptation_frequency: usize,
893        /// Current step count
894        step_count: usize,
895    }
896
897    impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> AutoCheckpointer<A, D> {
898        /// Create a new auto checkpointer
899        pub fn new(_initial_strategy: CheckpointStrategy, target_memoryratio: f64) -> Self {
900            Self {
901                checkpointer: GradientCheckpointer::new(_initial_strategy),
902                memory_history: VecDeque::with_capacity(100),
903                target_memoryratio: target_memoryratio.clamp(0.1, 0.9),
904                adaptation_frequency: 10,
905                step_count: 0,
906            }
907        }
908
909        /// Set adaptation frequency
910        pub fn with_adaptation_frequency(mut self, frequency: usize) -> Self {
911            self.adaptation_frequency = frequency.max(1);
912            self
913        }
914
915        /// Execute a step with automatic checkpointing
916        pub fn auto_step<F, Output>(
917            &mut self,
918            depth: usize,
919            input: &Array<A, D>,
920            forward_fn: F,
921        ) -> Result<(Output, Option<Array<A, D>>)>
922        where
923            F: FnOnce(&Array<A, D>) -> Result<(Output, Array<A, D>)>,
924        {
925            self.step_count += 1;
926
927            // Execute checkpointed forward
928            let result = self
929                .checkpointer
930                .checkpointed_forward(depth, input, forward_fn)?;
931
932            // Track memory usage
933            let current_usage = self.checkpointer.memory_usage().current_ratio();
934            self.memory_history.push_back(current_usage);
935            if self.memory_history.len() > 100 {
936                self.memory_history.pop_front();
937            }
938
939            // Adapt strategy periodically
940            if self.step_count.is_multiple_of(self.adaptation_frequency) {
941                self.adapt_strategy();
942            }
943
944            Ok(result)
945        }
946
947        /// Adapt checkpointing strategy based on memory usage history
948        fn adapt_strategy(&mut self) {
949            if self.memory_history.len() < 5 {
950                return;
951            }
952
953            // Calculate average memory usage over recent history
954            let recent_avg = self.memory_history.iter().rev().take(10).sum::<f64>()
955                / 10.0.min(self.memory_history.len() as f64);
956
957            // Optimize strategy if we're significantly off target
958            let deviation = (recent_avg - self.target_memoryratio).abs();
959            if deviation > 0.1 {
960                self.checkpointer.optimize_strategy(self.target_memoryratio);
961            }
962        }
963
964        /// Get checkpointer reference
965        pub fn checkpointer(&self) -> &GradientCheckpointer<A, D> {
966            &self.checkpointer
967        }
968
969        /// Get mutable checkpointer reference
970        pub fn checkpointer_mut(&mut self) -> &mut GradientCheckpointer<A, D> {
971            &mut self.checkpointer
972        }
973
974        /// Get memory usage statistics
975        pub fn get_memory_stats(&self) -> MemoryStats {
976            let usage = self.checkpointer.memory_usage();
977            let avg_usage = if self.memory_history.is_empty() {
978                0.0
979            } else {
980                self.memory_history.iter().sum::<f64>() / self.memory_history.len() as f64
981            };
982
983            MemoryStats {
984                current_usage: usage.current_ratio(),
985                peak_usage: usage.peak_ratio(),
986                average_usage: avg_usage,
987                target_usage: self.target_memoryratio,
988                checkpoints_stored: self.checkpointer.checkpoints.len(),
989            }
990        }
991    }
992
993    /// Memory usage statistics
994    #[derive(Debug, Clone, Copy)]
995    pub struct MemoryStats {
996        /// Current memory usage ratio
997        pub current_usage: f64,
998        /// Peak memory usage ratio
999        pub peak_usage: f64,
1000        /// Average memory usage ratio
1001        pub average_usage: f64,
1002        /// Target memory usage ratio
1003        pub target_usage: f64,
1004        /// Number of checkpoints currently stored
1005        pub checkpoints_stored: usize,
1006    }
1007
1008    impl MemoryStats {
1009        /// Check if memory usage is within target range
1010        pub fn is_within_target(&self, tolerance: f64) -> bool {
1011            (self.current_usage - self.target_usage).abs() <= tolerance
1012        }
1013
1014        /// Get efficiency score (how close to target without exceeding)
1015        pub fn efficiency_score(&self) -> f64 {
1016            if self.current_usage <= self.target_usage {
1017                self.current_usage / self.target_usage
1018            } else {
1019                self.target_usage / self.current_usage
1020            }
1021        }
1022    }
1023}
1024
1025/// Dynamic resource adaptation
1026pub mod adaptive {
1027    use super::*;
1028
1029    /// Memory-aware batch size adapter
1030    #[derive(Debug, Clone)]
1031    pub struct MemoryAwareBatchSizer {
1032        _initial_batchsize: usize,
1033        max_batch_size: usize,
1034        min_batch_size: usize,
1035        current_batch_size: usize,
1036        memory_threshold: f64, // Memory usage threshold (0.0 to 1.0)
1037        adaptation_factor: f64,
1038    }
1039
1040    impl MemoryAwareBatchSizer {
1041        /// Create a new memory-aware batch sizer
1042        pub fn new(_initial_batchsize: usize) -> Self {
1043            Self {
1044                _initial_batchsize,
1045                max_batch_size: _initial_batchsize * 4,
1046                min_batch_size: _initial_batchsize.max(1) / 4,
1047                current_batch_size: _initial_batchsize,
1048                memory_threshold: 0.8,
1049                adaptation_factor: 1.2,
1050            }
1051        }
1052
1053        /// Set memory threshold (0.0 to 1.0)
1054        pub fn with_memory_threshold(mut self, threshold: f64) -> Self {
1055            self.memory_threshold = threshold.clamp(0.1, 0.95);
1056            self
1057        }
1058
1059        /// Set adaptation factor
1060        pub fn with_adaptation_factor(mut self, factor: f64) -> Self {
1061            self.adaptation_factor = factor.max(1.0);
1062            self
1063        }
1064
1065        /// Get current batch size
1066        pub fn current_batch_size(&self) -> usize {
1067            self.current_batch_size
1068        }
1069
1070        /// Adapt batch size based on memory usage
1071        pub fn adapt(&mut self, memory_usageratio: f64) {
1072            if memory_usageratio > self.memory_threshold {
1073                // Reduce batch size if memory usage is high
1074                let new_size = (self.current_batch_size as f64 / self.adaptation_factor) as usize;
1075                self.current_batch_size = new_size.max(self.min_batch_size);
1076            } else if memory_usageratio < self.memory_threshold * 0.7 {
1077                // Increase batch size if memory usage is low
1078                let new_size = (self.current_batch_size as f64 * self.adaptation_factor) as usize;
1079                self.current_batch_size = new_size.min(self.max_batch_size);
1080            }
1081        }
1082
1083        /// Reset to initial batch size
1084        pub fn reset(&mut self) {
1085            self.current_batch_size = self._initial_batchsize;
1086        }
1087    }
1088
1089    /// Memory usage estimator for arrays
1090    pub fn estimate_memory_usage<A, D>(arrays: &[&Array<A, D>]) -> usize
1091    where
1092        A: Sized,
1093        D: Dimension,
1094    {
1095        arrays
1096            .iter()
1097            .map(|arr| arr.len() * std::mem::size_of::<A>())
1098            .sum()
1099    }
1100
1101    /// Get approximate system memory usage ratio
1102    pub fn get_memory_usage_ratio() -> f64 {
1103        // This is a simplified estimation
1104        // In a real implementation, you would use system APIs
1105        // to get actual memory information
1106        0.5 // Placeholder: assume 50% memory usage
1107    }
1108}
1109
1110// Re-export utility functions at module level for convenience
1111pub use utils::{
1112    add_inplace, apply_inplace, clip_inplace, normalize_inplace, scale_inplace, subtract_inplace,
1113};
1114
1115// Re-export new modules
1116pub use adaptive::*;
1117pub use fused::*;
1118pub use gradient_checkpointing::*;
1119pub use mixed_precision::*;
1120
1121#[cfg(test)]
1122mod tests {
1123    use super::*;
1124    use approx::assert_relative_eq;
1125    use scirs2_core::ndarray::Array1;
1126
1127    #[test]
1128    fn test_inplace_sgd() {
1129        let mut optimizer = InPlaceSGD::new(0.1);
1130        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1131        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1132
1133        optimizer.step_inplace(&mut params, &gradients).unwrap();
1134
1135        assert_relative_eq!(params[0], 0.99, epsilon = 1e-6);
1136        assert_relative_eq!(params[1], 1.98, epsilon = 1e-6);
1137        assert_relative_eq!(params[2], 2.97, epsilon = 1e-6);
1138    }
1139
1140    #[test]
1141    fn test_inplace_adam() {
1142        let mut optimizer = InPlaceAdam::new(0.001);
1143        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1144        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1145
1146        // Multiple steps to see momentum effects
1147        for _ in 0..5 {
1148            optimizer.step_inplace(&mut params, &gradients).unwrap();
1149        }
1150
1151        // Verify parameters have been updated
1152        assert!(params[0] < 1.0);
1153        assert!(params[1] < 2.0);
1154        assert!(params[2] < 3.0);
1155    }
1156
1157    #[test]
1158    fn test_utils_scale_inplace() {
1159        let mut array = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1160        utils::scale_inplace(&mut array, 2.0);
1161
1162        assert_eq!(array.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
1163    }
1164
1165    #[test]
1166    fn test_utils_clip_inplace() {
1167        let mut array = Array1::from_vec(vec![0.5, 1.5, 2.5]);
1168        utils::clip_inplace(&mut array, 1.0, 2.0);
1169
1170        assert_eq!(array.as_slice().unwrap(), &[1.0, 1.5, 2.0]);
1171    }
1172
1173    #[test]
1174    fn test_memory_efficiency() {
1175        // Test that in-place operations don't allocate new arrays
1176        let mut params = Array1::from_vec(vec![1.0; 1000]);
1177        let gradients = Array1::from_vec(vec![0.01; 1000]);
1178        let params_ptr = params.as_ptr();
1179
1180        let mut optimizer = InPlaceSGD::new(0.1);
1181        optimizer.step_inplace(&mut params, &gradients).unwrap();
1182
1183        // Verify the same memory is being used
1184        assert_eq!(params_ptr, params.as_ptr());
1185    }
1186
1187    #[test]
1188    fn test_fused_adam_update() {
1189        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1190        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1191        let mut m = Array1::zeros(3);
1192        let mut v = Array1::zeros(3);
1193
1194        let config = fused::AdamConfig {
1195            lr: 0.01,
1196            beta1: 0.9,
1197            beta2: 0.999,
1198            epsilon: 1e-8,
1199            bias1: 0.1,
1200            bias2: 0.001,
1201            weight_decay: None,
1202        };
1203
1204        fused::fused_adam_update(&mut params, &gradients, &mut m, &mut v, config);
1205
1206        // Verify parameters were updated
1207        assert!(params[0] < 1.0);
1208        assert!(params[1] < 2.0);
1209        assert!(params[2] < 3.0);
1210
1211        // Verify momentum and variance were updated
1212        assert!(m[0] > 0.0);
1213        assert!(v[0] > 0.0);
1214    }
1215
1216    #[test]
1217    fn test_fused_sgd_update() {
1218        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1219        let gradients = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1220        let mut momentum_buf = Array1::zeros(3);
1221
1222        fused::fused_sgd_update(
1223            &mut params,
1224            &gradients,
1225            Some(&mut momentum_buf),
1226            0.1,        // lr
1227            0.9,        // momentum
1228            Some(0.01), // weight_decay
1229            0.0,        // dampening
1230        );
1231
1232        // Verify parameters were updated
1233        assert!(params[0] < 1.0);
1234        assert!(params[1] < 2.0);
1235        assert!(params[2] < 3.0);
1236    }
1237
1238    #[test]
1239    fn test_fused_gradient_clip_normalize() {
1240        let mut gradients = Array1::from_vec(vec![5.0, -3.0, 2.0]);
1241
1242        fused::fused_gradient_clip_normalize(
1243            &mut gradients,
1244            Some(2.0), // max_norm
1245            Some(1.0), // clip_value
1246        );
1247
1248        // Verify values were clipped
1249        assert!(gradients.iter().all(|&x| x.abs() <= 1.0));
1250
1251        // Verify norm constraint
1252        let norm = gradients.iter().map(|&x| x * x).sum::<f64>().sqrt();
1253        assert!(norm <= 2.0 + 1e-6);
1254    }
1255
1256    #[test]
1257    fn test_mixed_precision_loss_scaler() {
1258        let scaler = mixed_precision::LossScaler::new(65536.0);
1259
1260        // Test loss scaling
1261        let loss = 0.5;
1262        let scaled_loss = scaler.scale_loss(loss);
1263        assert_eq!(scaled_loss, 0.5 * 65536.0);
1264
1265        // Test gradient unscaling
1266        let mut gradients = Array1::from_vec(vec![65536.0, 131072.0]);
1267        scaler.unscale_gradients(&mut gradients);
1268        assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
1269        assert_relative_eq!(gradients[1], 2.0, epsilon = 1e-6);
1270
1271        // Test overflow detection
1272        let inf_gradients = Array1::from_vec(vec![f64::INFINITY, 1.0]);
1273        assert!(scaler.check_gradients(&inf_gradients));
1274
1275        let finite_gradients = Array1::from_vec(vec![1.0, 2.0]);
1276        assert!(!scaler.check_gradients(&finite_gradients));
1277    }
1278
1279    #[test]
1280    fn test_memory_aware_batch_sizer() {
1281        let mut sizer = adaptive::MemoryAwareBatchSizer::new(32)
1282            .with_memory_threshold(0.8)
1283            .with_adaptation_factor(1.3); // Use smaller factor for more predictable behavior
1284
1285        assert_eq!(sizer.current_batch_size(), 32);
1286
1287        // High memory usage should reduce batch size
1288        sizer.adapt(0.9);
1289        let reduced_size = sizer.current_batch_size();
1290        assert!(reduced_size < 32);
1291
1292        // Low memory usage should increase batch size (multiple calls to ensure growth)
1293        sizer.adapt(0.3);
1294        sizer.adapt(0.3); // Call twice to ensure we exceed original size
1295        assert!(sizer.current_batch_size() >= 32);
1296
1297        // Reset should restore initial size
1298        sizer.reset();
1299        assert_eq!(sizer.current_batch_size(), 32);
1300    }
1301
1302    #[test]
1303    fn test_memory_estimation() {
1304        let array1 = Array1::from_vec(vec![1.0; 100]);
1305        let array2 = Array1::from_vec(vec![2.0; 200]);
1306
1307        let arrays = vec![&array1, &array2];
1308        let estimated_size = adaptive::estimate_memory_usage(&arrays);
1309
1310        // Should be roughly 300 * size_of::<f64>()
1311        let expected_size = 300 * std::mem::size_of::<f64>();
1312        assert_eq!(estimated_size, expected_size);
1313    }
1314
1315    #[test]
1316    fn test_gradient_checkpointing_uniform() {
1317        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1318            f64,
1319            scirs2_core::ndarray::Ix1,
1320        > = gradient_checkpointing::GradientCheckpointer::new(
1321            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1322        );
1323        checkpointer.set_max_depth(10);
1324
1325        // Should checkpoint at depths 0, 2, 4, 6, 8
1326        assert!(checkpointer.should_checkpoint(0));
1327        assert!(!checkpointer.should_checkpoint(1));
1328        assert!(checkpointer.should_checkpoint(2));
1329        assert!(!checkpointer.should_checkpoint(3));
1330        assert!(checkpointer.should_checkpoint(4));
1331
1332        // Store a checkpoint
1333        let activation = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1334        checkpointer.store_checkpoint(2, activation.clone());
1335
1336        // Retrieve checkpoint
1337        let retrieved = checkpointer.get_checkpoint(2).unwrap();
1338        assert_eq!(
1339            retrieved.as_slice().unwrap(),
1340            activation.as_slice().unwrap()
1341        );
1342
1343        // Non-checkpointed depth should return None
1344        assert!(checkpointer.get_checkpoint(1).is_none());
1345    }
1346
1347    #[test]
1348    fn test_gradient_checkpointing_logarithmic() {
1349        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1350            f64,
1351            scirs2_core::ndarray::Ix1,
1352        > = gradient_checkpointing::GradientCheckpointer::new(
1353            gradient_checkpointing::CheckpointStrategy::Logarithmic { base: 2.0 },
1354        );
1355
1356        // Set max depth to enable checkpointing
1357        checkpointer.set_max_depth(10);
1358
1359        // Should checkpoint at powers of 2: 1, 2, 4, 8, 16...
1360        assert!(checkpointer.should_checkpoint(1));
1361        assert!(checkpointer.should_checkpoint(2));
1362        assert!(!checkpointer.should_checkpoint(3));
1363        assert!(checkpointer.should_checkpoint(4));
1364        assert!(!checkpointer.should_checkpoint(5));
1365        assert!(!checkpointer.should_checkpoint(6));
1366        assert!(!checkpointer.should_checkpoint(7));
1367        assert!(checkpointer.should_checkpoint(8));
1368    }
1369
1370    #[test]
1371    fn test_gradient_checkpointing_custom() {
1372        let pattern = vec![true, false, false, true, false];
1373        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1374            f64,
1375            scirs2_core::ndarray::Ix1,
1376        > = gradient_checkpointing::GradientCheckpointer::new(
1377            gradient_checkpointing::CheckpointStrategy::Custom { pattern },
1378        );
1379
1380        // Set max depth to enable checkpointing
1381        checkpointer.set_max_depth(10);
1382
1383        // Should follow the custom pattern
1384        assert!(checkpointer.should_checkpoint(0));
1385        assert!(!checkpointer.should_checkpoint(1));
1386        assert!(!checkpointer.should_checkpoint(2));
1387        assert!(checkpointer.should_checkpoint(3));
1388        assert!(!checkpointer.should_checkpoint(4));
1389        assert!(!checkpointer.should_checkpoint(5)); // Beyond pattern length
1390    }
1391
1392    #[test]
1393    fn test_gradient_checkpointing_memory_tracking() {
1394        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1395            f64,
1396            scirs2_core::ndarray::Ix1,
1397        > = gradient_checkpointing::GradientCheckpointer::new(
1398            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1399        );
1400        checkpointer.set_max_depth(5);
1401
1402        let activation1 = Array1::from_vec(vec![1.0; 100]);
1403        let activation2 = Array1::from_vec(vec![2.0; 200]);
1404
1405        checkpointer.store_checkpoint(0, activation1);
1406        let usage_after_first = checkpointer.memory_usage();
1407        assert!(usage_after_first.current_bytes > 0);
1408
1409        checkpointer.store_checkpoint(1, activation2);
1410        let usage_after_second = checkpointer.memory_usage();
1411        assert!(usage_after_second.current_bytes > usage_after_first.current_bytes);
1412
1413        // Remove first checkpoint
1414        checkpointer.remove_checkpoint(0);
1415        let usage_after_removal = checkpointer.memory_usage();
1416        assert!(usage_after_removal.current_bytes < usage_after_second.current_bytes);
1417    }
1418
1419    #[test]
1420    fn test_checkpointed_forward() {
1421        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1422            f64,
1423            scirs2_core::ndarray::Ix1,
1424        > = gradient_checkpointing::GradientCheckpointer::new(
1425            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1426        );
1427        checkpointer.set_max_depth(5);
1428
1429        let input = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1430
1431        // Simple forward function that doubles the input
1432        let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
1433            let output = x.sum();
1434            let activation = x.mapv(|val| val * 2.0);
1435            Ok((output, activation))
1436        };
1437
1438        let (output, checkpoint) = checkpointer
1439            .checkpointed_forward(0, &input, forward_fn)
1440            .unwrap();
1441
1442        assert_eq!(output, 6.0); // 1 + 2 + 3
1443        assert!(checkpoint.is_some());
1444        let checkpoint = checkpoint.unwrap();
1445        assert_eq!(checkpoint.as_slice().unwrap(), &[2.0, 4.0, 6.0]);
1446    }
1447
1448    #[test]
1449    fn test_recompute_from_checkpoint() {
1450        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1451            f64,
1452            scirs2_core::ndarray::Ix1,
1453        > = gradient_checkpointing::GradientCheckpointer::new(
1454            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1455        );
1456        checkpointer.set_max_depth(10);
1457
1458        // Store checkpoints at depths 0, 2, 4
1459        let checkpoint0 = Array1::from_vec(vec![1.0, 2.0]);
1460        let checkpoint2 = Array1::from_vec(vec![3.0, 4.0]);
1461
1462        checkpointer.store_checkpoint(0, checkpoint0);
1463        checkpointer.store_checkpoint(2, checkpoint2);
1464
1465        // Recompute function that adds 1 to each element
1466        let recompute_fn =
1467            |_depth: usize, x: &Array1<f64>| -> Result<Array1<f64>> { Ok(x.mapv(|val| val + 1.0)) };
1468
1469        // Recompute from checkpoint 2 to depth 4
1470        let result = checkpointer
1471            .recompute_from_checkpoint(2, 4, recompute_fn)
1472            .unwrap();
1473
1474        // Should be [3,4] + 1 + 1 = [5,6]
1475        assert_eq!(result.as_slice().unwrap(), &[5.0, 6.0]);
1476    }
1477
1478    #[test]
1479    fn test_auto_checkpointer() {
1480        let mut auto_checkpointer: AutoCheckpointer<f64, scirs2_core::ndarray::Ix1> =
1481            gradient_checkpointing::AutoCheckpointer::new(
1482                gradient_checkpointing::CheckpointStrategy::Uniform { interval: 2 },
1483                0.6, // target 60% memory usage
1484            );
1485
1486        let input = Array1::from_vec(vec![1.0, 2.0]);
1487
1488        // Simple forward function
1489        let forward_fn = |x: &Array1<f64>| -> Result<(f64, Array1<f64>)> {
1490            let output = x.sum();
1491            let activation = x.clone();
1492            Ok((output, activation))
1493        };
1494
1495        // Execute several steps
1496        for depth in 0..5 {
1497            let (output_checkpoint, _) = auto_checkpointer
1498                .auto_step(depth, &input, forward_fn)
1499                .unwrap();
1500            assert_eq!(output_checkpoint, 3.0); // 1 + 2
1501        }
1502
1503        let stats = auto_checkpointer.get_memory_stats();
1504        assert!(stats.target_usage > 0.0);
1505    }
1506
1507    #[test]
1508    fn test_memory_stats() {
1509        let stats = gradient_checkpointing::MemoryStats {
1510            current_usage: 0.5,
1511            peak_usage: 0.7,
1512            average_usage: 0.6,
1513            target_usage: 0.6,
1514            checkpoints_stored: 3,
1515        };
1516
1517        assert!(stats.is_within_target(0.1));
1518        assert!(!stats.is_within_target(0.01));
1519
1520        let efficiency = stats.efficiency_score();
1521        assert!(efficiency > 0.8 && efficiency <= 1.0);
1522    }
1523
1524    #[test]
1525    fn test_memory_usage_formatting() {
1526        let usage = gradient_checkpointing::MemoryUsage {
1527            current_bytes: 1024 * 1024,                 // 1 MB
1528            peak_bytes: 2 * 1024 * 1024,                // 2 MB
1529            total_system_bytes: 8 * 1024 * 1024 * 1024, // 8 GB
1530        };
1531
1532        let formatted = usage.format();
1533        assert!(formatted.contains("1.0 MB"));
1534        assert!(formatted.contains("2.0 MB"));
1535        assert!(formatted.contains("8192.0 MB"));
1536
1537        assert_relative_eq!(usage.current_ratio(), 1.0 / 8192.0, epsilon = 1e-6);
1538        assert_relative_eq!(usage.peak_ratio(), 2.0 / 8192.0, epsilon = 1e-6);
1539    }
1540
1541    #[test]
1542    fn test_checkpointing_strategy_optimization() {
1543        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1544            f64,
1545            scirs2_core::ndarray::Ix1,
1546        > = gradient_checkpointing::GradientCheckpointer::new(
1547            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 4 },
1548        );
1549
1550        // Set max depth to enable checkpointing
1551        checkpointer.set_max_depth(10);
1552
1553        // Add some memory usage first to trigger optimization
1554        let checkpoint = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1555        checkpointer.store_checkpoint(0, checkpoint);
1556
1557        // Simulate high memory usage - should reduce interval
1558        checkpointer.optimize_strategy(0.3); // Target 30% usage
1559
1560        // Check that strategy was adapted (should checkpoint more frequently)
1561        // With interval 4, should checkpoint at 0, 4, 8... but optimization might change this
1562        assert!(
1563            checkpointer.should_checkpoint(0)
1564                || checkpointer.should_checkpoint(1)
1565                || checkpointer.should_checkpoint(2)
1566        );
1567    }
1568
1569    #[test]
1570    fn test_checkpointing_disabled() {
1571        let mut checkpointer: gradient_checkpointing::GradientCheckpointer<
1572            f64,
1573            scirs2_core::ndarray::Ix1,
1574        > = gradient_checkpointing::GradientCheckpointer::new(
1575            gradient_checkpointing::CheckpointStrategy::Uniform { interval: 1 },
1576        );
1577        checkpointer.set_enabled(false);
1578
1579        // Should not checkpoint when disabled
1580        assert!(!checkpointer.should_checkpoint(0));
1581        assert!(!checkpointer.should_checkpoint(1));
1582        assert!(!checkpointer.should_checkpoint(2));
1583    }
1584}