optirs_core/gradient_processing/
mod.rs

1// Gradient processing utilities for machine learning optimization
2//
3// This module provides comprehensive gradient manipulation utilities including
4// various clipping strategies, normalization, and other processing techniques.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12
13/// Gradient clipping configuration
14#[derive(Debug, Clone)]
15pub struct GradientClipConfig<A: Float> {
16    /// Maximum allowed value for individual gradient elements
17    pub max_value: Option<A>,
18    /// Minimum allowed value for individual gradient elements  
19    pub min_value: Option<A>,
20    /// Maximum allowed L2 norm for the entire gradient vector
21    pub maxnorm: Option<A>,
22    /// Maximum allowed L1 norm
23    pub max_l1norm: Option<A>,
24    /// Whether to apply gradient centralization
25    pub centralization: bool,
26    /// Threshold for zeroing small gradients
27    pub zero_threshold: Option<A>,
28}
29
30impl<A: Float + Send + Sync> Default for GradientClipConfig<A> {
31    fn default() -> Self {
32        Self {
33            max_value: None,
34            min_value: None,
35            maxnorm: None,
36            max_l1norm: None,
37            centralization: false,
38            zero_threshold: None,
39        }
40    }
41}
42
43/// Gradient clipping processor
44pub struct GradientProcessor<A: Float> {
45    config: GradientClipConfig<A>,
46}
47
48impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for GradientProcessor<A> {
49    fn default() -> Self {
50        Self {
51            config: GradientClipConfig::default(),
52        }
53    }
54}
55
56impl<A: Float + ScalarOperand + Debug + Send + Sync> GradientProcessor<A> {
57    /// Create a new gradient processor with default configuration
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    /// Create a new gradient processor with a specific configuration
63    pub fn with_config(config: GradientClipConfig<A>) -> Self {
64        Self { config }
65    }
66
67    /// Set max value clipping
68    pub fn set_max_value(&mut self, value: A) -> &mut Self {
69        self.config.max_value = Some(value);
70        self
71    }
72
73    /// Set min value clipping
74    pub fn set_min_value(&mut self, value: A) -> &mut Self {
75        self.config.min_value = Some(value);
76        self
77    }
78
79    /// Set max L2 norm clipping
80    pub fn set_max_norm(&mut self, value: A) -> &mut Self {
81        self.config.maxnorm = Some(value);
82        self
83    }
84
85    /// Set max L1 norm clipping
86    pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
87        self.config.max_l1norm = Some(value);
88        self
89    }
90
91    /// Enable gradient centralization
92    pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
93        self.config.centralization = enabled;
94        self
95    }
96
97    /// Set threshold for zeroing small gradients
98    pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
99        self.config.zero_threshold = Some(value);
100        self
101    }
102
103    /// Set value clipping range
104    pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
105        self.config.min_value = Some(min);
106        self.config.max_value = Some(max);
107        self
108    }
109
110    /// Set norm clipping
111    pub fn set_norm_clip(&mut self, maxnorm: A) -> &mut Self {
112        self.config.maxnorm = Some(maxnorm);
113        self
114    }
115
116    /// Set L1 norm clipping
117    pub fn set_l1_norm_clip(&mut self, max_l1norm: A) -> &mut Self {
118        self.config.max_l1norm = Some(max_l1norm);
119        self
120    }
121
122    /// Enable gradient centralization
123    pub fn enable_centralization(&mut self) -> &mut Self {
124        self.config.centralization = true;
125        self
126    }
127
128    /// Process gradients according to configuration
129    pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
130        // Apply value clipping if configured
131        if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
132            clip_gradients_by_value(gradients, min, max);
133        }
134
135        // Apply L2 norm clipping if configured
136        if let Some(maxnorm) = self.config.maxnorm {
137            clip_gradient_norm(gradients, maxnorm)?;
138        }
139
140        // Apply L1 norm clipping if configured
141        if let Some(max_l1norm) = self.config.max_l1norm {
142            clip_gradient_l1_norm(gradients, max_l1norm)?;
143        }
144
145        // Apply gradient centralization if enabled
146        if self.config.centralization {
147            gradient_centralization(gradients);
148        }
149
150        // Zero small gradients if threshold is set
151        if let Some(threshold) = self.config.zero_threshold {
152            zero_small_gradients(gradients, threshold);
153        }
154
155        Ok(())
156    }
157}
158
159/// Clip gradient values to a specified range
160#[allow(dead_code)]
161pub fn clip_gradients_by_value<A, D>(
162    gradients: &mut Array<A, D>,
163    min_value: A,
164    max_value: A,
165) -> &mut Array<A, D>
166where
167    A: Float + ScalarOperand,
168    D: Dimension,
169{
170    gradients.mapv_inplace(|x| {
171        if x < min_value {
172            min_value
173        } else if x > max_value {
174            max_value
175        } else {
176            x
177        }
178    });
179    gradients
180}
181
182/// Clip gradient L2 norm (global gradient clipping)
183#[allow(dead_code)]
184pub fn clip_gradient_norm<A, D>(gradients: &mut Array<A, D>, maxnorm: A) -> Result<&mut Array<A, D>>
185where
186    A: Float + ScalarOperand,
187    D: Dimension,
188{
189    if maxnorm <= A::zero() {
190        return Err(OptimError::InvalidConfig(
191            "maxnorm must be positive".to_string(),
192        ));
193    }
194
195    // Calculate current L2 _norm
196    let _norm = gradients
197        .iter()
198        .fold(A::zero(), |acc, &x| acc + x * x)
199        .sqrt();
200
201    // If _norm exceeds maxnorm, scale gradients
202    if _norm > maxnorm {
203        let scale = maxnorm / _norm;
204        gradients.mapv_inplace(|x| x * scale);
205    }
206
207    Ok(gradients)
208}
209
210/// Clip gradient L1 norm
211#[allow(dead_code)]
212pub fn clip_gradient_l1_norm<A, D>(
213    gradients: &mut Array<A, D>,
214    max_l1norm: A,
215) -> Result<&mut Array<A, D>>
216where
217    A: Float + ScalarOperand,
218    D: Dimension,
219{
220    if max_l1norm <= A::zero() {
221        return Err(OptimError::InvalidConfig(
222            "max_l1norm must be positive".to_string(),
223        ));
224    }
225
226    // Calculate current L1 _norm
227    let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
228
229    // If _norm exceeds max_l1norm, scale gradients
230    if l1_norm > max_l1norm {
231        let scale = max_l1norm / l1_norm;
232        gradients.mapv_inplace(|x| x * scale);
233    }
234
235    Ok(gradients)
236}
237
238/// Compute gradient centralization
239#[allow(dead_code)]
240pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
241where
242    A: Float + ScalarOperand,
243    D: Dimension,
244{
245    // Calculate mean
246    let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
247    let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
248
249    // Subtract mean from each element
250    gradients.mapv_inplace(|x| x - mean);
251
252    gradients
253}
254
255/// Zero out small gradient values
256#[allow(dead_code)]
257pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
258where
259    A: Float + ScalarOperand,
260    D: Dimension,
261{
262    let abs_threshold = threshold.abs();
263
264    gradients.mapv_inplace(|x| {
265        if x.abs() < abs_threshold {
266            A::zero()
267        } else {
268            x
269        }
270    });
271
272    gradients
273}
274
275/// Gradient accumulation utility
276#[derive(Debug, Clone)]
277pub struct GradientAccumulator<A: Float, D: Dimension> {
278    /// Accumulated gradients
279    accumulated_gradients: Option<Array<A, D>>,
280    /// Number of accumulated micro-batches
281    num_accumulated: usize,
282    /// Target number of micro-batches before step
283    accumulation_steps: usize,
284    /// Whether to average gradients (vs sum)
285    averagegradients: bool,
286}
287
288impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
289    /// Create a new gradient accumulator
290    ///
291    /// # Arguments
292    ///
293    /// * `accumulation_steps` - Number of micro-batches to accumulate before stepping
294    /// * `averagegradients` - Whether to average gradients (true) or sum them (false)
295    pub fn new(_accumulation_steps: usize, averagegradients: bool) -> Self {
296        Self {
297            accumulated_gradients: None,
298            num_accumulated: 0,
299            accumulation_steps: _accumulation_steps,
300            averagegradients,
301        }
302    }
303
304    /// Add gradients from a micro-batch
305    ///
306    /// # Arguments
307    ///
308    /// * `gradients` - Gradients from the current micro-batch
309    ///
310    /// # Returns
311    ///
312    /// `true` if enough gradients have been accumulated and it's time to step
313    pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
314        if self.accumulated_gradients.is_none() {
315            self.accumulated_gradients = Some(gradients.clone());
316        } else {
317            let acc = self.accumulated_gradients.as_mut().unwrap();
318            for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
319                *acc_val = *acc_val + grad_val;
320            }
321        }
322
323        self.num_accumulated += 1;
324        self.num_accumulated >= self.accumulation_steps
325    }
326
327    /// Get the accumulated gradients and reset the accumulator
328    ///
329    /// # Returns
330    ///
331    /// The accumulated gradients, ready for optimization step
332    pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
333        if let Some(mut gradients) = self.accumulated_gradients.take() {
334            if self.averagegradients && self.num_accumulated > 0 {
335                let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
336                gradients.mapv_inplace(|x| x * scale);
337            }
338            self.num_accumulated = 0;
339            Some(gradients)
340        } else {
341            None
342        }
343    }
344
345    /// Get current accumulation progress
346    pub fn progress(&self) -> (usize, usize) {
347        (self.num_accumulated, self.accumulation_steps)
348    }
349
350    /// Check if ready for optimization step
351    pub fn is_ready(&self) -> bool {
352        self.num_accumulated >= self.accumulation_steps
353    }
354
355    /// Reset the accumulator
356    pub fn reset(&mut self) {
357        self.accumulated_gradients = None;
358        self.num_accumulated = 0;
359    }
360
361    /// Change accumulation steps
362    pub fn set_accumulation_steps(&mut self, steps: usize) {
363        self.accumulation_steps = steps;
364    }
365}
366
367/// Adaptive gradient clipping
368///
369/// Clips gradients based on the ratio of gradient norm to parameter norm.
370/// This is particularly useful for transformer models.
371#[allow(dead_code)]
372pub fn adaptive_gradient_clipping<'a, A, D>(
373    gradients: &'a mut Array<A, D>,
374    parameters: &Array<A, D>,
375    max_ratio: A,
376) -> Result<&'a mut Array<A, D>>
377where
378    A: Float + ScalarOperand,
379    D: Dimension,
380{
381    if max_ratio <= A::zero() {
382        return Err(OptimError::InvalidConfig(
383            "max_ratio must be positive".to_string(),
384        ));
385    }
386
387    let grad_norm = gradients
388        .iter()
389        .fold(A::zero(), |acc, &x| acc + x * x)
390        .sqrt();
391
392    let param_norm = parameters
393        .iter()
394        .fold(A::zero(), |acc, &x| acc + x * x)
395        .sqrt();
396
397    if param_norm > A::zero() && grad_norm > A::zero() {
398        let _ratio = grad_norm / param_norm;
399        if _ratio > max_ratio {
400            let scale = max_ratio / _ratio;
401            gradients.mapv_inplace(|x| x * scale);
402        }
403    }
404
405    Ok(gradients)
406}
407
408/// Add noise to gradients for regularization
409///
410/// # Arguments
411///
412/// * `gradients` - Gradients to add noise to
413/// * `noise_std` - Standard deviation of Gaussian noise to add
414/// * `seed` - Optional seed for reproducible results
415#[allow(dead_code)]
416pub fn add_gradient_noise<A, D>(
417    gradients: &mut Array<A, D>,
418    noise_std: A,
419    seed: Option<u64>,
420) -> &mut Array<A, D>
421where
422    A: Float + ScalarOperand,
423    D: Dimension,
424{
425    use scirs2_core::random::RandNormal;
426    use scirs2_core::random::Rng;
427
428    if noise_std <= A::zero() {
429        return gradients;
430    }
431
432    let mut rng = thread_rng();
433
434    // Create noise array manually to avoid trait compatibility issues
435    let shape = gradients.raw_dim();
436    let mut noise = Array::zeros(shape);
437    let normal = RandNormal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).unwrap();
438
439    for elem in noise.iter_mut() {
440        *elem = A::from(rng.sample(normal)).unwrap_or(A::zero());
441    }
442
443    gradients.zip_mut_with(&noise, |g, &n| {
444        *g = *g + A::from(n).unwrap_or(A::zero());
445    });
446
447    gradients
448}
449
450/// Gradient masking and freezing utilities
451///
452/// Allows selective gradient updates by masking certain parameters
453#[derive(Debug, Clone)]
454pub struct GradientMask<A: Float, D: Dimension> {
455    /// Mask indicating which parameters to update (true = update, false = freeze)
456    mask: Array<bool, D>,
457    /// Optional learning rate multipliers for each parameter
458    lr_multipliers: Option<Array<A, D>>,
459}
460
461impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientMask<A, D> {
462    /// Create a new gradient mask
463    ///
464    /// # Arguments
465    ///
466    /// * `mask` - Boolean mask indicating which parameters to update
467    pub fn new(mask: Array<bool, D>) -> Self {
468        Self {
469            mask,
470            lr_multipliers: None,
471        }
472    }
473
474    /// Create a mask that freezes all parameters
475    pub fn freeze_all(shape: D) -> Self {
476        Self {
477            mask: Array::from_elem(shape, false),
478            lr_multipliers: None,
479        }
480    }
481
482    /// Create a mask that updates all parameters
483    pub fn update_all(shape: D) -> Self {
484        Self {
485            mask: Array::from_elem(shape, true),
486            lr_multipliers: None,
487        }
488    }
489
490    /// Set learning rate multipliers for different parameters
491    pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
492        self.lr_multipliers = Some(multipliers);
493        self
494    }
495
496    /// Apply the mask to gradients
497    ///
498    /// # Arguments
499    ///
500    /// * `gradients` - Gradients to mask
501    ///
502    /// # Returns
503    ///
504    /// Masked gradients where frozen parameters have zero gradients
505    pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
506        gradients.zip_mut_with(&self.mask, |grad, &should_update| {
507            if !should_update {
508                *grad = A::zero();
509            }
510        });
511
512        // Apply learning rate multipliers if present
513        if let Some(multipliers) = &self.lr_multipliers {
514            gradients.zip_mut_with(multipliers, |grad, &mult| {
515                *grad = *grad * mult;
516            });
517        }
518
519        gradients
520    }
521
522    /// Freeze specific parameters by indices
523    pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
524        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
525            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
526        })?;
527
528        for &idx in indices {
529            if idx < flat_mask.len() {
530                flat_mask[idx] = false;
531            } else {
532                return Err(OptimError::InvalidConfig(format!(
533                    "Index {} out of bounds for mask of size {}",
534                    idx,
535                    flat_mask.len()
536                )));
537            }
538        }
539        Ok(())
540    }
541
542    /// Unfreeze specific parameters by indices
543    pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
544        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
545            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
546        })?;
547
548        for &idx in indices {
549            if idx < flat_mask.len() {
550                flat_mask[idx] = true;
551            } else {
552                return Err(OptimError::InvalidConfig(format!(
553                    "Index {} out of bounds for mask of size {}",
554                    idx,
555                    flat_mask.len()
556                )));
557            }
558        }
559        Ok(())
560    }
561
562    /// Get the number of frozen parameters
563    pub fn num_frozen(&self) -> usize {
564        self.mask.iter().filter(|&&x| !x).count()
565    }
566
567    /// Get the number of active (unfrozen) parameters
568    pub fn num_active(&self) -> usize {
569        self.mask.iter().filter(|&&x| x).count()
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use approx::assert_relative_eq;
577    use scirs2_core::ndarray::Array1;
578
579    #[test]
580    fn test_gradient_processor() {
581        let config = GradientClipConfig::<f64> {
582            max_value: Some(5.0),
583            min_value: Some(-5.0),
584            maxnorm: Some(10.0),
585            ..Default::default()
586        };
587
588        let processor = GradientProcessor::with_config(config);
589
590        let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
591        processor.process(&mut gradients).unwrap();
592
593        // Check value clipping
594        assert_eq!(gradients[0], -5.0);
595        assert_eq!(gradients[2], 5.0);
596        assert_eq!(gradients[4], 5.0);
597    }
598
599    #[test]
600    fn test_adaptive_clipping() {
601        let mut gradients = Array1::from_vec(vec![3.0, 4.0]); // norm = 5
602        let parameters = Array1::from_vec(vec![1.0, 0.0]); // norm = 1
603
604        // Gradient/parameter ratio = 5/1 = 5, max_ratio = 2
605        adaptive_gradient_clipping(&mut gradients, &parameters, 2.0).unwrap();
606
607        // After clipping, ratio should be 2
608        let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
609        assert!((new_grad_norm - 2.0).abs() < 1e-6);
610    }
611
612    #[test]
613    fn test_gradient_accumulator() {
614        let mut accumulator = GradientAccumulator::new(3, true);
615
616        // First micro-batch
617        let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
618        assert!(!accumulator.accumulate(&grad1));
619        assert_eq!(accumulator.progress(), (1, 3));
620
621        // Second micro-batch
622        let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
623        assert!(!accumulator.accumulate(&grad2));
624        assert_eq!(accumulator.progress(), (2, 3));
625
626        // Third micro-batch - should trigger ready
627        let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
628        assert!(accumulator.accumulate(&grad3));
629        assert!(accumulator.is_ready());
630
631        // Get accumulated gradients (should be averaged)
632        let final_grads = accumulator.get_and_reset().unwrap();
633        assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); // (1+2+3)/3
634        assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); // (2+3+4)/3
635        assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); // (3+4+5)/3
636
637        // Should be reset now
638        assert_eq!(accumulator.progress(), (0, 3));
639        assert!(!accumulator.is_ready());
640    }
641
642    #[test]
643    fn test_gradient_accumulator_sum_mode() {
644        let mut accumulator = GradientAccumulator::new(2, false); // sum mode
645
646        let grad1 = Array1::from_vec(vec![1.0, 2.0]);
647        let grad2 = Array1::from_vec(vec![3.0, 4.0]);
648
649        accumulator.accumulate(&grad1);
650        accumulator.accumulate(&grad2);
651
652        let final_grads = accumulator.get_and_reset().unwrap();
653        assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); // 1+3
654        assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); // 2+4
655    }
656
657    #[test]
658    fn test_gradient_noise() {
659        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
660        let original = gradients.clone();
661
662        // Add noise with fixed seed for reproducibility
663        add_gradient_noise(&mut gradients, 0.1, Some(42));
664
665        // Gradients should be different but close to original
666        for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
667            assert!(
668                (orig - noisy).abs() < 1.0,
669                "Index {}: {} vs {}",
670                i,
671                orig,
672                noisy
673            );
674        }
675    }
676
677    #[test]
678    fn test_gradient_noise_zero_std() {
679        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
680        let original = gradients.clone();
681
682        // Zero noise should leave gradients unchanged
683        add_gradient_noise(&mut gradients, 0.0, Some(42));
684
685        for (orig, noisy) in original.iter().zip(gradients.iter()) {
686            assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
687        }
688    }
689
690    #[test]
691    fn test_gradient_mask_creation() {
692        let mask = Array1::from_vec(vec![true, false, true]);
693        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
694
695        assert_eq!(grad_mask.num_active(), 2);
696        assert_eq!(grad_mask.num_frozen(), 1);
697    }
698
699    #[test]
700    fn test_gradient_mask_apply() {
701        let mask = Array1::from_vec(vec![true, false, true]);
702        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
703        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
704
705        grad_mask.apply_mask(&mut gradients);
706
707        assert_eq!(gradients.as_slice().unwrap(), &[1.0, 0.0, 3.0]);
708    }
709
710    #[test]
711    fn test_gradient_mask_freeze_unfreeze() {
712        let mask = Array1::from_vec(vec![true, true, true]);
713        let mut grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
714
715        // Freeze some indices
716        grad_mask.freeze_indices(&[0, 2]).unwrap();
717        assert_eq!(grad_mask.num_frozen(), 2);
718        assert_eq!(grad_mask.num_active(), 1);
719
720        // Unfreeze one index
721        grad_mask.unfreeze_indices(&[0]).unwrap();
722        assert_eq!(grad_mask.num_frozen(), 1);
723        assert_eq!(grad_mask.num_active(), 2);
724    }
725
726    #[test]
727    fn test_gradient_mask_with_lr_multipliers() {
728        let mask = Array1::from_vec(vec![true, true, true]);
729        let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
730        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> =
731            GradientMask::new(mask).with_lr_multipliers(multipliers);
732        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
733
734        grad_mask.apply_mask(&mut gradients);
735
736        assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
737        assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); // 2.0 * 0.5
738        assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); // 3.0 * 2.0
739    }
740
741    #[test]
742    fn test_gradient_mask_freeze_all() {
743        let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::freeze_all(
744            scirs2_core::ndarray::Ix1(3),
745        );
746        assert_eq!(grad_mask.num_frozen(), 3);
747        assert_eq!(grad_mask.num_active(), 0);
748    }
749
750    #[test]
751    fn test_gradient_mask_update_all() {
752        let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::update_all(
753            scirs2_core::ndarray::Ix1(3),
754        );
755        assert_eq!(grad_mask.num_frozen(), 0);
756        assert_eq!(grad_mask.num_active(), 3);
757    }
758}