redstone_ml/ndarray/
equals.rs

1use crate::dtype::RawDataType;
2use crate::NdArray;
3
4impl<T: RawDataType> PartialEq<NdArray<'_, T>> for NdArray<'_, T> {
5    #[allow(clippy::op_ref)]
6    fn eq(&self, other: &NdArray<T>) -> bool {
7        &self == other
8    }
9}
10
11impl<T: RawDataType> PartialEq<NdArray<'_, T>> for &NdArray<'_, T> {
12    fn eq(&self, other: &NdArray<T>) -> bool {
13        if self.shape != other.shape {
14            return false;
15        }
16        self.flatiter().zip(other.flatiter()).all(|(a, b)| a == b)
17    }
18}
19
20impl<T: RawDataType> PartialEq<&NdArray<'_, T>> for NdArray<'_, T> {
21    fn eq(&self, other: &&NdArray<T>) -> bool {
22        if self.shape != other.shape {
23            return false;
24        }
25        self.flatiter().zip(other.flatiter()).all(|(a, b)| a == b)
26    }
27}