Skip to main content

ferray_core/indexing/
basic.rs

1// ferray-core: Basic indexing (REQ-12, REQ-14)
2//
3// Integer + slice indexing returning views (zero-copy).
4// insert_axis / remove_axis for dimension manipulation.
5//
6// These are implemented as methods on Array, ArrayView, ArrayViewMut.
7// The s![] macro is out of scope (Agent 1d).
8
9use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16/// Normalize a potentially negative index to a positive one.
17///
18/// Negative indices count from the end: -1 is the last element, -2 is
19/// second-to-last, etc. Returns `Err` if the normalized index is out of
20/// bounds.
21fn normalize_index(index: isize, size: usize, axis: usize) -> FerrayResult<usize> {
22    let normalized = if index < 0 {
23        let pos = size as isize + index;
24        if pos < 0 {
25            return Err(FerrayError::index_out_of_bounds(index, axis, size));
26        }
27        pos as usize
28    } else {
29        let idx = index as usize;
30        if idx >= size {
31            return Err(FerrayError::index_out_of_bounds(index, axis, size));
32        }
33        idx
34    };
35    Ok(normalized)
36}
37
38/// A slice specification for one axis, mirroring Python's `start:stop:step`.
39///
40/// All fields are optional (represented by `None`), matching NumPy behaviour:
41/// - `start`: defaults to 0 (or end-1 if step < 0)
42/// - `stop`: defaults to size (or before-start if step < 0)
43/// - `step`: defaults to 1; must not be 0
44#[derive(Debug, Clone, Copy)]
45pub struct SliceSpec {
46    /// Start index (inclusive). Negative counts from end.
47    pub start: Option<isize>,
48    /// Stop index (exclusive). Negative counts from end.
49    pub stop: Option<isize>,
50    /// Step size. Must not be zero.
51    pub step: Option<isize>,
52}
53
54impl SliceSpec {
55    /// Create a new full-range slice (equivalent to `:`).
56    pub fn full() -> Self {
57        Self {
58            start: None,
59            stop: None,
60            step: None,
61        }
62    }
63
64    /// Create a slice `start:stop` with step 1.
65    pub fn new(start: isize, stop: isize) -> Self {
66        Self {
67            start: Some(start),
68            stop: Some(stop),
69            step: None,
70        }
71    }
72
73    /// Create a slice `start:stop:step`.
74    pub fn with_step(start: isize, stop: isize, step: isize) -> Self {
75        Self {
76            start: Some(start),
77            stop: Some(stop),
78            step: Some(step),
79        }
80    }
81
82    /// Validate that the step is not zero.
83    fn validate(&self) -> FerrayResult<()> {
84        if let Some(0) = self.step {
85            return Err(FerrayError::invalid_value("slice step cannot be zero"));
86        }
87        Ok(())
88    }
89
90    /// Convert to an ndarray Slice.
91    #[allow(clippy::wrong_self_convention)]
92    fn to_ndarray_slice(&self) -> ndarray::Slice {
93        ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
94    }
95
96    /// Convert to an ndarray SliceInfoElem (used by s![] macro integration).
97    #[allow(dead_code, clippy::wrong_self_convention)]
98    pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
99        ndarray::SliceInfoElem::Slice {
100            start: self.start.unwrap_or(0),
101            end: self.stop,
102            step: self.step.unwrap_or(1),
103        }
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Array methods — basic indexing
109// ---------------------------------------------------------------------------
110
111impl<T: Element, D: Dimension> Array<T, D> {
112    /// Index into the array along a given axis, removing that axis.
113    ///
114    /// Equivalent to `a[i]` for axis 0, or `a[:, i]` for axis 1, etc.
115    /// Returns a view with one fewer dimension (dynamic-rank).
116    ///
117    /// # Errors
118    /// - `AxisOutOfBounds` if `axis >= ndim`
119    /// - `IndexOutOfBounds` if `index` is out of range (supports negative)
120    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, IxDyn>>
121    where
122        D::NdarrayDim: ndarray::RemoveAxis,
123    {
124        let ndim = self.ndim();
125        let ax = axis.index();
126        if ax >= ndim {
127            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
128        }
129        let size = self.shape()[ax];
130        let idx = normalize_index(index, size, ax)?;
131
132        let nd_axis = ndarray::Axis(ax);
133        let sub = self.inner.index_axis(nd_axis, idx);
134        let dyn_view = sub.into_dyn();
135        Ok(ArrayView::from_ndarray(dyn_view))
136    }
137
138    /// Slice the array along a given axis, returning a view.
139    ///
140    /// The returned view shares data with the source (zero-copy).
141    ///
142    /// # Errors
143    /// - `AxisOutOfBounds` if `axis >= ndim`
144    /// - `InvalidValue` if step is zero
145    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
146        let ndim = self.ndim();
147        let ax = axis.index();
148        if ax >= ndim {
149            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
150        }
151        spec.validate()?;
152
153        let nd_axis = ndarray::Axis(ax);
154        let nd_slice = spec.to_ndarray_slice();
155        let sliced = self.inner.slice_axis(nd_axis, nd_slice);
156        let dyn_view = sliced.into_dyn();
157        Ok(ArrayView::from_ndarray(dyn_view))
158    }
159
160    /// Slice the array along a given axis, returning a mutable view.
161    ///
162    /// # Errors
163    /// Same as [`slice_axis`](Self::slice_axis).
164    pub fn slice_axis_mut(
165        &mut self,
166        axis: Axis,
167        spec: SliceSpec,
168    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
169        let ndim = self.ndim();
170        let ax = axis.index();
171        if ax >= ndim {
172            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
173        }
174        spec.validate()?;
175
176        let nd_axis = ndarray::Axis(ax);
177        let nd_slice = spec.to_ndarray_slice();
178        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
179        let dyn_view = sliced.into_dyn();
180        Ok(ArrayViewMut::from_ndarray(dyn_view))
181    }
182
183    /// Multi-axis slicing: apply a slice specification to each axis.
184    ///
185    /// `specs` must have length equal to `ndim()`. For axes you don't
186    /// want to slice, pass `SliceSpec::full()`.
187    ///
188    /// # Errors
189    /// - `InvalidValue` if `specs.len() != ndim()`
190    /// - Any errors from individual axis slicing
191    pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
192        let ndim = self.ndim();
193        if specs.len() != ndim {
194            return Err(FerrayError::invalid_value(format!(
195                "expected {} slice specs, got {}",
196                ndim,
197                specs.len()
198            )));
199        }
200
201        for spec in specs {
202            spec.validate()?;
203        }
204
205        // Apply axis-by-axis slicing using move variants to preserve lifetimes
206        let mut result = self.inner.view().into_dyn();
207        for (ax, spec) in specs.iter().enumerate() {
208            let nd_axis = ndarray::Axis(ax);
209            let nd_slice = spec.to_ndarray_slice();
210            result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
211        }
212        Ok(ArrayView::from_ndarray(result))
213    }
214
215    /// Insert a new axis of length 1 at the given position.
216    ///
217    /// This is equivalent to `np.expand_dims` or `np.newaxis`.
218    /// Returns a dynamic-rank view with one more dimension.
219    ///
220    /// # Errors
221    /// - `AxisOutOfBounds` if `axis > ndim`
222    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
223        let ndim = self.ndim();
224        let ax = axis.index();
225        if ax > ndim {
226            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
227        }
228
229        let dyn_view = self.inner.view().into_dyn();
230        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
231        Ok(ArrayView::from_ndarray(expanded))
232    }
233
234    /// Remove an axis of length 1.
235    ///
236    /// This is equivalent to `np.squeeze` for a single axis.
237    /// Returns a dynamic-rank view with one fewer dimension.
238    ///
239    /// # Errors
240    /// - `AxisOutOfBounds` if `axis >= ndim`
241    /// - `InvalidValue` if the axis has size != 1
242    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
243        let ndim = self.ndim();
244        let ax = axis.index();
245        if ax >= ndim {
246            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
247        }
248        if self.shape()[ax] != 1 {
249            return Err(FerrayError::invalid_value(format!(
250                "cannot remove axis {} with size {} (must be 1)",
251                ax,
252                self.shape()[ax]
253            )));
254        }
255
256        // index_axis_move at 0 removes the axis (consumes the view, preserving lifetime)
257        let dyn_view = self.inner.view().into_dyn();
258        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
259        Ok(ArrayView::from_ndarray(squeezed))
260    }
261
262    /// Index into the array with a flat (linear) index.
263    ///
264    /// Elements are ordered in row-major (C) order.
265    ///
266    /// # Errors
267    /// Returns `IndexOutOfBounds` if the index is out of range.
268    pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
269        let size = self.size();
270        let idx = normalize_index(index, size, 0)?;
271        self.inner
272            .iter()
273            .nth(idx)
274            .ok_or_else(|| FerrayError::index_out_of_bounds(index, 0, size))
275    }
276
277    /// Get a reference to a single element by multi-dimensional index.
278    ///
279    /// Supports negative indices (counting from end).
280    ///
281    /// # Errors
282    /// - `InvalidValue` if `indices.len() != ndim()`
283    /// - `IndexOutOfBounds` if any index is out of range
284    pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
285        let ndim = self.ndim();
286        if indices.len() != ndim {
287            return Err(FerrayError::invalid_value(format!(
288                "expected {} indices, got {}",
289                ndim,
290                indices.len()
291            )));
292        }
293
294        // Compute the flat offset manually
295        let shape = self.shape();
296        let strides = self.inner.strides();
297        let base_ptr = self.inner.as_ptr();
298
299        let mut offset: isize = 0;
300        for (ax, &idx) in indices.iter().enumerate() {
301            let pos = normalize_index(idx, shape[ax], ax)?;
302            offset += pos as isize * strides[ax];
303        }
304
305        // SAFETY: all indices are validated in-bounds, so the computed
306        // offset is within the array's data allocation.
307        Ok(unsafe { &*base_ptr.offset(offset) })
308    }
309
310    /// Get a mutable reference to a single element by multi-dimensional index.
311    ///
312    /// # Errors
313    /// Same as [`get`](Self::get).
314    pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
315        let ndim = self.ndim();
316        if indices.len() != ndim {
317            return Err(FerrayError::invalid_value(format!(
318                "expected {} indices, got {}",
319                ndim,
320                indices.len()
321            )));
322        }
323
324        let shape = self.shape().to_vec();
325        let strides: Vec<isize> = self.inner.strides().to_vec();
326        let base_ptr = self.inner.as_mut_ptr();
327
328        let mut offset: isize = 0;
329        for (ax, &idx) in indices.iter().enumerate() {
330            let pos = normalize_index(idx, shape[ax], ax)?;
331            offset += pos as isize * strides[ax];
332        }
333
334        // SAFETY: we have &mut self so exclusive access is guaranteed,
335        // and all indices are validated in-bounds.
336        Ok(unsafe { &mut *base_ptr.offset(offset) })
337    }
338}
339
340// ---------------------------------------------------------------------------
341// ArrayView methods — basic indexing
342// ---------------------------------------------------------------------------
343
344impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
345    /// Index into the view along a given axis, removing that axis.
346    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, IxDyn>>
347    where
348        D::NdarrayDim: ndarray::RemoveAxis,
349    {
350        let ndim = self.ndim();
351        let ax = axis.index();
352        if ax >= ndim {
353            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
354        }
355        let size = self.shape()[ax];
356        let idx = normalize_index(index, size, ax)?;
357
358        let nd_axis = ndarray::Axis(ax);
359        // clone() on ArrayView is cheap (it's Copy-like)
360        let sub = self.inner.clone().index_axis_move(nd_axis, idx);
361        let dyn_view = sub.into_dyn();
362        Ok(ArrayView::from_ndarray(dyn_view))
363    }
364
365    /// Slice the view along a given axis.
366    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
367        let ndim = self.ndim();
368        let ax = axis.index();
369        if ax >= ndim {
370            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
371        }
372        spec.validate()?;
373
374        let nd_axis = ndarray::Axis(ax);
375        let nd_slice = spec.to_ndarray_slice();
376        // slice_axis on a cloned view preserves the 'a lifetime
377        let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
378        let dyn_view = sliced.into_dyn();
379        Ok(ArrayView::from_ndarray(dyn_view))
380    }
381
382    /// Insert a new axis of length 1 at the given position.
383    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
384        let ndim = self.ndim();
385        let ax = axis.index();
386        if ax > ndim {
387            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
388        }
389
390        let dyn_view = self.inner.clone().into_dyn();
391        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
392        Ok(ArrayView::from_ndarray(expanded))
393    }
394
395    /// Remove an axis of length 1.
396    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
397        let ndim = self.ndim();
398        let ax = axis.index();
399        if ax >= ndim {
400            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
401        }
402        if self.shape()[ax] != 1 {
403            return Err(FerrayError::invalid_value(format!(
404                "cannot remove axis {} with size {} (must be 1)",
405                ax,
406                self.shape()[ax]
407            )));
408        }
409
410        let dyn_view = self.inner.clone().into_dyn();
411        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
412        Ok(ArrayView::from_ndarray(squeezed))
413    }
414
415    /// Get a reference to a single element by multi-dimensional index.
416    pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
417        let ndim = self.ndim();
418        if indices.len() != ndim {
419            return Err(FerrayError::invalid_value(format!(
420                "expected {} indices, got {}",
421                ndim,
422                indices.len()
423            )));
424        }
425
426        let shape = self.shape();
427        let strides = self.inner.strides();
428        let base_ptr = self.inner.as_ptr();
429
430        let mut offset: isize = 0;
431        for (ax, &idx) in indices.iter().enumerate() {
432            let pos = normalize_index(idx, shape[ax], ax)?;
433            offset += pos as isize * strides[ax];
434        }
435
436        // SAFETY: indices validated in-bounds; the pointer is valid for 'a.
437        Ok(unsafe { &*base_ptr.offset(offset) })
438    }
439}
440
441// ---------------------------------------------------------------------------
442// ArrayViewMut methods — basic indexing
443// ---------------------------------------------------------------------------
444
445impl<'a, T: Element, D: Dimension> ArrayViewMut<'a, T, D> {
446    /// Slice the mutable view along a given axis.
447    pub fn slice_axis_mut(
448        &mut self,
449        axis: Axis,
450        spec: SliceSpec,
451    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
452        let ndim = self.ndim();
453        let ax = axis.index();
454        if ax >= ndim {
455            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
456        }
457        spec.validate()?;
458
459        let nd_axis = ndarray::Axis(ax);
460        let nd_slice = spec.to_ndarray_slice();
461        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
462        let dyn_view = sliced.into_dyn();
463        Ok(ArrayViewMut::from_ndarray(dyn_view))
464    }
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use crate::dimension::{Ix1, Ix2, Ix3};
471
472    // -----------------------------------------------------------------------
473    // Normalization
474    // -----------------------------------------------------------------------
475
476    #[test]
477    fn normalize_positive_in_bounds() {
478        assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
479    }
480
481    #[test]
482    fn normalize_negative() {
483        assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
484        assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
485    }
486
487    #[test]
488    fn normalize_out_of_bounds() {
489        assert!(normalize_index(5, 5, 0).is_err());
490        assert!(normalize_index(-6, 5, 0).is_err());
491    }
492
493    // -----------------------------------------------------------------------
494    // index_axis
495    // -----------------------------------------------------------------------
496
497    #[test]
498    fn index_axis_row() {
499        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
500        let row = arr.index_axis(Axis(0), 1).unwrap();
501        assert_eq!(row.shape(), &[4]);
502        let data: Vec<i32> = row.iter().copied().collect();
503        assert_eq!(data, vec![4, 5, 6, 7]);
504    }
505
506    #[test]
507    fn index_axis_column() {
508        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
509        let col = arr.index_axis(Axis(1), 2).unwrap();
510        assert_eq!(col.shape(), &[3]);
511        let data: Vec<i32> = col.iter().copied().collect();
512        assert_eq!(data, vec![2, 6, 10]);
513    }
514
515    #[test]
516    fn index_axis_negative() {
517        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
518        let row = arr.index_axis(Axis(0), -1).unwrap();
519        let data: Vec<i32> = row.iter().copied().collect();
520        assert_eq!(data, vec![8, 9, 10, 11]);
521    }
522
523    #[test]
524    fn index_axis_out_of_bounds() {
525        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
526        assert!(arr.index_axis(Axis(0), 3).is_err());
527        assert!(arr.index_axis(Axis(2), 0).is_err());
528    }
529
530    #[test]
531    fn index_axis_is_zero_copy() {
532        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
533            .unwrap();
534        let row = arr.index_axis(Axis(0), 0).unwrap();
535        assert_eq!(row.as_ptr(), arr.as_ptr());
536    }
537
538    // -----------------------------------------------------------------------
539    // slice_axis
540    // -----------------------------------------------------------------------
541
542    #[test]
543    fn slice_axis_basic() {
544        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
545        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
546        assert_eq!(sliced.shape(), &[3]);
547        let data: Vec<i32> = sliced.iter().copied().collect();
548        assert_eq!(data, vec![20, 30, 40]);
549    }
550
551    #[test]
552    fn slice_axis_step() {
553        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
554        let sliced = arr
555            .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
556            .unwrap();
557        assert_eq!(sliced.shape(), &[3]);
558        let data: Vec<i32> = sliced.iter().copied().collect();
559        assert_eq!(data, vec![0, 2, 4]);
560    }
561
562    #[test]
563    fn slice_axis_negative_step() {
564        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
565        // Reverse the entire array: range [0, len) traversed backwards
566        // ndarray interprets Slice::new(start, end, step) where start/end define
567        // a forward range and negative step reverses traversal within it.
568        let spec = SliceSpec {
569            start: None,
570            stop: None,
571            step: Some(-1),
572        };
573        let sliced = arr.slice_axis(Axis(0), spec).unwrap();
574        let data: Vec<i32> = sliced.iter().copied().collect();
575        assert_eq!(data, vec![4, 3, 2, 1, 0]);
576    }
577
578    #[test]
579    fn slice_axis_negative_step_partial() {
580        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
581        // Range [1, 4) traversed backwards with step -1: [3, 2, 1]
582        let sliced = arr
583            .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
584            .unwrap();
585        let data: Vec<i32> = sliced.iter().copied().collect();
586        assert_eq!(data, vec![3, 2, 1]);
587    }
588
589    #[test]
590    fn slice_axis_full() {
591        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
592        let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
593        let data: Vec<i32> = sliced.iter().copied().collect();
594        assert_eq!(data, vec![1, 2, 3]);
595    }
596
597    #[test]
598    fn slice_axis_2d_rows() {
599        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
600        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
601        assert_eq!(sliced.shape(), &[2, 3]);
602        let data: Vec<i32> = sliced.iter().copied().collect();
603        assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
604    }
605
606    #[test]
607    fn slice_axis_is_zero_copy() {
608        let arr =
609            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
610        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
611        unsafe {
612            assert_eq!(*sliced.as_ptr(), 2.0);
613        }
614    }
615
616    #[test]
617    fn slice_axis_zero_step_error() {
618        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
619        assert!(
620            arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
621                .is_err()
622        );
623    }
624
625    // -----------------------------------------------------------------------
626    // slice_multi
627    // -----------------------------------------------------------------------
628
629    #[test]
630    fn slice_multi_2d() {
631        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
632        let sliced = arr
633            .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
634            .unwrap();
635        assert_eq!(sliced.shape(), &[2, 4]);
636    }
637
638    #[test]
639    fn slice_multi_wrong_count() {
640        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
641        assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
642    }
643
644    // -----------------------------------------------------------------------
645    // insert_axis / remove_axis
646    // -----------------------------------------------------------------------
647
648    #[test]
649    fn insert_axis_at_front() {
650        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
651        let expanded = arr.insert_axis(Axis(0)).unwrap();
652        assert_eq!(expanded.shape(), &[1, 3]);
653    }
654
655    #[test]
656    fn insert_axis_at_end() {
657        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
658        let expanded = arr.insert_axis(Axis(1)).unwrap();
659        assert_eq!(expanded.shape(), &[3, 1]);
660    }
661
662    #[test]
663    fn insert_axis_out_of_bounds() {
664        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
665        assert!(arr.insert_axis(Axis(3)).is_err());
666    }
667
668    #[test]
669    fn remove_axis_single() {
670        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
671        let squeezed = arr.remove_axis(Axis(0)).unwrap();
672        assert_eq!(squeezed.shape(), &[3]);
673        let data: Vec<f64> = squeezed.iter().copied().collect();
674        assert_eq!(data, vec![1.0, 2.0, 3.0]);
675    }
676
677    #[test]
678    fn remove_axis_not_one() {
679        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
680        assert!(arr.remove_axis(Axis(0)).is_err());
681    }
682
683    #[test]
684    fn remove_axis_out_of_bounds() {
685        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
686        assert!(arr.remove_axis(Axis(1)).is_err());
687    }
688
689    // -----------------------------------------------------------------------
690    // flat_index
691    // -----------------------------------------------------------------------
692
693    #[test]
694    fn flat_index_positive() {
695        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
696        assert_eq!(*arr.flat_index(0).unwrap(), 1);
697        assert_eq!(*arr.flat_index(3).unwrap(), 4);
698        assert_eq!(*arr.flat_index(5).unwrap(), 6);
699    }
700
701    #[test]
702    fn flat_index_negative() {
703        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
704        assert_eq!(*arr.flat_index(-1).unwrap(), 50);
705        assert_eq!(*arr.flat_index(-5).unwrap(), 10);
706    }
707
708    #[test]
709    fn flat_index_out_of_bounds() {
710        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
711        assert!(arr.flat_index(3).is_err());
712        assert!(arr.flat_index(-4).is_err());
713    }
714
715    // -----------------------------------------------------------------------
716    // get / get_mut
717    // -----------------------------------------------------------------------
718
719    #[test]
720    fn get_2d() {
721        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
722        assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
723        assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
724        assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
725    }
726
727    #[test]
728    fn get_negative_indices() {
729        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
730        assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
731        assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
732    }
733
734    #[test]
735    fn get_wrong_ndim() {
736        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
737        assert!(arr.get(&[0]).is_err());
738        assert!(arr.get(&[0, 0, 0]).is_err());
739    }
740
741    #[test]
742    fn get_mut_modify() {
743        let mut arr =
744            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
745        *arr.get_mut(&[1, 2]).unwrap() = 99;
746        assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
747    }
748
749    // -----------------------------------------------------------------------
750    // ArrayView basic indexing
751    // -----------------------------------------------------------------------
752
753    #[test]
754    fn view_index_axis() {
755        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
756        let v = arr.view();
757        let row = v.index_axis(Axis(0), 1).unwrap();
758        let data: Vec<i32> = row.iter().copied().collect();
759        assert_eq!(data, vec![4, 5, 6, 7]);
760    }
761
762    #[test]
763    fn view_slice_axis() {
764        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
765        let v = arr.view();
766        let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
767        let data: Vec<i32> = sliced.iter().copied().collect();
768        assert_eq!(data, vec![20, 30, 40]);
769    }
770
771    #[test]
772    fn view_insert_remove_axis() {
773        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
774        let v = arr.view();
775        let expanded = v.insert_axis(Axis(0)).unwrap();
776        assert_eq!(expanded.shape(), &[1, 4]);
777        let squeezed = expanded.remove_axis(Axis(0)).unwrap();
778        assert_eq!(squeezed.shape(), &[4]);
779    }
780
781    #[test]
782    fn view_get() {
783        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
784        let v = arr.view();
785        assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
786    }
787
788    // -----------------------------------------------------------------------
789    // ArrayViewMut slice
790    // -----------------------------------------------------------------------
791
792    #[test]
793    fn view_mut_slice_axis() {
794        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
795        {
796            let mut vm = arr.view_mut();
797            let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
798            if let Some(s) = sliced.as_slice_mut() {
799                s[0] = 20;
800                s[1] = 30;
801            }
802        }
803        assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
804    }
805
806    // -----------------------------------------------------------------------
807    // 3D indexing
808    // -----------------------------------------------------------------------
809
810    #[test]
811    fn index_axis_3d() {
812        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
813        let plane = arr.index_axis(Axis(0), 1).unwrap();
814        assert_eq!(plane.shape(), &[3, 4]);
815        assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
816    }
817}