ndarray_stats/
deviation.rs

1use ndarray::{ArrayRef, Dimension, Zip};
2use num_traits::{Signed, ToPrimitive};
3use std::convert::Into;
4use std::ops::AddAssign;
5
6use crate::errors::MultiInputError;
7
8/// An extension trait for `ndarray` providing functions
9/// to compute different deviation measures.
10pub trait DeviationExt<A, D>
11where
12    D: Dimension,
13{
14    /// Counts the number of indices at which the elements of the arrays `self`
15    /// and `other` are equal.
16    ///
17    /// The following **errors** may be returned:
18    ///
19    /// * `MultiInputError::EmptyInput` if `self` is empty
20    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
21    fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
22    where
23        A: PartialEq;
24
25    /// Counts the number of indices at which the elements of the arrays `self`
26    /// and `other` are not equal.
27    ///
28    /// The following **errors** may be returned:
29    ///
30    /// * `MultiInputError::EmptyInput` if `self` is empty
31    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
32    fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
33    where
34        A: PartialEq;
35
36    /// Computes the [squared L2 distance] between `self` and `other`.
37    ///
38    /// ```text
39    ///  n
40    ///  ∑  |aᵢ - bᵢ|²
41    /// i=1
42    /// ```
43    ///
44    /// where `self` is `a` and `other` is `b`.
45    ///
46    /// The following **errors** may be returned:
47    ///
48    /// * `MultiInputError::EmptyInput` if `self` is empty
49    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
50    ///
51    /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
52    fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
53    where
54        A: AddAssign + Clone + Signed;
55
56    /// Computes the [L2 distance] between `self` and `other`.
57    ///
58    /// ```text
59    ///      n
60    /// √ (  ∑  |aᵢ - bᵢ|² )
61    ///     i=1
62    /// ```
63    ///
64    /// where `self` is `a` and `other` is `b`.
65    ///
66    /// The following **errors** may be returned:
67    ///
68    /// * `MultiInputError::EmptyInput` if `self` is empty
69    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
70    ///
71    /// **Panics** if the type cast from `A` to `f64` fails.
72    ///
73    /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
74    fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
75    where
76        A: AddAssign + Clone + Signed + ToPrimitive;
77
78    /// Computes the [L1 distance] between `self` and `other`.
79    ///
80    /// ```text
81    ///  n
82    ///  ∑  |aᵢ - bᵢ|
83    /// i=1
84    /// ```
85    ///
86    /// where `self` is `a` and `other` is `b`.
87    ///
88    /// The following **errors** may be returned:
89    ///
90    /// * `MultiInputError::EmptyInput` if `self` is empty
91    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
92    ///
93    /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
94    fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
95    where
96        A: AddAssign + Clone + Signed;
97
98    /// Computes the [L∞ distance] between `self` and `other`.
99    ///
100    /// ```text
101    /// max(|aᵢ - bᵢ|)
102    ///  ᵢ
103    /// ```
104    ///
105    /// where `self` is `a` and `other` is `b`.
106    ///
107    /// The following **errors** may be returned:
108    ///
109    /// * `MultiInputError::EmptyInput` if `self` is empty
110    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
111    ///
112    /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
113    fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
114    where
115        A: Clone + PartialOrd + Signed;
116
117    /// Computes the [mean absolute error] between `self` and `other`.
118    ///
119    /// ```text
120    ///        n
121    /// 1/n *  ∑  |aᵢ - bᵢ|
122    ///       i=1
123    /// ```
124    ///
125    /// where `self` is `a` and `other` is `b`.
126    ///
127    /// The following **errors** may be returned:
128    ///
129    /// * `MultiInputError::EmptyInput` if `self` is empty
130    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
131    ///
132    /// **Panics** if the type cast from `A` to `f64` fails.
133    ///
134    /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
135    fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
136    where
137        A: AddAssign + Clone + Signed + ToPrimitive;
138
139    /// Computes the [mean squared error] between `self` and `other`.
140    ///
141    /// ```text
142    ///        n
143    /// 1/n *  ∑  |aᵢ - bᵢ|²
144    ///       i=1
145    /// ```
146    ///
147    /// where `self` is `a` and `other` is `b`.
148    ///
149    /// The following **errors** may be returned:
150    ///
151    /// * `MultiInputError::EmptyInput` if `self` is empty
152    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
153    ///
154    /// **Panics** if the type cast from `A` to `f64` fails.
155    ///
156    /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
157    fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
158    where
159        A: AddAssign + Clone + Signed + ToPrimitive;
160
161    /// Computes the unnormalized [root-mean-square error] between `self` and `other`.
162    ///
163    /// ```text
164    /// √ mse(a, b)
165    /// ```
166    ///
167    /// where `self` is `a`, `other` is `b` and `mse` is the mean-squared-error.
168    ///
169    /// The following **errors** may be returned:
170    ///
171    /// * `MultiInputError::EmptyInput` if `self` is empty
172    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
173    ///
174    /// **Panics** if the type cast from `A` to `f64` fails.
175    ///
176    /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
177    fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
178    where
179        A: AddAssign + Clone + Signed + ToPrimitive;
180
181    /// Computes the [peak signal-to-noise ratio] between `self` and `other`.
182    ///
183    /// ```text
184    /// 10 * log10(maxv^2 / mse(a, b))
185    /// ```
186    ///
187    /// where `self` is `a`, `other` is `b`, `mse` is the mean-squared-error
188    /// and `maxv` is the maximum possible value either array can take.
189    ///
190    /// The following **errors** may be returned:
191    ///
192    /// * `MultiInputError::EmptyInput` if `self` is empty
193    /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
194    ///
195    /// **Panics** if the type cast from `A` to `f64` fails.
196    ///
197    /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
198    fn peak_signal_to_noise_ratio(
199        &self,
200        other: &ArrayRef<A, D>,
201        maxv: A,
202    ) -> Result<f64, MultiInputError>
203    where
204        A: AddAssign + Clone + Signed + ToPrimitive;
205
206    private_decl! {}
207}
208
209impl<A, D> DeviationExt<A, D> for ArrayRef<A, D>
210where
211    D: Dimension,
212{
213    fn count_eq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
214    where
215        A: PartialEq,
216    {
217        return_err_if_empty!(self);
218        return_err_unless_same_shape!(self, other);
219
220        let mut count = 0;
221
222        Zip::from(self).and(other).for_each(|a, b| {
223            if a == b {
224                count += 1;
225            }
226        });
227
228        Ok(count)
229    }
230
231    fn count_neq(&self, other: &ArrayRef<A, D>) -> Result<usize, MultiInputError>
232    where
233        A: PartialEq,
234    {
235        self.count_eq(other).map(|n_eq| self.len() - n_eq)
236    }
237
238    fn sq_l2_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
239    where
240        A: AddAssign + Clone + Signed,
241    {
242        return_err_if_empty!(self);
243        return_err_unless_same_shape!(self, other);
244
245        let mut result = A::zero();
246
247        Zip::from(self).and(other).for_each(|self_i, other_i| {
248            let (a, b) = (self_i.clone(), other_i.clone());
249            let diff = a - b;
250            result += diff.clone() * diff;
251        });
252
253        Ok(result)
254    }
255
256    fn l2_dist(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
257    where
258        A: AddAssign + Clone + Signed + ToPrimitive,
259    {
260        let sq_l2_dist = self
261            .sq_l2_dist(other)?
262            .to_f64()
263            .expect("failed cast from type A to f64");
264
265        Ok(sq_l2_dist.sqrt())
266    }
267
268    fn l1_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
269    where
270        A: AddAssign + Clone + Signed,
271    {
272        return_err_if_empty!(self);
273        return_err_unless_same_shape!(self, other);
274
275        let mut result = A::zero();
276
277        Zip::from(self).and(other).for_each(|self_i, other_i| {
278            let (a, b) = (self_i.clone(), other_i.clone());
279            result += (a - b).abs();
280        });
281
282        Ok(result)
283    }
284
285    fn linf_dist(&self, other: &ArrayRef<A, D>) -> Result<A, MultiInputError>
286    where
287        A: Clone + PartialOrd + Signed,
288    {
289        return_err_if_empty!(self);
290        return_err_unless_same_shape!(self, other);
291
292        let mut max = A::zero();
293
294        Zip::from(self).and(other).for_each(|self_i, other_i| {
295            let (a, b) = (self_i.clone(), other_i.clone());
296            let diff = (a - b).abs();
297            if diff > max {
298                max = diff;
299            }
300        });
301
302        Ok(max)
303    }
304
305    fn mean_abs_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
306    where
307        A: AddAssign + Clone + Signed + ToPrimitive,
308    {
309        let l1_dist = self
310            .l1_dist(other)?
311            .to_f64()
312            .expect("failed cast from type A to f64");
313        let n = self.len() as f64;
314
315        Ok(l1_dist / n)
316    }
317
318    fn mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
319    where
320        A: AddAssign + Clone + Signed + ToPrimitive,
321    {
322        let sq_l2_dist = self
323            .sq_l2_dist(other)?
324            .to_f64()
325            .expect("failed cast from type A to f64");
326        let n = self.len() as f64;
327
328        Ok(sq_l2_dist / n)
329    }
330
331    fn root_mean_sq_err(&self, other: &ArrayRef<A, D>) -> Result<f64, MultiInputError>
332    where
333        A: AddAssign + Clone + Signed + ToPrimitive,
334    {
335        let msd = self.mean_sq_err(other)?;
336        Ok(msd.sqrt())
337    }
338
339    fn peak_signal_to_noise_ratio(
340        &self,
341        other: &ArrayRef<A, D>,
342        maxv: A,
343    ) -> Result<f64, MultiInputError>
344    where
345        A: AddAssign + Clone + Signed + ToPrimitive,
346    {
347        let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
348        let msd = self.mean_sq_err(&other)?;
349        let psnr = 10. * f64::log10(maxv_f * maxv_f / msd);
350
351        Ok(psnr)
352    }
353
354    private_impl! {}
355}