1use std::ops::{Add, Div, Mul, Sub};
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12fn mask_union<D: Dimension>(
14 a: &Array<bool, D>,
15 b: &Array<bool, D>,
16) -> FerrayResult<Array<bool, D>> {
17 if a.shape() != b.shape() {
18 return Err(FerrayError::shape_mismatch(format!(
19 "mask_union: shapes {:?} and {:?} differ",
20 a.shape(),
21 b.shape()
22 )));
23 }
24 let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
25 Array::from_vec(a.dim().clone(), data)
26}
27
28pub fn masked_add<T, D>(
33 a: &MaskedArray<T, D>,
34 b: &MaskedArray<T, D>,
35) -> FerrayResult<MaskedArray<T, D>>
36where
37 T: Element + Add<Output = T> + Copy,
38 D: Dimension,
39{
40 if a.shape() != b.shape() {
41 return Err(FerrayError::shape_mismatch(format!(
42 "masked_add: shapes {:?} and {:?} differ",
43 a.shape(),
44 b.shape()
45 )));
46 }
47 let result_mask = mask_union(a.mask(), b.mask())?;
48 let data: Vec<T> = a
49 .data()
50 .iter()
51 .zip(b.data().iter())
52 .zip(result_mask.iter())
53 .map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
54 .collect();
55 let result_data = Array::from_vec(a.dim().clone(), data)?;
56 MaskedArray::new(result_data, result_mask)
57}
58
59pub fn masked_sub<T, D>(
64 a: &MaskedArray<T, D>,
65 b: &MaskedArray<T, D>,
66) -> FerrayResult<MaskedArray<T, D>>
67where
68 T: Element + Sub<Output = T> + Copy,
69 D: Dimension,
70{
71 if a.shape() != b.shape() {
72 return Err(FerrayError::shape_mismatch(format!(
73 "masked_sub: shapes {:?} and {:?} differ",
74 a.shape(),
75 b.shape()
76 )));
77 }
78 let result_mask = mask_union(a.mask(), b.mask())?;
79 let data: Vec<T> = a
80 .data()
81 .iter()
82 .zip(b.data().iter())
83 .zip(result_mask.iter())
84 .map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
85 .collect();
86 let result_data = Array::from_vec(a.dim().clone(), data)?;
87 MaskedArray::new(result_data, result_mask)
88}
89
90pub fn masked_mul<T, D>(
95 a: &MaskedArray<T, D>,
96 b: &MaskedArray<T, D>,
97) -> FerrayResult<MaskedArray<T, D>>
98where
99 T: Element + Mul<Output = T> + Copy,
100 D: Dimension,
101{
102 if a.shape() != b.shape() {
103 return Err(FerrayError::shape_mismatch(format!(
104 "masked_mul: shapes {:?} and {:?} differ",
105 a.shape(),
106 b.shape()
107 )));
108 }
109 let result_mask = mask_union(a.mask(), b.mask())?;
110 let data: Vec<T> = a
111 .data()
112 .iter()
113 .zip(b.data().iter())
114 .zip(result_mask.iter())
115 .map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
116 .collect();
117 let result_data = Array::from_vec(a.dim().clone(), data)?;
118 MaskedArray::new(result_data, result_mask)
119}
120
121pub fn masked_div<T, D>(
126 a: &MaskedArray<T, D>,
127 b: &MaskedArray<T, D>,
128) -> FerrayResult<MaskedArray<T, D>>
129where
130 T: Element + Div<Output = T> + Copy,
131 D: Dimension,
132{
133 if a.shape() != b.shape() {
134 return Err(FerrayError::shape_mismatch(format!(
135 "masked_div: shapes {:?} and {:?} differ",
136 a.shape(),
137 b.shape()
138 )));
139 }
140 let result_mask = mask_union(a.mask(), b.mask())?;
141 let data: Vec<T> = a
142 .data()
143 .iter()
144 .zip(b.data().iter())
145 .zip(result_mask.iter())
146 .map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
147 .collect();
148 let result_data = Array::from_vec(a.dim().clone(), data)?;
149 MaskedArray::new(result_data, result_mask)
150}
151
152pub fn masked_add_array<T, D>(
157 ma: &MaskedArray<T, D>,
158 arr: &Array<T, D>,
159) -> FerrayResult<MaskedArray<T, D>>
160where
161 T: Element + Add<Output = T> + Copy,
162 D: Dimension,
163{
164 if ma.shape() != arr.shape() {
165 return Err(FerrayError::shape_mismatch(format!(
166 "masked_add_array: shapes {:?} and {:?} differ",
167 ma.shape(),
168 arr.shape()
169 )));
170 }
171 let data: Vec<T> = ma
172 .data()
173 .iter()
174 .zip(arr.iter())
175 .zip(ma.mask().iter())
176 .map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
177 .collect();
178 let result_data = Array::from_vec(ma.dim().clone(), data)?;
179 MaskedArray::new(result_data, ma.mask().clone())
180}
181
182pub fn masked_sub_array<T, D>(
187 ma: &MaskedArray<T, D>,
188 arr: &Array<T, D>,
189) -> FerrayResult<MaskedArray<T, D>>
190where
191 T: Element + Sub<Output = T> + Copy,
192 D: Dimension,
193{
194 if ma.shape() != arr.shape() {
195 return Err(FerrayError::shape_mismatch(format!(
196 "masked_sub_array: shapes {:?} and {:?} differ",
197 ma.shape(),
198 arr.shape()
199 )));
200 }
201 let data: Vec<T> = ma
202 .data()
203 .iter()
204 .zip(arr.iter())
205 .zip(ma.mask().iter())
206 .map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
207 .collect();
208 let result_data = Array::from_vec(ma.dim().clone(), data)?;
209 MaskedArray::new(result_data, ma.mask().clone())
210}
211
212pub fn masked_mul_array<T, D>(
217 ma: &MaskedArray<T, D>,
218 arr: &Array<T, D>,
219) -> FerrayResult<MaskedArray<T, D>>
220where
221 T: Element + Mul<Output = T> + Copy,
222 D: Dimension,
223{
224 if ma.shape() != arr.shape() {
225 return Err(FerrayError::shape_mismatch(format!(
226 "masked_mul_array: shapes {:?} and {:?} differ",
227 ma.shape(),
228 arr.shape()
229 )));
230 }
231 let data: Vec<T> = ma
232 .data()
233 .iter()
234 .zip(arr.iter())
235 .zip(ma.mask().iter())
236 .map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
237 .collect();
238 let result_data = Array::from_vec(ma.dim().clone(), data)?;
239 MaskedArray::new(result_data, ma.mask().clone())
240}
241
242pub fn masked_div_array<T, D>(
247 ma: &MaskedArray<T, D>,
248 arr: &Array<T, D>,
249) -> FerrayResult<MaskedArray<T, D>>
250where
251 T: Element + Div<Output = T> + Copy,
252 D: Dimension,
253{
254 if ma.shape() != arr.shape() {
255 return Err(FerrayError::shape_mismatch(format!(
256 "masked_div_array: shapes {:?} and {:?} differ",
257 ma.shape(),
258 arr.shape()
259 )));
260 }
261 let data: Vec<T> = ma
262 .data()
263 .iter()
264 .zip(arr.iter())
265 .zip(ma.mask().iter())
266 .map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
267 .collect();
268 let result_data = Array::from_vec(ma.dim().clone(), data)?;
269 MaskedArray::new(result_data, ma.mask().clone())
270}