optirs_core/utils/mod.rs
1// Utility functions for machine learning optimization
2//
3// This module provides utility functions and helpers for optimization
4// tasks in machine learning.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12/// Clip gradient values to a specified range
13///
14/// # Arguments
15///
16/// * `gradients` - The gradients to clip
17/// * `min_value` - Minimum allowed value
18/// * `max_value` - Maximum allowed value
19///
20/// # Returns
21///
22/// The clipped gradients (in-place modification)
23///
24/// # Examples
25///
26/// ```
27/// use scirs2_core::ndarray::Array1;
28/// use optirs_core::utils::clip_gradients;
29///
30/// let mut gradients = Array1::from_vec(vec![-10.0, 0.5, 8.0, -0.2]);
31/// clip_gradients(&mut gradients, -5.0, 5.0);
32/// assert_eq!(gradients, Array1::from_vec(vec![-5.0, 0.5, 5.0, -0.2]));
33/// ```
34#[allow(dead_code)]
35pub fn clip_gradients<A, D>(
36 gradients: &mut Array<A, D>,
37 min_value: A,
38 max_value: A,
39) -> &mut Array<A, D>
40where
41 A: Float + ScalarOperand + Debug,
42 D: Dimension,
43{
44 for grad in gradients.iter_mut() {
45 *grad = if *grad < min_value {
46 min_value
47 } else if *grad > max_value {
48 max_value
49 } else {
50 *grad
51 };
52 }
53 gradients
54}
55
56/// Clip gradient norm (global gradient clipping)
57///
58/// # Arguments
59///
60/// * `gradients` - The gradients to clip
61/// * `max_norm` - Maximum allowed L2 norm
62///
63/// # Returns
64///
65/// The clipped gradients (in-place modification)
66///
67/// # Examples
68///
69/// ```
70/// use scirs2_core::ndarray::Array1;
71/// use optirs_core::utils::clip_gradient_norm;
72///
73/// let mut gradients = Array1::<f64>::from_vec(vec![3.0, 4.0]); // L2 norm = 5.0
74/// clip_gradient_norm(&mut gradients, 1.0f64).unwrap();
75/// // After clipping, L2 norm = 1.0
76/// let diff0 = (gradients[0] - 0.6f64).abs();
77/// let diff1 = (gradients[1] - 0.8f64).abs();
78/// assert!(diff0 < 1e-5);
79/// assert!(diff1 < 1e-5);
80/// ```
81#[allow(dead_code)]
82pub fn clip_gradient_norm<A, D>(
83 gradients: &mut Array<A, D>,
84 max_norm: A,
85) -> Result<&mut Array<A, D>>
86where
87 A: Float + ScalarOperand + Debug,
88 D: Dimension,
89{
90 if max_norm <= A::zero() {
91 return Err(OptimError::InvalidConfig(
92 "max_norm must be positive".to_string(),
93 ));
94 }
95
96 // Calculate current L2 _norm
97 let _norm = gradients
98 .iter()
99 .fold(A::zero(), |acc, &x| acc + x * x)
100 .sqrt();
101
102 // If _norm exceeds max_norm, scale gradients
103 if _norm > max_norm {
104 let scale = max_norm / _norm;
105 for grad in gradients.iter_mut() {
106 *grad = *grad * scale;
107 }
108 }
109
110 Ok(gradients)
111}
112
113/// Compute gradient centralization
114///
115/// Gradient Centralization is a technique that improves training stability
116/// by removing the mean from each gradient tensor.
117///
118/// # Arguments
119///
120/// * `gradients` - The gradients to centralize
121///
122/// # Returns
123///
124/// The centralized gradients (in-place modification)
125///
126/// # Examples
127///
128/// ```
129/// use scirs2_core::ndarray::Array1;
130/// use optirs_core::utils::gradient_centralization;
131///
132/// let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0]);
133/// gradient_centralization(&mut gradients);
134/// assert_eq!(gradients, Array1::from_vec(vec![-1.0, 0.0, 1.0, 0.0]));
135/// ```
136#[allow(dead_code)]
137pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
138where
139 A: Float + ScalarOperand + Debug,
140 D: Dimension,
141{
142 // Calculate mean
143 let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
144 let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
145
146 // Subtract mean from each element
147 for grad in gradients.iter_mut() {
148 *grad = *grad - mean;
149 }
150
151 gradients
152}
153
154/// Zero out small gradient values
155///
156/// # Arguments
157///
158/// * `gradients` - The gradients to process
159/// * `threshold` - Threshold below which gradients are set to zero
160///
161/// # Returns
162///
163/// The processed gradients (in-place modification)
164///
165/// # Examples
166///
167/// ```
168/// use scirs2_core::ndarray::Array1;
169/// use optirs_core::utils::zero_small_gradients;
170///
171/// let mut gradients = Array1::from_vec(vec![0.001, 0.02, -0.005, 0.3]);
172/// zero_small_gradients(&mut gradients, 0.01);
173/// assert_eq!(gradients, Array1::from_vec(vec![0.0, 0.02, 0.0, 0.3]));
174/// ```
175#[allow(dead_code)]
176pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
177where
178 A: Float + ScalarOperand + Debug,
179 D: Dimension,
180{
181 let abs_threshold = threshold.abs();
182
183 for grad in gradients.iter_mut() {
184 if grad.abs() < abs_threshold {
185 *grad = A::zero();
186 }
187 }
188
189 gradients
190}