concision_core/traits/
clip.rs

1/*
2    Appellation: clip <module>
3    Contrib: @FL03
4*/
5
6/// A trait denoting objects capable of being _clipped_ between some minimum and some maximum.
7pub trait Clip<T> {
8    /// the type output produced by the operation
9    type Output;
10    /// limits the values within an object to a range between `min` and `max`
11    fn clip(&self, min: T, max: T) -> Self::Output;
12}
13
14/// This trait enables tensor clipping; it is implemented for `ArrayBase`
15pub trait ClipMut<T = f32> {
16    /// clip the tensor between the minimum and maximum values
17    fn clip_between(&mut self, min: T, max: T);
18
19    fn clip_inf_nan(&mut self, on_inf: T, on_nan: T);
20    /// clip the tensor between a boundary value, replacing any infinite or NaN values
21    fn clip_inf_nan_between(&mut self, boundary: T, on_inf: T, on_nan: T);
22    /// clip any infinite values in the tensor
23    fn clip_inf(&mut self, threshold: T);
24    /// clip the tensor to a maximum threshold
25    fn clip_max(&mut self, threshold: T);
26    /// clip the tensor to a minimum threshold
27    fn clip_min(&mut self, threshold: T);
28    /// this method normalizes the tensor then clips any values outside of the given threshold.
29    /// the tensor is normalized using the L1 norm
30    fn clip_norm_l1(&mut self, threshold: T);
31    /// this method normalizes the tensor then clips any values outside of the given threshold.
32    /// the tensor is normalized using the L2 norm
33    fn clip_norm_l2(&mut self, threshold: T);
34    /// clip any NaN values in the tensor
35    fn clip_nan(&mut self, threshold: T);
36}
37
38/*
39 ************* Implementations *************
40*/
41use super::{L1Norm, L2Norm};
42use ndarray::{ArrayBase, Dimension, ScalarOperand};
43use num_traits::Float;
44
45impl<A, S, D> Clip<A> for ArrayBase<S, D>
46where
47    A: 'static + Clone + PartialOrd,
48    S: ndarray::Data<Elem = A>,
49    D: Dimension,
50{
51    type Output = ndarray::Array<A, D>;
52
53    fn clip(&self, min: A, max: A) -> Self::Output {
54        self.clamp(min, max)
55    }
56}
57
58impl<A, S, D> ClipMut<A> for ArrayBase<S, D>
59where
60    A: Float + ScalarOperand,
61    S: ndarray::DataMut<Elem = A>,
62    D: Dimension,
63{
64    fn clip_between(&mut self, min: A, max: A) {
65        self.mapv_inplace(|x| {
66            if x < min {
67                min
68            } else if x > max {
69                max
70            } else {
71                x
72            }
73        });
74    }
75
76    fn clip_inf_nan(&mut self, on_inf: A, on_nan: A) {
77        self.mapv_inplace(|x| {
78            if x.is_nan() {
79                on_nan
80            } else if x.is_infinite() {
81                on_inf
82            } else {
83                x
84            }
85        });
86    }
87
88    fn clip_inf_nan_between(&mut self, boundary: A, on_inf: A, on_nan: A) {
89        self.mapv_inplace(|x| {
90            if x.is_nan() {
91                on_nan
92            } else if x.is_infinite() {
93                on_inf
94            } else if x < boundary.neg() {
95                boundary.neg()
96            } else if x > boundary {
97                boundary
98            } else {
99                x
100            }
101        });
102    }
103
104    fn clip_inf(&mut self, threshold: A) {
105        self.mapv_inplace(|x| if x.is_infinite() { threshold } else { x });
106    }
107
108    fn clip_max(&mut self, threshold: A) {
109        self.mapv_inplace(|x| if x > threshold { threshold } else { x });
110    }
111
112    fn clip_min(&mut self, threshold: A) {
113        self.mapv_inplace(|x| if x < threshold { threshold } else { x });
114    }
115
116    fn clip_nan(&mut self, threshold: A) {
117        self.mapv_inplace(|x| if x.is_nan() { threshold } else { x });
118    }
119
120    fn clip_norm_l1(&mut self, threshold: A) {
121        let norm = self.l1_norm();
122        if norm > threshold {
123            self.mapv_inplace(|x| x * threshold / norm);
124        }
125    }
126
127    fn clip_norm_l2(&mut self, threshold: A) {
128        let norm = self.l2_norm();
129        if norm > threshold {
130            self.mapv_inplace(|x| x * threshold / norm);
131        }
132    }
133}