1use ferray_core::Array;
8use ferray_core::dimension::Dimension;
9use ferray_core::dtype::Element;
10use ferray_core::error::FerrayResult;
11use num_traits::Float;
12
13use crate::MaskedArray;
14
15pub fn masked_where<T: Element + Copy, D: Dimension>(
20 condition: &Array<bool, D>,
21 data: &Array<T, D>,
22) -> FerrayResult<MaskedArray<T, D>> {
23 MaskedArray::new(data.clone(), condition.clone())
24}
25
26pub fn masked_invalid<T: Element + Float, D: Dimension>(
35 data: &Array<T, D>,
36) -> FerrayResult<MaskedArray<T, D>> {
37 let mask_data: Vec<bool> = data.iter().map(|v| v.is_nan() || v.is_infinite()).collect();
38 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
39 MaskedArray::new(data.clone(), mask)
40}
41
42pub fn fix_invalid<T: Element + Float, D: Dimension>(
61 data: &Array<T, D>,
62 fill_value: T,
63) -> FerrayResult<MaskedArray<T, D>> {
64 let mut new_data: Vec<T> = Vec::with_capacity(data.size());
65 let mut new_mask: Vec<bool> = Vec::with_capacity(data.size());
66 for &v in data.iter() {
67 if v.is_nan() || v.is_infinite() {
68 new_data.push(fill_value);
69 new_mask.push(true);
70 } else {
71 new_data.push(v);
72 new_mask.push(false);
73 }
74 }
75 let data_arr = Array::from_vec(data.dim().clone(), new_data)?;
76 let mask_arr = Array::from_vec(data.dim().clone(), new_mask)?;
77 let mut out = MaskedArray::new(data_arr, mask_arr)?;
78 out.set_fill_value(fill_value);
79 Ok(out)
80}
81
82pub fn masked_equal<T: Element + PartialEq + Copy, D: Dimension>(
87 data: &Array<T, D>,
88 value: T,
89) -> FerrayResult<MaskedArray<T, D>> {
90 let mask_data: Vec<bool> = data.iter().map(|v| *v == value).collect();
91 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
92 MaskedArray::new(data.clone(), mask)
93}
94
95pub fn masked_not_equal<T: Element + PartialEq + Copy, D: Dimension>(
100 data: &Array<T, D>,
101 value: T,
102) -> FerrayResult<MaskedArray<T, D>> {
103 let mask_data: Vec<bool> = data.iter().map(|v| *v != value).collect();
104 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
105 MaskedArray::new(data.clone(), mask)
106}
107
108pub fn masked_greater<T: Element + PartialOrd + Copy, D: Dimension>(
113 data: &Array<T, D>,
114 value: T,
115) -> FerrayResult<MaskedArray<T, D>> {
116 let mask_data: Vec<bool> = data.iter().map(|v| *v > value).collect();
117 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
118 MaskedArray::new(data.clone(), mask)
119}
120
121pub fn masked_less<T: Element + PartialOrd + Copy, D: Dimension>(
126 data: &Array<T, D>,
127 value: T,
128) -> FerrayResult<MaskedArray<T, D>> {
129 let mask_data: Vec<bool> = data.iter().map(|v| *v < value).collect();
130 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
131 MaskedArray::new(data.clone(), mask)
132}
133
134pub fn masked_greater_equal<T: Element + PartialOrd + Copy, D: Dimension>(
139 data: &Array<T, D>,
140 value: T,
141) -> FerrayResult<MaskedArray<T, D>> {
142 let mask_data: Vec<bool> = data.iter().map(|v| *v >= value).collect();
143 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
144 MaskedArray::new(data.clone(), mask)
145}
146
147pub fn masked_less_equal<T: Element + PartialOrd + Copy, D: Dimension>(
152 data: &Array<T, D>,
153 value: T,
154) -> FerrayResult<MaskedArray<T, D>> {
155 let mask_data: Vec<bool> = data.iter().map(|v| *v <= value).collect();
156 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
157 MaskedArray::new(data.clone(), mask)
158}
159
160pub fn masked_inside<T: Element + PartialOrd + Copy, D: Dimension>(
168 data: &Array<T, D>,
169 v1: T,
170 v2: T,
171) -> FerrayResult<MaskedArray<T, D>> {
172 let (lo, hi) = if v1 <= v2 { (v1, v2) } else { (v2, v1) };
173 let mask_data: Vec<bool> = data.iter().map(|v| *v >= lo && *v <= hi).collect();
174 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
175 MaskedArray::new(data.clone(), mask)
176}
177
178pub fn masked_outside<T: Element + PartialOrd + Copy, D: Dimension>(
186 data: &Array<T, D>,
187 v1: T,
188 v2: T,
189) -> FerrayResult<MaskedArray<T, D>> {
190 let (lo, hi) = if v1 <= v2 { (v1, v2) } else { (v2, v1) };
191 let mask_data: Vec<bool> = data.iter().map(|v| *v < lo || *v > hi).collect();
192 let mask = Array::from_vec(data.dim().clone(), mask_data)?;
193 MaskedArray::new(data.clone(), mask)
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use ferray_core::dimension::Ix1;
200
201 #[test]
204 fn fix_invalid_masks_and_replaces_nan_and_inf() {
205 let data = Array::<f64, Ix1>::from_vec(
206 Ix1::new([6]),
207 vec![1.0, f64::NAN, 3.0, f64::INFINITY, f64::NEG_INFINITY, 6.0],
208 )
209 .unwrap();
210 let ma = fix_invalid(&data, -99.0).unwrap();
211
212 let m: Vec<bool> = ma.mask().iter().copied().collect();
214 assert_eq!(m, vec![false, true, false, true, true, false]);
215
216 let d: Vec<f64> = ma.data().iter().copied().collect();
218 assert_eq!(d, vec![1.0, -99.0, 3.0, -99.0, -99.0, 6.0]);
219
220 assert_eq!(ma.fill_value(), -99.0);
222 }
223
224 #[test]
225 fn fix_invalid_preserves_valid_values() {
226 let data = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
228 let ma = fix_invalid(&data, 0.0).unwrap();
229 assert_eq!(
230 ma.mask().iter().copied().collect::<Vec<_>>(),
231 vec![false, false, false, false]
232 );
233 assert_eq!(
234 ma.data().iter().copied().collect::<Vec<_>>(),
235 vec![1.0, 2.0, 3.0, 4.0]
236 );
237 }
238
239 #[test]
240 fn fix_invalid_all_nan_input() {
241 let data =
242 Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![f64::NAN, f64::NAN, f64::NAN]).unwrap();
243 let ma = fix_invalid(&data, 0.0).unwrap();
244 assert_eq!(
245 ma.mask().iter().copied().collect::<Vec<_>>(),
246 vec![true, true, true]
247 );
248 assert_eq!(
249 ma.data().iter().copied().collect::<Vec<_>>(),
250 vec![0.0, 0.0, 0.0]
251 );
252 assert!(ma.data().iter().all(|v| !v.is_nan()));
255 }
256
257 #[test]
258 fn fix_invalid_vs_masked_invalid_data_difference() {
259 let data = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, f64::NAN, 3.0]).unwrap();
261 let via_masked = masked_invalid(&data).unwrap();
262 let via_fixed = fix_invalid(&data, -1.0).unwrap();
263
264 assert_eq!(
266 via_masked.mask().iter().copied().collect::<Vec<_>>(),
267 via_fixed.mask().iter().copied().collect::<Vec<_>>()
268 );
269
270 assert!(via_masked.data().iter().nth(1).unwrap().is_nan());
273 assert_eq!(*via_fixed.data().iter().nth(1).unwrap(), -1.0);
274 }
275
276 #[test]
277 fn fix_invalid_2d_shape_preserved() {
278 use ferray_core::dimension::Ix2;
279 let data = Array::<f64, Ix2>::from_vec(
280 Ix2::new([2, 3]),
281 vec![1.0, f64::NAN, 3.0, 4.0, 5.0, f64::INFINITY],
282 )
283 .unwrap();
284 let ma = fix_invalid(&data, -1.0).unwrap();
285 assert_eq!(ma.shape(), &[2, 3]);
286 assert_eq!(
287 ma.mask().iter().copied().collect::<Vec<_>>(),
288 vec![false, true, false, false, false, true]
289 );
290 }
291
292 #[test]
295 fn masked_inside_canonical_order_masks_interior() {
296 let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
297 let ma = masked_inside(&data, 2, 4).unwrap();
298 assert_eq!(
299 ma.mask().iter().copied().collect::<Vec<_>>(),
300 vec![false, true, true, true, false]
301 );
302 }
303
304 #[test]
305 fn masked_inside_swaps_when_v1_greater_than_v2() {
306 let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
309 let swapped = masked_inside(&data, 4, 2).unwrap();
310 let canonical = masked_inside(&data, 2, 4).unwrap();
311 assert_eq!(
312 swapped.mask().iter().copied().collect::<Vec<_>>(),
313 canonical.mask().iter().copied().collect::<Vec<_>>()
314 );
315 }
316
317 #[test]
318 fn masked_outside_canonical_order_masks_exterior() {
319 let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
320 let ma = masked_outside(&data, 2, 4).unwrap();
321 assert_eq!(
322 ma.mask().iter().copied().collect::<Vec<_>>(),
323 vec![true, false, false, false, true]
324 );
325 }
326
327 #[test]
328 fn masked_outside_swaps_when_v1_greater_than_v2() {
329 let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
331 let swapped = masked_outside(&data, 4, 2).unwrap();
332 let canonical = masked_outside(&data, 2, 4).unwrap();
333 assert_eq!(
334 swapped.mask().iter().copied().collect::<Vec<_>>(),
335 canonical.mask().iter().copied().collect::<Vec<_>>()
336 );
337 }
338
339 #[test]
340 fn masked_inside_equal_endpoints_masks_only_that_value() {
341 let data = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
344 let ma = masked_inside(&data, 3, 3).unwrap();
345 assert_eq!(
346 ma.mask().iter().copied().collect::<Vec<_>>(),
347 vec![false, false, true, false, false]
348 );
349 }
350}