1use super::normalize_index;
10use crate::array::owned::Array;
11use crate::array::view::ArrayView;
12use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16impl<T: Element, D: Dimension> Array<T, D> {
21 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
33 let ndim = self.ndim();
34 let ax = axis.index();
35 if ax >= ndim {
36 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
37 }
38 let axis_size = self.shape()[ax];
39
40 let normalized: Vec<usize> = indices
42 .iter()
43 .map(|&idx| normalize_index(idx, axis_size, ax))
44 .collect::<FerrayResult<Vec<_>>>()?;
45
46 let dyn_view = self.inner.view().into_dyn();
47 let nd_axis = ndarray::Axis(ax);
48 let selected = dyn_view.select(nd_axis, &normalized);
49 Ok(Array::from_ndarray(selected))
50 }
51
52 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
64 if self.shape() != mask.shape() {
65 return Err(FerrayError::shape_mismatch(format!(
66 "boolean index mask shape {:?} does not match array shape {:?}",
67 mask.shape(),
68 self.shape()
69 )));
70 }
71
72 let data: Vec<T> = self
73 .inner
74 .iter()
75 .zip(mask.inner.iter())
76 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
77 .collect();
78
79 let len = data.len();
80 Array::from_vec(Ix1::new([len]), data)
81 }
82
83 pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
90 if mask.size() != self.size() {
91 return Err(FerrayError::shape_mismatch(format!(
92 "flat boolean mask length {} does not match array size {}",
93 mask.size(),
94 self.size()
95 )));
96 }
97
98 let data: Vec<T> = self
99 .inner
100 .iter()
101 .zip(mask.inner.iter())
102 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
103 .collect();
104
105 let len = data.len();
106 Array::from_vec(Ix1::new([len]), data)
107 }
108
109 pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
116 if self.shape() != mask.shape() {
117 return Err(FerrayError::shape_mismatch(format!(
118 "boolean index mask shape {:?} does not match array shape {:?}",
119 mask.shape(),
120 self.shape()
121 )));
122 }
123
124 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
125 if m {
126 *elem = value.clone();
127 }
128 }
129 Ok(())
130 }
131
132 pub fn boolean_index_assign_array(
140 &mut self,
141 mask: &Array<bool, D>,
142 values: &Array<T, Ix1>,
143 ) -> FerrayResult<()> {
144 if self.shape() != mask.shape() {
145 return Err(FerrayError::shape_mismatch(format!(
146 "boolean index mask shape {:?} does not match array shape {:?}",
147 mask.shape(),
148 self.shape()
149 )));
150 }
151
152 let true_count = mask.inner.iter().filter(|&&m| m).count();
153 if values.size() != true_count {
154 return Err(FerrayError::shape_mismatch(format!(
155 "values array has {} elements but mask has {} true entries",
156 values.size(),
157 true_count
158 )));
159 }
160
161 let mut val_iter = values.inner.iter();
162 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
163 if m && let Some(v) = val_iter.next() {
164 *elem = v.clone();
165 }
166 }
167 Ok(())
168 }
169}
170
171impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
176 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
178 let ndim = self.ndim();
179 let ax = axis.index();
180 if ax >= ndim {
181 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
182 }
183 let axis_size = self.shape()[ax];
184
185 let normalized: Vec<usize> = indices
186 .iter()
187 .map(|&idx| normalize_index(idx, axis_size, ax))
188 .collect::<FerrayResult<Vec<_>>>()?;
189
190 let dyn_view = self.inner.clone().into_dyn();
191 let nd_axis = ndarray::Axis(ax);
192 let selected = dyn_view.select(nd_axis, &normalized);
193 Ok(Array::from_ndarray(selected))
194 }
195
196 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
198 if self.shape() != mask.shape() {
199 return Err(FerrayError::shape_mismatch(format!(
200 "boolean index mask shape {:?} does not match view shape {:?}",
201 mask.shape(),
202 self.shape()
203 )));
204 }
205
206 let data: Vec<T> = self
207 .inner
208 .iter()
209 .zip(mask.inner.iter())
210 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
211 .collect();
212
213 let len = data.len();
214 Array::from_vec(Ix1::new([len]), data)
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use super::*;
221 use crate::dimension::{Ix1, Ix2};
222
223 #[test]
228 fn index_select_rows() {
229 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
230 let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
231 assert_eq!(sel.shape(), &[3, 3]);
232 let data: Vec<i32> = sel.iter().copied().collect();
233 assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
234 }
235
236 #[test]
237 fn index_select_columns() {
238 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
239 let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
240 assert_eq!(sel.shape(), &[3, 2]);
241 let data: Vec<i32> = sel.iter().copied().collect();
242 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
243 }
244
245 #[test]
246 fn index_select_negative() {
247 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
248 let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
249 assert_eq!(sel.shape(), &[2]);
250 let data: Vec<i32> = sel.iter().copied().collect();
251 assert_eq!(data, vec![50, 30]);
252 }
253
254 #[test]
255 fn index_select_out_of_bounds() {
256 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
257 assert!(arr.index_select(Axis(0), &[3]).is_err());
258 assert!(arr.index_select(Axis(0), &[-4]).is_err());
259 }
260
261 #[test]
262 fn index_select_returns_copy() {
263 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
264 let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
265 assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
267 }
268
269 #[test]
270 fn index_select_duplicate_indices() {
271 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
272 let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
273 assert_eq!(sel.shape(), &[5]);
274 let data: Vec<i32> = sel.iter().copied().collect();
275 assert_eq!(data, vec![20, 20, 10, 30, 30]);
276 }
277
278 #[test]
279 fn index_select_empty() {
280 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
281 let sel = arr.index_select(Axis(0), &[]).unwrap();
282 assert_eq!(sel.shape(), &[0]);
283 }
284
285 #[test]
290 fn boolean_index_1d() {
291 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
292 let mask =
293 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
294 .unwrap();
295 let selected = arr.boolean_index(&mask).unwrap();
296 assert_eq!(selected.shape(), &[3]);
297 assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
298 }
299
300 #[test]
301 fn boolean_index_2d() {
302 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
303 let mask = Array::<bool, Ix2>::from_vec(
304 Ix2::new([2, 3]),
305 vec![true, false, true, false, true, false],
306 )
307 .unwrap();
308 let selected = arr.boolean_index(&mask).unwrap();
309 assert_eq!(selected.shape(), &[3]);
310 assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
311 }
312
313 #[test]
314 fn boolean_index_all_false() {
315 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
316 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
317 let selected = arr.boolean_index(&mask).unwrap();
318 assert_eq!(selected.shape(), &[0]);
319 }
320
321 #[test]
322 fn boolean_index_shape_mismatch() {
323 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
324 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
325 assert!(arr.boolean_index(&mask).is_err());
326 }
327
328 #[test]
329 fn boolean_index_returns_copy() {
330 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
331 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
332 let selected = arr.boolean_index(&mask).unwrap();
333 assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
334 }
335
336 #[test]
341 fn boolean_index_flat_2d() {
342 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
343 let mask = Array::<bool, Ix1>::from_vec(
344 Ix1::new([6]),
345 vec![false, true, false, true, false, true],
346 )
347 .unwrap();
348 let selected = arr.boolean_index_flat(&mask).unwrap();
349 assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
350 }
351
352 #[test]
353 fn boolean_index_flat_wrong_size() {
354 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
355 let mask =
356 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
357 assert!(arr.boolean_index_flat(&mask).is_err());
358 }
359
360 #[test]
365 fn boolean_assign_scalar() {
366 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
367 let mask =
368 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
369 .unwrap();
370 arr.boolean_index_assign(&mask, 0).unwrap();
371 assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
372 }
373
374 #[test]
375 fn boolean_assign_array() {
376 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
377 let mask =
378 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
379 .unwrap();
380 let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
381 arr.boolean_index_assign_array(&mask, &values).unwrap();
382 assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
383 }
384
385 #[test]
386 fn boolean_assign_array_wrong_count() {
387 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
388 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
389 let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
390 assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
391 }
392
393 #[test]
394 fn boolean_assign_2d() {
395 let mut arr =
396 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
397 let mask = Array::<bool, Ix2>::from_vec(
398 Ix2::new([2, 3]),
399 vec![false, true, false, false, true, false],
400 )
401 .unwrap();
402 arr.boolean_index_assign(&mask, -1).unwrap();
403 let data: Vec<i32> = arr.iter().copied().collect();
404 assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
405 }
406
407 #[test]
412 fn view_index_select() {
413 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
414 let v = arr.view();
415 let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
416 assert_eq!(sel.shape(), &[3, 2]);
417 let data: Vec<i32> = sel.iter().copied().collect();
418 assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
419 }
420
421 #[test]
422 fn view_boolean_index() {
423 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
424 let v = arr.view();
425 let mask =
426 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
427 let selected = v.boolean_index(&mask).unwrap();
428 assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
429 }
430}