1use std::ops::{Add, Div, Mul, Sub};
10
11use ferray_core::Array;
12use ferray_core::dimension::Dimension;
13use ferray_core::dimension::broadcast::{broadcast_shapes, broadcast_to};
14use ferray_core::dtype::Element;
15use ferray_core::error::{FerrayError, FerrayResult};
16
17use crate::MaskedArray;
18
19fn mask_union<D: Dimension>(
24 a: &Array<bool, D>,
25 b: &Array<bool, D>,
26) -> FerrayResult<Array<bool, D>> {
27 if a.shape() != b.shape() {
28 return Err(FerrayError::shape_mismatch(format!(
29 "mask_union: shapes {:?} and {:?} differ",
30 a.shape(),
31 b.shape()
32 )));
33 }
34 let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
35 Array::from_vec(a.dim().clone(), data)
36}
37
38struct BroadcastedPair<T> {
44 a_data: Vec<T>,
45 a_mask: Vec<bool>,
46 b_data: Vec<T>,
47 b_mask: Vec<bool>,
48}
49
50fn broadcast_masked_pair<T, D>(
59 a: &MaskedArray<T, D>,
60 b: &MaskedArray<T, D>,
61 op_name: &str,
62) -> FerrayResult<(BroadcastedPair<T>, D)>
63where
64 T: Element + Copy,
65 D: Dimension,
66{
67 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
68 FerrayError::shape_mismatch(format!(
69 "{}: shapes {:?} and {:?} are not broadcast-compatible",
70 op_name,
71 a.shape(),
72 b.shape()
73 ))
74 })?;
75
76 let a_data_view = broadcast_to(a.data(), &target_shape)?;
77 let a_mask_view = broadcast_to(a.mask(), &target_shape)?;
78 let b_data_view = broadcast_to(b.data(), &target_shape)?;
79 let b_mask_view = broadcast_to(b.mask(), &target_shape)?;
80
81 let pair = BroadcastedPair {
82 a_data: a_data_view.iter().copied().collect(),
83 a_mask: a_mask_view.iter().copied().collect(),
84 b_data: b_data_view.iter().copied().collect(),
85 b_mask: b_mask_view.iter().copied().collect(),
86 };
87
88 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
89 FerrayError::shape_mismatch(format!(
90 "{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
91 ))
92 })?;
93
94 Ok((pair, result_dim))
95}
96
97pub(crate) fn masked_unary_op<T, D, F>(
102 ma: &MaskedArray<T, D>,
103 f: F,
104) -> FerrayResult<MaskedArray<T, D>>
105where
106 T: Element + Copy,
107 D: Dimension,
108 F: Fn(T) -> T,
109{
110 let fill = ma.fill_value;
111
112 if !ma.has_real_mask() {
118 let data: Vec<T> = ma.data().iter().map(|&v| f(v)).collect();
119 let result_data = Array::from_vec(ma.dim().clone(), data)?;
120 let mut out = MaskedArray::from_data(result_data)?;
121 out.fill_value = fill;
122 return Ok(out);
123 }
124
125 let data: Vec<T> = ma
128 .data()
129 .iter()
130 .zip(ma.mask().iter())
131 .map(|(v, m)| if *m { fill } else { f(*v) })
132 .collect();
133 let result_data = Array::from_vec(ma.dim().clone(), data)?;
134 let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
135 out.fill_value = fill;
136 Ok(out)
137}
138
139pub(crate) fn masked_binary_op<T, D, F>(
147 a: &MaskedArray<T, D>,
148 b: &MaskedArray<T, D>,
149 op: F,
150 op_name: &str,
151) -> FerrayResult<MaskedArray<T, D>>
152where
153 T: Element + Copy,
154 D: Dimension,
155 F: Fn(T, T) -> T,
156{
157 if a.shape() == b.shape() {
159 let fill = a.fill_value;
160
161 if !a.has_real_mask() && !b.has_real_mask() {
166 let data: Vec<T> = a
167 .data()
168 .iter()
169 .zip(b.data().iter())
170 .map(|(&x, &y)| op(x, y))
171 .collect();
172 let result_data = Array::from_vec(a.dim().clone(), data)?;
173 let mut result = MaskedArray::from_data(result_data)?;
174 result.fill_value = fill;
175 return Ok(result);
176 }
177
178 let result_mask = mask_union(a.mask(), b.mask())?;
181 let data: Vec<T> = a
182 .data()
183 .iter()
184 .zip(b.data().iter())
185 .zip(result_mask.iter())
186 .map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
187 .collect();
188 let result_data = Array::from_vec(a.dim().clone(), data)?;
189 let mut result = MaskedArray::new(result_data, result_mask)?;
190 result.fill_value = fill;
191 return Ok(result);
192 }
193
194 let (pair, result_dim) = broadcast_masked_pair(a, b, op_name)?;
196 let fill = a.fill_value;
197 let n = pair.a_data.len();
198 let mut result_data = Vec::with_capacity(n);
199 let mut result_mask = Vec::with_capacity(n);
200 for i in 0..n {
201 let m = pair.a_mask[i] || pair.b_mask[i];
202 result_mask.push(m);
203 result_data.push(if m {
204 fill
205 } else {
206 op(pair.a_data[i], pair.b_data[i])
207 });
208 }
209 let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
210 let mask_arr = Array::from_vec(result_dim, result_mask)?;
211 let mut out = MaskedArray::new(data_arr, mask_arr)?;
212 out.fill_value = fill;
213 Ok(out)
214}
215
216pub fn masked_add<T, D>(
224 a: &MaskedArray<T, D>,
225 b: &MaskedArray<T, D>,
226) -> FerrayResult<MaskedArray<T, D>>
227where
228 T: Element + Add<Output = T> + Copy,
229 D: Dimension,
230{
231 masked_binary_op(a, b, |x, y| x + y, "masked_add")
232}
233
234pub fn masked_sub<T, D>(
236 a: &MaskedArray<T, D>,
237 b: &MaskedArray<T, D>,
238) -> FerrayResult<MaskedArray<T, D>>
239where
240 T: Element + Sub<Output = T> + Copy,
241 D: Dimension,
242{
243 masked_binary_op(a, b, |x, y| x - y, "masked_sub")
244}
245
246pub fn masked_mul<T, D>(
248 a: &MaskedArray<T, D>,
249 b: &MaskedArray<T, D>,
250) -> FerrayResult<MaskedArray<T, D>>
251where
252 T: Element + Mul<Output = T> + Copy,
253 D: Dimension,
254{
255 masked_binary_op(a, b, |x, y| x * y, "masked_mul")
256}
257
258pub fn masked_div<T, D>(
260 a: &MaskedArray<T, D>,
261 b: &MaskedArray<T, D>,
262) -> FerrayResult<MaskedArray<T, D>>
263where
264 T: Element + Div<Output = T> + Copy,
265 D: Dimension,
266{
267 masked_binary_op(a, b, |x, y| x / y, "masked_div")
268}
269
270fn masked_array_op<T, D, F>(
276 ma: &MaskedArray<T, D>,
277 arr: &Array<T, D>,
278 op: F,
279 op_name: &str,
280) -> FerrayResult<MaskedArray<T, D>>
281where
282 T: Element + Copy,
283 D: Dimension,
284 F: Fn(T, T) -> T,
285{
286 let fill = ma.fill_value;
287
288 if ma.shape() == arr.shape() {
290 let data: Vec<T> = ma
291 .data()
292 .iter()
293 .zip(arr.iter())
294 .zip(ma.mask().iter())
295 .map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
296 .collect();
297 let result_data = Array::from_vec(ma.dim().clone(), data)?;
298 let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
299 out.fill_value = fill;
300 return Ok(out);
301 }
302
303 let target_shape = broadcast_shapes(ma.shape(), arr.shape()).map_err(|_| {
305 FerrayError::shape_mismatch(format!(
306 "{}: shapes {:?} and {:?} are not broadcast-compatible",
307 op_name,
308 ma.shape(),
309 arr.shape()
310 ))
311 })?;
312 let ma_data_view = broadcast_to(ma.data(), &target_shape)?;
313 let ma_mask_view = broadcast_to(ma.mask(), &target_shape)?;
314 let arr_view = broadcast_to(arr, &target_shape)?;
315
316 let ma_data: Vec<T> = ma_data_view.iter().copied().collect();
317 let ma_mask: Vec<bool> = ma_mask_view.iter().copied().collect();
318 let arr_data: Vec<T> = arr_view.iter().copied().collect();
319
320 let n = ma_data.len();
321 let mut result_data = Vec::with_capacity(n);
322 let mut result_mask = Vec::with_capacity(n);
323 for i in 0..n {
324 let m = ma_mask[i];
325 result_mask.push(m);
326 result_data.push(if m { fill } else { op(ma_data[i], arr_data[i]) });
327 }
328 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
329 FerrayError::shape_mismatch(format!(
330 "{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
331 ))
332 })?;
333 let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
334 let mask_arr = Array::from_vec(result_dim, result_mask)?;
335 let mut out = MaskedArray::new(data_arr, mask_arr)?;
336 out.fill_value = fill;
337 Ok(out)
338}
339
340pub fn masked_add_array<T, D>(
343 ma: &MaskedArray<T, D>,
344 arr: &Array<T, D>,
345) -> FerrayResult<MaskedArray<T, D>>
346where
347 T: Element + Add<Output = T> + Copy,
348 D: Dimension,
349{
350 masked_array_op(ma, arr, |x, y| x + y, "masked_add_array")
351}
352
353pub fn masked_sub_array<T, D>(
356 ma: &MaskedArray<T, D>,
357 arr: &Array<T, D>,
358) -> FerrayResult<MaskedArray<T, D>>
359where
360 T: Element + Sub<Output = T> + Copy,
361 D: Dimension,
362{
363 masked_array_op(ma, arr, |x, y| x - y, "masked_sub_array")
364}
365
366pub fn masked_mul_array<T, D>(
369 ma: &MaskedArray<T, D>,
370 arr: &Array<T, D>,
371) -> FerrayResult<MaskedArray<T, D>>
372where
373 T: Element + Mul<Output = T> + Copy,
374 D: Dimension,
375{
376 masked_array_op(ma, arr, |x, y| x * y, "masked_mul_array")
377}
378
379pub fn masked_div_array<T, D>(
382 ma: &MaskedArray<T, D>,
383 arr: &Array<T, D>,
384) -> FerrayResult<MaskedArray<T, D>>
385where
386 T: Element + Div<Output = T> + Copy,
387 D: Dimension,
388{
389 masked_array_op(ma, arr, |x, y| x / y, "masked_div_array")
390}
391
392impl<T, D> std::ops::Add<&MaskedArray<T, D>> for &MaskedArray<T, D>
398where
399 T: Element + Add<Output = T> + Copy,
400 D: Dimension,
401{
402 type Output = FerrayResult<MaskedArray<T, D>>;
403
404 fn add(self, rhs: &MaskedArray<T, D>) -> Self::Output {
405 masked_add(self, rhs)
406 }
407}
408
409impl<T, D> std::ops::Sub<&MaskedArray<T, D>> for &MaskedArray<T, D>
411where
412 T: Element + Sub<Output = T> + Copy,
413 D: Dimension,
414{
415 type Output = FerrayResult<MaskedArray<T, D>>;
416
417 fn sub(self, rhs: &MaskedArray<T, D>) -> Self::Output {
418 masked_sub(self, rhs)
419 }
420}
421
422impl<T, D> std::ops::Mul<&MaskedArray<T, D>> for &MaskedArray<T, D>
424where
425 T: Element + Mul<Output = T> + Copy,
426 D: Dimension,
427{
428 type Output = FerrayResult<MaskedArray<T, D>>;
429
430 fn mul(self, rhs: &MaskedArray<T, D>) -> Self::Output {
431 masked_mul(self, rhs)
432 }
433}
434
435impl<T, D> std::ops::Div<&MaskedArray<T, D>> for &MaskedArray<T, D>
437where
438 T: Element + Div<Output = T> + Copy,
439 D: Dimension,
440{
441 type Output = FerrayResult<MaskedArray<T, D>>;
442
443 fn div(self, rhs: &MaskedArray<T, D>) -> Self::Output {
444 masked_div(self, rhs)
445 }
446}
447
448#[cfg(test)]
449mod tests {
450 use super::*;
451 use ferray_core::dimension::Ix1;
452
453 fn ma1d(data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix1> {
454 let n = data.len();
455 let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap();
456 let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask).unwrap();
457 MaskedArray::new(d, m).unwrap()
458 }
459
460 #[test]
469 fn masked_div_positive_by_zero_yields_positive_infinity_unmasked() {
470 let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]);
471 let b = ma1d(vec![1.0, 0.0, 3.0], vec![false; 3]);
472 let r = masked_div(&a, &b).unwrap();
473 let rd: Vec<f64> = r.data().iter().copied().collect();
474 let rm: Vec<bool> = r.mask().iter().copied().collect();
475 assert_eq!(rd[0], 1.0);
476 assert!(rd[1].is_infinite() && rd[1].is_sign_positive());
477 assert_eq!(rd[2], 1.0);
478 assert_eq!(rm, vec![false, false, false]);
480 }
481
482 #[test]
483 fn masked_div_negative_by_zero_yields_negative_infinity_unmasked() {
484 let a = ma1d(vec![-4.0], vec![false]);
485 let b = ma1d(vec![0.0], vec![false]);
486 let r = masked_div(&a, &b).unwrap();
487 let v = r.data().iter().next().copied().unwrap();
488 assert!(v.is_infinite() && v.is_sign_negative());
489 assert!(!r.mask().iter().next().copied().unwrap());
490 }
491
492 #[test]
493 fn masked_div_zero_by_zero_yields_nan_unmasked() {
494 let a = ma1d(vec![0.0], vec![false]);
495 let b = ma1d(vec![0.0], vec![false]);
496 let r = masked_div(&a, &b).unwrap();
497 let v = r.data().iter().next().copied().unwrap();
498 assert!(v.is_nan());
499 assert!(!r.mask().iter().next().copied().unwrap());
500 }
501
502 #[test]
503 fn masked_div_skips_op_at_masked_divisor_positions() {
504 let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]).with_fill_value(-42.0);
507 let b = ma1d(vec![2.0, 0.0, 4.0], vec![false, true, false]);
508 let r = masked_div(&a, &b).unwrap();
509 let rd: Vec<f64> = r.data().iter().copied().collect();
510 let rm: Vec<bool> = r.mask().iter().copied().collect();
511 assert_eq!(rd, vec![0.5, -42.0, 0.75]);
512 assert_eq!(rm, vec![false, true, false]);
513 assert!(!rd[1].is_infinite() && !rd[1].is_nan());
515 }
516
517 #[test]
518 fn masked_div_array_by_zero_yields_infinity_unmasked() {
519 let a = ma1d(vec![5.0, 6.0], vec![false; 2]);
521 let divisor = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 2.0]).unwrap();
522 let r = masked_div_array(&a, &divisor).unwrap();
523 let rd: Vec<f64> = r.data().iter().copied().collect();
524 assert!(rd[0].is_infinite() && rd[0].is_sign_positive());
525 assert_eq!(rd[1], 3.0);
526 let rm: Vec<bool> = r.mask().iter().copied().collect();
527 assert_eq!(rm, vec![false, false]);
528 }
529}