Skip to main content

ferray_ma/
arithmetic.rs

1// ferray-ma: Masked binary ops with mask union (REQ-10, REQ-11)
2
3use std::ops::{Add, Div, Mul, Sub};
4
5use ferray_core::Array;
6use ferray_core::dimension::Dimension;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10use crate::MaskedArray;
11
12/// Helper: compute the union of two boolean mask arrays (element-wise OR).
13fn mask_union<D: Dimension>(
14    a: &Array<bool, D>,
15    b: &Array<bool, D>,
16) -> FerrayResult<Array<bool, D>> {
17    if a.shape() != b.shape() {
18        return Err(FerrayError::shape_mismatch(format!(
19            "mask_union: shapes {:?} and {:?} differ",
20            a.shape(),
21            b.shape()
22        )));
23    }
24    let data: Vec<bool> = a.iter().zip(b.iter()).map(|(x, y)| *x || *y).collect();
25    Array::from_vec(a.dim().clone(), data)
26}
27
28/// Add two masked arrays elementwise, propagating the mask union.
29///
30/// # Errors
31/// Returns `FerrayError::ShapeMismatch` if shapes differ.
32pub fn masked_add<T, D>(
33    a: &MaskedArray<T, D>,
34    b: &MaskedArray<T, D>,
35) -> FerrayResult<MaskedArray<T, D>>
36where
37    T: Element + Add<Output = T> + Copy,
38    D: Dimension,
39{
40    if a.shape() != b.shape() {
41        return Err(FerrayError::shape_mismatch(format!(
42            "masked_add: shapes {:?} and {:?} differ",
43            a.shape(),
44            b.shape()
45        )));
46    }
47    let result_mask = mask_union(a.mask(), b.mask())?;
48    let data: Vec<T> = a
49        .data()
50        .iter()
51        .zip(b.data().iter())
52        .zip(result_mask.iter())
53        .map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
54        .collect();
55    let result_data = Array::from_vec(a.dim().clone(), data)?;
56    MaskedArray::new(result_data, result_mask)
57}
58
59/// Subtract two masked arrays elementwise, propagating the mask union.
60///
61/// # Errors
62/// Returns `FerrayError::ShapeMismatch` if shapes differ.
63pub fn masked_sub<T, D>(
64    a: &MaskedArray<T, D>,
65    b: &MaskedArray<T, D>,
66) -> FerrayResult<MaskedArray<T, D>>
67where
68    T: Element + Sub<Output = T> + Copy,
69    D: Dimension,
70{
71    if a.shape() != b.shape() {
72        return Err(FerrayError::shape_mismatch(format!(
73            "masked_sub: shapes {:?} and {:?} differ",
74            a.shape(),
75            b.shape()
76        )));
77    }
78    let result_mask = mask_union(a.mask(), b.mask())?;
79    let data: Vec<T> = a
80        .data()
81        .iter()
82        .zip(b.data().iter())
83        .zip(result_mask.iter())
84        .map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
85        .collect();
86    let result_data = Array::from_vec(a.dim().clone(), data)?;
87    MaskedArray::new(result_data, result_mask)
88}
89
90/// Multiply two masked arrays elementwise, propagating the mask union.
91///
92/// # Errors
93/// Returns `FerrayError::ShapeMismatch` if shapes differ.
94pub fn masked_mul<T, D>(
95    a: &MaskedArray<T, D>,
96    b: &MaskedArray<T, D>,
97) -> FerrayResult<MaskedArray<T, D>>
98where
99    T: Element + Mul<Output = T> + Copy,
100    D: Dimension,
101{
102    if a.shape() != b.shape() {
103        return Err(FerrayError::shape_mismatch(format!(
104            "masked_mul: shapes {:?} and {:?} differ",
105            a.shape(),
106            b.shape()
107        )));
108    }
109    let result_mask = mask_union(a.mask(), b.mask())?;
110    let data: Vec<T> = a
111        .data()
112        .iter()
113        .zip(b.data().iter())
114        .zip(result_mask.iter())
115        .map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
116        .collect();
117    let result_data = Array::from_vec(a.dim().clone(), data)?;
118    MaskedArray::new(result_data, result_mask)
119}
120
121/// Divide two masked arrays elementwise, propagating the mask union.
122///
123/// # Errors
124/// Returns `FerrayError::ShapeMismatch` if shapes differ.
125pub fn masked_div<T, D>(
126    a: &MaskedArray<T, D>,
127    b: &MaskedArray<T, D>,
128) -> FerrayResult<MaskedArray<T, D>>
129where
130    T: Element + Div<Output = T> + Copy,
131    D: Dimension,
132{
133    if a.shape() != b.shape() {
134        return Err(FerrayError::shape_mismatch(format!(
135            "masked_div: shapes {:?} and {:?} differ",
136            a.shape(),
137            b.shape()
138        )));
139    }
140    let result_mask = mask_union(a.mask(), b.mask())?;
141    let data: Vec<T> = a
142        .data()
143        .iter()
144        .zip(b.data().iter())
145        .zip(result_mask.iter())
146        .map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
147        .collect();
148    let result_data = Array::from_vec(a.dim().clone(), data)?;
149    MaskedArray::new(result_data, result_mask)
150}
151
152/// Add a masked array and a regular array, treating the regular array as unmasked.
153///
154/// # Errors
155/// Returns `FerrayError::ShapeMismatch` if shapes differ.
156pub fn masked_add_array<T, D>(
157    ma: &MaskedArray<T, D>,
158    arr: &Array<T, D>,
159) -> FerrayResult<MaskedArray<T, D>>
160where
161    T: Element + Add<Output = T> + Copy,
162    D: Dimension,
163{
164    if ma.shape() != arr.shape() {
165        return Err(FerrayError::shape_mismatch(format!(
166            "masked_add_array: shapes {:?} and {:?} differ",
167            ma.shape(),
168            arr.shape()
169        )));
170    }
171    let data: Vec<T> = ma
172        .data()
173        .iter()
174        .zip(arr.iter())
175        .zip(ma.mask().iter())
176        .map(|((x, y), m)| if *m { T::zero() } else { *x + *y })
177        .collect();
178    let result_data = Array::from_vec(ma.dim().clone(), data)?;
179    MaskedArray::new(result_data, ma.mask().clone())
180}
181
182/// Subtract a regular array from a masked array, treating the regular array as unmasked.
183///
184/// # Errors
185/// Returns `FerrayError::ShapeMismatch` if shapes differ.
186pub fn masked_sub_array<T, D>(
187    ma: &MaskedArray<T, D>,
188    arr: &Array<T, D>,
189) -> FerrayResult<MaskedArray<T, D>>
190where
191    T: Element + Sub<Output = T> + Copy,
192    D: Dimension,
193{
194    if ma.shape() != arr.shape() {
195        return Err(FerrayError::shape_mismatch(format!(
196            "masked_sub_array: shapes {:?} and {:?} differ",
197            ma.shape(),
198            arr.shape()
199        )));
200    }
201    let data: Vec<T> = ma
202        .data()
203        .iter()
204        .zip(arr.iter())
205        .zip(ma.mask().iter())
206        .map(|((x, y), m)| if *m { T::zero() } else { *x - *y })
207        .collect();
208    let result_data = Array::from_vec(ma.dim().clone(), data)?;
209    MaskedArray::new(result_data, ma.mask().clone())
210}
211
212/// Multiply a masked array and a regular array, treating the regular array as unmasked.
213///
214/// # Errors
215/// Returns `FerrayError::ShapeMismatch` if shapes differ.
216pub fn masked_mul_array<T, D>(
217    ma: &MaskedArray<T, D>,
218    arr: &Array<T, D>,
219) -> FerrayResult<MaskedArray<T, D>>
220where
221    T: Element + Mul<Output = T> + Copy,
222    D: Dimension,
223{
224    if ma.shape() != arr.shape() {
225        return Err(FerrayError::shape_mismatch(format!(
226            "masked_mul_array: shapes {:?} and {:?} differ",
227            ma.shape(),
228            arr.shape()
229        )));
230    }
231    let data: Vec<T> = ma
232        .data()
233        .iter()
234        .zip(arr.iter())
235        .zip(ma.mask().iter())
236        .map(|((x, y), m)| if *m { T::zero() } else { *x * *y })
237        .collect();
238    let result_data = Array::from_vec(ma.dim().clone(), data)?;
239    MaskedArray::new(result_data, ma.mask().clone())
240}
241
242/// Divide a masked array by a regular array, treating the regular array as unmasked.
243///
244/// # Errors
245/// Returns `FerrayError::ShapeMismatch` if shapes differ.
246pub fn masked_div_array<T, D>(
247    ma: &MaskedArray<T, D>,
248    arr: &Array<T, D>,
249) -> FerrayResult<MaskedArray<T, D>>
250where
251    T: Element + Div<Output = T> + Copy,
252    D: Dimension,
253{
254    if ma.shape() != arr.shape() {
255        return Err(FerrayError::shape_mismatch(format!(
256            "masked_div_array: shapes {:?} and {:?} differ",
257            ma.shape(),
258            arr.shape()
259        )));
260    }
261    let data: Vec<T> = ma
262        .data()
263        .iter()
264        .zip(arr.iter())
265        .zip(ma.mask().iter())
266        .map(|((x, y), m)| if *m { T::zero() } else { *x / *y })
267        .collect();
268    let result_data = Array::from_vec(ma.dim().clone(), data)?;
269    MaskedArray::new(result_data, ma.mask().clone())
270}