use scirs2_core::ndarray::{Array, Dimension, ScalarOperand};
use scirs2_core::numeric::Float;
use scirs2_core::random::{thread_rng, Rng};
use std::fmt::Debug;
use crate::error::{OptimError, Result};
#[derive(Debug, Clone)]
pub struct GradientClipConfig<A: Float> {
pub max_value: Option<A>,
pub min_value: Option<A>,
pub maxnorm: Option<A>,
pub max_l1norm: Option<A>,
pub centralization: bool,
pub zero_threshold: Option<A>,
}
impl<A: Float + Send + Sync> Default for GradientClipConfig<A> {
fn default() -> Self {
Self {
max_value: None,
min_value: None,
maxnorm: None,
max_l1norm: None,
centralization: false,
zero_threshold: None,
}
}
}
pub struct GradientProcessor<A: Float> {
config: GradientClipConfig<A>,
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> Default for GradientProcessor<A> {
fn default() -> Self {
Self {
config: GradientClipConfig::default(),
}
}
}
impl<A: Float + ScalarOperand + Debug + Send + Sync> GradientProcessor<A> {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: GradientClipConfig<A>) -> Self {
Self { config }
}
pub fn set_max_value(&mut self, value: A) -> &mut Self {
self.config.max_value = Some(value);
self
}
pub fn set_min_value(&mut self, value: A) -> &mut Self {
self.config.min_value = Some(value);
self
}
pub fn set_max_norm(&mut self, value: A) -> &mut Self {
self.config.maxnorm = Some(value);
self
}
pub fn set_max_l1_norm(&mut self, value: A) -> &mut Self {
self.config.max_l1norm = Some(value);
self
}
pub fn set_centralization(&mut self, enabled: bool) -> &mut Self {
self.config.centralization = enabled;
self
}
pub fn set_zero_threshold(&mut self, value: A) -> &mut Self {
self.config.zero_threshold = Some(value);
self
}
pub fn set_value_clip(&mut self, min: A, max: A) -> &mut Self {
self.config.min_value = Some(min);
self.config.max_value = Some(max);
self
}
pub fn set_norm_clip(&mut self, maxnorm: A) -> &mut Self {
self.config.maxnorm = Some(maxnorm);
self
}
pub fn set_l1_norm_clip(&mut self, max_l1norm: A) -> &mut Self {
self.config.max_l1norm = Some(max_l1norm);
self
}
pub fn enable_centralization(&mut self) -> &mut Self {
self.config.centralization = true;
self
}
pub fn process<D: Dimension>(&self, gradients: &mut Array<A, D>) -> Result<()> {
if let (Some(min), Some(max)) = (self.config.min_value, self.config.max_value) {
clip_gradients_by_value(gradients, min, max);
}
if let Some(maxnorm) = self.config.maxnorm {
clip_gradient_norm(gradients, maxnorm)?;
}
if let Some(max_l1norm) = self.config.max_l1norm {
clip_gradient_l1_norm(gradients, max_l1norm)?;
}
if self.config.centralization {
gradient_centralization(gradients);
}
if let Some(threshold) = self.config.zero_threshold {
zero_small_gradients(gradients, threshold);
}
Ok(())
}
}
#[allow(dead_code)]
pub fn clip_gradients_by_value<A, D>(
gradients: &mut Array<A, D>,
min_value: A,
max_value: A,
) -> &mut Array<A, D>
where
A: Float + ScalarOperand,
D: Dimension,
{
gradients.mapv_inplace(|x| {
if x < min_value {
min_value
} else if x > max_value {
max_value
} else {
x
}
});
gradients
}
#[allow(dead_code)]
pub fn clip_gradient_norm<A, D>(gradients: &mut Array<A, D>, maxnorm: A) -> Result<&mut Array<A, D>>
where
A: Float + ScalarOperand,
D: Dimension,
{
if maxnorm <= A::zero() {
return Err(OptimError::InvalidConfig(
"maxnorm must be positive".to_string(),
));
}
let _norm = gradients
.iter()
.fold(A::zero(), |acc, &x| acc + x * x)
.sqrt();
if _norm > maxnorm {
let scale = maxnorm / _norm;
gradients.mapv_inplace(|x| x * scale);
}
Ok(gradients)
}
#[allow(dead_code)]
pub fn clip_gradient_l1_norm<A, D>(
gradients: &mut Array<A, D>,
max_l1norm: A,
) -> Result<&mut Array<A, D>>
where
A: Float + ScalarOperand,
D: Dimension,
{
if max_l1norm <= A::zero() {
return Err(OptimError::InvalidConfig(
"max_l1norm must be positive".to_string(),
));
}
let l1_norm = gradients.iter().fold(A::zero(), |acc, &x| acc + x.abs());
if l1_norm > max_l1norm {
let scale = max_l1norm / l1_norm;
gradients.mapv_inplace(|x| x * scale);
}
Ok(gradients)
}
#[allow(dead_code)]
pub fn gradient_centralization<A, D>(gradients: &mut Array<A, D>) -> &mut Array<A, D>
where
A: Float + ScalarOperand,
D: Dimension,
{
let sum = gradients.iter().fold(A::zero(), |acc, &x| acc + x);
let mean = sum / A::from(gradients.len()).unwrap_or(A::one());
gradients.mapv_inplace(|x| x - mean);
gradients
}
#[allow(dead_code)]
pub fn zero_small_gradients<A, D>(gradients: &mut Array<A, D>, threshold: A) -> &mut Array<A, D>
where
A: Float + ScalarOperand,
D: Dimension,
{
let abs_threshold = threshold.abs();
gradients.mapv_inplace(|x| {
if x.abs() < abs_threshold {
A::zero()
} else {
x
}
});
gradients
}
#[derive(Debug, Clone)]
pub struct GradientAccumulator<A: Float, D: Dimension> {
accumulated_gradients: Option<Array<A, D>>,
num_accumulated: usize,
accumulation_steps: usize,
averagegradients: bool,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientAccumulator<A, D> {
pub fn new(_accumulation_steps: usize, averagegradients: bool) -> Self {
Self {
accumulated_gradients: None,
num_accumulated: 0,
accumulation_steps: _accumulation_steps,
averagegradients,
}
}
pub fn accumulate(&mut self, gradients: &Array<A, D>) -> bool {
if let Some(acc) = &mut self.accumulated_gradients {
for (acc_val, &grad_val) in acc.iter_mut().zip(gradients.iter()) {
*acc_val = *acc_val + grad_val;
}
} else {
self.accumulated_gradients = Some(gradients.clone());
}
self.num_accumulated += 1;
self.num_accumulated >= self.accumulation_steps
}
pub fn get_and_reset(&mut self) -> Option<Array<A, D>> {
if let Some(mut gradients) = self.accumulated_gradients.take() {
if self.averagegradients && self.num_accumulated > 0 {
let scale = A::one() / A::from(self.num_accumulated).unwrap_or(A::one());
gradients.mapv_inplace(|x| x * scale);
}
self.num_accumulated = 0;
Some(gradients)
} else {
None
}
}
pub fn progress(&self) -> (usize, usize) {
(self.num_accumulated, self.accumulation_steps)
}
pub fn is_ready(&self) -> bool {
self.num_accumulated >= self.accumulation_steps
}
pub fn reset(&mut self) {
self.accumulated_gradients = None;
self.num_accumulated = 0;
}
pub fn set_accumulation_steps(&mut self, steps: usize) {
self.accumulation_steps = steps;
}
}
#[allow(dead_code)]
pub fn adaptive_gradient_clipping<'a, A, D>(
gradients: &'a mut Array<A, D>,
parameters: &Array<A, D>,
max_ratio: A,
) -> Result<&'a mut Array<A, D>>
where
A: Float + ScalarOperand,
D: Dimension,
{
if max_ratio <= A::zero() {
return Err(OptimError::InvalidConfig(
"max_ratio must be positive".to_string(),
));
}
let grad_norm = gradients
.iter()
.fold(A::zero(), |acc, &x| acc + x * x)
.sqrt();
let param_norm = parameters
.iter()
.fold(A::zero(), |acc, &x| acc + x * x)
.sqrt();
if param_norm > A::zero() && grad_norm > A::zero() {
let _ratio = grad_norm / param_norm;
if _ratio > max_ratio {
let scale = max_ratio / _ratio;
gradients.mapv_inplace(|x| x * scale);
}
}
Ok(gradients)
}
#[allow(dead_code)]
pub fn add_gradient_noise<A, D>(
gradients: &mut Array<A, D>,
noise_std: A,
seed: Option<u64>,
) -> &mut Array<A, D>
where
A: Float + ScalarOperand,
D: Dimension,
{
use scirs2_core::random::RandNormal;
use scirs2_core::random::Rng;
if noise_std <= A::zero() {
return gradients;
}
let mut rng = thread_rng();
let shape = gradients.raw_dim();
let mut noise = Array::zeros(shape);
let normal = RandNormal::new(0.0, noise_std.to_f64().unwrap_or(0.01)).expect("unwrap failed");
for elem in noise.iter_mut() {
*elem = A::from(rng.sample(normal)).unwrap_or(A::zero());
}
gradients.zip_mut_with(&noise, |g, &n| {
*g = *g + A::from(n).unwrap_or(A::zero());
});
gradients
}
#[derive(Debug, Clone)]
pub struct GradientMask<A: Float, D: Dimension> {
mask: Array<bool, D>,
lr_multipliers: Option<Array<A, D>>,
}
impl<A: Float + ScalarOperand + Debug, D: Dimension + Send + Sync> GradientMask<A, D> {
pub fn new(mask: Array<bool, D>) -> Self {
Self {
mask,
lr_multipliers: None,
}
}
pub fn freeze_all(shape: D) -> Self {
Self {
mask: Array::from_elem(shape, false),
lr_multipliers: None,
}
}
pub fn update_all(shape: D) -> Self {
Self {
mask: Array::from_elem(shape, true),
lr_multipliers: None,
}
}
pub fn with_lr_multipliers(mut self, multipliers: Array<A, D>) -> Self {
self.lr_multipliers = Some(multipliers);
self
}
pub fn apply_mask<'a>(&self, gradients: &'a mut Array<A, D>) -> &'a mut Array<A, D> {
gradients.zip_mut_with(&self.mask, |grad, &should_update| {
if !should_update {
*grad = A::zero();
}
});
if let Some(multipliers) = &self.lr_multipliers {
gradients.zip_mut_with(multipliers, |grad, &mult| {
*grad = *grad * mult;
});
}
gradients
}
pub fn freeze_indices(&mut self, indices: &[usize]) -> Result<()> {
let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
})?;
for &idx in indices {
if idx < flat_mask.len() {
flat_mask[idx] = false;
} else {
return Err(OptimError::InvalidConfig(format!(
"Index {} out of bounds for mask of size {}",
idx,
flat_mask.len()
)));
}
}
Ok(())
}
pub fn unfreeze_indices(&mut self, indices: &[usize]) -> Result<()> {
let flat_mask = self.mask.as_slice_mut().ok_or_else(|| {
OptimError::InvalidConfig("Cannot access mask as flat slice".to_string())
})?;
for &idx in indices {
if idx < flat_mask.len() {
flat_mask[idx] = true;
} else {
return Err(OptimError::InvalidConfig(format!(
"Index {} out of bounds for mask of size {}",
idx,
flat_mask.len()
)));
}
}
Ok(())
}
pub fn num_frozen(&self) -> usize {
self.mask.iter().filter(|&&x| !x).count()
}
pub fn num_active(&self) -> usize {
self.mask.iter().filter(|&&x| x).count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::Array1;
#[test]
fn test_gradient_processor() {
let config = GradientClipConfig::<f64> {
max_value: Some(5.0),
min_value: Some(-5.0),
maxnorm: Some(10.0),
..Default::default()
};
let processor = GradientProcessor::with_config(config);
let mut gradients = Array1::from_vec(vec![-8.0, 3.0, 7.0, -2.0, 6.0]);
processor.process(&mut gradients).expect("unwrap failed");
assert_eq!(gradients[0], -5.0);
assert_eq!(gradients[2], 5.0);
assert_eq!(gradients[4], 5.0);
}
#[test]
fn test_adaptive_clipping() {
let mut gradients = Array1::from_vec(vec![3.0, 4.0]); let parameters = Array1::from_vec(vec![1.0, 0.0]);
adaptive_gradient_clipping(&mut gradients, ¶meters, 2.0).expect("unwrap failed");
let new_grad_norm = gradients.iter().fold(0.0, |acc, &x| acc + x * x).sqrt();
assert!((new_grad_norm - 2.0).abs() < 1e-6);
}
#[test]
fn test_gradient_accumulator() {
let mut accumulator = GradientAccumulator::new(3, true);
let grad1 = Array1::from_vec(vec![1.0, 2.0, 3.0]);
assert!(!accumulator.accumulate(&grad1));
assert_eq!(accumulator.progress(), (1, 3));
let grad2 = Array1::from_vec(vec![2.0, 3.0, 4.0]);
assert!(!accumulator.accumulate(&grad2));
assert_eq!(accumulator.progress(), (2, 3));
let grad3 = Array1::from_vec(vec![3.0, 4.0, 5.0]);
assert!(accumulator.accumulate(&grad3));
assert!(accumulator.is_ready());
let final_grads = accumulator.get_and_reset().expect("unwrap failed");
assert_relative_eq!(final_grads[0], 2.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 3.0, epsilon = 1e-6); assert_relative_eq!(final_grads[2], 4.0, epsilon = 1e-6);
assert_eq!(accumulator.progress(), (0, 3));
assert!(!accumulator.is_ready());
}
#[test]
fn test_gradient_accumulator_sum_mode() {
let mut accumulator = GradientAccumulator::new(2, false);
let grad1 = Array1::from_vec(vec![1.0, 2.0]);
let grad2 = Array1::from_vec(vec![3.0, 4.0]);
accumulator.accumulate(&grad1);
accumulator.accumulate(&grad2);
let final_grads = accumulator.get_and_reset().expect("unwrap failed");
assert_relative_eq!(final_grads[0], 4.0, epsilon = 1e-6); assert_relative_eq!(final_grads[1], 6.0, epsilon = 1e-6); }
#[test]
fn test_gradient_noise() {
let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let original = gradients.clone();
add_gradient_noise(&mut gradients, 0.1, Some(42));
for (i, (&orig, &noisy)) in original.iter().zip(gradients.iter()).enumerate() {
assert!(
(orig - noisy).abs() < 1.0,
"Index {}: {} vs {}",
i,
orig,
noisy
);
}
}
#[test]
fn test_gradient_noise_zero_std() {
let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
let original = gradients.clone();
add_gradient_noise(&mut gradients, 0.0, Some(42));
for (orig, noisy) in original.iter().zip(gradients.iter()) {
assert_relative_eq!(*orig, *noisy, epsilon = 1e-10);
}
}
#[test]
fn test_gradient_mask_creation() {
let mask = Array1::from_vec(vec![true, false, true]);
let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
assert_eq!(grad_mask.num_active(), 2);
assert_eq!(grad_mask.num_frozen(), 1);
}
#[test]
fn test_gradient_mask_apply() {
let mask = Array1::from_vec(vec![true, false, true]);
let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
grad_mask.apply_mask(&mut gradients);
assert_eq!(
gradients.as_slice().expect("unwrap failed"),
&[1.0, 0.0, 3.0]
);
}
#[test]
fn test_gradient_mask_freeze_unfreeze() {
let mask = Array1::from_vec(vec![true, true, true]);
let mut grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> = GradientMask::new(mask);
grad_mask.freeze_indices(&[0, 2]).expect("unwrap failed");
assert_eq!(grad_mask.num_frozen(), 2);
assert_eq!(grad_mask.num_active(), 1);
grad_mask.unfreeze_indices(&[0]).expect("unwrap failed");
assert_eq!(grad_mask.num_frozen(), 1);
assert_eq!(grad_mask.num_active(), 2);
}
#[test]
fn test_gradient_mask_with_lr_multipliers() {
let mask = Array1::from_vec(vec![true, true, true]);
let multipliers = Array1::from_vec(vec![1.0, 0.5, 2.0]);
let grad_mask: GradientMask<f64, scirs2_core::ndarray::Ix1> =
GradientMask::new(mask).with_lr_multipliers(multipliers);
let mut gradients = Array1::from_vec(vec![1.0, 2.0, 3.0]);
grad_mask.apply_mask(&mut gradients);
assert_relative_eq!(gradients[0], 1.0, epsilon = 1e-6);
assert_relative_eq!(gradients[1], 1.0, epsilon = 1e-6); assert_relative_eq!(gradients[2], 6.0, epsilon = 1e-6); }
#[test]
fn test_gradient_mask_freeze_all() {
let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::freeze_all(
scirs2_core::ndarray::Ix1(3),
);
assert_eq!(grad_mask.num_frozen(), 3);
assert_eq!(grad_mask.num_active(), 0);
}
#[test]
fn test_gradient_mask_update_all() {
let grad_mask = GradientMask::<f64, scirs2_core::ndarray::Ix1>::update_all(
scirs2_core::ndarray::Ix1(3),
);
assert_eq!(grad_mask.num_frozen(), 0);
assert_eq!(grad_mask.num_active(), 3);
}
}