concision_core/traits/
clip.rs1pub trait Clip<T> {
8 type Output;
10 fn clip(&self, min: T, max: T) -> Self::Output;
12}
13
14pub trait ClipMut<T = f32> {
16 fn clip_between(&mut self, min: T, max: T);
18
19 fn clip_inf_nan(&mut self, on_inf: T, on_nan: T);
20 fn clip_inf_nan_between(&mut self, boundary: T, on_inf: T, on_nan: T);
22 fn clip_inf(&mut self, threshold: T);
24 fn clip_max(&mut self, threshold: T);
26 fn clip_min(&mut self, threshold: T);
28 fn clip_norm_l1(&mut self, threshold: T);
31 fn clip_norm_l2(&mut self, threshold: T);
34 fn clip_nan(&mut self, threshold: T);
36}
37
38use super::{L1Norm, L2Norm};
42use ndarray::{ArrayBase, Dimension, ScalarOperand};
43use num_traits::Float;
44
45impl<A, S, D> Clip<A> for ArrayBase<S, D>
46where
47 A: 'static + Clone + PartialOrd,
48 S: ndarray::Data<Elem = A>,
49 D: Dimension,
50{
51 type Output = ndarray::Array<A, D>;
52
53 fn clip(&self, min: A, max: A) -> Self::Output {
54 self.clamp(min, max)
55 }
56}
57
58impl<A, S, D> ClipMut<A> for ArrayBase<S, D>
59where
60 A: Float + ScalarOperand,
61 S: ndarray::DataMut<Elem = A>,
62 D: Dimension,
63{
64 fn clip_between(&mut self, min: A, max: A) {
65 self.mapv_inplace(|x| {
66 if x < min {
67 min
68 } else if x > max {
69 max
70 } else {
71 x
72 }
73 });
74 }
75
76 fn clip_inf_nan(&mut self, on_inf: A, on_nan: A) {
77 self.mapv_inplace(|x| {
78 if x.is_nan() {
79 on_nan
80 } else if x.is_infinite() {
81 on_inf
82 } else {
83 x
84 }
85 });
86 }
87
88 fn clip_inf_nan_between(&mut self, boundary: A, on_inf: A, on_nan: A) {
89 self.mapv_inplace(|x| {
90 if x.is_nan() {
91 on_nan
92 } else if x.is_infinite() {
93 on_inf
94 } else if x < boundary.neg() {
95 boundary.neg()
96 } else if x > boundary {
97 boundary
98 } else {
99 x
100 }
101 });
102 }
103
104 fn clip_inf(&mut self, threshold: A) {
105 self.mapv_inplace(|x| if x.is_infinite() { threshold } else { x });
106 }
107
108 fn clip_max(&mut self, threshold: A) {
109 self.mapv_inplace(|x| if x > threshold { threshold } else { x });
110 }
111
112 fn clip_min(&mut self, threshold: A) {
113 self.mapv_inplace(|x| if x < threshold { threshold } else { x });
114 }
115
116 fn clip_nan(&mut self, threshold: A) {
117 self.mapv_inplace(|x| if x.is_nan() { threshold } else { x });
118 }
119
120 fn clip_norm_l1(&mut self, threshold: A) {
121 let norm = self.l1_norm();
122 if norm > threshold {
123 self.mapv_inplace(|x| x * threshold / norm);
124 }
125 }
126
127 fn clip_norm_l2(&mut self, threshold: A) {
128 let norm = self.l2_norm();
129 if norm > threshold {
130 self.mapv_inplace(|x| x * threshold / norm);
131 }
132 }
133}