Skip to main content

ferray_ma/
manipulation.rs

1// ferray-ma: Shape manipulation and indexing for MaskedArray (#507, #508)
2//
3// Shape ops (reshape, ravel, transpose, flatten, squeeze) delegate to the
4// corresponding ferray-core `manipulation` functions applied to both the
5// data and mask in parallel. Indexing (index_axis, slice_axis, get,
6// boolean_index) returns new MaskedArrays with the corresponding mask
7// slices.
8
9use ferray_core::Array;
10use ferray_core::dimension::{Dimension, Ix1, IxDyn};
11use ferray_core::dtype::Element;
12use ferray_core::error::{FerrayError, FerrayResult};
13use ferray_core::manipulation;
14
15use crate::MaskedArray;
16
17// ---------------------------------------------------------------------------
18// Shape operations (#508)
19// ---------------------------------------------------------------------------
20
21impl<T: Element + Copy, D: Dimension> MaskedArray<T, D> {
22    /// Return a new `MaskedArray` with the given shape.
23    ///
24    /// Equivalent to `numpy.ma.reshape(a, new_shape)`. The result is a
25    /// dynamic-rank `MaskedArray<T, IxDyn>`. The total element count
26    /// must match; mask and data are reshaped in lockstep so their
27    /// logical positions remain aligned. The `fill_value` and
28    /// `hard_mask` flag are preserved.
29    ///
30    /// # Errors
31    /// - `FerrayError::ShapeMismatch` if `new_shape`'s product does not
32    ///   equal `self.size()`.
33    pub fn reshape(&self, new_shape: &[usize]) -> FerrayResult<MaskedArray<T, IxDyn>> {
34        let data = manipulation::reshape(self.data(), new_shape)?;
35        let mask = manipulation::reshape(self.mask(), new_shape)?;
36        let mut out = MaskedArray::new(data, mask)?;
37        out.set_fill_value(self.fill_value());
38        out.hard_mask = self.hard_mask;
39        Ok(out)
40    }
41
42    /// Return a 1-D `MaskedArray` with the same total element count.
43    ///
44    /// Equivalent to `numpy.ma.ravel(a)`. Preserves `fill_value` and
45    /// `hard_mask`.
46    pub fn ravel(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
47        let data = manipulation::ravel(self.data())?;
48        let mask = manipulation::ravel(self.mask())?;
49        let mut out = MaskedArray::new(data, mask)?;
50        out.set_fill_value(self.fill_value());
51        out.hard_mask = self.hard_mask;
52        Ok(out)
53    }
54
55    /// Alias for [`MaskedArray::ravel`].
56    ///
57    /// Equivalent to `numpy.ma.flatten(a)`.
58    pub fn flatten(&self) -> FerrayResult<MaskedArray<T, Ix1>> {
59        self.ravel()
60    }
61
62    /// Return a transposed `MaskedArray` by permuting axes.
63    ///
64    /// When `axes` is `None`, reverses the axis order (equivalent to
65    /// `numpy.ma.transpose(a)` or `a.T`). When `Some(ax)`, uses the
66    /// supplied permutation. Both data and mask are permuted together
67    /// so their logical positions remain aligned.
68    ///
69    /// Returns a dynamic-rank result because the permutation is chosen
70    /// at runtime.
71    ///
72    /// # Errors
73    /// - `FerrayError::InvalidValue` if `axes` is the wrong length or
74    ///   not a valid permutation.
75    pub fn transpose(&self, axes: Option<&[usize]>) -> FerrayResult<MaskedArray<T, IxDyn>> {
76        let data = manipulation::transpose(self.data(), axes)?;
77        let mask = manipulation::transpose(self.mask(), axes)?;
78        let mut out = MaskedArray::new(data, mask)?;
79        out.set_fill_value(self.fill_value());
80        out.hard_mask = self.hard_mask;
81        Ok(out)
82    }
83
84    /// Return a transposed `MaskedArray` with reversed axis order.
85    ///
86    /// Shorthand for `self.transpose(None)`, equivalent to `NumPy`'s
87    /// `.T` property.
88    pub fn t(&self) -> FerrayResult<MaskedArray<T, IxDyn>> {
89        self.transpose(None)
90    }
91
92    /// Remove size-1 axes from the masked array.
93    ///
94    /// Equivalent to `numpy.ma.squeeze(a)`. When `axis` is `None`,
95    /// removes every axis whose length is 1. When `Some(ax)`, only
96    /// the specified axis (which must have length 1) is removed.
97    ///
98    /// The single-axis restriction mirrors `ferray_core::manipulation::squeeze`
99    /// — chain multiple calls if you need to drop several axes.
100    pub fn squeeze(&self, axis: Option<usize>) -> FerrayResult<MaskedArray<T, IxDyn>> {
101        let data = manipulation::squeeze(self.data(), axis)?;
102        let mask = manipulation::squeeze(self.mask(), axis)?;
103        let mut out = MaskedArray::new(data, mask)?;
104        out.set_fill_value(self.fill_value());
105        out.hard_mask = self.hard_mask;
106        Ok(out)
107    }
108}
109
110// ---------------------------------------------------------------------------
111// Indexing and slicing (#507)
112// ---------------------------------------------------------------------------
113
114impl<T: Element + Copy, D: Dimension> MaskedArray<T, D> {
115    /// Get a single element by flat (row-major) index.
116    ///
117    /// Returns `Ok((value, is_masked))` where `is_masked` is `true`
118    /// when the mask bit at that position is set. Returns an error if
119    /// the index is out of bounds.
120    ///
121    /// Returning the raw `(T, bool)` pair lets callers decide what to
122    /// do with masked values instead of forcing a fill-value
123    /// substitution — if you want the NumPy-style masked scalar
124    /// behavior, check `is_masked` and fall back to `self.fill_value()`.
125    ///
126    /// # Errors
127    /// - `FerrayError::IndexOutOfBounds` if `flat_idx >= self.size()`.
128    pub fn get_flat(&self, flat_idx: usize) -> FerrayResult<(T, bool)> {
129        let size = self.size();
130        if flat_idx >= size {
131            return Err(FerrayError::index_out_of_bounds(flat_idx as isize, 0, size));
132        }
133        // Fast path: contiguous buffers give O(1) slice indexing.
134        let value = if let Some(s) = self.data().as_slice() {
135            s[flat_idx]
136        } else {
137            self.data().iter().nth(flat_idx).copied().unwrap()
138        };
139        let is_masked = if let Some(s) = self.mask().as_slice() {
140            s[flat_idx]
141        } else {
142            self.mask().iter().nth(flat_idx).copied().unwrap()
143        };
144        Ok((value, is_masked))
145    }
146
147    /// Select elements where `bool_mask` is `true`.
148    ///
149    /// Equivalent to `a[bool_mask]` in `NumPy` boolean indexing. Returns
150    /// a 1-D `MaskedArray` containing only positions where the
151    /// supplied `bool_mask` is `true`; each selected position carries
152    /// through both its value and its original mask bit. The
153    /// `bool_mask` must have exactly the same shape as `self`.
154    ///
155    /// # Errors
156    /// - `FerrayError::ShapeMismatch` if `bool_mask.shape() != self.shape()`.
157    pub fn boolean_index(&self, bool_mask: &Array<bool, D>) -> FerrayResult<MaskedArray<T, Ix1>> {
158        if bool_mask.shape() != self.shape() {
159            return Err(FerrayError::shape_mismatch(format!(
160                "boolean_index: selector shape {:?} does not match masked array shape {:?}",
161                bool_mask.shape(),
162                self.shape()
163            )));
164        }
165        let mut picked_data: Vec<T> = Vec::new();
166        let mut picked_mask: Vec<bool> = Vec::new();
167        for ((&v, &m_bit), &sel) in self
168            .data()
169            .iter()
170            .zip(self.mask().iter())
171            .zip(bool_mask.iter())
172        {
173            if sel {
174                picked_data.push(v);
175                picked_mask.push(m_bit);
176            }
177        }
178        let n = picked_data.len();
179        let data_arr = Array::<T, Ix1>::from_vec(Ix1::new([n]), picked_data)?;
180        let mask_arr = Array::<bool, Ix1>::from_vec(Ix1::new([n]), picked_mask)?;
181        let mut out = MaskedArray::new(data_arr, mask_arr)?;
182        out.set_fill_value(self.fill_value());
183        out.hard_mask = self.hard_mask;
184        Ok(out)
185    }
186
187    /// Fancy index selection from a 1-D `MaskedArray`.
188    ///
189    /// Equivalent to `a[indices]` in `NumPy` fancy indexing (restricted
190    /// to 1-D because higher-rank fancy indexing has NumPy-specific
191    /// broadcasting semantics that would be easy to get subtly wrong).
192    /// Returns a new 1-D `MaskedArray` whose elements are picked from
193    /// `self` at the supplied flat positions.
194    ///
195    /// Each result position carries through both the selected value
196    /// and its original mask bit.
197    ///
198    /// # Errors
199    /// - `FerrayError::IndexOutOfBounds` if any index is out of range.
200    pub fn take(&self, indices: &[usize]) -> FerrayResult<MaskedArray<T, Ix1>>
201    where
202        D: Dimension,
203    {
204        let size = self.size();
205        let mut picked_data: Vec<T> = Vec::with_capacity(indices.len());
206        let mut picked_mask: Vec<bool> = Vec::with_capacity(indices.len());
207        // Walk the flat buffers once rather than calling get_flat in a
208        // loop — for contiguous inputs this is O(1) per index.
209        let data_slice = self.data().as_slice();
210        let mask_slice = self.mask().as_slice();
211        let data_fallback: Option<Vec<T>> = if data_slice.is_none() {
212            Some(self.data().iter().copied().collect())
213        } else {
214            None
215        };
216        let mask_fallback: Option<Vec<bool>> = if mask_slice.is_none() {
217            Some(self.mask().iter().copied().collect())
218        } else {
219            None
220        };
221        for &idx in indices {
222            if idx >= size {
223                return Err(FerrayError::index_out_of_bounds(idx as isize, 0, size));
224            }
225            let v = if let Some(s) = data_slice {
226                s[idx]
227            } else {
228                data_fallback.as_ref().unwrap()[idx]
229            };
230            let m = if let Some(s) = mask_slice {
231                s[idx]
232            } else {
233                mask_fallback.as_ref().unwrap()[idx]
234            };
235            picked_data.push(v);
236            picked_mask.push(m);
237        }
238        let n = picked_data.len();
239        let data_arr = Array::<T, Ix1>::from_vec(Ix1::new([n]), picked_data)?;
240        let mask_arr = Array::<bool, Ix1>::from_vec(Ix1::new([n]), picked_mask)?;
241        let mut out = MaskedArray::new(data_arr, mask_arr)?;
242        out.set_fill_value(self.fill_value());
243        out.hard_mask = self.hard_mask;
244        Ok(out)
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use ferray_core::dimension::{Ix2, Ix3};
252
253    fn ma2d(rows: usize, cols: usize, data: Vec<f64>, mask: Vec<bool>) -> MaskedArray<f64, Ix2> {
254        let d = Array::<f64, Ix2>::from_vec(Ix2::new([rows, cols]), data).unwrap();
255        let m = Array::<bool, Ix2>::from_vec(Ix2::new([rows, cols]), mask).unwrap();
256        MaskedArray::new(d, m).unwrap()
257    }
258
259    // ---- reshape / ravel / transpose / squeeze (#508) ----
260
261    #[test]
262    fn reshape_2d_to_different_2d() {
263        let ma = ma2d(
264            2,
265            3,
266            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
267            vec![false, true, false, false, true, false],
268        );
269        let r = ma.reshape(&[3, 2]).unwrap();
270        assert_eq!(r.shape(), &[3, 2]);
271        // Row-major order preserved.
272        assert_eq!(
273            r.data().iter().copied().collect::<Vec<_>>(),
274            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
275        );
276        assert_eq!(
277            r.mask().iter().copied().collect::<Vec<_>>(),
278            vec![false, true, false, false, true, false]
279        );
280    }
281
282    #[test]
283    fn reshape_2d_to_1d() {
284        let ma = ma2d(
285            2,
286            3,
287            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
288            vec![false, true, false, false, true, false],
289        );
290        let r = ma.reshape(&[6]).unwrap();
291        assert_eq!(r.shape(), &[6]);
292        assert_eq!(r.size(), 6);
293    }
294
295    #[test]
296    fn reshape_mismatched_size_errors() {
297        let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
298        assert!(ma.reshape(&[2, 4]).is_err());
299    }
300
301    #[test]
302    fn reshape_preserves_fill_value_and_hard_mask() {
303        let mut ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
304        ma.set_fill_value(-99.0);
305        ma.harden_mask().unwrap();
306        let r = ma.reshape(&[3, 2]).unwrap();
307        assert_eq!(r.fill_value(), -99.0);
308        assert!(r.is_hard_mask());
309    }
310
311    #[test]
312    fn ravel_2d_flattens_in_row_major() {
313        let ma = ma2d(
314            2,
315            3,
316            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
317            vec![false, true, false, false, true, false],
318        );
319        let r = ma.ravel().unwrap();
320        assert_eq!(r.shape(), &[6]);
321        assert_eq!(
322            r.data().iter().copied().collect::<Vec<_>>(),
323            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
324        );
325        assert_eq!(
326            r.mask().iter().copied().collect::<Vec<_>>(),
327            vec![false, true, false, false, true, false]
328        );
329    }
330
331    #[test]
332    fn flatten_is_alias_for_ravel() {
333        let ma = ma2d(
334            2,
335            2,
336            vec![1.0, 2.0, 3.0, 4.0],
337            vec![false, true, false, true],
338        );
339        let r1 = ma.ravel().unwrap();
340        let r2 = ma.flatten().unwrap();
341        assert_eq!(
342            r1.data().iter().copied().collect::<Vec<_>>(),
343            r2.data().iter().copied().collect::<Vec<_>>()
344        );
345        assert_eq!(
346            r1.mask().iter().copied().collect::<Vec<_>>(),
347            r2.mask().iter().copied().collect::<Vec<_>>()
348        );
349    }
350
351    #[test]
352    fn transpose_swaps_2d() {
353        // [[1,2,3],[4,5,6]] → [[1,4],[2,5],[3,6]]
354        let ma = ma2d(
355            2,
356            3,
357            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
358            vec![false, true, false, false, true, false],
359        );
360        let t = ma.transpose(None).unwrap();
361        assert_eq!(t.shape(), &[3, 2]);
362        assert_eq!(
363            t.data().iter().copied().collect::<Vec<_>>(),
364            vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]
365        );
366        // Mask transposes too.
367        assert_eq!(
368            t.mask().iter().copied().collect::<Vec<_>>(),
369            vec![false, false, true, true, false, false]
370        );
371    }
372
373    #[test]
374    fn t_is_alias_for_transpose_none() {
375        let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
376        let t1 = ma.transpose(None).unwrap();
377        let t2 = ma.t().unwrap();
378        assert_eq!(t1.shape(), t2.shape());
379    }
380
381    #[test]
382    fn transpose_with_explicit_permutation() {
383        // 3-D with explicit permutation [2, 0, 1] — 2x3x4 → 4x2x3
384        let data: Vec<f64> = (0..24).map(f64::from).collect();
385        let mask = vec![false; 24];
386        let d = Array::<f64, Ix3>::from_vec(Ix3::new([2, 3, 4]), data).unwrap();
387        let m = Array::<bool, Ix3>::from_vec(Ix3::new([2, 3, 4]), mask).unwrap();
388        let ma = MaskedArray::new(d, m).unwrap();
389        let t = ma.transpose(Some(&[2, 0, 1])).unwrap();
390        assert_eq!(t.shape(), &[4, 2, 3]);
391    }
392
393    #[test]
394    fn squeeze_removes_all_size_1_dims_when_axis_none() {
395        // (1, 3, 1) → (3,)
396        let d = Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
397        let m =
398            Array::<bool, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![false, true, false]).unwrap();
399        let ma = MaskedArray::new(d, m).unwrap();
400        let s = ma.squeeze(None).unwrap();
401        assert_eq!(s.shape(), &[3]);
402        assert_eq!(
403            s.mask().iter().copied().collect::<Vec<_>>(),
404            vec![false, true, false]
405        );
406    }
407
408    #[test]
409    fn squeeze_single_axis() {
410        // (1, 3, 1), squeeze axis=0 → (3, 1)
411        let d = Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![10.0, 20.0, 30.0]).unwrap();
412        let m =
413            Array::<bool, Ix3>::from_vec(Ix3::new([1, 3, 1]), vec![false, true, false]).unwrap();
414        let ma = MaskedArray::new(d, m).unwrap();
415        let s = ma.squeeze(Some(0)).unwrap();
416        assert_eq!(s.shape(), &[3, 1]);
417    }
418
419    // ---- indexing (#507) ----
420
421    #[test]
422    fn get_flat_returns_value_and_mask_bit() {
423        let ma = ma2d(
424            2,
425            3,
426            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
427            vec![false, true, false, false, true, false],
428        );
429        // Row-major: position 1 is (0, 1) with value=2, masked=true.
430        let (v, m) = ma.get_flat(1).unwrap();
431        assert_eq!(v, 2.0);
432        assert!(m);
433        // Position 3 is (1, 0) with value=4, masked=false.
434        let (v, m) = ma.get_flat(3).unwrap();
435        assert_eq!(v, 4.0);
436        assert!(!m);
437    }
438
439    #[test]
440    fn get_flat_out_of_bounds_errors() {
441        let ma = ma2d(2, 2, vec![1.0; 4], vec![false; 4]);
442        assert!(ma.get_flat(4).is_err());
443        assert!(ma.get_flat(99).is_err());
444    }
445
446    #[test]
447    fn boolean_index_selects_unmasked_structure() {
448        let ma = ma2d(
449            2,
450            3,
451            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
452            vec![false, true, false, false, true, false],
453        );
454        let selector = Array::<bool, Ix2>::from_vec(
455            Ix2::new([2, 3]),
456            vec![true, true, false, false, true, true],
457        )
458        .unwrap();
459        let picked = ma.boolean_index(&selector).unwrap();
460        // Selected positions: 0, 1, 4, 5 → values 1, 2, 5, 6
461        assert_eq!(
462            picked.data().iter().copied().collect::<Vec<_>>(),
463            vec![1.0, 2.0, 5.0, 6.0]
464        );
465        // Original mask bits at those positions: F, T, T, F
466        assert_eq!(
467            picked.mask().iter().copied().collect::<Vec<_>>(),
468            vec![false, true, true, false]
469        );
470    }
471
472    #[test]
473    fn boolean_index_rejects_wrong_shape() {
474        let ma = ma2d(2, 3, vec![1.0; 6], vec![false; 6]);
475        let wrong = Array::<bool, Ix2>::from_vec(Ix2::new([3, 2]), vec![false; 6]).unwrap();
476        assert!(ma.boolean_index(&wrong).is_err());
477    }
478
479    #[test]
480    fn take_fancy_index_picks_flat_positions() {
481        let ma = ma2d(
482            2,
483            3,
484            vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
485            vec![false, true, false, false, false, true],
486        );
487        // Pick positions 0, 5, 2, 1 in that order.
488        let r = ma.take(&[0, 5, 2, 1]).unwrap();
489        assert_eq!(
490            r.data().iter().copied().collect::<Vec<_>>(),
491            vec![10.0, 60.0, 30.0, 20.0]
492        );
493        assert_eq!(
494            r.mask().iter().copied().collect::<Vec<_>>(),
495            vec![false, true, false, true]
496        );
497    }
498
499    #[test]
500    fn take_out_of_bounds_errors() {
501        let ma = ma2d(2, 2, vec![1.0; 4], vec![false; 4]);
502        assert!(ma.take(&[0, 1, 5]).is_err());
503    }
504
505    #[test]
506    fn take_with_repeated_indices() {
507        let ma = ma2d(1, 3, vec![1.0, 2.0, 3.0], vec![false, false, true]);
508        let r = ma.take(&[0, 0, 2, 2]).unwrap();
509        assert_eq!(
510            r.data().iter().copied().collect::<Vec<_>>(),
511            vec![1.0, 1.0, 3.0, 3.0]
512        );
513        assert_eq!(
514            r.mask().iter().copied().collect::<Vec<_>>(),
515            vec![false, false, true, true]
516        );
517    }
518}