ndarray_utils/
pairwise.rs

1use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, NdIndex};
2
3pub trait PairwiseInplaceExt<A, S, SS, D>
4where
5    S: DataMut<Elem = A>,
6    SS: Data<Elem = A>,
7{
8    /// Takes the elementwise maximum with another array.
9    fn maximum_with_inplace(&mut self, other: &ArrayBase<SS, D>);
10
11    /// Takes the elementwise minimum with another array.
12    fn minimum_with_inplace(&mut self, other: &ArrayBase<SS, D>);
13}
14
15pub trait PairwiseExt<A, S, D>
16where
17    S: Data<Elem = A>,
18{
19    /// Returns the elementwise maximum with another array.
20    fn maximum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D>;
21
22    /// Returns the elementwise minimum with another array.
23    fn minimum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D>;
24}
25
26impl<A, S, D> PairwiseExt<A, S, D> for ArrayBase<S, D>
27where
28    A: PartialOrd + Copy,
29    S: Data<Elem = A>,
30    D: Dimension,
31    <D as Dimension>::Pattern: NdIndex<D>,
32{
33    fn maximum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D> {
34        let mut array = self.to_owned();
35        array.maximum_with_inplace(other);
36        array
37    }
38
39    fn minimum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D> {
40        let mut array = self.to_owned();
41        array.minimum_with_inplace(other);
42        array
43    }
44}
45
46impl<A, S, SS, D> PairwiseInplaceExt<A, S, SS, D> for ArrayBase<S, D>
47where
48    A: PartialOrd + Copy,
49    S: DataMut<Elem = A>,
50    SS: Data<Elem = A>,
51    D: Dimension,
52    <D as Dimension>::Pattern: NdIndex<D>,
53{
54    fn maximum_with_inplace(&mut self, other: &ArrayBase<SS, D>) {
55        for (i, val) in self.indexed_iter_mut() {
56            let o = &other[i];
57            if *val < *o {
58                *val = *o;
59            }
60        }
61    }
62
63    fn minimum_with_inplace(&mut self, other: &ArrayBase<SS, D>) {
64        for (i, val) in self.indexed_iter_mut() {
65            let o = &other[i];
66            if *val > *o {
67                *val = *o;
68            }
69        }
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use ndarray::array;
76
77    use super::*;
78
79    #[test]
80    fn pairwise() {
81        let lhs = array![1., 2., 3.];
82        let rhs = array![-1., 2., 5.];
83
84        assert_eq!(lhs.maximum_with(&rhs), array![1., 2., 5.]);
85        assert_eq!(lhs.minimum_with(&rhs), array![-1., 2., 3.]);
86    }
87
88    #[test]
89    fn inplace() {
90        let mut lhs = array![1, 2, 3];
91        let rhs = array![-1, 2, 5];
92
93        lhs.maximum_with_inplace(&rhs);
94        assert_eq!(lhs, array![1, 2, 5]);
95
96        lhs.minimum_with_inplace(&rhs);
97        assert_eq!(lhs, array![-1, 2, 5]);
98    }
99}