optirs_core/utils/
mod.rs

1// Utility functions for machine learning optimization
2//
3// This module provides utility functions and helpers for optimization
4// tasks in machine learning.
5
6use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11
12/// Clip gradient values to a specified range
13///
14/// # Arguments
15///
16/// * `gradients` - The gradients to clip
17/// * `min_value` - Minimum allowed value
18/// * `max_value` - Maximum allowed value
19///
20/// # Returns
21///
22/// The clipped gradients (in-place modification)
23///
24/// # Examples
25///
26/// ```
27/// use scirs2_core::ndarray::Array1;
28/// use optirs_core::utils::clip_gradients;
29///
30/// let mut gradients = Array1::from_vec(vec![-10.0, 0.5, 8.0, -0.2]);
31/// clip_gradients(&mut gradients, -5.0, 5.0);
32/// assert_eq!(gradients, Array1::from_vec(vec![-5.0, 0.5, 5.0, -0.2]));
33/// ```
34#[allow(dead_code)]
35pub fn clip_gradients<A, D>(
36    gradients: &mut Array<A, D>,
37    min_value: A,
38    max_value: A,
39) -> &mut Array<A, D>
40where
41    A: Float + ScalarOperand + Debug,
42    D: Dimension,
43{
44    for grad in gradients.iter_mut() {
45        *grad = if *grad < min_value {
46            min_value
47        } else if *grad > max_value {
48            max_value
49        } else {
50            *grad
51        };
52    }
53    gradients
54}
55
56/// Clip gradient norm (global gradient clipping)
57///
58/// # Arguments
59///
60/// * `gradients` - The gradients to clip
61/// * `max_norm` - Maximum allowed L2 norm
62///
63/// # Returns
64///
65/// The clipped gradients (in-place modification)
66///
67/// # Examples
68///
69/// ```
70/// use scirs2_core::ndarray::Array1;
71/// use optirs_core::utils::clip_gradient_norm;
72///
73/// let mut gradients = Array1::<f64>::from_vec(vec![3.0, 4.0]); // L2 norm = 5.0
74/// clip_gradient_norm(&mut gradients, 1.0f64).unwrap();
75/// // After clipping, L2 norm = 1.0
76/// let diff0 = (gradients[0] - 0.6f64).abs();
77/// let diff1 = (gradients[1] - 0.8f64).abs();
78/// assert!(diff0 < 1e-5);
79/// assert!(diff1 < 1e-5);
80/// ```
81#[allow(dead_code)]
82pub fn clip_gradient_norm<A, D>(
83    gradients: &mut Array<A, D>,
84    max_norm: A,
85) -> Result<&mut Array<A, D>>
86where
87    A: Float + ScalarOperand + Debug,
88    D: Dimension,
89{
90    if max_norm <= A::zero() {
91        return Err(OptimError::InvalidConfig(
92            "max_norm must be positive".to_string(),
93        ));
94    }
95
96    // Calculate current L2 _norm
97    let _norm = gradients
98        .iter()
99        .fold(A::zero(), |acc, &x| acc + x * x)
100        .sqrt();
101
102    // If _norm exceeds max_norm, scale gradients
103    if _norm > max_norm {
104        let scale = max_norm / _norm;
105        for grad in gradients.iter_mut() {
106            *grad = *grad * scale;
107        }
108    }
109
110    Ok(gradients)
111}
112
113/// Compute gradient centralization
114///
115/// Gradient Centralization is a technique that improves training stability
116/// by removing the mean from each gradient tensor.
117///
118/// # Arguments
119///
120/// * `gradients` - The gradients to centralize
121///
122/// # Returns
123///
124/// The centralized gradients (in-place modification)
125///
126/// # Examples
127///
128/// ```
129/// use scirs2_core::ndarray::Array1;
130/// use optirs_core::utils::gradient_centralization;
131///
132/// let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0, 2.0]);
133/// gradient_centralization(&mut gradients);
134/// assert_eq!(gradients, Array1::from_vec(vec![-1.0, 0.0, 1.0, 0.0]));
135/// ```
136#[allow(dead_code)]
137pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
138where
139    A: Float + ScalarOperand + Debug,
140    D: Dimension,
141{
142    // Calculate mean
143    let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
144    let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
145
146    // Subtract mean from each element
147    for grad in gradients.iter_mut() {
148        *grad = *grad - mean;
149    }
150
151    gradients
152}
153
154/// Zero out small gradient values
155///
156/// # Arguments
157///
158/// * `gradients` - The gradients to process
159/// * `threshold` - Threshold below which gradients are set to zero
160///
161/// # Returns
162///
163/// The processed gradients (in-place modification)
164///
165/// # Examples
166///
167/// ```
168/// use scirs2_core::ndarray::Array1;
169/// use optirs_core::utils::zero_small_gradients;
170///
171/// let mut gradients = Array1::from_vec(vec![0.001, 0.02, -0.005, 0.3]);
172/// zero_small_gradients(&mut gradients, 0.01);
173/// assert_eq!(gradients, Array1::from_vec(vec![0.0, 0.02, 0.0, 0.3]));
174/// ```
175#[allow(dead_code)]
176pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
177where
178    A: Float + ScalarOperand + Debug,
179    D: Dimension,
180{
181    let abs_threshold = threshold.abs();
182
183    for grad in gradients.iter_mut() {
184        if grad.abs() < abs_threshold {
185            *grad = A::zero();
186        }
187    }
188
189    gradients
190}