use crate::Tensor;
pub fn clip_grad_norm(params: &mut [Tensor], max_norm: f32) -> f32 {
let mut total_norm_sq = 0.0;
for param in params.iter() {
if let Some(grad) = param.grad() {
let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
total_norm_sq += grad_norm_sq;
}
}
let global_norm = total_norm_sq.sqrt();
if global_norm > max_norm {
let clip_coef = max_norm / global_norm;
for param in params.iter_mut() {
if let Some(grad) = param.grad() {
let clipped_grad = grad * clip_coef;
param.set_grad(clipped_grad);
}
}
}
global_norm
}
pub fn clip_grad_norm_refs(params: &mut [&mut Tensor], max_norm: f32) -> f32 {
let mut total_norm_sq = 0.0;
for param in params.iter() {
if let Some(grad) = param.grad() {
let grad_norm_sq: f32 = grad.iter().map(|&g| g * g).sum();
total_norm_sq += grad_norm_sq;
}
}
let global_norm = total_norm_sq.sqrt();
if global_norm > max_norm {
let clip_coef = max_norm / global_norm;
for param in params.iter_mut() {
if let Some(grad) = param.grad() {
let clipped_grad = grad * clip_coef;
param.set_grad(clipped_grad);
}
}
}
global_norm
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_clip_grad_norm_no_clipping() {
let mut params =
vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
params[0].set_grad(ndarray::arr1(&[0.1, 0.2]));
params[1].set_grad(ndarray::arr1(&[0.1]));
let global_norm = clip_grad_norm(&mut params, 1.0);
assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[0],
0.1,
epsilon = 1e-6
);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[1],
0.2,
epsilon = 1e-6
);
assert_abs_diff_eq!(
params[1].grad().expect("gradient should be available")[0],
0.1,
epsilon = 1e-6
);
}
#[test]
fn test_clip_grad_norm_with_clipping() {
let mut params =
vec![Tensor::from_vec(vec![1.0, 2.0], true), Tensor::from_vec(vec![3.0], true)];
params[0].set_grad(ndarray::arr1(&[3.0, 4.0]));
params[1].set_grad(ndarray::arr1(&[0.0]));
let global_norm = clip_grad_norm(&mut params, 1.0);
assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[0],
0.6,
epsilon = 1e-6
); assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[1],
0.8,
epsilon = 1e-6
); assert_abs_diff_eq!(
params[1].grad().expect("gradient should be available")[0],
0.0,
epsilon = 1e-6
); }
#[test]
fn test_clip_grad_norm_exactly_at_threshold() {
let mut params = vec![Tensor::from_vec(vec![3.0, 4.0], true)];
params[0].set_grad(ndarray::arr1(&[3.0, 4.0]));
let global_norm = clip_grad_norm(&mut params, 5.0);
assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[0],
3.0,
epsilon = 1e-6
);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[1],
4.0,
epsilon = 1e-6
);
}
#[test]
fn test_clip_grad_norm_preserves_relative_magnitudes() {
let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
params[0].set_grad(ndarray::arr1(&[10.0]));
params[1].set_grad(ndarray::arr1(&[5.0]));
let _global_norm = clip_grad_norm(&mut params, 1.0);
let grad0 = params[0].grad().expect("gradient should be available")[0];
let grad1 = params[1].grad().expect("gradient should be available")[0];
assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
}
#[test]
fn test_clip_grad_norm_no_gradients() {
let mut params = vec![
Tensor::from_vec(vec![1.0, 2.0], false), Tensor::from_vec(vec![3.0], false),
];
let global_norm = clip_grad_norm(&mut params, 1.0);
assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
}
#[test]
fn test_clip_grad_norm_mixed_gradients() {
let mut params = vec![Tensor::from_vec(vec![1.0], true), Tensor::from_vec(vec![1.0], true)];
params[0].set_grad(ndarray::arr1(&[3.0]));
let global_norm = clip_grad_norm(&mut params, 1.0);
assert_abs_diff_eq!(global_norm, 3.0, epsilon = 1e-6);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[0],
1.0,
epsilon = 1e-6
); assert!(params[1].grad().is_none()); }
#[test]
fn test_clip_grad_norm_zero_max_norm() {
let mut params = vec![Tensor::from_vec(vec![1.0], true)];
params[0].set_grad(ndarray::arr1(&[5.0]));
let global_norm = clip_grad_norm(&mut params, 0.0);
assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
assert_abs_diff_eq!(
params[0].grad().expect("gradient should be available")[0],
0.0,
epsilon = 1e-6
);
}
#[test]
fn test_clip_grad_norm_refs_no_clipping() {
let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
let mut p1 = Tensor::from_vec(vec![3.0], true);
p0.set_grad(ndarray::arr1(&[0.1, 0.2]));
p1.set_grad(ndarray::arr1(&[0.1]));
let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
assert_abs_diff_eq!(global_norm, 0.245, epsilon = 1e-3);
assert_abs_diff_eq!(
p0.grad().expect("gradient should be available")[0],
0.1,
epsilon = 1e-6
);
assert_abs_diff_eq!(
p0.grad().expect("gradient should be available")[1],
0.2,
epsilon = 1e-6
);
assert_abs_diff_eq!(
p1.grad().expect("gradient should be available")[0],
0.1,
epsilon = 1e-6
);
}
#[test]
fn test_clip_grad_norm_refs_with_clipping() {
let mut p0 = Tensor::from_vec(vec![1.0, 2.0], true);
let mut p1 = Tensor::from_vec(vec![3.0], true);
p0.set_grad(ndarray::arr1(&[3.0, 4.0]));
p1.set_grad(ndarray::arr1(&[0.0]));
let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
assert_abs_diff_eq!(global_norm, 5.0, epsilon = 1e-6);
assert_abs_diff_eq!(
p0.grad().expect("gradient should be available")[0],
0.6,
epsilon = 1e-6
);
assert_abs_diff_eq!(
p0.grad().expect("gradient should be available")[1],
0.8,
epsilon = 1e-6
);
assert_abs_diff_eq!(
p1.grad().expect("gradient should be available")[0],
0.0,
epsilon = 1e-6
);
}
#[test]
fn test_clip_grad_norm_refs_preserves_relative_magnitudes() {
let mut p0 = Tensor::from_vec(vec![1.0], true);
let mut p1 = Tensor::from_vec(vec![1.0], true);
p0.set_grad(ndarray::arr1(&[10.0]));
p1.set_grad(ndarray::arr1(&[5.0]));
let _global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
let grad0 = p0.grad().expect("gradient should be available")[0];
let grad1 = p1.grad().expect("gradient should be available")[0];
assert_abs_diff_eq!(grad0 / grad1, 2.0, epsilon = 1e-4);
}
#[test]
fn test_clip_grad_norm_refs_no_gradients() {
let mut p0 = Tensor::from_vec(vec![1.0, 2.0], false);
let mut p1 = Tensor::from_vec(vec![3.0], false);
let global_norm = clip_grad_norm_refs(&mut [&mut p0, &mut p1], 1.0);
assert_abs_diff_eq!(global_norm, 0.0, epsilon = 1e-6);
}
}