concision_core/utils/
gradient.rs

1/*
2    Appellation: gradient <module>
3    Contrib: @FL03
4*/
5use crate::traits::L2Norm;
6use ndarray::{Array, Dimension, ScalarOperand};
7use num_traits::Float;
8
9/// Clip the gradient to a maximum value.
10pub fn clip_gradient<A, D>(gradient: &mut Array<A, D>, threshold: A)
11where
12    A: Float + ScalarOperand,
13    D: Dimension,
14{
15    let norm = gradient.l2_norm();
16    if norm > threshold {
17        let scale = threshold / norm;
18        gradient.mapv_inplace(|x| x * scale);
19    }
20}
21
22pub fn clip_inf_nan<A, D>(gradient: &mut Array<A, D>, threshold: A)
23where
24    A: Float + ScalarOperand,
25    D: Dimension,
26{
27    let norm = gradient.l2_norm();
28    gradient.mapv_inplace(|x| {
29        if x.is_nan() {
30            A::one() / norm
31        } else if x.is_infinite() {
32            threshold / norm
33        } else {
34            x
35        }
36    });
37}