1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use crate::tensor::gen_tensor::GenTensor;
pub trait CompareTensor {
type TensorType;
fn max_pair(&self, o: &Self::TensorType) -> Self::TensorType;
fn min_pair(&self, o: &Self::TensorType) -> Self::TensorType;
}
impl<T> CompareTensor for GenTensor<T> where T: num_traits::Float {
type TensorType = GenTensor<T>;
fn max_pair(&self, o: &GenTensor<T>) -> GenTensor<T> {
if self.size() != o.size() {
panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size());
}
let mut ret = GenTensor::empty(&self.size());
for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) {
if a >= b {
*c = *a;
} else {
*c = *b;
}
}
ret
}
fn min_pair(&self, o: &GenTensor<T>) -> GenTensor<T> {
if self.size() != o.size() {
panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size());
}
let mut ret = GenTensor::empty(&self.size());
for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) {
if a >= b {
*c = *b;
} else {
*c = *a;
}
}
ret
}
}
#[cfg(test)]
mod tests {
use crate::tensor::gen_tensor::GenTensor;
use super::*;
#[test]
fn max_pair() {
let a = GenTensor::<f32>::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]);
let b = GenTensor::<f32>::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]);
let c = a.max_pair(&b);
assert_eq!(c, GenTensor::<f32>::new_raw(&vec![2., 4., 10., 11.], &vec![2,2]));
}
#[test]
fn min_pair() {
let a = GenTensor::<f32>::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]);
let b = GenTensor::<f32>::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]);
let c = a.min_pair(&b);
assert_eq!(c, GenTensor::<f32>::new_raw(&vec![1., 3., 5., 6.], &vec![2,2]));
}
}