use ferrotorch_core::{Device, FerrotorchResult, Float, Tensor, TensorStorage};
use crate::parameter::Parameter;
pub fn clip_grad_norm_<T: Float>(
params: &[&Parameter<T>],
max_norm: f64,
norm_type: f64,
) -> FerrotorchResult<f64> {
let grads: Vec<Tensor<T>> = params
.iter()
.filter_map(|p| p.grad().ok().flatten())
.collect();
let total_norm: f64 = if norm_type == f64::INFINITY {
let mut max_val: f64 = 0.0;
for g in &grads {
let data = g.data_vec()?;
for v in &data {
let abs_v = v.to_f64().unwrap().abs();
if abs_v > max_val {
max_val = abs_v;
}
}
}
max_val
} else {
let mut accum: f64 = 0.0;
for g in &grads {
let data = g.data_vec()?;
for v in &data {
accum += v.to_f64().unwrap().abs().powf(norm_type);
}
}
accum.powf(1.0 / norm_type)
};
if total_norm > max_norm {
let clip_coef = max_norm / total_norm;
let clip_t = T::from(clip_coef).unwrap();
for param in params {
if let Some(g) = param.grad()? {
let data = g.data_vec()?;
let scaled: Vec<T> = data.iter().map(|&v| v * clip_t).collect();
let device = g.device();
let new_grad =
Tensor::from_storage(TensorStorage::cpu(scaled), g.shape().to_vec(), false)?;
let new_grad = if device != Device::Cpu {
new_grad.to(device)?
} else {
new_grad
};
param.set_grad(Some(new_grad))?;
}
}
}
Ok(total_norm)
}
pub fn clip_grad_value_<T: Float>(
params: &[&Parameter<T>],
clip_value: f64,
) -> FerrotorchResult<()> {
let lo = T::from(-clip_value).unwrap();
let hi = T::from(clip_value).unwrap();
for param in params {
if let Some(g) = param.grad()? {
let data = g.data_vec()?;
let clamped: Vec<T> = data
.iter()
.map(|&v| {
if v < lo {
lo
} else if v > hi {
hi
} else {
v
}
})
.collect();
let device = g.device();
let new_grad =
Tensor::from_storage(TensorStorage::cpu(clamped), g.shape().to_vec(), false)?;
let new_grad = if device != Device::Cpu {
new_grad.to(device)?
} else {
new_grad
};
param.set_grad(Some(new_grad))?;
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn param_with_grad(shape: &[usize], grad_data: &[f32]) -> Parameter<f32> {
let p = Parameter::<f32>::zeros(shape).unwrap();
let grad = Tensor::from_storage(
TensorStorage::cpu(grad_data.to_vec()),
shape.to_vec(),
false,
)
.unwrap();
p.set_grad(Some(grad)).unwrap();
p
}
#[test]
fn test_clip_grad_norm_reduces_norm() {
let p = param_with_grad(&[2], &[3.0, 4.0]);
let total = clip_grad_norm_(&[&p], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
let new_norm = (d[0] as f64 * d[0] as f64 + d[1] as f64 * d[1] as f64).sqrt();
assert!((new_norm - 2.5).abs() < 1e-4);
}
#[test]
fn test_clip_grad_norm_no_clip_when_below() {
let p = param_with_grad(&[2], &[1.0, 0.0]);
let total = clip_grad_norm_(&[&p], 10.0, 2.0).unwrap();
assert!((total - 1.0).abs() < 1e-6);
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - 1.0).abs() < 1e-6);
assert!((d[1] - 0.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_norm_multiple_params() {
let p1 = param_with_grad(&[1], &[3.0]);
let p2 = param_with_grad(&[1], &[4.0]);
let total = clip_grad_norm_(&[&p1, &p2], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
let g1 = p1.grad().unwrap().unwrap().data().unwrap()[0] as f64;
let g2 = p2.grad().unwrap().unwrap().data().unwrap()[0] as f64;
let new_norm = (g1 * g1 + g2 * g2).sqrt();
assert!((new_norm - 2.5).abs() < 1e-4);
}
#[test]
fn test_clip_grad_norm_returns_total_norm() {
let p = param_with_grad(&[3], &[1.0, 2.0, 2.0]);
let total = clip_grad_norm_(&[&p], 100.0, 2.0).unwrap();
assert!((total - 3.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_norm_skips_none_grads() {
let p_with = param_with_grad(&[2], &[3.0, 4.0]);
let p_without = Parameter::<f32>::zeros(&[2]).unwrap();
let total = clip_grad_norm_(&[&p_with, &p_without], 2.5, 2.0).unwrap();
assert!((total - 5.0).abs() < 1e-6);
}
#[test]
fn test_clip_grad_value_clamps_elements() {
let p = param_with_grad(&[4], &[-5.0, 0.5, 3.0, -0.1]);
clip_grad_value_(&[&p], 1.0).unwrap();
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - (-1.0)).abs() < 1e-6);
assert!((d[1] - 0.5).abs() < 1e-6);
assert!((d[2] - 1.0).abs() < 1e-6);
assert!((d[3] - (-0.1)).abs() < 1e-6);
}
#[test]
fn test_clip_grad_value_skips_none_grads() {
let p = Parameter::<f32>::zeros(&[3]).unwrap();
clip_grad_value_(&[&p], 1.0).unwrap();
assert!(p.grad().unwrap().is_none());
}
#[test]
fn test_clip_grad_value_preserves_within_range() {
let p = param_with_grad(&[3], &[0.1, -0.2, 0.3]);
clip_grad_value_(&[&p], 1.0).unwrap();
let g = p.grad().unwrap().unwrap();
let d = g.data().unwrap();
assert!((d[0] - 0.1).abs() < 1e-6);
assert!((d[1] - (-0.2)).abs() < 1e-6);
assert!((d[2] - 0.3).abs() < 1e-6);
}
}