use crate::{
Array, Result,
lm::load::Weights,
ops::{arithmetic, reduction::sum},
};
fn scalar(v: f32) -> Result<Array> {
Array::full::<f32>(&[0i32; 0], v)
}
pub fn clip_grad_norm(grads: &mut Weights, max_norm: f32) -> Result<Array> {
let mut norm_squared: Option<Array> = None;
for grad in grads.values() {
let sq = arithmetic::square(grad)?;
let s = sum(&sq, false)?;
norm_squared = Some(match norm_squared.take() {
Some(acc) => arithmetic::add(&acc, &s)?,
None => s,
});
}
let total_norm = match norm_squared {
Some(ns) => arithmetic::sqrt(&ns)?,
None => scalar(0.0)?,
};
let max_s = scalar(max_norm)?;
let eps = scalar(1e-6)?;
let denom = arithmetic::add(&total_norm, &eps)?;
let ratio = arithmetic::divide(&max_s, &denom)?;
let one = scalar(1.0)?;
let normalizer = arithmetic::minimum(&ratio, &one)?;
let keys: Vec<String> = grads.keys().cloned().collect();
for key in keys {
let g = grads.remove(&key).expect("key from .keys() must exist");
let clipped = arithmetic::multiply(&g, &normalizer)?;
grads.insert(key, clipped);
}
Ok(total_norm)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn clip_grad_norm_no_clip_when_below_threshold() -> Result<()> {
let mut grads: Weights = HashMap::new();
grads.insert("w1".into(), Array::from_slice::<f32>(&[3.0, 4.0], &[2])?);
let mut norm = clip_grad_norm(&mut grads, 10.0)?;
assert!((norm.item::<f32>()? - 5.0).abs() < 1e-4);
let mut got = grads["w1"].try_clone()?;
let v: Vec<f32> = got.to_vec()?;
assert!((v[0] - 3.0).abs() < 1e-5 && (v[1] - 4.0).abs() < 1e-5);
Ok(())
}
#[test]
fn clip_grad_norm_rescales_when_above_threshold() -> Result<()> {
let mut grads: Weights = HashMap::new();
grads.insert("w1".into(), Array::from_slice::<f32>(&[3.0, 4.0], &[2])?);
let mut norm = clip_grad_norm(&mut grads, 2.5)?;
assert!((norm.item::<f32>()? - 5.0).abs() < 1e-4);
let mut got = grads["w1"].try_clone()?;
let v: Vec<f32> = got.to_vec()?;
assert!((v[0] - 1.5).abs() < 1e-3);
assert!((v[1] - 2.0).abs() < 1e-3);
Ok(())
}
#[test]
fn clip_grad_norm_handles_multiple_entries() -> Result<()> {
let mut grads: Weights = HashMap::new();
grads.insert("w1".into(), Array::from_slice::<f32>(&[2.0, 3.0], &[2])?);
grads.insert("w2".into(), Array::from_slice::<f32>(&[1.0], &[1])?);
let mut norm = clip_grad_norm(&mut grads, 100.0)?;
let got = norm.item::<f32>()?;
assert!((got - 14.0_f32.sqrt()).abs() < 1e-4, "got {got}");
Ok(())
}
#[test]
fn clip_grad_norm_empty_map_returns_zero_norm() -> Result<()> {
let mut grads: Weights = HashMap::new();
let mut norm = clip_grad_norm(&mut grads, 1.0)?;
assert!((norm.item::<f32>()?).abs() < 1e-6);
Ok(())
}
}