ndarray_utils/
pairwise.rs1use ndarray::{Array, ArrayBase, Data, DataMut, Dimension, NdIndex};
2
3pub trait PairwiseInplaceExt<A, S, SS, D>
4where
5 S: DataMut<Elem = A>,
6 SS: Data<Elem = A>,
7{
8 fn maximum_with_inplace(&mut self, other: &ArrayBase<SS, D>);
10
11 fn minimum_with_inplace(&mut self, other: &ArrayBase<SS, D>);
13}
14
15pub trait PairwiseExt<A, S, D>
16where
17 S: Data<Elem = A>,
18{
19 fn maximum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D>;
21
22 fn minimum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D>;
24}
25
26impl<A, S, D> PairwiseExt<A, S, D> for ArrayBase<S, D>
27where
28 A: PartialOrd + Copy,
29 S: Data<Elem = A>,
30 D: Dimension,
31 <D as Dimension>::Pattern: NdIndex<D>,
32{
33 fn maximum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D> {
34 let mut array = self.to_owned();
35 array.maximum_with_inplace(other);
36 array
37 }
38
39 fn minimum_with(&self, other: &ArrayBase<S, D>) -> Array<A, D> {
40 let mut array = self.to_owned();
41 array.minimum_with_inplace(other);
42 array
43 }
44}
45
46impl<A, S, SS, D> PairwiseInplaceExt<A, S, SS, D> for ArrayBase<S, D>
47where
48 A: PartialOrd + Copy,
49 S: DataMut<Elem = A>,
50 SS: Data<Elem = A>,
51 D: Dimension,
52 <D as Dimension>::Pattern: NdIndex<D>,
53{
54 fn maximum_with_inplace(&mut self, other: &ArrayBase<SS, D>) {
55 for (i, val) in self.indexed_iter_mut() {
56 let o = &other[i];
57 if *val < *o {
58 *val = *o;
59 }
60 }
61 }
62
63 fn minimum_with_inplace(&mut self, other: &ArrayBase<SS, D>) {
64 for (i, val) in self.indexed_iter_mut() {
65 let o = &other[i];
66 if *val > *o {
67 *val = *o;
68 }
69 }
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use ndarray::array;
76
77 use super::*;
78
79 #[test]
80 fn pairwise() {
81 let lhs = array![1., 2., 3.];
82 let rhs = array![-1., 2., 5.];
83
84 assert_eq!(lhs.maximum_with(&rhs), array![1., 2., 5.]);
85 assert_eq!(lhs.minimum_with(&rhs), array![-1., 2., 3.]);
86 }
87
88 #[test]
89 fn inplace() {
90 let mut lhs = array![1, 2, 3];
91 let rhs = array![-1, 2, 5];
92
93 lhs.maximum_with_inplace(&rhs);
94 assert_eq!(lhs, array![1, 2, 5]);
95
96 lhs.minimum_with_inplace(&rhs);
97 assert_eq!(lhs, array![-1, 2, 5]);
98 }
99}