1use std::ops::{Add, Div, Mul, Sub};
10
11use ferray_core::Array;
12use ferray_core::dimension::Dimension;
13use ferray_core::dimension::broadcast::{broadcast_shapes, broadcast_to};
14use ferray_core::dtype::Element;
15use ferray_core::error::{FerrayError, FerrayResult};
16
17use crate::MaskedArray;
18
19fn mask_union<D: Dimension>(
24 a: &Array<bool, D>,
25 b: &Array<bool, D>,
26) -> FerrayResult<Array<bool, D>> {
27 if a.shape() != b.shape() {
28 return Err(FerrayError::shape_mismatch(format!(
29 "mask_union: shapes {:?} and {:?} differ",
30 a.shape(),
31 b.shape()
32 )));
33 }
34 let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
35 Array::from_vec(a.dim().clone(), data)
36}
37
38struct BroadcastedPair<T> {
44 a_data: Vec<T>,
45 a_mask: Vec<bool>,
46 b_data: Vec<T>,
47 b_mask: Vec<bool>,
48}
49
50fn broadcast_masked_pair<T, D>(
59 a: &MaskedArray<T, D>,
60 b: &MaskedArray<T, D>,
61 op_name: &str,
62) -> FerrayResult<(BroadcastedPair<T>, D)>
63where
64 T: Element + Copy,
65 D: Dimension,
66{
67 let target_shape = broadcast_shapes(a.shape(), b.shape()).map_err(|_| {
68 FerrayError::shape_mismatch(format!(
69 "{}: shapes {:?} and {:?} are not broadcast-compatible",
70 op_name,
71 a.shape(),
72 b.shape()
73 ))
74 })?;
75
76 let a_data_view = broadcast_to(a.data(), &target_shape)?;
77 let a_mask_view = broadcast_to(a.mask(), &target_shape)?;
78 let b_data_view = broadcast_to(b.data(), &target_shape)?;
79 let b_mask_view = broadcast_to(b.mask(), &target_shape)?;
80
81 let pair = BroadcastedPair {
82 a_data: a_data_view.iter().copied().collect(),
83 a_mask: a_mask_view.iter().copied().collect(),
84 b_data: b_data_view.iter().copied().collect(),
85 b_mask: b_mask_view.iter().copied().collect(),
86 };
87
88 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
89 FerrayError::shape_mismatch(format!(
90 "{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
91 ))
92 })?;
93
94 Ok((pair, result_dim))
95}
96
97pub(crate) fn masked_unary_op<T, D, F>(
102 ma: &MaskedArray<T, D>,
103 f: F,
104) -> FerrayResult<MaskedArray<T, D>>
105where
106 T: Element + Copy,
107 D: Dimension,
108 F: Fn(T) -> T,
109{
110 let fill = ma.fill_value;
111
112 if !ma.has_real_mask() {
118 let data: Vec<T> = ma.data().iter().map(|&v| f(v)).collect();
119 let result_data = Array::from_vec(ma.dim().clone(), data)?;
120 let mut out = MaskedArray::from_data(result_data)?;
121 out.fill_value = fill;
122 return Ok(out);
123 }
124
125 let data_vec: Vec<T> = ma.data().iter().map(|&v| f(v)).collect();
141 let data: Vec<T> = data_vec
142 .into_iter()
143 .zip(ma.mask().iter())
144 .map(|(v, m)| if *m { fill } else { v })
145 .collect();
146 let result_data = Array::from_vec(ma.dim().clone(), data)?;
147 let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
148 out.fill_value = fill;
149 Ok(out)
150}
151
152pub(crate) fn masked_binary_op<T, D, F>(
160 a: &MaskedArray<T, D>,
161 b: &MaskedArray<T, D>,
162 op: F,
163 op_name: &str,
164) -> FerrayResult<MaskedArray<T, D>>
165where
166 T: Element + Copy,
167 D: Dimension,
168 F: Fn(T, T) -> T,
169{
170 if a.shape() == b.shape() {
172 let fill = a.fill_value;
173
174 if !a.has_real_mask() && !b.has_real_mask() {
179 let data: Vec<T> = a
180 .data()
181 .iter()
182 .zip(b.data().iter())
183 .map(|(&x, &y)| op(x, y))
184 .collect();
185 let result_data = Array::from_vec(a.dim().clone(), data)?;
186 let mut result = MaskedArray::from_data(result_data)?;
187 result.fill_value = fill;
188 return Ok(result);
189 }
190
191 let result_mask = mask_union(a.mask(), b.mask())?;
194 let data: Vec<T> = a
195 .data()
196 .iter()
197 .zip(b.data().iter())
198 .zip(result_mask.iter())
199 .map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
200 .collect();
201 let result_data = Array::from_vec(a.dim().clone(), data)?;
202 let mut result = MaskedArray::new(result_data, result_mask)?;
203 result.fill_value = fill;
204 return Ok(result);
205 }
206
207 let (pair, result_dim) = broadcast_masked_pair(a, b, op_name)?;
209 let fill = a.fill_value;
210 let n = pair.a_data.len();
211 let mut result_data = Vec::with_capacity(n);
212 let mut result_mask = Vec::with_capacity(n);
213 for i in 0..n {
214 let m = pair.a_mask[i] || pair.b_mask[i];
215 result_mask.push(m);
216 result_data.push(if m {
217 fill
218 } else {
219 op(pair.a_data[i], pair.b_data[i])
220 });
221 }
222 let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
223 let mask_arr = Array::from_vec(result_dim, result_mask)?;
224 let mut out = MaskedArray::new(data_arr, mask_arr)?;
225 out.fill_value = fill;
226 Ok(out)
227}
228
229pub fn masked_add<T, D>(
237 a: &MaskedArray<T, D>,
238 b: &MaskedArray<T, D>,
239) -> FerrayResult<MaskedArray<T, D>>
240where
241 T: Element + Add<Output = T> + Copy,
242 D: Dimension,
243{
244 masked_binary_op(a, b, |x, y| x + y, "masked_add")
245}
246
247pub fn masked_sub<T, D>(
249 a: &MaskedArray<T, D>,
250 b: &MaskedArray<T, D>,
251) -> FerrayResult<MaskedArray<T, D>>
252where
253 T: Element + Sub<Output = T> + Copy,
254 D: Dimension,
255{
256 masked_binary_op(a, b, |x, y| x - y, "masked_sub")
257}
258
259pub fn masked_mul<T, D>(
261 a: &MaskedArray<T, D>,
262 b: &MaskedArray<T, D>,
263) -> FerrayResult<MaskedArray<T, D>>
264where
265 T: Element + Mul<Output = T> + Copy,
266 D: Dimension,
267{
268 masked_binary_op(a, b, |x, y| x * y, "masked_mul")
269}
270
271pub fn masked_div<T, D>(
273 a: &MaskedArray<T, D>,
274 b: &MaskedArray<T, D>,
275) -> FerrayResult<MaskedArray<T, D>>
276where
277 T: Element + Div<Output = T> + Copy,
278 D: Dimension,
279{
280 masked_binary_op(a, b, |x, y| x / y, "masked_div")
281}
282
283fn masked_array_op<T, D, F>(
289 ma: &MaskedArray<T, D>,
290 arr: &Array<T, D>,
291 op: F,
292 op_name: &str,
293) -> FerrayResult<MaskedArray<T, D>>
294where
295 T: Element + Copy,
296 D: Dimension,
297 F: Fn(T, T) -> T,
298{
299 let fill = ma.fill_value;
300
301 if ma.shape() == arr.shape() {
303 let data: Vec<T> = ma
304 .data()
305 .iter()
306 .zip(arr.iter())
307 .zip(ma.mask().iter())
308 .map(|((x, y), m)| if *m { fill } else { op(*x, *y) })
309 .collect();
310 let result_data = Array::from_vec(ma.dim().clone(), data)?;
311 let mut out = MaskedArray::new(result_data, ma.mask().clone())?;
312 out.fill_value = fill;
313 return Ok(out);
314 }
315
316 let target_shape = broadcast_shapes(ma.shape(), arr.shape()).map_err(|_| {
318 FerrayError::shape_mismatch(format!(
319 "{}: shapes {:?} and {:?} are not broadcast-compatible",
320 op_name,
321 ma.shape(),
322 arr.shape()
323 ))
324 })?;
325 let ma_data_view = broadcast_to(ma.data(), &target_shape)?;
326 let ma_mask_view = broadcast_to(ma.mask(), &target_shape)?;
327 let arr_view = broadcast_to(arr, &target_shape)?;
328
329 let ma_data: Vec<T> = ma_data_view.iter().copied().collect();
330 let ma_mask: Vec<bool> = ma_mask_view.iter().copied().collect();
331 let arr_data: Vec<T> = arr_view.iter().copied().collect();
332
333 let n = ma_data.len();
334 let mut result_data = Vec::with_capacity(n);
335 let mut result_mask = Vec::with_capacity(n);
336 for i in 0..n {
337 let m = ma_mask[i];
338 result_mask.push(m);
339 result_data.push(if m { fill } else { op(ma_data[i], arr_data[i]) });
340 }
341 let result_dim = D::from_dim_slice(&target_shape).ok_or_else(|| {
342 FerrayError::shape_mismatch(format!(
343 "{op_name}: cannot represent broadcast result shape {target_shape:?} as the input dimension type"
344 ))
345 })?;
346 let data_arr = Array::from_vec(result_dim.clone(), result_data)?;
347 let mask_arr = Array::from_vec(result_dim, result_mask)?;
348 let mut out = MaskedArray::new(data_arr, mask_arr)?;
349 out.fill_value = fill;
350 Ok(out)
351}
352
353pub fn masked_add_array<T, D>(
356 ma: &MaskedArray<T, D>,
357 arr: &Array<T, D>,
358) -> FerrayResult<MaskedArray<T, D>>
359where
360 T: Element + Add<Output = T> + Copy,
361 D: Dimension,
362{
363 masked_array_op(ma, arr, |x, y| x + y, "masked_add_array")
364}
365
366pub fn masked_sub_array<T, D>(
369 ma: &MaskedArray<T, D>,
370 arr: &Array<T, D>,
371) -> FerrayResult<MaskedArray<T, D>>
372where
373 T: Element + Sub<Output = T> + Copy,
374 D: Dimension,
375{
376 masked_array_op(ma, arr, |x, y| x - y, "masked_sub_array")
377}
378
379pub fn masked_mul_array<T, D>(
382 ma: &MaskedArray<T, D>,
383 arr: &Array<T, D>,
384) -> FerrayResult<MaskedArray<T, D>>
385where
386 T: Element + Mul<Output = T> + Copy,
387 D: Dimension,
388{
389 masked_array_op(ma, arr, |x, y| x * y, "masked_mul_array")
390}
391
392pub fn masked_div_array<T, D>(
395 ma: &MaskedArray<T, D>,
396 arr: &Array<T, D>,
397) -> FerrayResult<MaskedArray<T, D>>
398where
399 T: Element + Div<Output = T> + Copy,
400 D: Dimension,
401{
402 masked_array_op(ma, arr, |x, y| x / y, "masked_div_array")
403}
404
405impl<T, D> std::ops::Add<&MaskedArray<T, D>> for &MaskedArray<T, D>
411where
412 T: Element + Add<Output = T> + Copy,
413 D: Dimension,
414{
415 type Output = FerrayResult<MaskedArray<T, D>>;
416
417 fn add(self, rhs: &MaskedArray<T, D>) -> Self::Output {
418 masked_add(self, rhs)
419 }
420}
421
422impl<T, D> std::ops::Sub<&MaskedArray<T, D>> for &MaskedArray<T, D>
424where
425 T: Element + Sub<Output = T> + Copy,
426 D: Dimension,
427{
428 type Output = FerrayResult<MaskedArray<T, D>>;
429
430 fn sub(self, rhs: &MaskedArray<T, D>) -> Self::Output {
431 masked_sub(self, rhs)
432 }
433}
434
435impl<T, D> std::ops::Mul<&MaskedArray<T, D>> for &MaskedArray<T, D>
437where
438 T: Element + Mul<Output = T> + Copy,
439 D: Dimension,
440{
441 type Output = FerrayResult<MaskedArray<T, D>>;
442
443 fn mul(self, rhs: &MaskedArray<T, D>) -> Self::Output {
444 masked_mul(self, rhs)
445 }
446}
447
448impl<T, D> std::ops::Div<&MaskedArray<T, D>> for &MaskedArray<T, D>
450where
451 T: Element + Div<Output = T> + Copy,
452 D: Dimension,
453{
454 type Output = FerrayResult<MaskedArray<T, D>>;
455
456 fn div(self, rhs: &MaskedArray<T, D>) -> Self::Output {
457 masked_div(self, rhs)
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use ferray_core::dimension::Ix1;
465
466 fn ma1d(data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix1> {
467 let n = data.len();
468 let d = Array::<f64, Ix1>::from_vec(Ix1::new([n]), data).unwrap();
469 let m = Array::<bool, Ix1>::from_vec(Ix1::new([n]), mask).unwrap();
470 MaskedArray::new(d, m).unwrap()
471 }
472
473 #[test]
482 fn masked_div_positive_by_zero_yields_positive_infinity_unmasked() {
483 let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]);
484 let b = ma1d(vec![1.0, 0.0, 3.0], vec![false; 3]);
485 let r = masked_div(&a, &b).unwrap();
486 let rd: Vec<f64> = r.data().iter().copied().collect();
487 let rm: Vec<bool> = r.mask().iter().copied().collect();
488 assert_eq!(rd[0], 1.0);
489 assert!(rd[1].is_infinite() && rd[1].is_sign_positive());
490 assert_eq!(rd[2], 1.0);
491 assert_eq!(rm, vec![false, false, false]);
493 }
494
495 #[test]
496 fn masked_div_negative_by_zero_yields_negative_infinity_unmasked() {
497 let a = ma1d(vec![-4.0], vec![false]);
498 let b = ma1d(vec![0.0], vec![false]);
499 let r = masked_div(&a, &b).unwrap();
500 let v = r.data().iter().next().copied().unwrap();
501 assert!(v.is_infinite() && v.is_sign_negative());
502 assert!(!r.mask().iter().next().copied().unwrap());
503 }
504
505 #[test]
506 fn masked_div_zero_by_zero_yields_nan_unmasked() {
507 let a = ma1d(vec![0.0], vec![false]);
508 let b = ma1d(vec![0.0], vec![false]);
509 let r = masked_div(&a, &b).unwrap();
510 let v = r.data().iter().next().copied().unwrap();
511 assert!(v.is_nan());
512 assert!(!r.mask().iter().next().copied().unwrap());
513 }
514
515 #[test]
516 fn masked_div_skips_op_at_masked_divisor_positions() {
517 let a = ma1d(vec![1.0, 2.0, 3.0], vec![false; 3]).with_fill_value(-42.0);
520 let b = ma1d(vec![2.0, 0.0, 4.0], vec![false, true, false]);
521 let r = masked_div(&a, &b).unwrap();
522 let rd: Vec<f64> = r.data().iter().copied().collect();
523 let rm: Vec<bool> = r.mask().iter().copied().collect();
524 assert_eq!(rd, vec![0.5, -42.0, 0.75]);
525 assert_eq!(rm, vec![false, true, false]);
526 assert!(!rd[1].is_infinite() && !rd[1].is_nan());
528 }
529
530 #[test]
531 fn masked_div_array_by_zero_yields_infinity_unmasked() {
532 let a = ma1d(vec![5.0, 6.0], vec![false; 2]);
534 let divisor = Array::<f64, Ix1>::from_vec(Ix1::new([2]), vec![0.0, 2.0]).unwrap();
535 let r = masked_div_array(&a, &divisor).unwrap();
536 let rd: Vec<f64> = r.data().iter().copied().collect();
537 assert!(rd[0].is_infinite() && rd[0].is_sign_positive());
538 assert_eq!(rd[1], 3.0);
539 let rm: Vec<bool> = r.mask().iter().copied().collect();
540 assert_eq!(rm, vec![false, false]);
541 }
542}