Skip to main content

optirs_core/gradient_processing/
mod.rs

1// Gradient processing utilities for machine learning optimization
2//
3// This module provides comprehensive gradient manipulation utilities including
4// various clipping strategies, normalization, and other processing techniques.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12
13/// Gradient clipping configuration
14#[derive(Debug, Clone)]
15pub struct GradientClipConfig<A: Float> {
16    /// Maximum allowed value for individual gradient elements
17    pub max_value: Option<A>,
18    /// Minimum allowed value for individual gradient elements  
19    pub min_value: Option<A>,
20    /// Maximum allowed L2 norm for the entire gradient vector
21    pub maxnorm: Option<A>,
22    /// Maximum allowed L1 norm
23    pub max_l1norm: Option<A>,
24    /// Whether to apply gradient centralization
25    pub centralization: bool,
26    /// Threshold for zeroing small gradients
27    pub zero_threshold: Option<A>,
28}
29
30impl<A: Float + Send + Sync> Default for GradientClipConfig<A> {
31    fn default() -> Self {
32        Self {
33            max_value: None,
34            min_value: None,
35            maxnorm: None,
36            max_l1norm: None,
37            centralization: false,
38            zero_threshold: None,
39        }
40    }
41}
42
43/// Gradient clipping processor
44pub struct GradientProcessor<A: Float> {
45    config: GradientClipConfig<A>,
46}
47
48impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for GradientProcessor<A> {
49    fn default() -> Self {
50        Self {
51            config: GradientClipConfig::default(),
52        }
53    }
54}
55
56impl<A: Float + ScalarOperand + Debug + Send + Sync> GradientProcessor<A> {
57    /// Create a new gradient processor with default configuration
58    pub fn new() -> Self {
59        Self::default()
60    }
61
62    /// Create a new gradient processor with a specific configuration
63    pub fn with_config(config: GradientClipConfig<A>) -> Self {
64        Self { config }
65    }
66
67    /// Set max value clipping
68    pub fn set_max_value(&mut self, value: A) -> &mut Self {
69        self.config.max_value = Some(value);
70        self
71    }
72
73    /// Set min value clipping
74    pub fn set_min_value(&mut self, value: A) -> &mut Self {
75        self.config.min_value = Some(value);
76        self
77    }
78
79    /// Set max L2 norm clipping
80    pub fn set_max_norm(&mut self, value: A) -> &mut Self {
81        self.config.maxnorm = Some(value);
82        self
83    }
84
85    /// Set max L1 norm clipping
86    pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
87        self.config.max_l1norm = Some(value);
88        self
89    }
90
91    /// Enable gradient centralization
92    pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
93        self.config.centralization = enabled;
94        self
95    }
96
97    /// Set threshold for zeroing small gradients
98    pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
99        self.config.zero_threshold = Some(value);
100        self
101    }
102
103    /// Set value clipping range
104    pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
105        self.config.min_value = Some(min);
106        self.config.max_value = Some(max);
107        self
108    }
109
110    /// Set norm clipping
111    pub fn set_norm_clip(&mut self, maxnorm: A) -> &mut Self {
112        self.config.maxnorm = Some(maxnorm);
113        self
114    }
115
116    /// Set L1 norm clipping
117    pub fn set_l1_norm_clip(&mut self, max_l1norm: A) -> &mut Self {
118        self.config.max_l1norm = Some(max_l1norm);
119        self
120    }
121
122    /// Enable gradient centralization
123    pub fn enable_centralization(&mut self) -> &mut Self {
124        self.config.centralization = true;
125        self
126    }
127
128    /// Process gradients according to configuration
129    pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
130        // Apply value clipping if configured
131        if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
132            clip_gradients_by_value(gradients, min, max);
133        }
134
135        // Apply L2 norm clipping if configured
136        if let Some(maxnorm) = self.config.maxnorm {
137            clip_gradient_norm(gradients, maxnorm)?;
138        }
139
140        // Apply L1 norm clipping if configured
141        if let Some(max_l1norm) = self.config.max_l1norm {
142            clip_gradient_l1_norm(gradients, max_l1norm)?;
143        }
144
145        // Apply gradient centralization if enabled
146        if self.config.centralization {
147            gradient_centralization(gradients);
148        }
149
150        // Zero small gradients if threshold is set
151        if let Some(threshold) = self.config.zero_threshold {
152            zero_small_gradients(gradients, threshold);
153        }
154
155        Ok(())
156    }
157}
158
159/// Clip gradient values to a specified range
160#[allow(dead_code)]
161pub fn clip_gradients_by_value<A, D>(
162    gradients: &mut Array<A, D>,
163    min_value: A,
164    max_value: A,
165) -> &mut Array<A, D>
166where
167    A: Float + ScalarOperand,
168    D: Dimension,
169{
170    gradients.mapv_inplace(|x| {
171        if x < min_value {
172            min_value
173        } else if x > max_value {
174            max_value
175        } else {
176            x
177        }
178    });
179    gradients
180}
181
182/// Clip gradient L2 norm (global gradient clipping)
183#[allow(dead_code)]
184pub fn clip_gradient_norm<A, D>(gradients: &mut Array<A, D>, maxnorm: A) -> Result<&mut Array<A, D>>
185where
186    A: Float + ScalarOperand,
187    D: Dimension,
188{
189    if maxnorm <= A::zero() {
190        return Err(OptimError::InvalidConfig(
191            "maxnorm must be positive".to_string(),
192        ));
193    }
194
195    // Calculate current L2 _norm
196    let _norm = gradients
197        .iter()
198        .fold(A::zero(), |acc, &x| acc + x * x)
199        .sqrt();
200
201    // If _norm exceeds maxnorm, scale gradients
202    if _norm > maxnorm {
203        let scale = maxnorm / _norm;
204        gradients.mapv_inplace(|x| x * scale);
205    }
206
207    Ok(gradients)
208}
209
210/// Clip gradient L1 norm
211#[allow(dead_code)]
212pub fn clip_gradient_l1_norm<A, D>(
213    gradients: &mut Array<A, D>,
214    max_l1norm: A,
215) -> Result<&mut Array<A, D>>
216where
217    A: Float + ScalarOperand,
218    D: Dimension,
219{
220    if max_l1norm <= A::zero() {
221        return Err(OptimError::InvalidConfig(
222            "max_l1norm must be positive".to_string(),
223        ));
224    }
225
226    // Calculate current L1 _norm
227    let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
228
229    // If _norm exceeds max_l1norm, scale gradients
230    if l1_norm > max_l1norm {
231        let scale = max_l1norm / l1_norm;
232        gradients.mapv_inplace(|x| x * scale);
233    }
234
235    Ok(gradients)
236}
237
238/// Compute gradient centralization
239#[allow(dead_code)]
240pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
241where
242    A: Float + ScalarOperand,
243    D: Dimension,
244{
245    // Calculate mean
246    let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
247    let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
248
249    // Subtract mean from each element
250    gradients.mapv_inplace(|x| x - mean);
251
252    gradients
253}
254
255/// Zero out small gradient values
256#[allow(dead_code)]
257pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
258where
259    A: Float + ScalarOperand,
260    D: Dimension,
261{
262    let abs_threshold = threshold.abs();
263
264    gradients.mapv_inplace(|x| {
265        if x.abs() < abs_threshold {
266            A::zero()
267        } else {
268            x
269        }
270    });
271
272    gradients
273}
274
275/// Gradient accumulation utility
276#[derive(Debug, Clone)]
277pub struct GradientAccumulator<A: Float, D: Dimension> {
278    /// Accumulated gradients
279    accumulated_gradients: Option<Array<A, D>>,
280    /// Number of accumulated micro-batches
281    num_accumulated: usize,
282    /// Target number of micro-batches before step
283    accumulation_steps: usize,
284    /// Whether to average gradients (vs sum)
285    averagegradients: bool,
286}
287
288impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
289    /// Create a new gradient accumulator
290    ///
291    /// # Arguments
292    ///
293    /// * `accumulation_steps` - Number of micro-batches to accumulate before stepping
294    /// * `averagegradients` - Whether to average gradients (true) or sum them (false)
295    pub fn new(_accumulation_steps: usize, averagegradients: bool) -> Self {
296        Self {
297            accumulated_gradients: None,
298            num_accumulated: 0,
299            accumulation_steps: _accumulation_steps,
300            averagegradients,
301        }
302    }
303
304    /// Add gradients from a micro-batch
305    ///
306    /// # Arguments
307    ///
308    /// * `gradients` - Gradients from the current micro-batch
309    ///
310    /// # Returns
311    ///
312    /// `true` if enough gradients have been accumulated and it's time to step
313    pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
314        if let Some(acc) = &mut self.accumulated_gradients {
315            for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
316                *acc_val = *acc_val + grad_val;
317            }
318        } else {
319            self.accumulated_gradients = Some(gradients.clone());
320        }
321
322        self.num_accumulated += 1;
323        self.num_accumulated >= self.accumulation_steps
324    }
325
326    /// Get the accumulated gradients and reset the accumulator
327    ///
328    /// # Returns
329    ///
330    /// The accumulated gradients, ready for optimization step
331    pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
332        if let Some(mut gradients) = self.accumulated_gradients.take() {
333            if self.averagegradients && self.num_accumulated > 0 {
334                let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
335                gradients.mapv_inplace(|x| x * scale);
336            }
337            self.num_accumulated = 0;
338            Some(gradients)
339        } else {
340            None
341        }
342    }
343
344    /// Get current accumulation progress
345    pub fn progress(&self) -> (usize, usize) {
346        (self.num_accumulated, self.accumulation_steps)
347    }
348
349    /// Check if ready for optimization step
350    pub fn is_ready(&self) -> bool {
351        self.num_accumulated >= self.accumulation_steps
352    }
353
354    /// Reset the accumulator
355    pub fn reset(&mut self) {
356        self.accumulated_gradients = None;
357        self.num_accumulated = 0;
358    }
359
360    /// Change accumulation steps
361    pub fn set_accumulation_steps(&mut self, steps: usize) {
362        self.accumulation_steps = steps;
363    }
364}
365
366/// Adaptive gradient clipping
367///
368/// Clips gradients based on the ratio of gradient norm to parameter norm.
369/// This is particularly useful for transformer models.
370#[allow(dead_code)]
371pub fn adaptive_gradient_clipping<'a, A, D>(
372    gradients: &'a mut Array<A, D>,
373    parameters: &Array<A, D>,
374    max_ratio: A,
375) -> Result<&'a mut Array<A, D>>
376where
377    A: Float + ScalarOperand,
378    D: Dimension,
379{
380    if max_ratio <= A::zero() {
381        return Err(OptimError::InvalidConfig(
382            "max_ratio must be positive".to_string(),
383        ));
384    }
385
386    let grad_norm = gradients
387        .iter()
388        .fold(A::zero(), |acc, &x| acc + x * x)
389        .sqrt();
390
391    let param_norm = parameters
392        .iter()
393        .fold(A::zero(), |acc, &x| acc + x * x)
394        .sqrt();
395
396    if param_norm > A::zero() && grad_norm > A::zero() {
397        let _ratio = grad_norm / param_norm;
398        if _ratio > max_ratio {
399            let scale = max_ratio / _ratio;
400            gradients.mapv_inplace(|x| x * scale);
401        }
402    }
403
404    Ok(gradients)
405}
406
407/// Add noise to gradients for regularization
408///
409/// # Arguments
410///
411/// * `gradients` - Gradients to add noise to
412/// * `noise_std` - Standard deviation of Gaussian noise to add
413/// * `seed` - Optional seed for reproducible results
414#[allow(dead_code)]
415pub fn add_gradient_noise<A, D>(
416    gradients: &mut Array<A, D>,
417    noise_std: A,
418    seed: Option<u64>,
419) -> &mut Array<A, D>
420where
421    A: Float + ScalarOperand,
422    D: Dimension,
423{
424    use scirs2_core::random::RandNormal;
425    use scirs2_core::random::Rng;
426
427    if noise_std <= A::zero() {
428        return gradients;
429    }
430
431    let mut rng = thread_rng();
432
433    // Create noise array manually to avoid trait compatibility issues
434    let shape = gradients.raw_dim();
435    let mut noise = Array::zeros(shape);
436    let normal = RandNormal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).expect("unwrap failed");
437
438    for elem in noise.iter_mut() {
439        *elem = A::from(rng.sample(normal)).unwrap_or(A::zero());
440    }
441
442    gradients.zip_mut_with(&noise, |g, &n| {
443        *g = *g + A::from(n).unwrap_or(A::zero());
444    });
445
446    gradients
447}
448
449/// Gradient masking and freezing utilities
450///
451/// Allows selective gradient updates by masking certain parameters
452#[derive(Debug, Clone)]
453pub struct GradientMask<A: Float, D: Dimension> {
454    /// Mask indicating which parameters to update (true = update, false = freeze)
455    mask: Array<bool, D>,
456    /// Optional learning rate multipliers for each parameter
457    lr_multipliers: Option<Array<A, D>>,
458}
459
460impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientMask<A, D> {
461    /// Create a new gradient mask
462    ///
463    /// # Arguments
464    ///
465    /// * `mask` - Boolean mask indicating which parameters to update
466    pub fn new(mask: Array<bool, D>) -> Self {
467        Self {
468            mask,
469            lr_multipliers: None,
470        }
471    }
472
473    /// Create a mask that freezes all parameters
474    pub fn freeze_all(shape: D) -> Self {
475        Self {
476            mask: Array::from_elem(shape, false),
477            lr_multipliers: None,
478        }
479    }
480
481    /// Create a mask that updates all parameters
482    pub fn update_all(shape: D) -> Self {
483        Self {
484            mask: Array::from_elem(shape, true),
485            lr_multipliers: None,
486        }
487    }
488
489    /// Set learning rate multipliers for different parameters
490    pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
491        self.lr_multipliers = Some(multipliers);
492        self
493    }
494
495    /// Apply the mask to gradients
496    ///
497    /// # Arguments
498    ///
499    /// * `gradients` - Gradients to mask
500    ///
501    /// # Returns
502    ///
503    /// Masked gradients where frozen parameters have zero gradients
504    pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
505        gradients.zip_mut_with(&self.mask, |grad, &should_update| {
506            if !should_update {
507                *grad = A::zero();
508            }
509        });
510
511        // Apply learning rate multipliers if present
512        if let Some(multipliers) = &self.lr_multipliers {
513            gradients.zip_mut_with(multipliers, |grad, &mult| {
514                *grad = *grad * mult;
515            });
516        }
517
518        gradients
519    }
520
521    /// Freeze specific parameters by indices
522    pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
523        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
524            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
525        })?;
526
527        for &idx in indices {
528            if idx < flat_mask.len() {
529                flat_mask[idx] = false;
530            } else {
531                return Err(OptimError::InvalidConfig(format!(
532                    "Index {} out of bounds for mask of size {}",
533                    idx,
534                    flat_mask.len()
535                )));
536            }
537        }
538        Ok(())
539    }
540
541    /// Unfreeze specific parameters by indices
542    pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
543        let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
544            OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
545        })?;
546
547        for &idx in indices {
548            if idx < flat_mask.len() {
549                flat_mask[idx] = true;
550            } else {
551                return Err(OptimError::InvalidConfig(format!(
552                    "Index {} out of bounds for mask of size {}",
553                    idx,
554                    flat_mask.len()
555                )));
556            }
557        }
558        Ok(())
559    }
560
561    /// Get the number of frozen parameters
562    pub fn num_frozen(&self) -> usize {
563        self.mask.iter().filter(|&&x| !x).count()
564    }
565
566    /// Get the number of active (unfrozen) parameters
567    pub fn num_active(&self) -> usize {
568        self.mask.iter().filter(|&&x| x).count()
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use approx::assert_relative_eq;
576    use scirs2_core::ndarray::Array1;
577
578    #[test]
579    fn test_gradient_processor() {
580        let config = GradientClipConfig::<f64> {
581            max_value: Some(5.0),
582            min_value: Some(-5.0),
583            maxnorm: Some(10.0),
584            ..Default::default()
585        };
586
587        let processor = GradientProcessor::with_config(config);
588
589        let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
590        processor.process(&mut gradients).expect("unwrap failed");
591
592        // Check value clipping
593        assert_eq!(gradients[0], -5.0);
594        assert_eq!(gradients[2], 5.0);
595        assert_eq!(gradients[4], 5.0);
596    }
597
598    #[test]
599    fn test_adaptive_clipping() {
600        let mut gradients = Array1::from_vec(vec![3.0, 4.0]); // norm = 5
601        let parameters = Array1::from_vec(vec![1.0, 0.0]); // norm = 1
602
603        // Gradient/parameter ratio = 5/1 = 5, max_ratio = 2
604        adaptive_gradient_clipping(&mut gradients, &parameters, 2.0).expect("unwrap failed");
605
606        // After clipping, ratio should be 2
607        let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
608        assert!((new_grad_norm - 2.0).abs() < 1e-6);
609    }
610
611    #[test]
612    fn test_gradient_accumulator() {
613        let mut accumulator = GradientAccumulator::new(3, true);
614
615        // First micro-batch
616        let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
617        assert!(!accumulator.accumulate(&grad1));
618        assert_eq!(accumulator.progress(), (1, 3));
619
620        // Second micro-batch
621        let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
622        assert!(!accumulator.accumulate(&grad2));
623        assert_eq!(accumulator.progress(), (2, 3));
624
625        // Third micro-batch - should trigger ready
626        let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
627        assert!(accumulator.accumulate(&grad3));
628        assert!(accumulator.is_ready());
629
630        // Get accumulated gradients (should be averaged)
631        let final_grads = accumulator.get_and_reset().expect("unwrap failed");
632        assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); // (1+2+3)/3
633        assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); // (2+3+4)/3
634        assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); // (3+4+5)/3
635
636        // Should be reset now
637        assert_eq!(accumulator.progress(), (0, 3));
638        assert!(!accumulator.is_ready());
639    }
640
641    #[test]
642    fn test_gradient_accumulator_sum_mode() {
643        let mut accumulator = GradientAccumulator::new(2, false); // sum mode
644
645        let grad1 = Array1::from_vec(vec![1.0, 2.0]);
646        let grad2 = Array1::from_vec(vec![3.0, 4.0]);
647
648        accumulator.accumulate(&grad1);
649        accumulator.accumulate(&grad2);
650
651        let final_grads = accumulator.get_and_reset().expect("unwrap failed");
652        assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); // 1+3
653        assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); // 2+4
654    }
655
656    #[test]
657    fn test_gradient_noise() {
658        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
659        let original = gradients.clone();
660
661        // Add noise with fixed seed for reproducibility
662        add_gradient_noise(&mut gradients, 0.1, Some(42));
663
664        // Gradients should be different but close to original
665        for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
666            assert!(
667                (orig - noisy).abs() < 1.0,
668                "Index {}: {} vs {}",
669                i,
670                orig,
671                noisy
672            );
673        }
674    }
675
676    #[test]
677    fn test_gradient_noise_zero_std() {
678        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
679        let original = gradients.clone();
680
681        // Zero noise should leave gradients unchanged
682        add_gradient_noise(&mut gradients, 0.0, Some(42));
683
684        for (orig, noisy) in original.iter().zip(gradients.iter()) {
685            assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
686        }
687    }
688
689    #[test]
690    fn test_gradient_mask_creation() {
691        let mask = Array1::from_vec(vec![true, false, true]);
692        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
693
694        assert_eq!(grad_mask.num_active(), 2);
695        assert_eq!(grad_mask.num_frozen(), 1);
696    }
697
698    #[test]
699    fn test_gradient_mask_apply() {
700        let mask = Array1::from_vec(vec![true, false, true]);
701        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
702        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
703
704        grad_mask.apply_mask(&mut gradients);
705
706        assert_eq!(
707            gradients.as_slice().expect("unwrap failed"),
708            &[1.0, 0.0, 3.0]
709        );
710    }
711
712    #[test]
713    fn test_gradient_mask_freeze_unfreeze() {
714        let mask = Array1::from_vec(vec![true, true, true]);
715        let mut grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
716
717        // Freeze some indices
718        grad_mask.freeze_indices(&[0, 2]).expect("unwrap failed");
719        assert_eq!(grad_mask.num_frozen(), 2);
720        assert_eq!(grad_mask.num_active(), 1);
721
722        // Unfreeze one index
723        grad_mask.unfreeze_indices(&[0]).expect("unwrap failed");
724        assert_eq!(grad_mask.num_frozen(), 1);
725        assert_eq!(grad_mask.num_active(), 2);
726    }
727
728    #[test]
729    fn test_gradient_mask_with_lr_multipliers() {
730        let mask = Array1::from_vec(vec![true, true, true]);
731        let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
732        let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> =
733            GradientMask::new(mask).with_lr_multipliers(multipliers);
734        let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
735
736        grad_mask.apply_mask(&mut gradients);
737
738        assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
739        assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); // 2.0 * 0.5
740        assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); // 3.0 * 2.0
741    }
742
743    #[test]
744    fn test_gradient_mask_freeze_all() {
745        let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::freeze_all(
746            scirs2_core::ndarray::Ix1(3),
747        );
748        assert_eq!(grad_mask.num_frozen(), 3);
749        assert_eq!(grad_mask.num_active(), 0);
750    }
751
752    #[test]
753    fn test_gradient_mask_update_all() {
754        let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::update_all(
755            scirs2_core::ndarray::Ix1(3),
756        );
757        assert_eq!(grad_mask.num_frozen(), 0);
758        assert_eq!(grad_mask.num_active(), 3);
759    }
760}