ndarray_stats/deviation.rs
1use ndarray::{ArrayRef, Dimension, Zip};
2use num_traits::{Signed, ToPrimitive};
3use std::convert::Into;
4use std::ops::AddAssign;
5
6use crate::errors::MultiInputError;
7
8/// An extension trait for `ndarray` providing functions
9/// to compute different deviation measures.
10pub trait DeviationExt<A, D>
11where
12 D: Dimension,
13{
14 /// Counts the number of indices at which the elements of the arrays `self`
15 /// and `other` are equal.
16 ///
17 /// The following **errors** may be returned:
18 ///
19 /// * `MultiInputError::EmptyInput` if `self` is empty
20 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
21 fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
22 where
23 A: PartialEq;
24
25 /// Counts the number of indices at which the elements of the arrays `self`
26 /// and `other` are not equal.
27 ///
28 /// The following **errors** may be returned:
29 ///
30 /// * `MultiInputError::EmptyInput` if `self` is empty
31 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
32 fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
33 where
34 A: PartialEq;
35
36 /// Computes the [squared L2 distance] between `self` and `other`.
37 ///
38 /// ```text
39 /// n
40 /// ∑ |aᵢ - bᵢ|²
41 /// i=1
42 /// ```
43 ///
44 /// where `self` is `a` and `other` is `b`.
45 ///
46 /// The following **errors** may be returned:
47 ///
48 /// * `MultiInputError::EmptyInput` if `self` is empty
49 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
50 ///
51 /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
52 fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
53 where
54 A: AddAssign + Clone + Signed;
55
56 /// Computes the [L2 distance] between `self` and `other`.
57 ///
58 /// ```text
59 /// n
60 /// √ ( ∑ |aᵢ - bᵢ|² )
61 /// i=1
62 /// ```
63 ///
64 /// where `self` is `a` and `other` is `b`.
65 ///
66 /// The following **errors** may be returned:
67 ///
68 /// * `MultiInputError::EmptyInput` if `self` is empty
69 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
70 ///
71 /// **Panics** if the type cast from `A` to `f64` fails.
72 ///
73 /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
74 fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
75 where
76 A: AddAssign + Clone + Signed + ToPrimitive;
77
78 /// Computes the [L1 distance] between `self` and `other`.
79 ///
80 /// ```text
81 /// n
82 /// ∑ |aᵢ - bᵢ|
83 /// i=1
84 /// ```
85 ///
86 /// where `self` is `a` and `other` is `b`.
87 ///
88 /// The following **errors** may be returned:
89 ///
90 /// * `MultiInputError::EmptyInput` if `self` is empty
91 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
92 ///
93 /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
94 fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
95 where
96 A: AddAssign + Clone + Signed;
97
98 /// Computes the [L∞ distance] between `self` and `other`.
99 ///
100 /// ```text
101 /// max(|aᵢ - bᵢ|)
102 /// ᵢ
103 /// ```
104 ///
105 /// where `self` is `a` and `other` is `b`.
106 ///
107 /// The following **errors** may be returned:
108 ///
109 /// * `MultiInputError::EmptyInput` if `self` is empty
110 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
111 ///
112 /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
113 fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
114 where
115 A: Clone + PartialOrd + Signed;
116
117 /// Computes the [mean absolute error] between `self` and `other`.
118 ///
119 /// ```text
120 /// n
121 /// 1/n * ∑ |aᵢ - bᵢ|
122 /// i=1
123 /// ```
124 ///
125 /// where `self` is `a` and `other` is `b`.
126 ///
127 /// The following **errors** may be returned:
128 ///
129 /// * `MultiInputError::EmptyInput` if `self` is empty
130 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
131 ///
132 /// **Panics** if the type cast from `A` to `f64` fails.
133 ///
134 /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
135 fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
136 where
137 A: AddAssign + Clone + Signed + ToPrimitive;
138
139 /// Computes the [mean squared error] between `self` and `other`.
140 ///
141 /// ```text
142 /// n
143 /// 1/n * ∑ |aᵢ - bᵢ|²
144 /// i=1
145 /// ```
146 ///
147 /// where `self` is `a` and `other` is `b`.
148 ///
149 /// The following **errors** may be returned:
150 ///
151 /// * `MultiInputError::EmptyInput` if `self` is empty
152 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
153 ///
154 /// **Panics** if the type cast from `A` to `f64` fails.
155 ///
156 /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
157 fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
158 where
159 A: AddAssign + Clone + Signed + ToPrimitive;
160
161 /// Computes the unnormalized [root-mean-square error] between `self` and `other`.
162 ///
163 /// ```text
164 /// √ mse(a, b)
165 /// ```
166 ///
167 /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
168 ///
169 /// The following **errors** may be returned:
170 ///
171 /// * `MultiInputError::EmptyInput` if `self` is empty
172 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
173 ///
174 /// **Panics** if the type cast from `A` to `f64` fails.
175 ///
176 /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
177 fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
178 where
179 A: AddAssign + Clone + Signed + ToPrimitive;
180
181 /// Computes the [peak signal-to-noise ratio] between `self` and `other`.
182 ///
183 /// ```text
184 /// 10 * log10(maxv^2 / mse(a, b))
185 /// ```
186 ///
187 /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
188 /// and `maxv` is the maximum possible value either array can take.
189 ///
190 /// The following **errors** may be returned:
191 ///
192 /// * `MultiInputError::EmptyInput` if `self` is empty
193 /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
194 ///
195 /// **Panics** if the type cast from `A` to `f64` fails.
196 ///
197 /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
198 fn peak_signal_to_noise_ratio(
199 &self,
200 other: &ArrayRef<A, D>,
201 maxv: A,
202 ) -> Result<f64, MultiInputError>
203 where
204 A: AddAssign + Clone + Signed + ToPrimitive;
205
206 private_decl! {}
207}
208
209impl<A, D> DeviationExt<A, D> for ArrayRef<A, D>
210where
211 D: Dimension,
212{
213 fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
214 where
215 A: PartialEq,
216 {
217 return_err_if_empty!(self);
218 return_err_unless_same_shape!(self, other);
219
220 let mut count = 0;
221
222 Zip::from(self).and(other).for_each(|a, b| {
223 if a == b {
224 count += 1;
225 }
226 });
227
228 Ok(count)
229 }
230
231 fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
232 where
233 A: PartialEq,
234 {
235 self.count_eq(other).map(|n_eq| self.len() - n_eq)
236 }
237
238 fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
239 where
240 A: AddAssign + Clone + Signed,
241 {
242 return_err_if_empty!(self);
243 return_err_unless_same_shape!(self, other);
244
245 let mut result = A::zero();
246
247 Zip::from(self).and(other).for_each(|self_i, other_i| {
248 let (a, b) = (self_i.clone(), other_i.clone());
249 let diff = a - b;
250 result += diff.clone() * diff;
251 });
252
253 Ok(result)
254 }
255
256 fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
257 where
258 A: AddAssign + Clone + Signed + ToPrimitive,
259 {
260 let sq_l2_dist = self
261 .sq_l2_dist(other)?
262 .to_f64()
263 .expect("failed cast from type A to f64");
264
265 Ok(sq_l2_dist.sqrt())
266 }
267
268 fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
269 where
270 A: AddAssign + Clone + Signed,
271 {
272 return_err_if_empty!(self);
273 return_err_unless_same_shape!(self, other);
274
275 let mut result = A::zero();
276
277 Zip::from(self).and(other).for_each(|self_i, other_i| {
278 let (a, b) = (self_i.clone(), other_i.clone());
279 result += (a - b).abs();
280 });
281
282 Ok(result)
283 }
284
285 fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
286 where
287 A: Clone + PartialOrd + Signed,
288 {
289 return_err_if_empty!(self);
290 return_err_unless_same_shape!(self, other);
291
292 let mut max = A::zero();
293
294 Zip::from(self).and(other).for_each(|self_i, other_i| {
295 let (a, b) = (self_i.clone(), other_i.clone());
296 let diff = (a - b).abs();
297 if diff > max {
298 max = diff;
299 }
300 });
301
302 Ok(max)
303 }
304
305 fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
306 where
307 A: AddAssign + Clone + Signed + ToPrimitive,
308 {
309 let l1_dist = self
310 .l1_dist(other)?
311 .to_f64()
312 .expect("failed cast from type A to f64");
313 let n = self.len() as f64;
314
315 Ok(l1_dist / n)
316 }
317
318 fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
319 where
320 A: AddAssign + Clone + Signed + ToPrimitive,
321 {
322 let sq_l2_dist = self
323 .sq_l2_dist(other)?
324 .to_f64()
325 .expect("failed cast from type A to f64");
326 let n = self.len() as f64;
327
328 Ok(sq_l2_dist / n)
329 }
330
331 fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
332 where
333 A: AddAssign + Clone + Signed + ToPrimitive,
334 {
335 let msd = self.mean_sq_err(other)?;
336 Ok(msd.sqrt())
337 }
338
339 fn peak_signal_to_noise_ratio(
340 &self,
341 other: &ArrayRef<A, D>,
342 maxv: A,
343 ) -> Result<f64, MultiInputError>
344 where
345 A: AddAssign + Clone + Signed + ToPrimitive,
346 {
347 let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
348 let msd = self.mean_sq_err(&other)?;
349 let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
350
351 Ok(psnr)
352 }
353
354 private_impl! {}
355}