1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
use crate::tensor::gen_tensor::GenTensor;

pub trait CompareTensor {
    type TensorType;

    fn max_pair(&self, o: &Self::TensorType) -> Self::TensorType;
    fn min_pair(&self, o: &Self::TensorType) -> Self::TensorType;
    
}

impl<T> CompareTensor for GenTensor<T> where T: num_traits::Float {
    type TensorType = GenTensor<T>;
    
    fn max_pair(&self, o: &GenTensor<T>) -> GenTensor<T> {
        if self.size() != o.size() {
            panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size());
        }
        let mut ret = GenTensor::empty(&self.size());

        for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) {
            if a >= b {
                *c = *a;
            } else {
                *c = *b;
            }
        }
        ret
    }
    // min, 
    fn min_pair(&self, o: &GenTensor<T>) -> GenTensor<T> {
        if self.size() != o.size() {
            panic!("max needs two tensor have the same size, {:?}, {:?}", self.size(), o.size());
        }
        let mut ret = GenTensor::empty(&self.size());

        for ((a, b), c) in self.get_data().iter().zip(o.get_data().iter()).zip(ret.get_data_mut().iter_mut()) {
            if a >= b {
                *c = *b;
            } else {
                *c = *a;
            }
        }
        ret
    }
}

#[cfg(test)]
mod tests {
    use crate::tensor::gen_tensor::GenTensor;
    use super::*;
    
    #[test]
    fn max_pair() {
        let a = GenTensor::<f32>::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]);
        let b = GenTensor::<f32>::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]);
        let c = a.max_pair(&b);
        assert_eq!(c, GenTensor::<f32>::new_raw(&vec![2., 4., 10., 11.], &vec![2,2]));
    }

    #[test]
    fn min_pair() {
        let a = GenTensor::<f32>::new_raw(&vec![1., 3., 10., 11.], &vec![2,2]);
        let b = GenTensor::<f32>::new_raw(&vec![2., 4., 5., 6.], &vec![2,2]);
        let c = a.min_pair(&b);
        assert_eq!(c, GenTensor::<f32>::new_raw(&vec![1., 3., 5., 6.], &vec![2,2]));
    }
}