#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GradientError {
Empty,
InvalidThreshold,
}
pub fn l2_norm(values: &[f32]) -> Result<f32, GradientError> {
if values.is_empty() {
return Err(GradientError::Empty);
}
let mut sum = 0.0f32;
for &v in values {
sum += v * v;
}
Ok(crate::math::sqrtf(sum))
}
pub fn clip_by_global_norm(values: &mut [f32], max_norm: f32) -> Result<f32, GradientError> {
if values.is_empty() {
return Err(GradientError::Empty);
}
if !max_norm.is_finite() || max_norm <= 0.0 {
return Err(GradientError::InvalidThreshold);
}
let norm = l2_norm(values)?;
if norm > max_norm {
let scale = max_norm / norm;
for v in values.iter_mut() {
*v *= scale;
}
}
Ok(norm)
}
pub fn all_finite(values: &[f32]) -> bool {
values.iter().all(|v| v.is_finite())
}