concision_utils/utils/
gradient.rs1use ndarray::{Array, Dimension, ScalarOperand};
6use num_traits::Float;
7
8pub fn clip_gradient<A, D>(gradient: &mut Array<A, D>, threshold: A)
10where
11    A: Float + ScalarOperand,
12    D: Dimension,
13{
14    gradient.clamp(-threshold, threshold);
15}
16
17pub fn clip_inf_nan<A, D>(gradient: &mut Array<A, D>, threshold: A)
18where
19    A: Float + ScalarOperand,
20    D: Dimension,
21{
22    let norm = gradient.pow2().sum().sqrt();
23    gradient.mapv_inplace(|x| {
24        if x.is_nan() {
25            A::one() / norm
26        } else if x.is_infinite() {
27            threshold / norm
28        } else {
29            x
30        }
31    });
32}