1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12
13#[derive(Debug, Clone)]
15pub struct GradientClipConfig<A: Float> {
16 pub max_value: Option<A>,
18 pub min_value: Option<A>,
20 pub maxnorm: Option<A>,
22 pub max_l1norm: Option<A>,
24 pub centralization: bool,
26 pub zero_threshold: Option<A>,
28}
29
30impl<A: Float + Send + Sync> Default for GradientClipConfig<A> {
31 fn default() -> Self {
32 Self {
33 max_value: None,
34 min_value: None,
35 maxnorm: None,
36 max_l1norm: None,
37 centralization: false,
38 zero_threshold: None,
39 }
40 }
41}
42
43pub struct GradientProcessor<A: Float> {
45 config: GradientClipConfig<A>,
46}
47
48impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for GradientProcessor<A> {
49 fn default() -> Self {
50 Self {
51 config: GradientClipConfig::default(),
52 }
53 }
54}
55
56impl<A: Float + ScalarOperand + Debug + Send + Sync> GradientProcessor<A> {
57 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn with_config(config: GradientClipConfig<A>) -> Self {
64 Self { config }
65 }
66
67 pub fn set_max_value(&mut self, value: A) -> &mut Self {
69 self.config.max_value = Some(value);
70 self
71 }
72
73 pub fn set_min_value(&mut self, value: A) -> &mut Self {
75 self.config.min_value = Some(value);
76 self
77 }
78
79 pub fn set_max_norm(&mut self, value: A) -> &mut Self {
81 self.config.maxnorm = Some(value);
82 self
83 }
84
85 pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
87 self.config.max_l1norm = Some(value);
88 self
89 }
90
91 pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
93 self.config.centralization = enabled;
94 self
95 }
96
97 pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
99 self.config.zero_threshold = Some(value);
100 self
101 }
102
103 pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
105 self.config.min_value = Some(min);
106 self.config.max_value = Some(max);
107 self
108 }
109
110 pub fn set_norm_clip(&mut self, maxnorm: A) -> &mut Self {
112 self.config.maxnorm = Some(maxnorm);
113 self
114 }
115
116 pub fn set_l1_norm_clip(&mut self, max_l1norm: A) -> &mut Self {
118 self.config.max_l1norm = Some(max_l1norm);
119 self
120 }
121
122 pub fn enable_centralization(&mut self) -> &mut Self {
124 self.config.centralization = true;
125 self
126 }
127
128 pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
130 if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
132 clip_gradients_by_value(gradients, min, max);
133 }
134
135 if let Some(maxnorm) = self.config.maxnorm {
137 clip_gradient_norm(gradients, maxnorm)?;
138 }
139
140 if let Some(max_l1norm) = self.config.max_l1norm {
142 clip_gradient_l1_norm(gradients, max_l1norm)?;
143 }
144
145 if self.config.centralization {
147 gradient_centralization(gradients);
148 }
149
150 if let Some(threshold) = self.config.zero_threshold {
152 zero_small_gradients(gradients, threshold);
153 }
154
155 Ok(())
156 }
157}
158
159#[allow(dead_code)]
161pub fn clip_gradients_by_value<A, D>(
162 gradients: &mut Array<A, D>,
163 min_value: A,
164 max_value: A,
165) -> &mut Array<A, D>
166where
167 A: Float + ScalarOperand,
168 D: Dimension,
169{
170 gradients.mapv_inplace(|x| {
171 if x < min_value {
172 min_value
173 } else if x > max_value {
174 max_value
175 } else {
176 x
177 }
178 });
179 gradients
180}
181
182#[allow(dead_code)]
184pub fn clip_gradient_norm<A, D>(gradients: &mut Array<A, D>, maxnorm: A) -> Result<&mut Array<A, D>>
185where
186 A: Float + ScalarOperand,
187 D: Dimension,
188{
189 if maxnorm <= A::zero() {
190 return Err(OptimError::InvalidConfig(
191 "maxnorm must be positive".to_string(),
192 ));
193 }
194
195 let _norm = gradients
197 .iter()
198 .fold(A::zero(), |acc, &x| acc + x * x)
199 .sqrt();
200
201 if _norm > maxnorm {
203 let scale = maxnorm / _norm;
204 gradients.mapv_inplace(|x| x * scale);
205 }
206
207 Ok(gradients)
208}
209
210#[allow(dead_code)]
212pub fn clip_gradient_l1_norm<A, D>(
213 gradients: &mut Array<A, D>,
214 max_l1norm: A,
215) -> Result<&mut Array<A, D>>
216where
217 A: Float + ScalarOperand,
218 D: Dimension,
219{
220 if max_l1norm <= A::zero() {
221 return Err(OptimError::InvalidConfig(
222 "max_l1norm must be positive".to_string(),
223 ));
224 }
225
226 let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
228
229 if l1_norm > max_l1norm {
231 let scale = max_l1norm / l1_norm;
232 gradients.mapv_inplace(|x| x * scale);
233 }
234
235 Ok(gradients)
236}
237
238#[allow(dead_code)]
240pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
241where
242 A: Float + ScalarOperand,
243 D: Dimension,
244{
245 let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
247 let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
248
249 gradients.mapv_inplace(|x| x - mean);
251
252 gradients
253}
254
255#[allow(dead_code)]
257pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
258where
259 A: Float + ScalarOperand,
260 D: Dimension,
261{
262 let abs_threshold = threshold.abs();
263
264 gradients.mapv_inplace(|x| {
265 if x.abs() < abs_threshold {
266 A::zero()
267 } else {
268 x
269 }
270 });
271
272 gradients
273}
274
275#[derive(Debug, Clone)]
277pub struct GradientAccumulator<A: Float, D: Dimension> {
278 accumulated_gradients: Option<Array<A, D>>,
280 num_accumulated: usize,
282 accumulation_steps: usize,
284 averagegradients: bool,
286}
287
288impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
289 pub fn new(_accumulation_steps: usize, averagegradients: bool) -> Self {
296 Self {
297 accumulated_gradients: None,
298 num_accumulated: 0,
299 accumulation_steps: _accumulation_steps,
300 averagegradients,
301 }
302 }
303
304 pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
314 if let Some(acc) = &mut self.accumulated_gradients {
315 for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
316 *acc_val = *acc_val + grad_val;
317 }
318 } else {
319 self.accumulated_gradients = Some(gradients.clone());
320 }
321
322 self.num_accumulated += 1;
323 self.num_accumulated >= self.accumulation_steps
324 }
325
326 pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
332 if let Some(mut gradients) = self.accumulated_gradients.take() {
333 if self.averagegradients && self.num_accumulated > 0 {
334 let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
335 gradients.mapv_inplace(|x| x * scale);
336 }
337 self.num_accumulated = 0;
338 Some(gradients)
339 } else {
340 None
341 }
342 }
343
344 pub fn progress(&self) -> (usize, usize) {
346 (self.num_accumulated, self.accumulation_steps)
347 }
348
349 pub fn is_ready(&self) -> bool {
351 self.num_accumulated >= self.accumulation_steps
352 }
353
354 pub fn reset(&mut self) {
356 self.accumulated_gradients = None;
357 self.num_accumulated = 0;
358 }
359
360 pub fn set_accumulation_steps(&mut self, steps: usize) {
362 self.accumulation_steps = steps;
363 }
364}
365
366#[allow(dead_code)]
371pub fn adaptive_gradient_clipping<'a, A, D>(
372 gradients: &'a mut Array<A, D>,
373 parameters: &Array<A, D>,
374 max_ratio: A,
375) -> Result<&'a mut Array<A, D>>
376where
377 A: Float + ScalarOperand,
378 D: Dimension,
379{
380 if max_ratio <= A::zero() {
381 return Err(OptimError::InvalidConfig(
382 "max_ratio must be positive".to_string(),
383 ));
384 }
385
386 let grad_norm = gradients
387 .iter()
388 .fold(A::zero(), |acc, &x| acc + x * x)
389 .sqrt();
390
391 let param_norm = parameters
392 .iter()
393 .fold(A::zero(), |acc, &x| acc + x * x)
394 .sqrt();
395
396 if param_norm > A::zero() && grad_norm > A::zero() {
397 let _ratio = grad_norm / param_norm;
398 if _ratio > max_ratio {
399 let scale = max_ratio / _ratio;
400 gradients.mapv_inplace(|x| x * scale);
401 }
402 }
403
404 Ok(gradients)
405}
406
407#[allow(dead_code)]
415pub fn add_gradient_noise<A, D>(
416 gradients: &mut Array<A, D>,
417 noise_std: A,
418 seed: Option<u64>,
419) -> &mut Array<A, D>
420where
421 A: Float + ScalarOperand,
422 D: Dimension,
423{
424 use scirs2_core::random::RandNormal;
425 use scirs2_core::random::Rng;
426
427 if noise_std <= A::zero() {
428 return gradients;
429 }
430
431 let mut rng = thread_rng();
432
433 let shape = gradients.raw_dim();
435 let mut noise = Array::zeros(shape);
436 let normal = RandNormal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).expect("unwrap failed");
437
438 for elem in noise.iter_mut() {
439 *elem = A::from(rng.sample(normal)).unwrap_or(A::zero());
440 }
441
442 gradients.zip_mut_with(&noise, |g, &n| {
443 *g = *g + A::from(n).unwrap_or(A::zero());
444 });
445
446 gradients
447}
448
449#[derive(Debug, Clone)]
453pub struct GradientMask<A: Float, D: Dimension> {
454 mask: Array<bool, D>,
456 lr_multipliers: Option<Array<A, D>>,
458}
459
460impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientMask<A, D> {
461 pub fn new(mask: Array<bool, D>) -> Self {
467 Self {
468 mask,
469 lr_multipliers: None,
470 }
471 }
472
473 pub fn freeze_all(shape: D) -> Self {
475 Self {
476 mask: Array::from_elem(shape, false),
477 lr_multipliers: None,
478 }
479 }
480
481 pub fn update_all(shape: D) -> Self {
483 Self {
484 mask: Array::from_elem(shape, true),
485 lr_multipliers: None,
486 }
487 }
488
489 pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
491 self.lr_multipliers = Some(multipliers);
492 self
493 }
494
495 pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
505 gradients.zip_mut_with(&self.mask, |grad, &should_update| {
506 if !should_update {
507 *grad = A::zero();
508 }
509 });
510
511 if let Some(multipliers) = &self.lr_multipliers {
513 gradients.zip_mut_with(multipliers, |grad, &mult| {
514 *grad = *grad * mult;
515 });
516 }
517
518 gradients
519 }
520
521 pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
523 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
524 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
525 })?;
526
527 for &idx in indices {
528 if idx < flat_mask.len() {
529 flat_mask[idx] = false;
530 } else {
531 return Err(OptimError::InvalidConfig(format!(
532 "Index {} out of bounds for mask of size {}",
533 idx,
534 flat_mask.len()
535 )));
536 }
537 }
538 Ok(())
539 }
540
541 pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
543 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
544 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
545 })?;
546
547 for &idx in indices {
548 if idx < flat_mask.len() {
549 flat_mask[idx] = true;
550 } else {
551 return Err(OptimError::InvalidConfig(format!(
552 "Index {} out of bounds for mask of size {}",
553 idx,
554 flat_mask.len()
555 )));
556 }
557 }
558 Ok(())
559 }
560
561 pub fn num_frozen(&self) -> usize {
563 self.mask.iter().filter(|&&x| !x).count()
564 }
565
566 pub fn num_active(&self) -> usize {
568 self.mask.iter().filter(|&&x| x).count()
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use approx::assert_relative_eq;
576 use scirs2_core::ndarray::Array1;
577
578 #[test]
579 fn test_gradient_processor() {
580 let config = GradientClipConfig::<f64> {
581 max_value: Some(5.0),
582 min_value: Some(-5.0),
583 maxnorm: Some(10.0),
584 ..Default::default()
585 };
586
587 let processor = GradientProcessor::with_config(config);
588
589 let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
590 processor.process(&mut gradients).expect("unwrap failed");
591
592 assert_eq!(gradients[0], -5.0);
594 assert_eq!(gradients[2], 5.0);
595 assert_eq!(gradients[4], 5.0);
596 }
597
598 #[test]
599 fn test_adaptive_clipping() {
600 let mut gradients = Array1::from_vec(vec![3.0, 4.0]); let parameters = Array1::from_vec(vec![1.0, 0.0]); adaptive_gradient_clipping(&mut gradients, ¶meters, 2.0).expect("unwrap failed");
605
606 let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
608 assert!((new_grad_norm - 2.0).abs() < 1e-6);
609 }
610
611 #[test]
612 fn test_gradient_accumulator() {
613 let mut accumulator = GradientAccumulator::new(3, true);
614
615 let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
617 assert!(!accumulator.accumulate(&grad1));
618 assert_eq!(accumulator.progress(), (1, 3));
619
620 let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
622 assert!(!accumulator.accumulate(&grad2));
623 assert_eq!(accumulator.progress(), (2, 3));
624
625 let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
627 assert!(accumulator.accumulate(&grad3));
628 assert!(accumulator.is_ready());
629
630 let final_grads = accumulator.get_and_reset().expect("unwrap failed");
632 assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); assert_eq!(accumulator.progress(), (0, 3));
638 assert!(!accumulator.is_ready());
639 }
640
641 #[test]
642 fn test_gradient_accumulator_sum_mode() {
643 let mut accumulator = GradientAccumulator::new(2, false); let grad1 = Array1::from_vec(vec![1.0, 2.0]);
646 let grad2 = Array1::from_vec(vec![3.0, 4.0]);
647
648 accumulator.accumulate(&grad1);
649 accumulator.accumulate(&grad2);
650
651 let final_grads = accumulator.get_and_reset().expect("unwrap failed");
652 assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); }
655
656 #[test]
657 fn test_gradient_noise() {
658 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
659 let original = gradients.clone();
660
661 add_gradient_noise(&mut gradients, 0.1, Some(42));
663
664 for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
666 assert!(
667 (orig - noisy).abs() < 1.0,
668 "Index {}: {} vs {}",
669 i,
670 orig,
671 noisy
672 );
673 }
674 }
675
676 #[test]
677 fn test_gradient_noise_zero_std() {
678 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
679 let original = gradients.clone();
680
681 add_gradient_noise(&mut gradients, 0.0, Some(42));
683
684 for (orig, noisy) in original.iter().zip(gradients.iter()) {
685 assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
686 }
687 }
688
689 #[test]
690 fn test_gradient_mask_creation() {
691 let mask = Array1::from_vec(vec![true, false, true]);
692 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
693
694 assert_eq!(grad_mask.num_active(), 2);
695 assert_eq!(grad_mask.num_frozen(), 1);
696 }
697
698 #[test]
699 fn test_gradient_mask_apply() {
700 let mask = Array1::from_vec(vec![true, false, true]);
701 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
702 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
703
704 grad_mask.apply_mask(&mut gradients);
705
706 assert_eq!(
707 gradients.as_slice().expect("unwrap failed"),
708 &[1.0, 0.0, 3.0]
709 );
710 }
711
712 #[test]
713 fn test_gradient_mask_freeze_unfreeze() {
714 let mask = Array1::from_vec(vec![true, true, true]);
715 let mut grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
716
717 grad_mask.freeze_indices(&[0, 2]).expect("unwrap failed");
719 assert_eq!(grad_mask.num_frozen(), 2);
720 assert_eq!(grad_mask.num_active(), 1);
721
722 grad_mask.unfreeze_indices(&[0]).expect("unwrap failed");
724 assert_eq!(grad_mask.num_frozen(), 1);
725 assert_eq!(grad_mask.num_active(), 2);
726 }
727
728 #[test]
729 fn test_gradient_mask_with_lr_multipliers() {
730 let mask = Array1::from_vec(vec![true, true, true]);
731 let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
732 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> =
733 GradientMask::new(mask).with_lr_multipliers(multipliers);
734 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
735
736 grad_mask.apply_mask(&mut gradients);
737
738 assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
739 assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); }
742
743 #[test]
744 fn test_gradient_mask_freeze_all() {
745 let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::freeze_all(
746 scirs2_core::ndarray::Ix1(3),
747 );
748 assert_eq!(grad_mask.num_frozen(), 3);
749 assert_eq!(grad_mask.num_active(), 0);
750 }
751
752 #[test]
753 fn test_gradient_mask_update_all() {
754 let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::update_all(
755 scirs2_core::ndarray::Ix1(3),
756 );
757 assert_eq!(grad_mask.num_frozen(), 0);
758 assert_eq!(grad_mask.num_active(), 3);
759 }
760}