1use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{thread_rng, Rng};
9use std::fmt::Debug;
10
11use crate::error::{OptimError, Result};
12
13#[derive(Debug, Clone)]
15pub struct GradientClipConfig<A: Float> {
16 pub max_value: Option<A>,
18 pub min_value: Option<A>,
20 pub maxnorm: Option<A>,
22 pub max_l1norm: Option<A>,
24 pub centralization: bool,
26 pub zero_threshold: Option<A>,
28}
29
30impl<A: Float + Send + Sync> Default for GradientClipConfig<A> {
31 fn default() -> Self {
32 Self {
33 max_value: None,
34 min_value: None,
35 maxnorm: None,
36 max_l1norm: None,
37 centralization: false,
38 zero_threshold: None,
39 }
40 }
41}
42
43pub struct GradientProcessor<A: Float> {
45 config: GradientClipConfig<A>,
46}
47
48impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for GradientProcessor<A> {
49 fn default() -> Self {
50 Self {
51 config: GradientClipConfig::default(),
52 }
53 }
54}
55
56impl<A: Float + ScalarOperand + Debug + Send + Sync> GradientProcessor<A> {
57 pub fn new() -> Self {
59 Self::default()
60 }
61
62 pub fn with_config(config: GradientClipConfig<A>) -> Self {
64 Self { config }
65 }
66
67 pub fn set_max_value(&mut self, value: A) -> &mut Self {
69 self.config.max_value = Some(value);
70 self
71 }
72
73 pub fn set_min_value(&mut self, value: A) -> &mut Self {
75 self.config.min_value = Some(value);
76 self
77 }
78
79 pub fn set_max_norm(&mut self, value: A) -> &mut Self {
81 self.config.maxnorm = Some(value);
82 self
83 }
84
85 pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
87 self.config.max_l1norm = Some(value);
88 self
89 }
90
91 pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
93 self.config.centralization = enabled;
94 self
95 }
96
97 pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
99 self.config.zero_threshold = Some(value);
100 self
101 }
102
103 pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
105 self.config.min_value = Some(min);
106 self.config.max_value = Some(max);
107 self
108 }
109
110 pub fn set_norm_clip(&mut self, maxnorm: A) -> &mut Self {
112 self.config.maxnorm = Some(maxnorm);
113 self
114 }
115
116 pub fn set_l1_norm_clip(&mut self, max_l1norm: A) -> &mut Self {
118 self.config.max_l1norm = Some(max_l1norm);
119 self
120 }
121
122 pub fn enable_centralization(&mut self) -> &mut Self {
124 self.config.centralization = true;
125 self
126 }
127
128 pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
130 if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
132 clip_gradients_by_value(gradients, min, max);
133 }
134
135 if let Some(maxnorm) = self.config.maxnorm {
137 clip_gradient_norm(gradients, maxnorm)?;
138 }
139
140 if let Some(max_l1norm) = self.config.max_l1norm {
142 clip_gradient_l1_norm(gradients, max_l1norm)?;
143 }
144
145 if self.config.centralization {
147 gradient_centralization(gradients);
148 }
149
150 if let Some(threshold) = self.config.zero_threshold {
152 zero_small_gradients(gradients, threshold);
153 }
154
155 Ok(())
156 }
157}
158
159#[allow(dead_code)]
161pub fn clip_gradients_by_value<A, D>(
162 gradients: &mut Array<A, D>,
163 min_value: A,
164 max_value: A,
165) -> &mut Array<A, D>
166where
167 A: Float + ScalarOperand,
168 D: Dimension,
169{
170 gradients.mapv_inplace(|x| {
171 if x < min_value {
172 min_value
173 } else if x > max_value {
174 max_value
175 } else {
176 x
177 }
178 });
179 gradients
180}
181
182#[allow(dead_code)]
184pub fn clip_gradient_norm<A, D>(gradients: &mut Array<A, D>, maxnorm: A) -> Result<&mut Array<A, D>>
185where
186 A: Float + ScalarOperand,
187 D: Dimension,
188{
189 if maxnorm <= A::zero() {
190 return Err(OptimError::InvalidConfig(
191 "maxnorm must be positive".to_string(),
192 ));
193 }
194
195 let _norm = gradients
197 .iter()
198 .fold(A::zero(), |acc, &x| acc + x * x)
199 .sqrt();
200
201 if _norm > maxnorm {
203 let scale = maxnorm / _norm;
204 gradients.mapv_inplace(|x| x * scale);
205 }
206
207 Ok(gradients)
208}
209
210#[allow(dead_code)]
212pub fn clip_gradient_l1_norm<A, D>(
213 gradients: &mut Array<A, D>,
214 max_l1norm: A,
215) -> Result<&mut Array<A, D>>
216where
217 A: Float + ScalarOperand,
218 D: Dimension,
219{
220 if max_l1norm <= A::zero() {
221 return Err(OptimError::InvalidConfig(
222 "max_l1norm must be positive".to_string(),
223 ));
224 }
225
226 let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
228
229 if l1_norm > max_l1norm {
231 let scale = max_l1norm / l1_norm;
232 gradients.mapv_inplace(|x| x * scale);
233 }
234
235 Ok(gradients)
236}
237
238#[allow(dead_code)]
240pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
241where
242 A: Float + ScalarOperand,
243 D: Dimension,
244{
245 let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
247 let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
248
249 gradients.mapv_inplace(|x| x - mean);
251
252 gradients
253}
254
255#[allow(dead_code)]
257pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
258where
259 A: Float + ScalarOperand,
260 D: Dimension,
261{
262 let abs_threshold = threshold.abs();
263
264 gradients.mapv_inplace(|x| {
265 if x.abs() < abs_threshold {
266 A::zero()
267 } else {
268 x
269 }
270 });
271
272 gradients
273}
274
275#[derive(Debug, Clone)]
277pub struct GradientAccumulator<A: Float, D: Dimension> {
278 accumulated_gradients: Option<Array<A, D>>,
280 num_accumulated: usize,
282 accumulation_steps: usize,
284 averagegradients: bool,
286}
287
288impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
289 pub fn new(_accumulation_steps: usize, averagegradients: bool) -> Self {
296 Self {
297 accumulated_gradients: None,
298 num_accumulated: 0,
299 accumulation_steps: _accumulation_steps,
300 averagegradients,
301 }
302 }
303
304 pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
314 if self.accumulated_gradients.is_none() {
315 self.accumulated_gradients = Some(gradients.clone());
316 } else {
317 let acc = self.accumulated_gradients.as_mut().unwrap();
318 for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
319 *acc_val = *acc_val + grad_val;
320 }
321 }
322
323 self.num_accumulated += 1;
324 self.num_accumulated >= self.accumulation_steps
325 }
326
327 pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
333 if let Some(mut gradients) = self.accumulated_gradients.take() {
334 if self.averagegradients && self.num_accumulated > 0 {
335 let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
336 gradients.mapv_inplace(|x| x * scale);
337 }
338 self.num_accumulated = 0;
339 Some(gradients)
340 } else {
341 None
342 }
343 }
344
345 pub fn progress(&self) -> (usize, usize) {
347 (self.num_accumulated, self.accumulation_steps)
348 }
349
350 pub fn is_ready(&self) -> bool {
352 self.num_accumulated >= self.accumulation_steps
353 }
354
355 pub fn reset(&mut self) {
357 self.accumulated_gradients = None;
358 self.num_accumulated = 0;
359 }
360
361 pub fn set_accumulation_steps(&mut self, steps: usize) {
363 self.accumulation_steps = steps;
364 }
365}
366
367#[allow(dead_code)]
372pub fn adaptive_gradient_clipping<'a, A, D>(
373 gradients: &'a mut Array<A, D>,
374 parameters: &Array<A, D>,
375 max_ratio: A,
376) -> Result<&'a mut Array<A, D>>
377where
378 A: Float + ScalarOperand,
379 D: Dimension,
380{
381 if max_ratio <= A::zero() {
382 return Err(OptimError::InvalidConfig(
383 "max_ratio must be positive".to_string(),
384 ));
385 }
386
387 let grad_norm = gradients
388 .iter()
389 .fold(A::zero(), |acc, &x| acc + x * x)
390 .sqrt();
391
392 let param_norm = parameters
393 .iter()
394 .fold(A::zero(), |acc, &x| acc + x * x)
395 .sqrt();
396
397 if param_norm > A::zero() && grad_norm > A::zero() {
398 let _ratio = grad_norm / param_norm;
399 if _ratio > max_ratio {
400 let scale = max_ratio / _ratio;
401 gradients.mapv_inplace(|x| x * scale);
402 }
403 }
404
405 Ok(gradients)
406}
407
408#[allow(dead_code)]
416pub fn add_gradient_noise<A, D>(
417 gradients: &mut Array<A, D>,
418 noise_std: A,
419 seed: Option<u64>,
420) -> &mut Array<A, D>
421where
422 A: Float + ScalarOperand,
423 D: Dimension,
424{
425 use scirs2_core::random::RandNormal;
426 use scirs2_core::random::Rng;
427
428 if noise_std <= A::zero() {
429 return gradients;
430 }
431
432 let mut rng = thread_rng();
433
434 let shape = gradients.raw_dim();
436 let mut noise = Array::zeros(shape);
437 let normal = RandNormal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).unwrap();
438
439 for elem in noise.iter_mut() {
440 *elem = A::from(rng.sample(normal)).unwrap_or(A::zero());
441 }
442
443 gradients.zip_mut_with(&noise, |g, &n| {
444 *g = *g + A::from(n).unwrap_or(A::zero());
445 });
446
447 gradients
448}
449
450#[derive(Debug, Clone)]
454pub struct GradientMask<A: Float, D: Dimension> {
455 mask: Array<bool, D>,
457 lr_multipliers: Option<Array<A, D>>,
459}
460
461impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientMask<A, D> {
462 pub fn new(mask: Array<bool, D>) -> Self {
468 Self {
469 mask,
470 lr_multipliers: None,
471 }
472 }
473
474 pub fn freeze_all(shape: D) -> Self {
476 Self {
477 mask: Array::from_elem(shape, false),
478 lr_multipliers: None,
479 }
480 }
481
482 pub fn update_all(shape: D) -> Self {
484 Self {
485 mask: Array::from_elem(shape, true),
486 lr_multipliers: None,
487 }
488 }
489
490 pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
492 self.lr_multipliers = Some(multipliers);
493 self
494 }
495
496 pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
506 gradients.zip_mut_with(&self.mask, |grad, &should_update| {
507 if !should_update {
508 *grad = A::zero();
509 }
510 });
511
512 if let Some(multipliers) = &self.lr_multipliers {
514 gradients.zip_mut_with(multipliers, |grad, &mult| {
515 *grad = *grad * mult;
516 });
517 }
518
519 gradients
520 }
521
522 pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
524 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
525 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
526 })?;
527
528 for &idx in indices {
529 if idx < flat_mask.len() {
530 flat_mask[idx] = false;
531 } else {
532 return Err(OptimError::InvalidConfig(format!(
533 "Index {} out of bounds for mask of size {}",
534 idx,
535 flat_mask.len()
536 )));
537 }
538 }
539 Ok(())
540 }
541
542 pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
544 let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
545 OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
546 })?;
547
548 for &idx in indices {
549 if idx < flat_mask.len() {
550 flat_mask[idx] = true;
551 } else {
552 return Err(OptimError::InvalidConfig(format!(
553 "Index {} out of bounds for mask of size {}",
554 idx,
555 flat_mask.len()
556 )));
557 }
558 }
559 Ok(())
560 }
561
562 pub fn num_frozen(&self) -> usize {
564 self.mask.iter().filter(|&&x| !x).count()
565 }
566
567 pub fn num_active(&self) -> usize {
569 self.mask.iter().filter(|&&x| x).count()
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use approx::assert_relative_eq;
577 use scirs2_core::ndarray::Array1;
578
579 #[test]
580 fn test_gradient_processor() {
581 let config = GradientClipConfig::<f64> {
582 max_value: Some(5.0),
583 min_value: Some(-5.0),
584 maxnorm: Some(10.0),
585 ..Default::default()
586 };
587
588 let processor = GradientProcessor::with_config(config);
589
590 let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
591 processor.process(&mut gradients).unwrap();
592
593 assert_eq!(gradients[0], -5.0);
595 assert_eq!(gradients[2], 5.0);
596 assert_eq!(gradients[4], 5.0);
597 }
598
599 #[test]
600 fn test_adaptive_clipping() {
601 let mut gradients = Array1::from_vec(vec![3.0, 4.0]); let parameters = Array1::from_vec(vec![1.0, 0.0]); adaptive_gradient_clipping(&mut gradients, ¶meters, 2.0).unwrap();
606
607 let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
609 assert!((new_grad_norm - 2.0).abs() < 1e-6);
610 }
611
612 #[test]
613 fn test_gradient_accumulator() {
614 let mut accumulator = GradientAccumulator::new(3, true);
615
616 let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
618 assert!(!accumulator.accumulate(&grad1));
619 assert_eq!(accumulator.progress(), (1, 3));
620
621 let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
623 assert!(!accumulator.accumulate(&grad2));
624 assert_eq!(accumulator.progress(), (2, 3));
625
626 let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
628 assert!(accumulator.accumulate(&grad3));
629 assert!(accumulator.is_ready());
630
631 let final_grads = accumulator.get_and_reset().unwrap();
633 assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6); assert_eq!(accumulator.progress(), (0, 3));
639 assert!(!accumulator.is_ready());
640 }
641
642 #[test]
643 fn test_gradient_accumulator_sum_mode() {
644 let mut accumulator = GradientAccumulator::new(2, false); let grad1 = Array1::from_vec(vec![1.0, 2.0]);
647 let grad2 = Array1::from_vec(vec![3.0, 4.0]);
648
649 accumulator.accumulate(&grad1);
650 accumulator.accumulate(&grad2);
651
652 let final_grads = accumulator.get_and_reset().unwrap();
653 assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); }
656
657 #[test]
658 fn test_gradient_noise() {
659 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
660 let original = gradients.clone();
661
662 add_gradient_noise(&mut gradients, 0.1, Some(42));
664
665 for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
667 assert!(
668 (orig - noisy).abs() < 1.0,
669 "Index {}: {} vs {}",
670 i,
671 orig,
672 noisy
673 );
674 }
675 }
676
677 #[test]
678 fn test_gradient_noise_zero_std() {
679 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
680 let original = gradients.clone();
681
682 add_gradient_noise(&mut gradients, 0.0, Some(42));
684
685 for (orig, noisy) in original.iter().zip(gradients.iter()) {
686 assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
687 }
688 }
689
690 #[test]
691 fn test_gradient_mask_creation() {
692 let mask = Array1::from_vec(vec![true, false, true]);
693 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
694
695 assert_eq!(grad_mask.num_active(), 2);
696 assert_eq!(grad_mask.num_frozen(), 1);
697 }
698
699 #[test]
700 fn test_gradient_mask_apply() {
701 let mask = Array1::from_vec(vec![true, false, true]);
702 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
703 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
704
705 grad_mask.apply_mask(&mut gradients);
706
707 assert_eq!(gradients.as_slice().unwrap(), &[1.0, 0.0, 3.0]);
708 }
709
710 #[test]
711 fn test_gradient_mask_freeze_unfreeze() {
712 let mask = Array1::from_vec(vec![true, true, true]);
713 let mut grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
714
715 grad_mask.freeze_indices(&[0, 2]).unwrap();
717 assert_eq!(grad_mask.num_frozen(), 2);
718 assert_eq!(grad_mask.num_active(), 1);
719
720 grad_mask.unfreeze_indices(&[0]).unwrap();
722 assert_eq!(grad_mask.num_frozen(), 1);
723 assert_eq!(grad_mask.num_active(), 2);
724 }
725
726 #[test]
727 fn test_gradient_mask_with_lr_multipliers() {
728 let mask = Array1::from_vec(vec![true, true, true]);
729 let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
730 let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> =
731 GradientMask::new(mask).with_lr_multipliers(multipliers);
732 let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
733
734 grad_mask.apply_mask(&mut gradients);
735
736 assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
737 assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); }
740
741 #[test]
742 fn test_gradient_mask_freeze_all() {
743 let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::freeze_all(
744 scirs2_core::ndarray::Ix1(3),
745 );
746 assert_eq!(grad_mask.num_frozen(), 3);
747 assert_eq!(grad_mask.num_active(), 0);
748 }
749
750 #[test]
751 fn test_gradient_mask_update_all() {
752 let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::update_all(
753 scirs2_core::ndarray::Ix1(3),
754 );
755 assert_eq!(grad_mask.num_frozen(), 0);
756 assert_eq!(grad_mask.num_active(), 3);
757 }
758}