1use ferray_core::Array;
10use ferray_core::dimension::{Dimension, Ix1, IxDyn};
11use ferray_core::dtype::Element;
12use ferray_core::error::{FerrayError, FerrayResult};
13use ferray_core::manipulation;
14
15use crate::MaskedArray;
16
17impl<T: Element + Copy, D: Dimension> MaskedArray<T, D> {
22 pub fn reshape(&self, new_shape: &[usize]) -> FerrayResult<MaskedArray<T, IxDyn>> {
34 let data = manipulation::reshape(self.data(), new_shape)?;
35 let mask = manipulation::reshape(self.mask(), new_shape)?;
36 let mut out = MaskedArray::new(data, mask)?;
37 out.set_fill_value(self.fill_value());
38 out.hard_mask = self.hard_mask;
39 Ok(out)
40 }
41
42 pub fn ravel(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
47 let data = manipulation::ravel(self.data())?;
48 let mask = manipulation::ravel(self.mask())?;
49 let mut out = MaskedArray::new(data, mask)?;
50 out.set_fill_value(self.fill_value());
51 out.hard_mask = self.hard_mask;
52 Ok(out)
53 }
54
55 pub fn flatten(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
59 self.ravel()
60 }
61
62 pub fn transpose(&self, axes: Option<&[usize]>) -> FerrayResult<MaskedArray<T, IxDyn>> {
76 let data = manipulation::transpose(self.data(), axes)?;
77 let mask = manipulation::transpose(self.mask(), axes)?;
78 let mut out = MaskedArray::new(data, mask)?;
79 out.set_fill_value(self.fill_value());
80 out.hard_mask = self.hard_mask;
81 Ok(out)
82 }
83
84 pub fn t(&self) -> FerrayResult<MaskedArray<T, IxDyn>> {
89 self.transpose(None)
90 }
91
92 pub fn squeeze(&self, axis: Option<usize>) -> FerrayResult<MaskedArray<T, IxDyn>> {
101 let data = manipulation::squeeze(self.data(), axis)?;
102 let mask = manipulation::squeeze(self.mask(), axis)?;
103 let mut out = MaskedArray::new(data, mask)?;
104 out.set_fill_value(self.fill_value());
105 out.hard_mask = self.hard_mask;
106 Ok(out)
107 }
108}
109
110impl<T: Element + Copy, D: Dimension> MaskedArray<T, D> {
115 pub fn get_flat(&self, flat_idx: usize) -> FerrayResult<(T, bool)> {
129 let size = self.size();
130 if flat_idx >= size {
131 return Err(FerrayError::index_out_of_bounds(flat_idx as isize, 0, size));
132 }
133 let value = if let Some(s) = self.data().as_slice() {
135 s[flat_idx]
136 } else {
137 self.data().iter().nth(flat_idx).copied().unwrap()
138 };
139 let is_masked = if let Some(s) = self.mask().as_slice() {
140 s[flat_idx]
141 } else {
142 self.mask().iter().nth(flat_idx).copied().unwrap()
143 };
144 Ok((value, is_masked))
145 }
146
147 pub fn boolean_index(&self, bool_mask: &Array<bool, D>) -> FerrayResult<MaskedArray<T, Ix1>> {
158 if bool_mask.shape() != self.shape() {
159 return Err(FerrayError::shape_mismatch(format!(
160 "boolean_index: selector shape {:?} does not match masked array shape {:?}",
161 bool_mask.shape(),
162 self.shape()
163 )));
164 }
165 let mut picked_data: Vec<T> = Vec::new();
166 let mut picked_mask: Vec<bool> = Vec::new();
167 for ((&v, &m_bit), &sel) in self
168 .data()
169 .iter()
170 .zip(self.mask().iter())
171 .zip(bool_mask.iter())
172 {
173 if sel {
174 picked_data.push(v);
175 picked_mask.push(m_bit);
176 }
177 }
178 let n = picked_data.len();
179 let data_arr = Array::<T, Ix1>::from_vec(Ix1::new([n]), picked_data)?;
180 let mask_arr = Array::<bool, Ix1>::from_vec(Ix1::new([n]), picked_mask)?;
181 let mut out = MaskedArray::new(data_arr, mask_arr)?;
182 out.set_fill_value(self.fill_value());
183 out.hard_mask = self.hard_mask;
184 Ok(out)
185 }
186
187 pub fn take(&self, indices: &[usize]) -> FerrayResult<MaskedArray<T, Ix1>>
201 where
202 D: Dimension,
203 {
204 let size = self.size();
205 let mut picked_data: Vec<T> = Vec::with_capacity(indices.len());
206 let mut picked_mask: Vec<bool> = Vec::with_capacity(indices.len());
207 let data_slice = self.data().as_slice();
210 let mask_slice = self.mask().as_slice();
211 let data_fallback: Option<Vec<T>> = if data_slice.is_none() {
212 Some(self.data().iter().copied().collect())
213 } else {
214 None
215 };
216 let mask_fallback: Option<Vec<bool>> = if mask_slice.is_none() {
217 Some(self.mask().iter().copied().collect())
218 } else {
219 None
220 };
221 for &idx in indices {
222 if idx >= size {
223 return Err(FerrayError::index_out_of_bounds(idx as isize, 0, size));
224 }
225 let v = if let Some(s) = data_slice {
226 s[idx]
227 } else {
228 data_fallback.as_ref().unwrap()[idx]
229 };
230 let m = if let Some(s) = mask_slice {
231 s[idx]
232 } else {
233 mask_fallback.as_ref().unwrap()[idx]
234 };
235 picked_data.push(v);
236 picked_mask.push(m);
237 }
238 let n = picked_data.len();
239 let data_arr = Array::<T, Ix1>::from_vec(Ix1::new([n]), picked_data)?;
240 let mask_arr = Array::<bool, Ix1>::from_vec(Ix1::new([n]), picked_mask)?;
241 let mut out = MaskedArray::new(data_arr, mask_arr)?;
242 out.set_fill_value(self.fill_value());
243 out.hard_mask = self.hard_mask;
244 Ok(out)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use ferray_core::dimension::{Ix2, Ix3};
252
253 fn ma2d(rows: usize, cols: usize, data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix2> {
254 let d = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap();
255 let m = Array::<bool, Ix2>::from_vec(Ix2::new([rows, cols]), mask).unwrap();
256 MaskedArray::new(d, m).unwrap()
257 }
258
259 #[test]
262 fn reshape_2d_to_different_2d() {
263 let ma = ma2d(
264 2,
265 3,
266 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
267 vec![false, true, false, false, true, false],
268 );
269 let r = ma.reshape(&[3, 2]).unwrap();
270 assert_eq!(r.shape(), &[3, 2]);
271 assert_eq!(
273 r.data().iter().copied().collect::<Vec<_>>(),
274 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
275 );
276 assert_eq!(
277 r.mask().iter().copied().collect::<Vec<_>>(),
278 vec![false, true, false, false, true, false]
279 );
280 }
281
282 #[test]
283 fn reshape_2d_to_1d() {
284 let ma = ma2d(
285 2,
286 3,
287 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
288 vec![false, true, false, false, true, false],
289 );
290 let r = ma.reshape(&[6]).unwrap();
291 assert_eq!(r.shape(), &[6]);
292 assert_eq!(r.size(), 6);
293 }
294
295 #[test]
296 fn reshape_mismatched_size_errors() {
297 let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
298 assert!(ma.reshape(&[2, 4]).is_err());
299 }
300
301 #[test]
302 fn reshape_preserves_fill_value_and_hard_mask() {
303 let mut ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
304 ma.set_fill_value(-99.0);
305 ma.harden_mask().unwrap();
306 let r = ma.reshape(&[3, 2]).unwrap();
307 assert_eq!(r.fill_value(), -99.0);
308 assert!(r.is_hard_mask());
309 }
310
311 #[test]
312 fn ravel_2d_flattens_in_row_major() {
313 let ma = ma2d(
314 2,
315 3,
316 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
317 vec![false, true, false, false, true, false],
318 );
319 let r = ma.ravel().unwrap();
320 assert_eq!(r.shape(), &[6]);
321 assert_eq!(
322 r.data().iter().copied().collect::<Vec<_>>(),
323 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
324 );
325 assert_eq!(
326 r.mask().iter().copied().collect::<Vec<_>>(),
327 vec![false, true, false, false, true, false]
328 );
329 }
330
331 #[test]
332 fn flatten_is_alias_for_ravel() {
333 let ma = ma2d(
334 2,
335 2,
336 vec![1.0, 2.0, 3.0, 4.0],
337 vec![false, true, false, true],
338 );
339 let r1 = ma.ravel().unwrap();
340 let r2 = ma.flatten().unwrap();
341 assert_eq!(
342 r1.data().iter().copied().collect::<Vec<_>>(),
343 r2.data().iter().copied().collect::<Vec<_>>()
344 );
345 assert_eq!(
346 r1.mask().iter().copied().collect::<Vec<_>>(),
347 r2.mask().iter().copied().collect::<Vec<_>>()
348 );
349 }
350
351 #[test]
352 fn transpose_swaps_2d() {
353 let ma = ma2d(
355 2,
356 3,
357 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
358 vec![false, true, false, false, true, false],
359 );
360 let t = ma.transpose(None).unwrap();
361 assert_eq!(t.shape(), &[3, 2]);
362 assert_eq!(
363 t.data().iter().copied().collect::<Vec<_>>(),
364 vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]
365 );
366 assert_eq!(
368 t.mask().iter().copied().collect::<Vec<_>>(),
369 vec![false, false, true, true, false, false]
370 );
371 }
372
373 #[test]
374 fn t_is_alias_for_transpose_none() {
375 let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
376 let t1 = ma.transpose(None).unwrap();
377 let t2 = ma.t().unwrap();
378 assert_eq!(t1.shape(), t2.shape());
379 }
380
381 #[test]
382 fn transpose_with_explicit_permutation() {
383 let data: Vec<f64> = (0..24).map(f64::from).collect();
385 let mask = vec![false; 24];
386 let d = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
387 let m = Array::<bool, Ix3>::from_vec(Ix3::new([2, 3, 4]), mask).unwrap();
388 let ma = MaskedArray::new(d, m).unwrap();
389 let t = ma.transpose(Some(&[2, 0, 1])).unwrap();
390 assert_eq!(t.shape(), &[4, 2, 3]);
391 }
392
393 #[test]
394 fn squeeze_removes_all_size_1_dims_when_axis_none() {
395 let d = Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
397 let m =
398 Array::<bool, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![false, true, false]).unwrap();
399 let ma = MaskedArray::new(d, m).unwrap();
400 let s = ma.squeeze(None).unwrap();
401 assert_eq!(s.shape(), &[3]);
402 assert_eq!(
403 s.mask().iter().copied().collect::<Vec<_>>(),
404 vec![false, true, false]
405 );
406 }
407
408 #[test]
409 fn squeeze_single_axis() {
410 let d = Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
412 let m =
413 Array::<bool, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![false, true, false]).unwrap();
414 let ma = MaskedArray::new(d, m).unwrap();
415 let s = ma.squeeze(Some(0)).unwrap();
416 assert_eq!(s.shape(), &[3, 1]);
417 }
418
419 #[test]
422 fn get_flat_returns_value_and_mask_bit() {
423 let ma = ma2d(
424 2,
425 3,
426 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
427 vec![false, true, false, false, true, false],
428 );
429 let (v, m) = ma.get_flat(1).unwrap();
431 assert_eq!(v, 2.0);
432 assert!(m);
433 let (v, m) = ma.get_flat(3).unwrap();
435 assert_eq!(v, 4.0);
436 assert!(!m);
437 }
438
439 #[test]
440 fn get_flat_out_of_bounds_errors() {
441 let ma = ma2d(2, 2, vec![1.0; 4], vec![false; 4]);
442 assert!(ma.get_flat(4).is_err());
443 assert!(ma.get_flat(99).is_err());
444 }
445
446 #[test]
447 fn boolean_index_selects_unmasked_structure() {
448 let ma = ma2d(
449 2,
450 3,
451 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
452 vec![false, true, false, false, true, false],
453 );
454 let selector = Array::<bool, Ix2>::from_vec(
455 Ix2::new([2, 3]),
456 vec![true, true, false, false, true, true],
457 )
458 .unwrap();
459 let picked = ma.boolean_index(&selector).unwrap();
460 assert_eq!(
462 picked.data().iter().copied().collect::<Vec<_>>(),
463 vec![1.0, 2.0, 5.0, 6.0]
464 );
465 assert_eq!(
467 picked.mask().iter().copied().collect::<Vec<_>>(),
468 vec![false, true, true, false]
469 );
470 }
471
472 #[test]
473 fn boolean_index_rejects_wrong_shape() {
474 let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
475 let wrong = Array::<bool, Ix2>::from_vec(Ix2::new([3, 2]), vec![false; 6]).unwrap();
476 assert!(ma.boolean_index(&wrong).is_err());
477 }
478
479 #[test]
480 fn take_fancy_index_picks_flat_positions() {
481 let ma = ma2d(
482 2,
483 3,
484 vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
485 vec![false, true, false, false, false, true],
486 );
487 let r = ma.take(&[0, 5, 2, 1]).unwrap();
489 assert_eq!(
490 r.data().iter().copied().collect::<Vec<_>>(),
491 vec![10.0, 60.0, 30.0, 20.0]
492 );
493 assert_eq!(
494 r.mask().iter().copied().collect::<Vec<_>>(),
495 vec![false, true, false, true]
496 );
497 }
498
499 #[test]
500 fn take_out_of_bounds_errors() {
501 let ma = ma2d(2, 2, vec![1.0; 4], vec![false; 4]);
502 assert!(ma.take(&[0, 1, 5]).is_err());
503 }
504
505 #[test]
506 fn take_with_repeated_indices() {
507 let ma = ma2d(1, 3, vec![1.0, 2.0, 3.0], vec![false, false, true]);
508 let r = ma.take(&[0, 0, 2, 2]).unwrap();
509 assert_eq!(
510 r.data().iter().copied().collect::<Vec<_>>(),
511 vec![1.0, 1.0, 3.0, 3.0]
512 );
513 assert_eq!(
514 r.mask().iter().copied().collect::<Vec<_>>(),
515 vec![false, false, true, true]
516 );
517 }
518}