concision_core/utils/
gradient.rs1use crate::traits::L2Norm;
6use ndarray::{Array, Dimension, ScalarOperand};
7use num_traits::Float;
8
9pub fn clip_gradient<A, D>(gradient: &mut Array<A, D>, threshold: A)
11where
12 A: Float + ScalarOperand,
13 D: Dimension,
14{
15 let norm = gradient.l2_norm();
16 if norm > threshold {
17 let scale = threshold / norm;
18 gradient.mapv_inplace(|x| x * scale);
19 }
20}
21
22pub fn clip_inf_nan<A, D>(gradient: &mut Array<A, D>, threshold: A)
23where
24 A: Float + ScalarOperand,
25 D: Dimension,
26{
27 let norm = gradient.l2_norm();
28 gradient.mapv_inplace(|x| {
29 if x.is_nan() {
30 A::one() / norm
31 } else if x.is_infinite() {
32 threshold / norm
33 } else {
34 x
35 }
36 });
37}