concision_utils/utils/
gradient.rs1use ndarray::{Array, Dimension, ScalarOperand};
6use num_traits::Float;
7
8pub fn clip_gradient<A, D>(gradient: &mut Array<A, D>, threshold: A)
10where
11 A: Float + ScalarOperand,
12 D: Dimension,
13{
14 gradient.clamp(-threshold, threshold);
15}
16
17pub fn clip_inf_nan<A, D>(gradient: &mut Array<A, D>, threshold: A)
18where
19 A: Float + ScalarOperand,
20 D: Dimension,
21{
22 let norm = gradient.pow2().sum().sqrt();
23 gradient.mapv_inplace(|x| {
24 if x.is_nan() {
25 A::one() / norm
26 } else if x.is_infinite() {
27 threshold / norm
28 } else {
29 x
30 }
31 });
32}