concision_core/traits/
clip.rs

1/*
2    Appellation: clip <module>
3    Contrib: @FL03
4*/
5
6pub trait Clip<T> {
7    fn clip(&self, min: T, max: T) -> Self;
8}
9
10/// This trait enables tensor clipping; it is implemented for `ArrayBase`
11pub trait ClipMut<T = f32> {
12    /// clip the tensor between the minimum and maximum values
13    fn clip_between(&mut self, min: T, max: T);
14
15    fn clip_inf_nan(&mut self, on_inf: T, on_nan: T);
16    /// clip the tensor between a boundary value, replacing any infinite or NaN values
17    fn clip_inf_nan_between(&mut self, boundary: T, on_inf: T, on_nan: T);
18    /// clip any infinite values in the tensor
19    fn clip_inf(&mut self, threshold: T);
20    /// clip the tensor to a maximum threshold
21    fn clip_max(&mut self, threshold: T);
22    /// clip the tensor to a minimum threshold
23    fn clip_min(&mut self, threshold: T);
24    /// this method normalizes the tensor then clips any values outside of the given threshold.
25    /// the tensor is normalized using the L1 norm
26    fn clip_norm_l1(&mut self, threshold: T);
27    /// this method normalizes the tensor then clips any values outside of the given threshold.
28    /// the tensor is normalized using the L2 norm
29    fn clip_norm_l2(&mut self, threshold: T);
30    /// clip any NaN values in the tensor
31    fn clip_nan(&mut self, threshold: T);
32}
33
34/*
35 ************* Implementations *************
36*/
37use super::{L1Norm, L2Norm};
38use ndarray::{ArrayBase, DataMut, Dimension, ScalarOperand};
39use num_traits::Float;
40
41impl<A, S, D> ClipMut<A> for ArrayBase<S, D>
42where
43    A: Float + ScalarOperand,
44    S: DataMut<Elem = A>,
45    D: Dimension,
46{
47    fn clip_between(&mut self, min: A, max: A) {
48        self.mapv_inplace(|x| {
49            if x < min {
50                min
51            } else if x > max {
52                max
53            } else {
54                x
55            }
56        });
57    }
58
59    fn clip_inf_nan(&mut self, on_inf: A, on_nan: A) {
60        self.mapv_inplace(|x| {
61            if x.is_nan() {
62                on_nan
63            } else if x.is_infinite() {
64                on_inf
65            } else {
66                x
67            }
68        });
69    }
70
71    fn clip_inf_nan_between(&mut self, boundary: A, on_inf: A, on_nan: A) {
72        self.mapv_inplace(|x| {
73            if x.is_nan() {
74                on_nan
75            } else if x.is_infinite() {
76                on_inf
77            } else if x < boundary.neg() {
78                boundary.neg()
79            } else if x > boundary {
80                boundary
81            } else {
82                x
83            }
84        });
85    }
86
87    fn clip_inf(&mut self, threshold: A) {
88        self.mapv_inplace(|x| if x.is_infinite() { threshold } else { x });
89    }
90
91    fn clip_max(&mut self, threshold: A) {
92        self.mapv_inplace(|x| if x > threshold { threshold } else { x });
93    }
94
95    fn clip_min(&mut self, threshold: A) {
96        self.mapv_inplace(|x| if x < threshold { threshold } else { x });
97    }
98
99    fn clip_nan(&mut self, threshold: A) {
100        self.mapv_inplace(|x| if x.is_nan() { threshold } else { x });
101    }
102
103    fn clip_norm_l1(&mut self, threshold: A) {
104        let norm = self.l1_norm();
105        if norm > threshold {
106            self.mapv_inplace(|x| x * threshold / norm);
107        }
108    }
109
110    fn clip_norm_l2(&mut self, threshold: A) {
111        let norm = self.l2_norm();
112        if norm > threshold {
113            self.mapv_inplace(|x| x * threshold / norm);
114        }
115    }
116}