rnn/gradients/
gradients.rs1#[derive(Clone, Copy, Debug, PartialEq, Eq)]
2pub enum GradientError {
3 Empty,
4 InvalidThreshold,
5}
6
7pub fn l2_norm(values: &[f32]) -> Result<f32, GradientError> {
8 if values.is_empty() {
9 return Err(GradientError::Empty);
10 }
11 let mut sum = 0.0f32;
12 for &v in values {
13 sum += v * v;
14 }
15 Ok(crate::math::sqrtf(sum))
16}
17
18pub fn clip_by_global_norm(values: &mut [f32], max_norm: f32) -> Result<f32, GradientError> {
19 if values.is_empty() {
20 return Err(GradientError::Empty);
21 }
22 if !max_norm.is_finite() || max_norm <= 0.0 {
23 return Err(GradientError::InvalidThreshold);
24 }
25
26 let norm = l2_norm(values)?;
27 if norm > max_norm {
28 let scale = max_norm / norm;
29 for v in values.iter_mut() {
30 *v *= scale;
31 }
32 }
33 Ok(norm)
34}
35
36pub fn all_finite(values: &[f32]) -> bool {
37 values.iter().all(|v| v.is_finite())
38}