concision_utils/utils/
gradient.rs

1/*
2    Appellation: gradient <module>
3    Contrib: @FL03
4*/
5use ndarray::{Array, Dimension, ScalarOperand};
6use num_traits::Float;
7
8/// Clip the gradient to a maximum value.
9pub fn clip_gradient<A, D>(gradient: &mut Array<A, D>, threshold: A)
10where
11    A: Float + ScalarOperand,
12    D: Dimension,
13{
14    gradient.clamp(-threshold, threshold);
15}
16
17pub fn clip_inf_nan<A, D>(gradient: &mut Array<A, D>, threshold: A)
18where
19    A: Float + ScalarOperand,
20    D: Dimension,
21{
22    let norm = gradient.pow2().sum().sqrt();
23    gradient.mapv_inplace(|x| {
24        if x.is_nan() {
25            A::one() / norm
26        } else if x.is_infinite() {
27            threshold / norm
28        } else {
29            x
30        }
31    });
32}