redstone_ml/ndarray/
equals.rs1use crate::dtype::RawDataType;
2use crate::NdArray;
3
4impl<T: RawDataType> PartialEq<NdArray<'_, T>> for NdArray<'_, T> {
5 #[allow(clippy::op_ref)]
6 fn eq(&self, other: &NdArray<T>) -> bool {
7 &self == other
8 }
9}
10
11impl<T: RawDataType> PartialEq<NdArray<'_, T>> for &NdArray<'_, T> {
12 fn eq(&self, other: &NdArray<T>) -> bool {
13 if self.shape != other.shape {
14 return false;
15 }
16 self.flatiter().zip(other.flatiter()).all(|(a, b)| a == b)
17 }
18}
19
20impl<T: RawDataType> PartialEq<&NdArray<'_, T>> for NdArray<'_, T> {
21 fn eq(&self, other: &&NdArray<T>) -> bool {
22 if self.shape != other.shape {
23 return false;
24 }
25 self.flatiter().zip(other.flatiter()).all(|(a, b)| a == b)
26 }
27}