1use ferray_core::Array;
7use ferray_core::dimension::Dimension;
8use ferray_core::dtype::Element;
9use ferray_core::error::FerrayResult;
10use num_traits::Float;
11
12use crate::helpers::binary_map_op;
13
14pub fn equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
16where
17 T: Element + PartialEq + Copy,
18 D: Dimension,
19{
20 binary_map_op(a, b, |x, y| x == y)
21}
22
23pub fn not_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
25where
26 T: Element + PartialEq + Copy,
27 D: Dimension,
28{
29 binary_map_op(a, b, |x, y| x != y)
30}
31
32pub fn less<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
34where
35 T: Element + PartialOrd + Copy,
36 D: Dimension,
37{
38 binary_map_op(a, b, |x, y| x < y)
39}
40
41pub fn less_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
43where
44 T: Element + PartialOrd + Copy,
45 D: Dimension,
46{
47 binary_map_op(a, b, |x, y| x <= y)
48}
49
50pub fn greater<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
52where
53 T: Element + PartialOrd + Copy,
54 D: Dimension,
55{
56 binary_map_op(a, b, |x, y| x > y)
57}
58
59pub fn greater_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> FerrayResult<Array<bool, D>>
61where
62 T: Element + PartialOrd + Copy,
63 D: Dimension,
64{
65 binary_map_op(a, b, |x, y| x >= y)
66}
67
68pub fn array_equal<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
70where
71 T: Element + PartialEq,
72 D: Dimension,
73{
74 if a.shape() != b.shape() {
75 return false;
76 }
77 a.iter().zip(b.iter()).all(|(x, y)| x == y)
78}
79
80pub fn array_equiv<T, D>(a: &Array<T, D>, b: &Array<T, D>) -> bool
85where
86 T: Element + PartialEq,
87 D: Dimension,
88{
89 array_equal(a, b)
91}
92
93pub fn allclose<T, D>(a: &Array<T, D>, b: &Array<T, D>, rtol: T, atol: T) -> FerrayResult<bool>
97where
98 T: Element + Float,
99 D: Dimension,
100{
101 let close = isclose(a, b, rtol, atol, false)?;
102 Ok(close.iter().all(|&x| x))
103}
104
105pub fn isclose<T, D>(
111 a: &Array<T, D>,
112 b: &Array<T, D>,
113 rtol: T,
114 atol: T,
115 equal_nan: bool,
116) -> FerrayResult<Array<bool, D>>
117where
118 T: Element + Float,
119 D: Dimension,
120{
121 binary_map_op(a, b, |x, y| {
122 if equal_nan && x.is_nan() && y.is_nan() {
123 return true;
124 }
125 if x.is_nan() || y.is_nan() {
126 return false;
127 }
128 (x - y).abs() <= atol + rtol * y.abs()
129 })
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use ferray_core::dimension::Ix1;
136
137 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
138 let n = data.len();
139 Array::from_vec(Ix1::new([n]), data).unwrap()
140 }
141
142 fn arr1_i32(data: Vec<i32>) -> Array<i32, Ix1> {
143 let n = data.len();
144 Array::from_vec(Ix1::new([n]), data).unwrap()
145 }
146
147 #[test]
148 fn test_equal() {
149 let a = arr1_i32(vec![1, 2, 3]);
150 let b = arr1_i32(vec![1, 5, 3]);
151 let r = equal(&a, &b).unwrap();
152 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
153 }
154
155 #[test]
156 fn test_not_equal() {
157 let a = arr1_i32(vec![1, 2, 3]);
158 let b = arr1_i32(vec![1, 5, 3]);
159 let r = not_equal(&a, &b).unwrap();
160 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
161 }
162
163 #[test]
164 fn test_less() {
165 let a = arr1(vec![1.0, 5.0, 3.0]);
166 let b = arr1(vec![2.0, 3.0, 3.0]);
167 let r = less(&a, &b).unwrap();
168 assert_eq!(r.as_slice().unwrap(), &[true, false, false]);
169 }
170
171 #[test]
172 fn test_less_equal() {
173 let a = arr1(vec![1.0, 5.0, 3.0]);
174 let b = arr1(vec![2.0, 3.0, 3.0]);
175 let r = less_equal(&a, &b).unwrap();
176 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
177 }
178
179 #[test]
180 fn test_greater() {
181 let a = arr1(vec![1.0, 5.0, 3.0]);
182 let b = arr1(vec![2.0, 3.0, 3.0]);
183 let r = greater(&a, &b).unwrap();
184 assert_eq!(r.as_slice().unwrap(), &[false, true, false]);
185 }
186
187 #[test]
188 fn test_greater_equal() {
189 let a = arr1(vec![1.0, 5.0, 3.0]);
190 let b = arr1(vec![2.0, 3.0, 3.0]);
191 let r = greater_equal(&a, &b).unwrap();
192 assert_eq!(r.as_slice().unwrap(), &[false, true, true]);
193 }
194
195 #[test]
196 fn test_array_equal() {
197 let a = arr1(vec![1.0, 2.0, 3.0]);
198 let b = arr1(vec![1.0, 2.0, 3.0]);
199 let c = arr1(vec![1.0, 2.0, 4.0]);
200 assert!(array_equal(&a, &b));
201 assert!(!array_equal(&a, &c));
202 }
203
204 #[test]
205 fn test_array_equal_different_shapes() {
206 let a = arr1(vec![1.0, 2.0]);
207 let b = arr1(vec![1.0, 2.0, 3.0]);
208 assert!(!array_equal(&a, &b));
209 }
210
211 #[test]
212 fn test_allclose() {
213 let a = arr1(vec![1.0, 2.0, 3.0]);
214 let b = arr1(vec![1.0 + 1e-9, 2.0 + 1e-9, 3.0 + 1e-9]);
215 assert!(allclose(&a, &b, 1e-5, 1e-8).unwrap());
216 }
217
218 #[test]
219 fn test_allclose_not_close() {
220 let a = arr1(vec![1.0, 2.0, 3.0]);
221 let b = arr1(vec![1.0, 2.0, 4.0]);
222 assert!(!allclose(&a, &b, 1e-5, 1e-8).unwrap());
223 }
224
225 #[test]
226 fn test_isclose() {
227 let a = arr1(vec![1.0, 2.0, 3.0]);
228 let b = arr1(vec![1.0, 2.1, 3.0]);
229 let r = isclose(&a, &b, 1e-5, 1e-8, false).unwrap();
230 assert_eq!(r.as_slice().unwrap(), &[true, false, true]);
231 }
232
233 #[test]
234 fn test_isclose_equal_nan() {
235 let a = arr1(vec![f64::NAN, 1.0]);
236 let b = arr1(vec![f64::NAN, 1.0]);
237 let r = isclose(&a, &b, 1e-5, 1e-8, true).unwrap();
238 assert_eq!(r.as_slice().unwrap(), &[true, true]);
239 }
240}