ndarray_utils/
floats.rs

1use ndarray::{Array, ArrayBase, Axis, Data, DataMut, Dimension, RemoveAxis};
2use num_traits::Float;
3
4pub trait FillInPlaceExt<A, S, D>
5where
6    S: DataMut<Elem = A>,
7    A: Float,
8{
9    /// Fills non-finite floating point values (NaN, infinity, and negative
10    /// infinity) with the given replacement.
11    fn fill_non_finite_inplace(&mut self, with: A);
12}
13
14impl<A, S, D> FillInPlaceExt<A, S, D> for ArrayBase<S, D>
15where
16    A: Float,
17    S: DataMut<Elem = A>,
18    D: Dimension,
19{
20    fn fill_non_finite_inplace(&mut self, with: A) {
21        self.map_inplace(|x| {
22            if !x.is_finite() {
23                *x = with;
24            }
25        });
26    }
27}
28
29pub trait CountExt<A, S, D>
30where
31    S: Data<Elem = A>,
32    A: Float,
33{
34    /// Returns the number of finite values.
35    fn count_finite(&self) -> usize;
36
37    /// Returns the number of non-finite values.
38    fn count_non_finite(&self) -> usize;
39}
40
41impl<A, S, D> CountExt<A, S, D> for ArrayBase<S, D>
42where
43    A: Float,
44    S: Data<Elem = A>,
45    D: Dimension,
46{
47    fn count_finite(&self) -> usize {
48        self.fold(0, |a, b| a + if b.is_finite() { 1 } else { 0 })
49    }
50
51    fn count_non_finite(&self) -> usize {
52        self.fold(0, |a, b| a + if !b.is_finite() { 1 } else { 0 })
53    }
54}
55
56pub trait CountAxisExt<A, S, D>
57where
58    S: Data<Elem = A>,
59    A: Float,
60    D: Dimension + RemoveAxis,
61    <D as Dimension>::Smaller: Dimension,
62{
63    /// Returns the number of finite values for each index along the given axis.
64    /// For example, in a matrix, specifying Axis(0) will give the number of
65    /// finite values per row.
66    fn count_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller>;
67
68    /// Returns the number of non-finite values for each index along the given
69    /// axis.  For example, in a matrix, specifying Axis(0) will give the number
70    /// of non-finite values per row.
71    fn count_non_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller>;
72}
73
74impl<A, S, D> CountAxisExt<A, S, D> for ArrayBase<S, D>
75where
76    A: Float,
77    S: Data<Elem = A>,
78    D: Dimension + RemoveAxis,
79    <D as Dimension>::Smaller: Dimension,
80    Array<usize, <D as Dimension>::Smaller>: FromIterator<usize>,
81{
82    fn count_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller> {
83        self.axis_iter(axis)
84            .map(|view| view.count_finite())
85            .collect()
86    }
87
88    fn count_non_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller> {
89        self.axis_iter(axis)
90            .map(|view| view.count_non_finite())
91            .collect()
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use ndarray::array;
99    use std::f64::NAN;
100
101    #[test]
102    fn count_and_fill() {
103        let mut vals = array![1., 2., NAN, 3.];
104        assert_eq!(3, vals.count_finite());
105        assert_eq!(1, vals.count_non_finite());
106        vals.fill_non_finite_inplace(42.);
107        assert_eq!(vals, array![1., 2., 42., 3.]);
108        assert_eq!(4, vals.count_finite());
109        assert_eq!(0, vals.count_non_finite());
110    }
111
112    #[test]
113    fn count_matrix() {
114        let vals = array![[1., 2., NAN, 3.], [NAN, 4., 5., NAN]];
115        assert_eq!(5, vals.count_finite());
116        assert_eq!(array![3, 2], vals.count_finite_axis(Axis(0)));
117        assert_eq!(array![1, 2], vals.count_non_finite_axis(Axis(0)));
118        assert_eq!(array![1, 2, 1, 1], vals.count_finite_axis(Axis(1)));
119        assert_eq!(array![1, 0, 1, 1], vals.count_non_finite_axis(Axis(1)));
120    }
121}