use scivex_core::{Float, Tensor};
use crate::variable::Variable;
pub fn clip_grad_norm<T: Float>(parameters: &[Variable<T>], max_norm: T) -> T {
let mut total_norm_sq = T::zero();
for p in parameters {
if let Some(g) = p.grad() {
for &val in g.as_slice() {
total_norm_sq += val * val;
}
}
}
let total_norm = total_norm_sq.sqrt();
if total_norm > max_norm {
let scale = max_norm / total_norm;
for p in parameters {
if let Some(g) = p.grad() {
let clipped = &g * scale;
p.zero_grad();
p.acc_grad(&clipped);
}
}
}
total_norm
}
pub fn clip_grad_value<T: Float>(parameters: &[Variable<T>], clip_value: T) {
let neg_clip = T::zero() - clip_value;
for p in parameters {
if let Some(g) = p.grad() {
let data: Vec<T> = g
.as_slice()
.iter()
.map(|&v| {
if v > clip_value {
clip_value
} else if v < neg_clip {
neg_clip
} else {
v
}
})
.collect();
let clipped =
Tensor::from_vec(data, g.shape().to_vec()).expect("shape unchanged; cannot fail");
p.zero_grad();
p.acc_grad(&clipped);
}
}
}