1use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::dimension::{Axis, Dimension, Ix1, IxDyn};
12use crate::dtype::Element;
13use crate::error::{FerrayError, FerrayResult};
14
15impl<T: Element, D: Dimension> Array<T, D> {
20 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
32 let ndim = self.ndim();
33 let ax = axis.index();
34 if ax >= ndim {
35 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
36 }
37 let axis_size = self.shape()[ax];
38
39 let normalized: Vec<usize> = indices
41 .iter()
42 .map(|&idx| normalize_index_adv(idx, axis_size, ax))
43 .collect::<FerrayResult<Vec<_>>>()?;
44
45 let dyn_view = self.inner.view().into_dyn();
46 let nd_axis = ndarray::Axis(ax);
47 let selected = dyn_view.select(nd_axis, &normalized);
48 Ok(Array::from_ndarray(selected))
49 }
50
51 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
63 if self.shape() != mask.shape() {
64 return Err(FerrayError::shape_mismatch(format!(
65 "boolean index mask shape {:?} does not match array shape {:?}",
66 mask.shape(),
67 self.shape()
68 )));
69 }
70
71 let data: Vec<T> = self
72 .inner
73 .iter()
74 .zip(mask.inner.iter())
75 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
76 .collect();
77
78 let len = data.len();
79 Array::from_vec(Ix1::new([len]), data)
80 }
81
82 pub fn boolean_index_flat(&self, mask: &Array<bool, Ix1>) -> FerrayResult<Array<T, Ix1>> {
89 if mask.size() != self.size() {
90 return Err(FerrayError::shape_mismatch(format!(
91 "flat boolean mask length {} does not match array size {}",
92 mask.size(),
93 self.size()
94 )));
95 }
96
97 let data: Vec<T> = self
98 .inner
99 .iter()
100 .zip(mask.inner.iter())
101 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
102 .collect();
103
104 let len = data.len();
105 Array::from_vec(Ix1::new([len]), data)
106 }
107
108 pub fn boolean_index_assign(&mut self, mask: &Array<bool, D>, value: T) -> FerrayResult<()> {
115 if self.shape() != mask.shape() {
116 return Err(FerrayError::shape_mismatch(format!(
117 "boolean index mask shape {:?} does not match array shape {:?}",
118 mask.shape(),
119 self.shape()
120 )));
121 }
122
123 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
124 if m {
125 *elem = value.clone();
126 }
127 }
128 Ok(())
129 }
130
131 pub fn boolean_index_assign_array(
139 &mut self,
140 mask: &Array<bool, D>,
141 values: &Array<T, Ix1>,
142 ) -> FerrayResult<()> {
143 if self.shape() != mask.shape() {
144 return Err(FerrayError::shape_mismatch(format!(
145 "boolean index mask shape {:?} does not match array shape {:?}",
146 mask.shape(),
147 self.shape()
148 )));
149 }
150
151 let true_count = mask.inner.iter().filter(|&&m| m).count();
152 if values.size() != true_count {
153 return Err(FerrayError::shape_mismatch(format!(
154 "values array has {} elements but mask has {} true entries",
155 values.size(),
156 true_count
157 )));
158 }
159
160 let mut val_iter = values.inner.iter();
161 for (elem, &m) in self.inner.iter_mut().zip(mask.inner.iter()) {
162 if m {
163 if let Some(v) = val_iter.next() {
164 *elem = v.clone();
165 }
166 }
167 }
168 Ok(())
169 }
170}
171
172impl<T: Element, D: Dimension> ArrayView<'_, T, D> {
177 pub fn index_select(&self, axis: Axis, indices: &[isize]) -> FerrayResult<Array<T, IxDyn>> {
179 let ndim = self.ndim();
180 let ax = axis.index();
181 if ax >= ndim {
182 return Err(FerrayError::axis_out_of_bounds(ax, ndim));
183 }
184 let axis_size = self.shape()[ax];
185
186 let normalized: Vec<usize> = indices
187 .iter()
188 .map(|&idx| normalize_index_adv(idx, axis_size, ax))
189 .collect::<FerrayResult<Vec<_>>>()?;
190
191 let dyn_view = self.inner.clone().into_dyn();
192 let nd_axis = ndarray::Axis(ax);
193 let selected = dyn_view.select(nd_axis, &normalized);
194 Ok(Array::from_ndarray(selected))
195 }
196
197 pub fn boolean_index(&self, mask: &Array<bool, D>) -> FerrayResult<Array<T, Ix1>> {
199 if self.shape() != mask.shape() {
200 return Err(FerrayError::shape_mismatch(format!(
201 "boolean index mask shape {:?} does not match view shape {:?}",
202 mask.shape(),
203 self.shape()
204 )));
205 }
206
207 let data: Vec<T> = self
208 .inner
209 .iter()
210 .zip(mask.inner.iter())
211 .filter_map(|(val, &m)| if m { Some(val.clone()) } else { None })
212 .collect();
213
214 let len = data.len();
215 Array::from_vec(Ix1::new([len]), data)
216 }
217}
218
219fn normalize_index_adv(index: isize, size: usize, axis: usize) -> FerrayResult<usize> {
221 if index < 0 {
222 let pos = size as isize + index;
223 if pos < 0 {
224 return Err(FerrayError::index_out_of_bounds(index, axis, size));
225 }
226 Ok(pos as usize)
227 } else {
228 let idx = index as usize;
229 if idx >= size {
230 return Err(FerrayError::index_out_of_bounds(index, axis, size));
231 }
232 Ok(idx)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::dimension::{Ix1, Ix2};
240
241 #[test]
246 fn index_select_rows() {
247 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
248 let sel = arr.index_select(Axis(0), &[0, 2, 3]).unwrap();
249 assert_eq!(sel.shape(), &[3, 3]);
250 let data: Vec<i32> = sel.iter().copied().collect();
251 assert_eq!(data, vec![0, 1, 2, 6, 7, 8, 9, 10, 11]);
252 }
253
254 #[test]
255 fn index_select_columns() {
256 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
257 let sel = arr.index_select(Axis(1), &[0, 2]).unwrap();
258 assert_eq!(sel.shape(), &[3, 2]);
259 let data: Vec<i32> = sel.iter().copied().collect();
260 assert_eq!(data, vec![0, 2, 4, 6, 8, 10]);
261 }
262
263 #[test]
264 fn index_select_negative() {
265 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
266 let sel = arr.index_select(Axis(0), &[-1, -3]).unwrap();
267 assert_eq!(sel.shape(), &[2]);
268 let data: Vec<i32> = sel.iter().copied().collect();
269 assert_eq!(data, vec![50, 30]);
270 }
271
272 #[test]
273 fn index_select_out_of_bounds() {
274 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
275 assert!(arr.index_select(Axis(0), &[3]).is_err());
276 assert!(arr.index_select(Axis(0), &[-4]).is_err());
277 }
278
279 #[test]
280 fn index_select_returns_copy() {
281 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
282 let sel = arr.index_select(Axis(0), &[0, 1]).unwrap();
283 assert_ne!(sel.as_ptr() as usize, arr.as_ptr() as usize);
285 }
286
287 #[test]
288 fn index_select_duplicate_indices() {
289 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![10, 20, 30]).unwrap();
290 let sel = arr.index_select(Axis(0), &[1, 1, 0, 2, 2]).unwrap();
291 assert_eq!(sel.shape(), &[5]);
292 let data: Vec<i32> = sel.iter().copied().collect();
293 assert_eq!(data, vec![20, 20, 10, 30, 30]);
294 }
295
296 #[test]
297 fn index_select_empty() {
298 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
299 let sel = arr.index_select(Axis(0), &[]).unwrap();
300 assert_eq!(sel.shape(), &[0]);
301 }
302
303 #[test]
308 fn boolean_index_1d() {
309 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
310 let mask =
311 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
312 .unwrap();
313 let selected = arr.boolean_index(&mask).unwrap();
314 assert_eq!(selected.shape(), &[3]);
315 assert_eq!(selected.as_slice().unwrap(), &[10, 30, 50]);
316 }
317
318 #[test]
319 fn boolean_index_2d() {
320 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
321 let mask = Array::<bool, Ix2>::from_vec(
322 Ix2::new([2, 3]),
323 vec![true, false, true, false, true, false],
324 )
325 .unwrap();
326 let selected = arr.boolean_index(&mask).unwrap();
327 assert_eq!(selected.shape(), &[3]);
328 assert_eq!(selected.as_slice().unwrap(), &[1, 3, 5]);
329 }
330
331 #[test]
332 fn boolean_index_all_false() {
333 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
334 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![false, false, false]).unwrap();
335 let selected = arr.boolean_index(&mask).unwrap();
336 assert_eq!(selected.shape(), &[0]);
337 }
338
339 #[test]
340 fn boolean_index_shape_mismatch() {
341 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
342 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([2]), vec![true, false]).unwrap();
343 assert!(arr.boolean_index(&mask).is_err());
344 }
345
346 #[test]
347 fn boolean_index_returns_copy() {
348 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
349 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, true]).unwrap();
350 let selected = arr.boolean_index(&mask).unwrap();
351 assert_ne!(selected.as_ptr() as usize, arr.as_ptr() as usize);
352 }
353
354 #[test]
359 fn boolean_index_flat_2d() {
360 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
361 let mask = Array::<bool, Ix1>::from_vec(
362 Ix1::new([6]),
363 vec![false, true, false, true, false, true],
364 )
365 .unwrap();
366 let selected = arr.boolean_index_flat(&mask).unwrap();
367 assert_eq!(selected.as_slice().unwrap(), &[2, 4, 6]);
368 }
369
370 #[test]
371 fn boolean_index_flat_wrong_size() {
372 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
373 let mask =
374 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, true, false]).unwrap();
375 assert!(arr.boolean_index_flat(&mask).is_err());
376 }
377
378 #[test]
383 fn boolean_assign_scalar() {
384 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
385 let mask =
386 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![true, false, true, false, true])
387 .unwrap();
388 arr.boolean_index_assign(&mask, 0).unwrap();
389 assert_eq!(arr.as_slice().unwrap(), &[0, 2, 0, 4, 0]);
390 }
391
392 #[test]
393 fn boolean_assign_array() {
394 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
395 let mask =
396 Array::<bool, Ix1>::from_vec(Ix1::new([5]), vec![false, true, false, true, false])
397 .unwrap();
398 let values = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![99, 88]).unwrap();
399 arr.boolean_index_assign_array(&mask, &values).unwrap();
400 assert_eq!(arr.as_slice().unwrap(), &[1, 99, 3, 88, 5]);
401 }
402
403 #[test]
404 fn boolean_assign_array_wrong_count() {
405 let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
406 let mask = Array::<bool, Ix1>::from_vec(Ix1::new([3]), vec![true, true, false]).unwrap();
407 let values = Array::<i32, Ix1>::from_vec(Ix1::new([1]), vec![99]).unwrap();
408 assert!(arr.boolean_index_assign_array(&mask, &values).is_err());
409 }
410
411 #[test]
412 fn boolean_assign_2d() {
413 let mut arr =
414 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
415 let mask = Array::<bool, Ix2>::from_vec(
416 Ix2::new([2, 3]),
417 vec![false, true, false, false, true, false],
418 )
419 .unwrap();
420 arr.boolean_index_assign(&mask, -1).unwrap();
421 let data: Vec<i32> = arr.iter().copied().collect();
422 assert_eq!(data, vec![1, -1, 3, 4, -1, 6]);
423 }
424
425 #[test]
430 fn view_index_select() {
431 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
432 let v = arr.view();
433 let sel = v.index_select(Axis(1), &[0, 3]).unwrap();
434 assert_eq!(sel.shape(), &[3, 2]);
435 let data: Vec<i32> = sel.iter().copied().collect();
436 assert_eq!(data, vec![0, 3, 4, 7, 8, 11]);
437 }
438
439 #[test]
440 fn view_boolean_index() {
441 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), vec![10, 20, 30, 40]).unwrap();
442 let v = arr.view();
443 let mask =
444 Array::<bool, Ix1>::from_vec(Ix1::new([4]), vec![true, false, false, true]).unwrap();
445 let selected = v.boolean_index(&mask).unwrap();
446 assert_eq!(selected.as_slice().unwrap(), &[10, 40]);
447 }
448}