1use ndarray::{Array, ArrayBase, Axis, Data, DataMut, Dimension, RemoveAxis};
2use num_traits::Float;
3
4pub trait FillInPlaceExt<A, S, D>
5where
6 S: DataMut<Elem = A>,
7 A: Float,
8{
9 fn fill_non_finite_inplace(&mut self, with: A);
12}
13
14impl<A, S, D> FillInPlaceExt<A, S, D> for ArrayBase<S, D>
15where
16 A: Float,
17 S: DataMut<Elem = A>,
18 D: Dimension,
19{
20 fn fill_non_finite_inplace(&mut self, with: A) {
21 self.map_inplace(|x| {
22 if !x.is_finite() {
23 *x = with;
24 }
25 });
26 }
27}
28
29pub trait CountExt<A, S, D>
30where
31 S: Data<Elem = A>,
32 A: Float,
33{
34 fn count_finite(&self) -> usize;
36
37 fn count_non_finite(&self) -> usize;
39}
40
41impl<A, S, D> CountExt<A, S, D> for ArrayBase<S, D>
42where
43 A: Float,
44 S: Data<Elem = A>,
45 D: Dimension,
46{
47 fn count_finite(&self) -> usize {
48 self.fold(0, |a, b| a + if b.is_finite() { 1 } else { 0 })
49 }
50
51 fn count_non_finite(&self) -> usize {
52 self.fold(0, |a, b| a + if !b.is_finite() { 1 } else { 0 })
53 }
54}
55
56pub trait CountAxisExt<A, S, D>
57where
58 S: Data<Elem = A>,
59 A: Float,
60 D: Dimension + RemoveAxis,
61 <D as Dimension>::Smaller: Dimension,
62{
63 fn count_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller>;
67
68 fn count_non_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller>;
72}
73
74impl<A, S, D> CountAxisExt<A, S, D> for ArrayBase<S, D>
75where
76 A: Float,
77 S: Data<Elem = A>,
78 D: Dimension + RemoveAxis,
79 <D as Dimension>::Smaller: Dimension,
80 Array<usize, <D as Dimension>::Smaller>: FromIterator<usize>,
81{
82 fn count_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller> {
83 self.axis_iter(axis)
84 .map(|view| view.count_finite())
85 .collect()
86 }
87
88 fn count_non_finite_axis(&self, axis: Axis) -> Array<usize, D::Smaller> {
89 self.axis_iter(axis)
90 .map(|view| view.count_non_finite())
91 .collect()
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use super::*;
98 use ndarray::array;
99 use std::f64::NAN;
100
101 #[test]
102 fn count_and_fill() {
103 let mut vals = array![1., 2., NAN, 3.];
104 assert_eq!(3, vals.count_finite());
105 assert_eq!(1, vals.count_non_finite());
106 vals.fill_non_finite_inplace(42.);
107 assert_eq!(vals, array![1., 2., 42., 3.]);
108 assert_eq!(4, vals.count_finite());
109 assert_eq!(0, vals.count_non_finite());
110 }
111
112 #[test]
113 fn count_matrix() {
114 let vals = array![[1., 2., NAN, 3.], [NAN, 4., 5., NAN]];
115 assert_eq!(5, vals.count_finite());
116 assert_eq!(array![3, 2], vals.count_finite_axis(Axis(0)));
117 assert_eq!(array![1, 2], vals.count_non_finite_axis(Axis(0)));
118 assert_eq!(array![1, 2, 1, 1], vals.count_finite_axis(Axis(1)));
119 assert_eq!(array![1, 0, 1, 1], vals.count_non_finite_axis(Axis(1)));
120 }
121}