Skip to main content

rnn/gradients/
gradients.rs

1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum GradientError {
3    Empty,
4    InvalidThreshold,
5}
6
7pub fn l2_norm(values: &[f32]) -> Result<f32, GradientError> {
8    if values.is_empty() {
9        return Err(GradientError::Empty);
10    }
11    let mut sum = 0.0f32;
12    for &v in values {
13        sum += v * v;
14    }
15    Ok(crate::math::sqrtf(sum))
16}
17
18pub fn clip_by_global_norm(values: &mut [f32], max_norm: f32) -> Result<f32, GradientError> {
19    if values.is_empty() {
20        return Err(GradientError::Empty);
21    }
22    if !max_norm.is_finite() || max_norm <= 0.0 {
23        return Err(GradientError::InvalidThreshold);
24    }
25
26    let norm = l2_norm(values)?;
27    if norm > max_norm {
28        let scale = max_norm / norm;
29        for v in values.iter_mut() {
30            *v *= scale;
31        }
32    }
33    Ok(norm)
34}
35
36pub fn all_finite(values: &[f32]) -> bool {
37    values.iter().all(|v| v.is_finite())
38}