Skip to main content

ferray_core/indexing/
advanced.rs

1// ferray-core: Advanced (fancy) indexing (REQ-13, REQ-15)
2//
3// - index_select: integer-array indexing along an axis → copies
4// - boolean_index: boolean-mask indexing → copies (always 1-D)
5// - boolean_index_assign: masked assignment (a[mask] = value)
6//
7// All advanced indexing operations return COPIES, not views.
8
9use super::normalize_index;
10use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16// ---------------------------------------------------------------------------
17// index_select
18// ---------------------------------------------------------------------------
19
20impl<T: Element, D: Dimension> Array<T, D> {
21    /// Select elements along an axis using an array of indices.
22    ///
23    /// This is advanced (fancy) indexing — it always returns a **copy**.
24    /// The result has the same number of dimensions as the input, but
25    /// the size along `axis` is replaced by `indices.len()`.
26    ///
27    /// Negative indices are supported (counting from end).
28    ///
29    /// # Errors
30    /// - `AxisOutOfBounds` if `axis >= ndim`
31    /// - `IndexOutOfBounds` if any index is out of range
32    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
33        let ndim = self.ndim();
34        let ax = axis.index();
35        if ax >= ndim {
36            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
37        }
38        let axis_size = self.shape()[ax];
39
40        // Normalize all indices
41        let normalized: Vec<usize> = indices
42            .iter()
43            .map(|&idx| normalize_index(idx, axis_size, ax))
44            .collect::<FerrayResult<Vec<_>>>()?;
45
46        let dyn_view = self.inner.view().into_dyn();
47        let nd_axis = ndarray::Axis(ax);
48        let selected = dyn_view.select(nd_axis, &normalized);
49        Ok(Array::from_ndarray(selected))
50    }
51
52    /// Select elements using a boolean mask.
53    ///
54    /// Returns a 1-D array containing elements where `mask` is `true`.
55    /// This is always a **copy**.
56    ///
57    /// The mask must be broadcastable to the array's shape, or have the
58    /// same total number of elements. When the mask is 1-D and the array
59    /// is N-D, the mask is applied to the flattened array.
60    ///
61    /// # Errors
62    /// - `ShapeMismatch` if mask shape is incompatible
63    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
64        if self.shape() != mask.shape() {
65            return Err(FerrayError::shape_mismatch(format!(
66                "boolean index mask shape {:?} does not match array shape {:?}",
67                mask.shape(),
68                self.shape()
69            )));
70        }
71
72        let data: Vec<T> = self
73            .inner
74            .iter()
75            .zip(mask.inner.iter())
76            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
77            .collect();
78
79        let len = data.len();
80        Array::from_vec(Ix1::new([len]), data)
81    }
82
83    /// Boolean indexing with a flat mask (1-D mask on N-D array).
84    ///
85    /// The mask is applied to the flattened (row-major) array.
86    ///
87    /// # Errors
88    /// - `ShapeMismatch` if `mask.len() != self.size()`
89    pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
90        if mask.size() != self.size() {
91            return Err(FerrayError::shape_mismatch(format!(
92                "flat boolean mask length {} does not match array size {}",
93                mask.size(),
94                self.size()
95            )));
96        }
97
98        let data: Vec<T> = self
99            .inner
100            .iter()
101            .zip(mask.inner.iter())
102            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
103            .collect();
104
105        let len = data.len();
106        Array::from_vec(Ix1::new([len]), data)
107    }
108
109    /// Assign a scalar value to elements selected by a boolean mask.
110    ///
111    /// Equivalent to `a[mask] = value` in NumPy.
112    ///
113    /// # Errors
114    /// - `ShapeMismatch` if mask shape differs from array shape
115    pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
116        if self.shape() != mask.shape() {
117            return Err(FerrayError::shape_mismatch(format!(
118                "boolean index mask shape {:?} does not match array shape {:?}",
119                mask.shape(),
120                self.shape()
121            )));
122        }
123
124        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
125            if m {
126                *elem = value.clone();
127            }
128        }
129        Ok(())
130    }
131
132    /// Assign values from an array to elements selected by a boolean mask.
133    ///
134    /// `values` must have exactly as many elements as `mask` has `true`
135    /// entries.
136    ///
137    /// # Errors
138    /// - `ShapeMismatch` if mask shape differs or values length mismatches
139    pub fn boolean_index_assign_array(
140        &mut self,
141        mask: &Array<bool, D>,
142        values: &Array<T, Ix1>,
143    ) -> FerrayResult<()> {
144        if self.shape() != mask.shape() {
145            return Err(FerrayError::shape_mismatch(format!(
146                "boolean index mask shape {:?} does not match array shape {:?}",
147                mask.shape(),
148                self.shape()
149            )));
150        }
151
152        let true_count = mask.inner.iter().filter(|&&m| m).count();
153        if values.size() != true_count {
154            return Err(FerrayError::shape_mismatch(format!(
155                "values array has {} elements but mask has {} true entries",
156                values.size(),
157                true_count
158            )));
159        }
160
161        let mut val_iter = values.inner.iter();
162        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
163            if m {
164                if let Some(v) = val_iter.next() {
165                    *elem = v.clone();
166                }
167            }
168        }
169        Ok(())
170    }
171}
172
173// ---------------------------------------------------------------------------
174// ArrayView advanced indexing
175// ---------------------------------------------------------------------------
176
177impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
178    /// Select elements along an axis using an array of indices (copy).
179    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
180        let ndim = self.ndim();
181        let ax = axis.index();
182        if ax >= ndim {
183            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
184        }
185        let axis_size = self.shape()[ax];
186
187        let normalized: Vec<usize> = indices
188            .iter()
189            .map(|&idx| normalize_index(idx, axis_size, ax))
190            .collect::<FerrayResult<Vec<_>>>()?;
191
192        let dyn_view = self.inner.clone().into_dyn();
193        let nd_axis = ndarray::Axis(ax);
194        let selected = dyn_view.select(nd_axis, &normalized);
195        Ok(Array::from_ndarray(selected))
196    }
197
198    /// Select elements using a boolean mask (copy).
199    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
200        if self.shape() != mask.shape() {
201            return Err(FerrayError::shape_mismatch(format!(
202                "boolean index mask shape {:?} does not match view shape {:?}",
203                mask.shape(),
204                self.shape()
205            )));
206        }
207
208        let data: Vec<T> = self
209            .inner
210            .iter()
211            .zip(mask.inner.iter())
212            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
213            .collect();
214
215        let len = data.len();
216        Array::from_vec(Ix1::new([len]), data)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::dimension::{Ix1, Ix2};
224
225    // -----------------------------------------------------------------------
226    // index_select
227    // -----------------------------------------------------------------------
228
229    #[test]
230    fn index_select_rows() {
231        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
232        let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
233        assert_eq!(sel.shape(), &[3, 3]);
234        let data: Vec<i32> = sel.iter().copied().collect();
235        assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
236    }
237
238    #[test]
239    fn index_select_columns() {
240        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
241        let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
242        assert_eq!(sel.shape(), &[3, 2]);
243        let data: Vec<i32> = sel.iter().copied().collect();
244        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
245    }
246
247    #[test]
248    fn index_select_negative() {
249        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
250        let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
251        assert_eq!(sel.shape(), &[2]);
252        let data: Vec<i32> = sel.iter().copied().collect();
253        assert_eq!(data, vec![50, 30]);
254    }
255
256    #[test]
257    fn index_select_out_of_bounds() {
258        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
259        assert!(arr.index_select(Axis(0), &[3]).is_err());
260        assert!(arr.index_select(Axis(0), &[-4]).is_err());
261    }
262
263    #[test]
264    fn index_select_returns_copy() {
265        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
266        let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
267        // Should be a different allocation
268        assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
269    }
270
271    #[test]
272    fn index_select_duplicate_indices() {
273        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
274        let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
275        assert_eq!(sel.shape(), &[5]);
276        let data: Vec<i32> = sel.iter().copied().collect();
277        assert_eq!(data, vec![20, 20, 10, 30, 30]);
278    }
279
280    #[test]
281    fn index_select_empty() {
282        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
283        let sel = arr.index_select(Axis(0), &[]).unwrap();
284        assert_eq!(sel.shape(), &[0]);
285    }
286
287    // -----------------------------------------------------------------------
288    // boolean_index
289    // -----------------------------------------------------------------------
290
291    #[test]
292    fn boolean_index_1d() {
293        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
294        let mask =
295            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
296                .unwrap();
297        let selected = arr.boolean_index(&mask).unwrap();
298        assert_eq!(selected.shape(), &[3]);
299        assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
300    }
301
302    #[test]
303    fn boolean_index_2d() {
304        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
305        let mask = Array::<bool, Ix2>::from_vec(
306            Ix2::new([2, 3]),
307            vec![true, false, true, false, true, false],
308        )
309        .unwrap();
310        let selected = arr.boolean_index(&mask).unwrap();
311        assert_eq!(selected.shape(), &[3]);
312        assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
313    }
314
315    #[test]
316    fn boolean_index_all_false() {
317        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
318        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
319        let selected = arr.boolean_index(&mask).unwrap();
320        assert_eq!(selected.shape(), &[0]);
321    }
322
323    #[test]
324    fn boolean_index_shape_mismatch() {
325        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
326        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
327        assert!(arr.boolean_index(&mask).is_err());
328    }
329
330    #[test]
331    fn boolean_index_returns_copy() {
332        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
333        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
334        let selected = arr.boolean_index(&mask).unwrap();
335        assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
336    }
337
338    // -----------------------------------------------------------------------
339    // boolean_index_flat
340    // -----------------------------------------------------------------------
341
342    #[test]
343    fn boolean_index_flat_2d() {
344        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
345        let mask = Array::<bool, Ix1>::from_vec(
346            Ix1::new([6]),
347            vec![false, true, false, true, false, true],
348        )
349        .unwrap();
350        let selected = arr.boolean_index_flat(&mask).unwrap();
351        assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
352    }
353
354    #[test]
355    fn boolean_index_flat_wrong_size() {
356        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
357        let mask =
358            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
359        assert!(arr.boolean_index_flat(&mask).is_err());
360    }
361
362    // -----------------------------------------------------------------------
363    // boolean_index_assign
364    // -----------------------------------------------------------------------
365
366    #[test]
367    fn boolean_assign_scalar() {
368        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
369        let mask =
370            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
371                .unwrap();
372        arr.boolean_index_assign(&mask, 0).unwrap();
373        assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
374    }
375
376    #[test]
377    fn boolean_assign_array() {
378        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
379        let mask =
380            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
381                .unwrap();
382        let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
383        arr.boolean_index_assign_array(&mask, &values).unwrap();
384        assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
385    }
386
387    #[test]
388    fn boolean_assign_array_wrong_count() {
389        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
390        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
391        let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
392        assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
393    }
394
395    #[test]
396    fn boolean_assign_2d() {
397        let mut arr =
398            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
399        let mask = Array::<bool, Ix2>::from_vec(
400            Ix2::new([2, 3]),
401            vec![false, true, false, false, true, false],
402        )
403        .unwrap();
404        arr.boolean_index_assign(&mask, -1).unwrap();
405        let data: Vec<i32> = arr.iter().copied().collect();
406        assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
407    }
408
409    // -----------------------------------------------------------------------
410    // ArrayView advanced indexing
411    // -----------------------------------------------------------------------
412
413    #[test]
414    fn view_index_select() {
415        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
416        let v = arr.view();
417        let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
418        assert_eq!(sel.shape(), &[3, 2]);
419        let data: Vec<i32> = sel.iter().copied().collect();
420        assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
421    }
422
423    #[test]
424    fn view_boolean_index() {
425        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
426        let v = arr.view();
427        let mask =
428            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
429        let selected = v.boolean_index(&mask).unwrap();
430        assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
431    }
432}