1use super::normalize_index;
10use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16impl<T: Element, D: Dimension> Array<T, D> {
21 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
33 let ndim = self.ndim();
34 let ax = axis.index();
35 if ax >= ndim {
36 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
37 }
38 let axis_size = self.shape()[ax];
39
40 let normalized: Vec<usize> = indices
42 .iter()
43 .map(|&idx| normalize_index(idx, axis_size, ax))
44 .collect::<FerrayResult<Vec<_>>>()?;
45
46 let dyn_view = self.inner.view().into_dyn();
47 let nd_axis = ndarray::Axis(ax);
48 let selected = dyn_view.select(nd_axis, &normalized);
49 Ok(Array::from_ndarray(selected))
50 }
51
52 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
64 if self.shape() != mask.shape() {
65 return Err(FerrayError::shape_mismatch(format!(
66 "boolean index mask shape {:?} does not match array shape {:?}",
67 mask.shape(),
68 self.shape()
69 )));
70 }
71
72 let data: Vec<T> = self
73 .inner
74 .iter()
75 .zip(mask.inner.iter())
76 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
77 .collect();
78
79 let len = data.len();
80 Array::from_vec(Ix1::new([len]), data)
81 }
82
83 pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
90 if mask.size() != self.size() {
91 return Err(FerrayError::shape_mismatch(format!(
92 "flat boolean mask length {} does not match array size {}",
93 mask.size(),
94 self.size()
95 )));
96 }
97
98 let data: Vec<T> = self
99 .inner
100 .iter()
101 .zip(mask.inner.iter())
102 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
103 .collect();
104
105 let len = data.len();
106 Array::from_vec(Ix1::new([len]), data)
107 }
108
109 pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
116 if self.shape() != mask.shape() {
117 return Err(FerrayError::shape_mismatch(format!(
118 "boolean index mask shape {:?} does not match array shape {:?}",
119 mask.shape(),
120 self.shape()
121 )));
122 }
123
124 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
125 if m {
126 *elem = value.clone();
127 }
128 }
129 Ok(())
130 }
131
132 pub fn boolean_index_assign_array(
140 &mut self,
141 mask: &Array<bool, D>,
142 values: &Array<T, Ix1>,
143 ) -> FerrayResult<()> {
144 if self.shape() != mask.shape() {
145 return Err(FerrayError::shape_mismatch(format!(
146 "boolean index mask shape {:?} does not match array shape {:?}",
147 mask.shape(),
148 self.shape()
149 )));
150 }
151
152 let true_count = mask.inner.iter().filter(|&&m| m).count();
153 if values.size() != true_count {
154 return Err(FerrayError::shape_mismatch(format!(
155 "values array has {} elements but mask has {} true entries",
156 values.size(),
157 true_count
158 )));
159 }
160
161 let mut val_iter = values.inner.iter();
162 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
163 if m {
164 if let Some(v) = val_iter.next() {
165 *elem = v.clone();
166 }
167 }
168 }
169 Ok(())
170 }
171}
172
173impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
178 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
180 let ndim = self.ndim();
181 let ax = axis.index();
182 if ax >= ndim {
183 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
184 }
185 let axis_size = self.shape()[ax];
186
187 let normalized: Vec<usize> = indices
188 .iter()
189 .map(|&idx| normalize_index(idx, axis_size, ax))
190 .collect::<FerrayResult<Vec<_>>>()?;
191
192 let dyn_view = self.inner.clone().into_dyn();
193 let nd_axis = ndarray::Axis(ax);
194 let selected = dyn_view.select(nd_axis, &normalized);
195 Ok(Array::from_ndarray(selected))
196 }
197
198 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
200 if self.shape() != mask.shape() {
201 return Err(FerrayError::shape_mismatch(format!(
202 "boolean index mask shape {:?} does not match view shape {:?}",
203 mask.shape(),
204 self.shape()
205 )));
206 }
207
208 let data: Vec<T> = self
209 .inner
210 .iter()
211 .zip(mask.inner.iter())
212 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
213 .collect();
214
215 let len = data.len();
216 Array::from_vec(Ix1::new([len]), data)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::dimension::{Ix1, Ix2};
224
225 #[test]
230 fn index_select_rows() {
231 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
232 let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
233 assert_eq!(sel.shape(), &[3, 3]);
234 let data: Vec<i32> = sel.iter().copied().collect();
235 assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
236 }
237
238 #[test]
239 fn index_select_columns() {
240 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
241 let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
242 assert_eq!(sel.shape(), &[3, 2]);
243 let data: Vec<i32> = sel.iter().copied().collect();
244 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
245 }
246
247 #[test]
248 fn index_select_negative() {
249 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
250 let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
251 assert_eq!(sel.shape(), &[2]);
252 let data: Vec<i32> = sel.iter().copied().collect();
253 assert_eq!(data, vec![50, 30]);
254 }
255
256 #[test]
257 fn index_select_out_of_bounds() {
258 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
259 assert!(arr.index_select(Axis(0), &[3]).is_err());
260 assert!(arr.index_select(Axis(0), &[-4]).is_err());
261 }
262
263 #[test]
264 fn index_select_returns_copy() {
265 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
266 let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
267 assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
269 }
270
271 #[test]
272 fn index_select_duplicate_indices() {
273 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
274 let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
275 assert_eq!(sel.shape(), &[5]);
276 let data: Vec<i32> = sel.iter().copied().collect();
277 assert_eq!(data, vec![20, 20, 10, 30, 30]);
278 }
279
280 #[test]
281 fn index_select_empty() {
282 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
283 let sel = arr.index_select(Axis(0), &[]).unwrap();
284 assert_eq!(sel.shape(), &[0]);
285 }
286
287 #[test]
292 fn boolean_index_1d() {
293 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
294 let mask =
295 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
296 .unwrap();
297 let selected = arr.boolean_index(&mask).unwrap();
298 assert_eq!(selected.shape(), &[3]);
299 assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
300 }
301
302 #[test]
303 fn boolean_index_2d() {
304 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
305 let mask = Array::<bool, Ix2>::from_vec(
306 Ix2::new([2, 3]),
307 vec![true, false, true, false, true, false],
308 )
309 .unwrap();
310 let selected = arr.boolean_index(&mask).unwrap();
311 assert_eq!(selected.shape(), &[3]);
312 assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
313 }
314
315 #[test]
316 fn boolean_index_all_false() {
317 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
318 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
319 let selected = arr.boolean_index(&mask).unwrap();
320 assert_eq!(selected.shape(), &[0]);
321 }
322
323 #[test]
324 fn boolean_index_shape_mismatch() {
325 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
326 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
327 assert!(arr.boolean_index(&mask).is_err());
328 }
329
330 #[test]
331 fn boolean_index_returns_copy() {
332 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
333 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
334 let selected = arr.boolean_index(&mask).unwrap();
335 assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
336 }
337
338 #[test]
343 fn boolean_index_flat_2d() {
344 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
345 let mask = Array::<bool, Ix1>::from_vec(
346 Ix1::new([6]),
347 vec![false, true, false, true, false, true],
348 )
349 .unwrap();
350 let selected = arr.boolean_index_flat(&mask).unwrap();
351 assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
352 }
353
354 #[test]
355 fn boolean_index_flat_wrong_size() {
356 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
357 let mask =
358 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
359 assert!(arr.boolean_index_flat(&mask).is_err());
360 }
361
362 #[test]
367 fn boolean_assign_scalar() {
368 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
369 let mask =
370 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
371 .unwrap();
372 arr.boolean_index_assign(&mask, 0).unwrap();
373 assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
374 }
375
376 #[test]
377 fn boolean_assign_array() {
378 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
379 let mask =
380 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
381 .unwrap();
382 let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
383 arr.boolean_index_assign_array(&mask, &values).unwrap();
384 assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
385 }
386
387 #[test]
388 fn boolean_assign_array_wrong_count() {
389 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
390 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
391 let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
392 assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
393 }
394
395 #[test]
396 fn boolean_assign_2d() {
397 let mut arr =
398 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
399 let mask = Array::<bool, Ix2>::from_vec(
400 Ix2::new([2, 3]),
401 vec![false, true, false, false, true, false],
402 )
403 .unwrap();
404 arr.boolean_index_assign(&mask, -1).unwrap();
405 let data: Vec<i32> = arr.iter().copied().collect();
406 assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
407 }
408
409 #[test]
414 fn view_index_select() {
415 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
416 let v = arr.view();
417 let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
418 assert_eq!(sel.shape(), &[3, 2]);
419 let data: Vec<i32> = sel.iter().copied().collect();
420 assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
421 }
422
423 #[test]
424 fn view_boolean_index() {
425 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
426 let v = arr.view();
427 let mask =
428 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
429 let selected = v.boolean_index(&mask).unwrap();
430 assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
431 }
432}