Skip to main content

ferray_core/indexing/
advanced.rs

1// ferray-core: Advanced (fancy) indexing (REQ-13, REQ-15)
2//
3// - index_select: integer-array indexing along an axis → copies
4// - boolean_index: boolean-mask indexing → copies (always 1-D)
5// - boolean_index_assign: masked assignment (a[mask] = value)
6//
7// All advanced indexing operations return COPIES, not views.
8
9use super::normalize_index;
10use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16// ---------------------------------------------------------------------------
17// index_select
18// ---------------------------------------------------------------------------
19
20impl<T: Element, D: Dimension> Array<T, D> {
21    /// Select elements along an axis using an array of indices.
22    ///
23    /// This is advanced (fancy) indexing — it always returns a **copy**.
24    /// The result has the same number of dimensions as the input, but
25    /// the size along `axis` is replaced by `indices.len()`.
26    ///
27    /// Negative indices are supported (counting from end).
28    ///
29    /// # Errors
30    /// - `AxisOutOfBounds` if `axis >= ndim`
31    /// - `IndexOutOfBounds` if any index is out of range
32    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
33        let ndim = self.ndim();
34        let ax = axis.index();
35        if ax >= ndim {
36            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
37        }
38        let axis_size = self.shape()[ax];
39
40        // Normalize all indices
41        let normalized: Vec<usize> = indices
42            .iter()
43            .map(|&idx| normalize_index(idx, axis_size, ax))
44            .collect::<FerrayResult<Vec<_>>>()?;
45
46        let dyn_view = self.inner.view().into_dyn();
47        let nd_axis = ndarray::Axis(ax);
48        let selected = dyn_view.select(nd_axis, &normalized);
49        Ok(Array::from_ndarray(selected))
50    }
51
52    /// Select elements using a boolean mask.
53    ///
54    /// Returns a 1-D array containing elements where `mask` is `true`.
55    /// This is always a **copy**.
56    ///
57    /// The mask must be broadcastable to the array's shape, or have the
58    /// same total number of elements. When the mask is 1-D and the array
59    /// is N-D, the mask is applied to the flattened array.
60    ///
61    /// # Errors
62    /// - `ShapeMismatch` if mask shape is incompatible
63    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
64        if self.shape() != mask.shape() {
65            return Err(FerrayError::shape_mismatch(format!(
66                "boolean index mask shape {:?} does not match array shape {:?}",
67                mask.shape(),
68                self.shape()
69            )));
70        }
71
72        let data: Vec<T> = self
73            .inner
74            .iter()
75            .zip(mask.inner.iter())
76            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
77            .collect();
78
79        let len = data.len();
80        Array::from_vec(Ix1::new([len]), data)
81    }
82
83    /// Boolean indexing with a flat mask (1-D mask on N-D array).
84    ///
85    /// The mask is applied to the flattened (row-major) array.
86    ///
87    /// # Errors
88    /// - `ShapeMismatch` if `mask.len() != self.size()`
89    pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
90        if mask.size() != self.size() {
91            return Err(FerrayError::shape_mismatch(format!(
92                "flat boolean mask length {} does not match array size {}",
93                mask.size(),
94                self.size()
95            )));
96        }
97
98        let data: Vec<T> = self
99            .inner
100            .iter()
101            .zip(mask.inner.iter())
102            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
103            .collect();
104
105        let len = data.len();
106        Array::from_vec(Ix1::new([len]), data)
107    }
108
109    /// Assign a scalar value to elements selected by a boolean mask.
110    ///
111    /// Equivalent to `a[mask] = value` in `NumPy`.
112    ///
113    /// # Errors
114    /// - `ShapeMismatch` if mask shape differs from array shape
115    pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
116        if self.shape() != mask.shape() {
117            return Err(FerrayError::shape_mismatch(format!(
118                "boolean index mask shape {:?} does not match array shape {:?}",
119                mask.shape(),
120                self.shape()
121            )));
122        }
123
124        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
125            if m {
126                *elem = value.clone();
127            }
128        }
129        Ok(())
130    }
131
132    /// Assign values from an array to elements selected by a boolean mask.
133    ///
134    /// `values` must have exactly as many elements as `mask` has `true`
135    /// entries.
136    ///
137    /// # Errors
138    /// - `ShapeMismatch` if mask shape differs or values length mismatches
139    pub fn boolean_index_assign_array(
140        &mut self,
141        mask: &Array<bool, D>,
142        values: &Array<T, Ix1>,
143    ) -> FerrayResult<()> {
144        if self.shape() != mask.shape() {
145            return Err(FerrayError::shape_mismatch(format!(
146                "boolean index mask shape {:?} does not match array shape {:?}",
147                mask.shape(),
148                self.shape()
149            )));
150        }
151
152        let true_count = mask.inner.iter().filter(|&&m| m).count();
153        if values.size() != true_count {
154            return Err(FerrayError::shape_mismatch(format!(
155                "values array has {} elements but mask has {} true entries",
156                values.size(),
157                true_count
158            )));
159        }
160
161        let mut val_iter = values.inner.iter();
162        for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
163            if m && let Some(v) = val_iter.next() {
164                *elem = v.clone();
165            }
166        }
167        Ok(())
168    }
169}
170
171// ---------------------------------------------------------------------------
172// ArrayView advanced indexing
173// ---------------------------------------------------------------------------
174
175impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
176    /// Select elements along an axis using an array of indices (copy).
177    pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
178        let ndim = self.ndim();
179        let ax = axis.index();
180        if ax >= ndim {
181            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
182        }
183        let axis_size = self.shape()[ax];
184
185        let normalized: Vec<usize> = indices
186            .iter()
187            .map(|&idx| normalize_index(idx, axis_size, ax))
188            .collect::<FerrayResult<Vec<_>>>()?;
189
190        let dyn_view = self.inner.clone().into_dyn();
191        let nd_axis = ndarray::Axis(ax);
192        let selected = dyn_view.select(nd_axis, &normalized);
193        Ok(Array::from_ndarray(selected))
194    }
195
196    /// Select elements using a boolean mask (copy).
197    pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
198        if self.shape() != mask.shape() {
199            return Err(FerrayError::shape_mismatch(format!(
200                "boolean index mask shape {:?} does not match view shape {:?}",
201                mask.shape(),
202                self.shape()
203            )));
204        }
205
206        let data: Vec<T> = self
207            .inner
208            .iter()
209            .zip(mask.inner.iter())
210            .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
211            .collect();
212
213        let len = data.len();
214        Array::from_vec(Ix1::new([len]), data)
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::dimension::{Ix1, Ix2};
222
223    // -----------------------------------------------------------------------
224    // index_select
225    // -----------------------------------------------------------------------
226
227    #[test]
228    fn index_select_rows() {
229        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
230        let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
231        assert_eq!(sel.shape(), &[3, 3]);
232        let data: Vec<i32> = sel.iter().copied().collect();
233        assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
234    }
235
236    #[test]
237    fn index_select_columns() {
238        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
239        let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
240        assert_eq!(sel.shape(), &[3, 2]);
241        let data: Vec<i32> = sel.iter().copied().collect();
242        assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
243    }
244
245    #[test]
246    fn index_select_negative() {
247        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
248        let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
249        assert_eq!(sel.shape(), &[2]);
250        let data: Vec<i32> = sel.iter().copied().collect();
251        assert_eq!(data, vec![50, 30]);
252    }
253
254    #[test]
255    fn index_select_out_of_bounds() {
256        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
257        assert!(arr.index_select(Axis(0), &[3]).is_err());
258        assert!(arr.index_select(Axis(0), &[-4]).is_err());
259    }
260
261    #[test]
262    fn index_select_returns_copy() {
263        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
264        let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
265        // Should be a different allocation
266        assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
267    }
268
269    #[test]
270    fn index_select_duplicate_indices() {
271        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
272        let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
273        assert_eq!(sel.shape(), &[5]);
274        let data: Vec<i32> = sel.iter().copied().collect();
275        assert_eq!(data, vec![20, 20, 10, 30, 30]);
276    }
277
278    #[test]
279    fn index_select_empty() {
280        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
281        let sel = arr.index_select(Axis(0), &[]).unwrap();
282        assert_eq!(sel.shape(), &[0]);
283    }
284
285    // -----------------------------------------------------------------------
286    // boolean_index
287    // -----------------------------------------------------------------------
288
289    #[test]
290    fn boolean_index_1d() {
291        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
292        let mask =
293            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
294                .unwrap();
295        let selected = arr.boolean_index(&mask).unwrap();
296        assert_eq!(selected.shape(), &[3]);
297        assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
298    }
299
300    #[test]
301    fn boolean_index_2d() {
302        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
303        let mask = Array::<bool, Ix2>::from_vec(
304            Ix2::new([2, 3]),
305            vec![true, false, true, false, true, false],
306        )
307        .unwrap();
308        let selected = arr.boolean_index(&mask).unwrap();
309        assert_eq!(selected.shape(), &[3]);
310        assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
311    }
312
313    #[test]
314    fn boolean_index_all_false() {
315        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
316        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
317        let selected = arr.boolean_index(&mask).unwrap();
318        assert_eq!(selected.shape(), &[0]);
319    }
320
321    #[test]
322    fn boolean_index_shape_mismatch() {
323        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
324        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
325        assert!(arr.boolean_index(&mask).is_err());
326    }
327
328    #[test]
329    fn boolean_index_returns_copy() {
330        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
331        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
332        let selected = arr.boolean_index(&mask).unwrap();
333        assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
334    }
335
336    // -----------------------------------------------------------------------
337    // boolean_index_flat
338    // -----------------------------------------------------------------------
339
340    #[test]
341    fn boolean_index_flat_2d() {
342        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
343        let mask = Array::<bool, Ix1>::from_vec(
344            Ix1::new([6]),
345            vec![false, true, false, true, false, true],
346        )
347        .unwrap();
348        let selected = arr.boolean_index_flat(&mask).unwrap();
349        assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
350    }
351
352    #[test]
353    fn boolean_index_flat_wrong_size() {
354        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
355        let mask =
356            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
357        assert!(arr.boolean_index_flat(&mask).is_err());
358    }
359
360    // -----------------------------------------------------------------------
361    // boolean_index_assign
362    // -----------------------------------------------------------------------
363
364    #[test]
365    fn boolean_assign_scalar() {
366        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
367        let mask =
368            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
369                .unwrap();
370        arr.boolean_index_assign(&mask, 0).unwrap();
371        assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
372    }
373
374    #[test]
375    fn boolean_assign_array() {
376        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
377        let mask =
378            Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
379                .unwrap();
380        let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
381        arr.boolean_index_assign_array(&mask, &values).unwrap();
382        assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
383    }
384
385    #[test]
386    fn boolean_assign_array_wrong_count() {
387        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
388        let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
389        let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
390        assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
391    }
392
393    #[test]
394    fn boolean_assign_2d() {
395        let mut arr =
396            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
397        let mask = Array::<bool, Ix2>::from_vec(
398            Ix2::new([2, 3]),
399            vec![false, true, false, false, true, false],
400        )
401        .unwrap();
402        arr.boolean_index_assign(&mask, -1).unwrap();
403        let data: Vec<i32> = arr.iter().copied().collect();
404        assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
405    }
406
407    // -----------------------------------------------------------------------
408    // ArrayView advanced indexing
409    // -----------------------------------------------------------------------
410
411    #[test]
412    fn view_index_select() {
413        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
414        let v = arr.view();
415        let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
416        assert_eq!(sel.shape(), &[3, 2]);
417        let data: Vec<i32> = sel.iter().copied().collect();
418        assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
419    }
420
421    #[test]
422    fn view_boolean_index() {
423        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
424        let v = arr.view();
425        let mask =
426            Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
427        let selected = v.boolean_index(&mask).unwrap();
428        assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
429    }
430}