Skip to main content

ferray_core/indexing/
advanced.rs

1// ferray-core: Advanced (fancy) indexing (REQ-13, REQ-15)
2//
3// - index_select: integer-array indexing along an axis → copies
4// - boolean_index: boolean-mask indexing → copies (always 1-D)
5// - boolean_index_assign: masked assignment (a[mask] = value)
6//
7// All advanced indexing operations return COPIES, not views.
8
9use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
12use crate::dtype::Element;
13use crate::error::{FerrayError, FerrayResult};
14
15// ---------------------------------------------------------------------------
16// index_select
17// ---------------------------------------------------------------------------
18
19impl<T: Element, D: Dimension> Array<T, D> {
20    /// Select elements along an axis using an array of indices.
21    ///
22    /// This is advanced (fancy) indexing — it always returns a **copy**.
23    /// The result has the same number of dimensions as the input, but
24    /// the size along `axis` is replaced by `indices.len()`.
25    ///
26    /// Negative indices are supported (counting from end).
27    ///
28    /// # Errors
29    /// - `AxisOutOfBounds` if `axis >= ndim`
30    /// - `IndexOutOfBounds` if any index is out of range
31    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
32        let ndim = self.ndim();
33        let ax = axis.index();
34        if ax >= ndim {
35            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
36        }
37        let axis_size = self.shape()[ax];
38
39        // Normalize all indices
40        let normalized: Vec<usize> = indices
41            .iter()
42            .map(|&idx| normalize_index_adv(idx, axis_size, ax))
43            .collect::<FerrayResult<Vec<_>>>()?;
44
45        let dyn_view = self.inner.view().into_dyn();
46        let nd_axis = ndarray::Axis(ax);
47        let selected = dyn_view.select(nd_axis, &normalized);
48        Ok(Array::from_ndarray(selected))
49    }
50
51    /// Select elements using a boolean mask.
52    ///
53    /// Returns a 1-D array containing elements where `mask` is `true`.
54    /// This is always a **copy**.
55    ///
56    /// The mask must be broadcastable to the array's shape, or have the
57    /// same total number of elements. When the mask is 1-D and the array
58    /// is N-D, the mask is applied to the flattened array.
59    ///
60    /// # Errors
61    /// - `ShapeMismatch` if mask shape is incompatible
62    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
63        if self.shape() != mask.shape() {
64            return Err(FerrayError::shape_mismatch(format!(
65                "boolean index mask shape {:?} does not match array shape {:?}",
66                mask.shape(),
67                self.shape()
68            )));
69        }
70
71        let data: Vec<T> = self
72            .inner
73            .iter()
74            .zip(mask.inner.iter())
75            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
76            .collect();
77
78        let len = data.len();
79        Array::from_vec(Ix1::new([len]), data)
80    }
81
82    /// Boolean indexing with a flat mask (1-D mask on N-D array).
83    ///
84    /// The mask is applied to the flattened (row-major) array.
85    ///
86    /// # Errors
87    /// - `ShapeMismatch` if `mask.len() != self.size()`
88    pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
89        if mask.size() != self.size() {
90            return Err(FerrayError::shape_mismatch(format!(
91                "flat boolean mask length {} does not match array size {}",
92                mask.size(),
93                self.size()
94            )));
95        }
96
97        let data: Vec<T> = self
98            .inner
99            .iter()
100            .zip(mask.inner.iter())
101            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
102            .collect();
103
104        let len = data.len();
105        Array::from_vec(Ix1::new([len]), data)
106    }
107
108    /// Assign a scalar value to elements selected by a boolean mask.
109    ///
110    /// Equivalent to `a[mask] = value` in NumPy.
111    ///
112    /// # Errors
113    /// - `ShapeMismatch` if mask shape differs from array shape
114    pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
115        if self.shape() != mask.shape() {
116            return Err(FerrayError::shape_mismatch(format!(
117                "boolean index mask shape {:?} does not match array shape {:?}",
118                mask.shape(),
119                self.shape()
120            )));
121        }
122
123        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
124            if m {
125                *elem = value.clone();
126            }
127        }
128        Ok(())
129    }
130
131    /// Assign values from an array to elements selected by a boolean mask.
132    ///
133    /// `values` must have exactly as many elements as `mask` has `true`
134    /// entries.
135    ///
136    /// # Errors
137    /// - `ShapeMismatch` if mask shape differs or values length mismatches
138    pub fn boolean_index_assign_array(
139        &mut self,
140        mask: &Array<bool, D>,
141        values: &Array<T, Ix1>,
142    ) -> FerrayResult<()> {
143        if self.shape() != mask.shape() {
144            return Err(FerrayError::shape_mismatch(format!(
145                "boolean index mask shape {:?} does not match array shape {:?}",
146                mask.shape(),
147                self.shape()
148            )));
149        }
150
151        let true_count = mask.inner.iter().filter(|&&m| m).count();
152        if values.size() != true_count {
153            return Err(FerrayError::shape_mismatch(format!(
154                "values array has {} elements but mask has {} true entries",
155                values.size(),
156                true_count
157            )));
158        }
159
160        let mut val_iter = values.inner.iter();
161        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
162            if m {
163                if let Some(v) = val_iter.next() {
164                    *elem = v.clone();
165                }
166            }
167        }
168        Ok(())
169    }
170}
171
172// ---------------------------------------------------------------------------
173// ArrayView advanced indexing
174// ---------------------------------------------------------------------------
175
176impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
177    /// Select elements along an axis using an array of indices (copy).
178    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
179        let ndim = self.ndim();
180        let ax = axis.index();
181        if ax >= ndim {
182            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
183        }
184        let axis_size = self.shape()[ax];
185
186        let normalized: Vec<usize> = indices
187            .iter()
188            .map(|&idx| normalize_index_adv(idx, axis_size, ax))
189            .collect::<FerrayResult<Vec<_>>>()?;
190
191        let dyn_view = self.inner.clone().into_dyn();
192        let nd_axis = ndarray::Axis(ax);
193        let selected = dyn_view.select(nd_axis, &normalized);
194        Ok(Array::from_ndarray(selected))
195    }
196
197    /// Select elements using a boolean mask (copy).
198    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
199        if self.shape() != mask.shape() {
200            return Err(FerrayError::shape_mismatch(format!(
201                "boolean index mask shape {:?} does not match view shape {:?}",
202                mask.shape(),
203                self.shape()
204            )));
205        }
206
207        let data: Vec<T> = self
208            .inner
209            .iter()
210            .zip(mask.inner.iter())
211            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
212            .collect();
213
214        let len = data.len();
215        Array::from_vec(Ix1::new([len]), data)
216    }
217}
218
219/// Normalize a (potentially negative) index for advanced indexing.
220fn normalize_index_adv(index: isize, size: usize, axis: usize) -> FerrayResult<usize> {
221    if index < 0 {
222        let pos = size as isize + index;
223        if pos < 0 {
224            return Err(FerrayError::index_out_of_bounds(index, axis, size));
225        }
226        Ok(pos as usize)
227    } else {
228        let idx = index as usize;
229        if idx >= size {
230            return Err(FerrayError::index_out_of_bounds(index, axis, size));
231        }
232        Ok(idx)
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::dimension::{Ix1, Ix2};
240
241    // -----------------------------------------------------------------------
242    // index_select
243    // -----------------------------------------------------------------------
244
245    #[test]
246    fn index_select_rows() {
247        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
248        let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
249        assert_eq!(sel.shape(), &[3, 3]);
250        let data: Vec<i32> = sel.iter().copied().collect();
251        assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
252    }
253
254    #[test]
255    fn index_select_columns() {
256        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
257        let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
258        assert_eq!(sel.shape(), &[3, 2]);
259        let data: Vec<i32> = sel.iter().copied().collect();
260        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
261    }
262
263    #[test]
264    fn index_select_negative() {
265        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
266        let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
267        assert_eq!(sel.shape(), &[2]);
268        let data: Vec<i32> = sel.iter().copied().collect();
269        assert_eq!(data, vec![50, 30]);
270    }
271
272    #[test]
273    fn index_select_out_of_bounds() {
274        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
275        assert!(arr.index_select(Axis(0), &[3]).is_err());
276        assert!(arr.index_select(Axis(0), &[-4]).is_err());
277    }
278
279    #[test]
280    fn index_select_returns_copy() {
281        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
282        let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
283        // Should be a different allocation
284        assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
285    }
286
287    #[test]
288    fn index_select_duplicate_indices() {
289        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
290        let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
291        assert_eq!(sel.shape(), &[5]);
292        let data: Vec<i32> = sel.iter().copied().collect();
293        assert_eq!(data, vec![20, 20, 10, 30, 30]);
294    }
295
296    #[test]
297    fn index_select_empty() {
298        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
299        let sel = arr.index_select(Axis(0), &[]).unwrap();
300        assert_eq!(sel.shape(), &[0]);
301    }
302
303    // -----------------------------------------------------------------------
304    // boolean_index
305    // -----------------------------------------------------------------------
306
307    #[test]
308    fn boolean_index_1d() {
309        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
310        let mask =
311            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
312                .unwrap();
313        let selected = arr.boolean_index(&mask).unwrap();
314        assert_eq!(selected.shape(), &[3]);
315        assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
316    }
317
318    #[test]
319    fn boolean_index_2d() {
320        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
321        let mask = Array::<bool, Ix2>::from_vec(
322            Ix2::new([2, 3]),
323            vec![true, false, true, false, true, false],
324        )
325        .unwrap();
326        let selected = arr.boolean_index(&mask).unwrap();
327        assert_eq!(selected.shape(), &[3]);
328        assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
329    }
330
331    #[test]
332    fn boolean_index_all_false() {
333        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
334        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
335        let selected = arr.boolean_index(&mask).unwrap();
336        assert_eq!(selected.shape(), &[0]);
337    }
338
339    #[test]
340    fn boolean_index_shape_mismatch() {
341        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
342        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
343        assert!(arr.boolean_index(&mask).is_err());
344    }
345
346    #[test]
347    fn boolean_index_returns_copy() {
348        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
349        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
350        let selected = arr.boolean_index(&mask).unwrap();
351        assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
352    }
353
354    // -----------------------------------------------------------------------
355    // boolean_index_flat
356    // -----------------------------------------------------------------------
357
358    #[test]
359    fn boolean_index_flat_2d() {
360        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
361        let mask = Array::<bool, Ix1>::from_vec(
362            Ix1::new([6]),
363            vec![false, true, false, true, false, true],
364        )
365        .unwrap();
366        let selected = arr.boolean_index_flat(&mask).unwrap();
367        assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
368    }
369
370    #[test]
371    fn boolean_index_flat_wrong_size() {
372        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
373        let mask =
374            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
375        assert!(arr.boolean_index_flat(&mask).is_err());
376    }
377
378    // -----------------------------------------------------------------------
379    // boolean_index_assign
380    // -----------------------------------------------------------------------
381
382    #[test]
383    fn boolean_assign_scalar() {
384        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
385        let mask =
386            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
387                .unwrap();
388        arr.boolean_index_assign(&mask, 0).unwrap();
389        assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
390    }
391
392    #[test]
393    fn boolean_assign_array() {
394        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
395        let mask =
396            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
397                .unwrap();
398        let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
399        arr.boolean_index_assign_array(&mask, &values).unwrap();
400        assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
401    }
402
403    #[test]
404    fn boolean_assign_array_wrong_count() {
405        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
406        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
407        let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
408        assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
409    }
410
411    #[test]
412    fn boolean_assign_2d() {
413        let mut arr =
414            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
415        let mask = Array::<bool, Ix2>::from_vec(
416            Ix2::new([2, 3]),
417            vec![false, true, false, false, true, false],
418        )
419        .unwrap();
420        arr.boolean_index_assign(&mask, -1).unwrap();
421        let data: Vec<i32> = arr.iter().copied().collect();
422        assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
423    }
424
425    // -----------------------------------------------------------------------
426    // ArrayView advanced indexing
427    // -----------------------------------------------------------------------
428
429    #[test]
430    fn view_index_select() {
431        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
432        let v = arr.view();
433        let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
434        assert_eq!(sel.shape(), &[3, 2]);
435        let data: Vec<i32> = sel.iter().copied().collect();
436        assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
437    }
438
439    #[test]
440    fn view_boolean_index() {
441        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
442        let v = arr.view();
443        let mask =
444            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
445        let selected = v.boolean_index(&mask).unwrap();
446        assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
447    }
448}