use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
use std::collections::HashMap;
pub fn clip_grad_norm_map(
grads: &mut HashMap<String, ArrayD<f64>>,
max_norm: f64,
) -> Result<f64> {
if max_norm < 0.0 {
return Err(NeuralError::InvalidArgument(
"max_norm must be non-negative".to_string(),
));
}
let mut sum_sq = 0.0_f64;
for tensor in grads.values() {
for &v in tensor.iter() {
sum_sq += v * v;
}
}
let global_norm = sum_sq.sqrt();
if global_norm > max_norm && max_norm > 0.0 {
let clip_coef = max_norm / (global_norm + 1e-6);
for tensor in grads.values_mut() {
tensor.mapv_inplace(|v| v * clip_coef);
}
}
Ok(global_norm)
}
pub fn clip_grad_value_map(
grads: &mut HashMap<String, ArrayD<f64>>,
clip_value: f64,
) -> Result<()> {
if clip_value < 0.0 {
return Err(NeuralError::InvalidArgument(
"clip_value must be non-negative".to_string(),
));
}
for tensor in grads.values_mut() {
tensor.mapv_inplace(|v| v.clamp(-clip_value, clip_value));
}
Ok(())
}
pub fn grad_norm_map(grads: &HashMap<String, ArrayD<f64>>) -> f64 {
let mut sum_sq = 0.0_f64;
for tensor in grads.values() {
for &v in tensor.iter() {
sum_sq += v * v;
}
}
sum_sq.sqrt()
}
#[derive(Debug, Clone)]
pub struct GradientAccumulatorMap {
accumulation_steps: usize,
current_step: usize,
accumulated_grads: HashMap<String, ArrayD<f64>>,
}
impl GradientAccumulatorMap {
pub fn new(steps: usize) -> Self {
Self {
accumulation_steps: steps.max(1),
current_step: 0,
accumulated_grads: HashMap::new(),
}
}
pub fn accumulate(&mut self, grads: &HashMap<String, ArrayD<f64>>) -> Result<()> {
for (name, grad) in grads {
match self.accumulated_grads.get_mut(name) {
Some(acc) => {
if acc.shape() != grad.shape() {
return Err(NeuralError::InvalidArgument(format!(
"Shape mismatch for parameter '{}': accumulated {:?} vs new {:?}",
name,
acc.shape(),
grad.shape()
)));
}
acc.zip_mut_with(grad, |a, &b| *a += b);
}
None => {
self.accumulated_grads.insert(name.clone(), grad.clone());
}
}
}
self.current_step += 1;
Ok(())
}
pub fn should_update(&self) -> bool {
self.current_step >= self.accumulation_steps
}
pub fn get_averaged_grads(&self) -> Result<HashMap<String, ArrayD<f64>>> {
if self.current_step == 0 {
return Err(NeuralError::InvalidArgument(
"No gradients have been accumulated yet".to_string(),
));
}
let scale = 1.0 / self.current_step as f64;
let averaged = self
.accumulated_grads
.iter()
.map(|(name, acc)| {
let avg = acc.mapv(|v| v * scale);
(name.clone(), avg)
})
.collect();
Ok(averaged)
}
pub fn reset(&mut self) {
self.current_step = 0;
self.accumulated_grads.clear();
}
pub fn current_step(&self) -> usize {
self.current_step
}
pub fn accumulation_steps(&self) -> usize {
self.accumulation_steps
}
pub fn param_names(&self) -> impl Iterator<Item = &str> {
self.accumulated_grads.keys().map(|s| s.as_str())
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
fn make_grad_map(names_vals: &[(&str, Vec<f64>)]) -> HashMap<String, ArrayD<f64>> {
names_vals
.iter()
.map(|(name, vals)| {
(
name.to_string(),
Array::from_vec(vals.clone()).into_dyn(),
)
})
.collect()
}
#[test]
fn test_clip_grad_norm_map_clips_above_threshold() {
let mut grads = make_grad_map(&[("w1", vec![3.0, 0.0]), ("w2", vec![0.0, 4.0])]);
let orig = clip_grad_norm_map(&mut grads, 2.5).expect("failed to create orig");
assert!((orig - 5.0).abs() < 1e-6);
let clipped = grad_norm_map(&grads);
assert!((clipped - 2.5).abs() < 0.1);
}
#[test]
fn test_clip_grad_norm_map_no_clip_below_threshold() {
let mut grads = make_grad_map(&[("w", vec![1.0, 1.0])]);
let orig = clip_grad_norm_map(&mut grads, 100.0).expect("failed to create orig");
let expected = 2.0_f64.sqrt();
assert!((orig - expected).abs() < 1e-10);
let vals: Vec<f64> = grads["w"].iter().copied().collect();
assert!((vals[0] - 1.0).abs() < 1e-10);
assert!((vals[1] - 1.0).abs() < 1e-10);
}
#[test]
fn test_clip_grad_norm_map_negative_max_norm_errors() {
let mut grads = make_grad_map(&[("w", vec![1.0])]);
assert!(clip_grad_norm_map(&mut grads, -1.0).is_err());
}
#[test]
fn test_clip_grad_norm_map_empty_map() {
let mut grads: HashMap<String, ArrayD<f64>> = HashMap::new();
let orig = clip_grad_norm_map(&mut grads, 1.0).expect("failed to create orig");
assert!((orig - 0.0).abs() < 1e-10);
}
#[test]
fn test_clip_grad_norm_map_zero_max_norm() {
let mut grads = make_grad_map(&[("w", vec![3.0, 4.0])]);
let orig = clip_grad_norm_map(&mut grads, 0.0).expect("failed to create orig");
assert!((orig - 5.0).abs() < 1e-6);
let vals: Vec<f64> = grads["w"].iter().copied().collect();
assert!((vals[0] - 3.0).abs() < 1e-10);
}
#[test]
fn test_clip_grad_value_map_clips_both_directions() {
let mut grads = make_grad_map(&[("w", vec![10.0, -10.0, 0.5, -0.5])]);
clip_grad_value_map(&mut grads, 1.0).expect("unexpected None or Err");
let vals: Vec<f64> = grads["w"].iter().copied().collect();
assert!((vals[0] - 1.0).abs() < 1e-10);
assert!((vals[1] - (-1.0)).abs() < 1e-10);
assert!((vals[2] - 0.5).abs() < 1e-10);
assert!((vals[3] - (-0.5)).abs() < 1e-10);
}
#[test]
fn test_clip_grad_value_map_no_op_within_range() {
let mut grads = make_grad_map(&[("w", vec![0.1, -0.2, 0.3])]);
clip_grad_value_map(&mut grads, 1.0).expect("unexpected None or Err");
let vals: Vec<f64> = grads["w"].iter().copied().collect();
assert!((vals[0] - 0.1).abs() < 1e-10);
assert!((vals[1] - (-0.2)).abs() < 1e-10);
assert!((vals[2] - 0.3).abs() < 1e-10);
}
#[test]
fn test_clip_grad_value_map_negative_clip_errors() {
let mut grads = make_grad_map(&[("w", vec![1.0])]);
assert!(clip_grad_value_map(&mut grads, -1.0).is_err());
}
#[test]
fn test_clip_grad_value_map_zero_clip() {
let mut grads = make_grad_map(&[("w", vec![5.0, -3.0, 0.0])]);
clip_grad_value_map(&mut grads, 0.0).expect("unexpected None or Err");
for &v in grads["w"].iter() {
assert!((v - 0.0).abs() < 1e-10);
}
}
#[test]
fn test_grad_norm_map_single_tensor() {
let grads = make_grad_map(&[("w", vec![3.0, 4.0])]);
let norm = grad_norm_map(&grads);
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn test_grad_norm_map_multiple_tensors() {
let grads = make_grad_map(&[("w1", vec![3.0, 0.0]), ("w2", vec![0.0, 4.0])]);
let norm = grad_norm_map(&grads);
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn test_grad_norm_map_empty() {
let grads: HashMap<String, ArrayD<f64>> = HashMap::new();
assert!((grad_norm_map(&grads) - 0.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_basic_flow() {
let mut acc = GradientAccumulatorMap::new(4);
for _ in 0..4 {
let grads = make_grad_map(&[("w", vec![1.0, 2.0])]);
acc.accumulate(&grads).expect("unexpected None or Err");
}
assert!(acc.should_update());
let avg = acc.get_averaged_grads().expect("failed to create avg");
let vals: Vec<f64> = avg["w"].iter().copied().collect();
assert!((vals[0] - 1.0).abs() < 1e-10);
assert!((vals[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_averaging() {
let mut acc = GradientAccumulatorMap::new(2);
let g1 = make_grad_map(&[("w", vec![2.0, 4.0])]);
acc.accumulate(&g1).expect("unexpected None or Err");
assert!(!acc.should_update());
let g2 = make_grad_map(&[("w", vec![4.0, 8.0])]);
acc.accumulate(&g2).expect("unexpected None or Err");
assert!(acc.should_update());
let avg = acc.get_averaged_grads().expect("failed to create avg");
let vals: Vec<f64> = avg["w"].iter().copied().collect();
assert!((vals[0] - 3.0).abs() < 1e-10);
assert!((vals[1] - 6.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_reset() {
let mut acc = GradientAccumulatorMap::new(2);
let grads = make_grad_map(&[("w", vec![1.0])]);
acc.accumulate(&grads).expect("unexpected None or Err");
acc.accumulate(&grads).expect("unexpected None or Err");
assert!(acc.should_update());
acc.reset();
assert!(!acc.should_update());
assert_eq!(acc.current_step(), 0);
}
#[test]
fn test_accumulator_map_get_averaged_grads_empty_errors() {
let acc = GradientAccumulatorMap::new(4);
assert!(acc.get_averaged_grads().is_err());
}
#[test]
fn test_accumulator_map_shape_mismatch_errors() {
let mut acc = GradientAccumulatorMap::new(4);
let g1 = make_grad_map(&[("w", vec![1.0, 2.0])]);
acc.accumulate(&g1).expect("unexpected None or Err");
let g2: HashMap<String, ArrayD<f64>> = {
let mut m = HashMap::new();
m.insert(
"w".to_string(),
Array::from_shape_vec(IxDyn(&[1, 3]), vec![1.0, 2.0, 3.0]).expect("unexpected None or Err"),
);
m
};
assert!(acc.accumulate(&g2).is_err());
}
#[test]
fn test_accumulator_map_multiple_params() {
let mut acc = GradientAccumulatorMap::new(2);
let g1 = make_grad_map(&[("weight", vec![1.0, 0.0]), ("bias", vec![0.5])]);
let g2 = make_grad_map(&[("weight", vec![3.0, 2.0]), ("bias", vec![1.5])]);
acc.accumulate(&g1).expect("unexpected None or Err");
acc.accumulate(&g2).expect("unexpected None or Err");
assert!(acc.should_update());
let avg = acc.get_averaged_grads().expect("failed to create avg");
let w_vals: Vec<f64> = avg["weight"].iter().copied().collect();
let b_vals: Vec<f64> = avg["bias"].iter().copied().collect();
assert!((w_vals[0] - 2.0).abs() < 1e-10);
assert!((w_vals[1] - 1.0).abs() < 1e-10);
assert!((b_vals[0] - 1.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_step_size_one() {
let mut acc = GradientAccumulatorMap::new(1);
let grads = make_grad_map(&[("w", vec![5.0, 10.0])]);
acc.accumulate(&grads).expect("unexpected None or Err");
assert!(acc.should_update());
let avg = acc.get_averaged_grads().expect("failed to create avg");
let vals: Vec<f64> = avg["w"].iter().copied().collect();
assert!((vals[0] - 5.0).abs() < 1e-10);
assert!((vals[1] - 10.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_zero_step_size_clamped_to_one() {
let mut acc = GradientAccumulatorMap::new(0);
assert_eq!(acc.accumulation_steps(), 1);
let grads = make_grad_map(&[("w", vec![1.0])]);
acc.accumulate(&grads).expect("unexpected None or Err");
assert!(acc.should_update());
}
#[test]
fn test_accumulator_map_reusable_after_reset() {
let mut acc = GradientAccumulatorMap::new(2);
let g1 = make_grad_map(&[("w", vec![4.0])]);
acc.accumulate(&g1).expect("unexpected None or Err");
acc.accumulate(&g1).expect("unexpected None or Err");
let avg1 = acc.get_averaged_grads().expect("failed to create avg1");
assert!((avg1["w"][[0]] - 4.0).abs() < 1e-10);
acc.reset();
let g2 = make_grad_map(&[("w", vec![8.0])]);
acc.accumulate(&g2).expect("unexpected None or Err");
acc.accumulate(&g2).expect("unexpected None or Err");
let avg2 = acc.get_averaged_grads().expect("failed to create avg2");
assert!((avg2["w"][[0]] - 8.0).abs() < 1e-10);
}
#[test]
fn test_accumulator_map_param_names() {
let mut acc = GradientAccumulatorMap::new(4);
let grads = make_grad_map(&[("layer1.weight", vec![1.0]), ("layer1.bias", vec![0.1])]);
acc.accumulate(&grads).expect("unexpected None or Err");
let mut names: Vec<&str> = acc.param_names().collect();
names.sort();
assert_eq!(names, &["layer1.bias", "layer1.weight"]);
}
#[test]
fn test_clip_and_accumulate_integration() {
let mut acc = GradientAccumulatorMap::new(3);
for step in 0..3 {
let mut grads =
make_grad_map(&[("w", vec![10.0 * (step as f64 + 1.0), -5.0])]);
clip_grad_norm_map(&mut grads, 1.0).expect("unexpected None or Err");
acc.accumulate(&grads).expect("unexpected None or Err");
}
assert!(acc.should_update());
let avg = acc.get_averaged_grads().expect("failed to create avg");
let norm = grad_norm_map(&avg);
assert!(norm <= 1.0 + 1e-6);
}
}