Skip to main content

ferray_core/indexing/
basic.rs

1// ferray-core: Basic indexing (REQ-12, REQ-14)
2//
3// Integer + slice indexing returning views (zero-copy).
4// insert_axis / remove_axis for dimension manipulation.
5//
6// These are implemented as methods on Array, ArrayView, ArrayViewMut.
7// The s![] macro is out of scope (Agent 1d).
8
9use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16use super::normalize_index;
17
18/// A slice specification for one axis, mirroring Python's `start:stop:step`.
19///
20/// All fields are optional (represented by `None`), matching `NumPy` behaviour:
21/// - `start`: defaults to 0 (or end-1 if step < 0)
22/// - `stop`: defaults to size (or before-start if step < 0)
23/// - `step`: defaults to 1; must not be 0
24#[derive(Debug, Clone, Copy)]
25pub struct SliceSpec {
26    /// Start index (inclusive). Negative counts from end.
27    pub start: Option<isize>,
28    /// Stop index (exclusive). Negative counts from end.
29    pub stop: Option<isize>,
30    /// Step size. Must not be zero.
31    pub step: Option<isize>,
32}
33
34impl SliceSpec {
35    /// Create a new full-range slice (equivalent to `:`).
36    #[must_use]
37    pub const fn full() -> Self {
38        Self {
39            start: None,
40            stop: None,
41            step: None,
42        }
43    }
44
45    /// Create a slice `start:stop` with step 1.
46    #[must_use]
47    pub const fn new(start: isize, stop: isize) -> Self {
48        Self {
49            start: Some(start),
50            stop: Some(stop),
51            step: None,
52        }
53    }
54
55    /// Create a slice `start:stop:step`.
56    #[must_use]
57    pub const fn with_step(start: isize, stop: isize, step: isize) -> Self {
58        Self {
59            start: Some(start),
60            stop: Some(stop),
61            step: Some(step),
62        }
63    }
64
65    /// Validate that the step is not zero.
66    fn validate(&self) -> FerrayResult<()> {
67        if self.step == Some(0) {
68            return Err(FerrayError::invalid_value("slice step cannot be zero"));
69        }
70        Ok(())
71    }
72
73    /// Convert to an ndarray Slice.
74    #[allow(clippy::wrong_self_convention)]
75    fn to_ndarray_slice(&self) -> ndarray::Slice {
76        ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
77    }
78
79    /// Convert to an ndarray `SliceInfoElem` (used by s![] macro integration).
80    #[allow(dead_code, clippy::wrong_self_convention)]
81    pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
82        ndarray::SliceInfoElem::Slice {
83            start: self.start.unwrap_or(0),
84            end: self.stop,
85            step: self.step.unwrap_or(1),
86        }
87    }
88}
89
90// ---------------------------------------------------------------------------
91// Array methods — basic indexing
92// ---------------------------------------------------------------------------
93
94impl<T: Element, D: Dimension> Array<T, D> {
95    /// Index into the array along a given axis, removing that axis.
96    ///
97    /// Equivalent to `a[i]` for axis 0, or `a[:, i]` for axis 1, etc.
98    /// Returns a view with one fewer dimension (dynamic-rank).
99    ///
100    /// # Errors
101    /// - `AxisOutOfBounds` if `axis >= ndim`
102    /// - `IndexOutOfBounds` if `index` is out of range (supports negative)
103    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, D::Smaller>>
104    where
105        D::NdarrayDim:
106            ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
107    {
108        // #349: return type now preserves rank — Array<T, Ix2>::index_axis
109        // produces ArrayView<T, Ix1> instead of the old IxDyn fallback.
110        let ndim = self.ndim();
111        let ax = axis.index();
112        if ax >= ndim {
113            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
114        }
115        let size = self.shape()[ax];
116        let idx = normalize_index(index, size, ax)?;
117
118        let nd_axis = ndarray::Axis(ax);
119        let sub = self.inner.index_axis(nd_axis, idx);
120        Ok(ArrayView::from_ndarray(sub))
121    }
122
123    /// Slice the array along a given axis, returning a view.
124    ///
125    /// The returned view shares data with the source (zero-copy).
126    ///
127    /// # Errors
128    /// - `AxisOutOfBounds` if `axis >= ndim`
129    /// - `InvalidValue` if step is zero
130    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
131        let ndim = self.ndim();
132        let ax = axis.index();
133        if ax >= ndim {
134            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
135        }
136        spec.validate()?;
137
138        let nd_axis = ndarray::Axis(ax);
139        let nd_slice = spec.to_ndarray_slice();
140        let sliced = self.inner.slice_axis(nd_axis, nd_slice);
141        let dyn_view = sliced.into_dyn();
142        Ok(ArrayView::from_ndarray(dyn_view))
143    }
144
145    /// Slice the array along a given axis, returning a mutable view.
146    ///
147    /// # Errors
148    /// Same as [`slice_axis`](Self::slice_axis).
149    pub fn slice_axis_mut(
150        &mut self,
151        axis: Axis,
152        spec: SliceSpec,
153    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
154        let ndim = self.ndim();
155        let ax = axis.index();
156        if ax >= ndim {
157            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
158        }
159        spec.validate()?;
160
161        let nd_axis = ndarray::Axis(ax);
162        let nd_slice = spec.to_ndarray_slice();
163        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
164        let dyn_view = sliced.into_dyn();
165        Ok(ArrayViewMut::from_ndarray(dyn_view))
166    }
167
168    /// Multi-axis slicing: apply a slice specification to each axis.
169    ///
170    /// `specs` must have length equal to `ndim()`. For axes you don't
171    /// want to slice, pass `SliceSpec::full()`.
172    ///
173    /// # Errors
174    /// - `InvalidValue` if `specs.len() != ndim()`
175    /// - Any errors from individual axis slicing
176    pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
177        let ndim = self.ndim();
178        if specs.len() != ndim {
179            return Err(FerrayError::invalid_value(format!(
180                "expected {} slice specs, got {}",
181                ndim,
182                specs.len()
183            )));
184        }
185
186        for spec in specs {
187            spec.validate()?;
188        }
189
190        // Apply axis-by-axis slicing using move variants to preserve lifetimes
191        let mut result = self.inner.view().into_dyn();
192        for (ax, spec) in specs.iter().enumerate() {
193            let nd_axis = ndarray::Axis(ax);
194            let nd_slice = spec.to_ndarray_slice();
195            result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
196        }
197        Ok(ArrayView::from_ndarray(result))
198    }
199
200    /// Insert a new axis of length 1 at the given position.
201    ///
202    /// This is equivalent to `np.expand_dims` or `np.newaxis`.
203    /// Returns a view with rank one greater than `Self` (#349) —
204    /// `Array<T, Ix2>::insert_axis` produces `ArrayView<T, Ix3>`.
205    ///
206    /// # Errors
207    /// - `AxisOutOfBounds` if `axis > ndim`
208    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, D::Larger>>
209    where
210        D::NdarrayDim:
211            ndarray::Dimension<Larger = <D::Larger as crate::dimension::Dimension>::NdarrayDim>,
212    {
213        let ndim = self.ndim();
214        let ax = axis.index();
215        if ax > ndim {
216            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
217        }
218        let expanded = self.inner.view().insert_axis(ndarray::Axis(ax));
219        Ok(ArrayView::from_ndarray(expanded))
220    }
221
222    /// Remove an axis of length 1.
223    ///
224    /// This is equivalent to `np.squeeze` for a single axis.
225    /// Returns a view with rank one less than `Self` (#349).
226    ///
227    /// # Errors
228    /// - `AxisOutOfBounds` if `axis >= ndim`
229    /// - `InvalidValue` if the axis has size != 1
230    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, D::Smaller>>
231    where
232        D::NdarrayDim:
233            ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
234    {
235        let ndim = self.ndim();
236        let ax = axis.index();
237        if ax >= ndim {
238            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
239        }
240        if self.shape()[ax] != 1 {
241            return Err(FerrayError::invalid_value(format!(
242                "cannot remove axis {} with size {} (must be 1)",
243                ax,
244                self.shape()[ax]
245            )));
246        }
247        let squeezed = self.inner.view().index_axis_move(ndarray::Axis(ax), 0);
248        Ok(ArrayView::from_ndarray(squeezed))
249    }
250
251    /// Index into the array with a flat (linear) index.
252    ///
253    /// Elements are ordered in row-major (C) order. The flat index is
254    /// converted to a multi-dimensional index via unravel, then stride
255    /// arithmetic computes the physical offset in O(ndim) time.
256    ///
257    /// # Errors
258    /// Returns `IndexOutOfBounds` if the index is out of range.
259    pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
260        let size = self.size();
261        let idx = normalize_index(index, size, 0)?;
262
263        // Unravel flat index to multi-dim index in row-major order,
264        // then compute the physical offset via strides.
265        let shape = self.shape();
266        let strides = self.inner.strides();
267        let base_ptr = self.inner.as_ptr();
268        let ndim = shape.len();
269
270        let mut remaining = idx;
271        let mut offset: isize = 0;
272        for d in 0..ndim {
273            let dim_stride: usize = shape[d + 1..].iter().product::<usize>().max(1);
274            let coord = remaining / dim_stride;
275            remaining %= dim_stride;
276            offset += coord as isize * strides[d];
277        }
278
279        // SAFETY: idx is validated in-bounds via normalize_index, and the
280        // unravel produces coordinates within each axis, so the computed
281        // offset is within the array's data allocation.
282        Ok(unsafe { &*base_ptr.offset(offset) })
283    }
284
285    /// Get a reference to a single element by multi-dimensional index.
286    ///
287    /// Supports negative indices (counting from end).
288    ///
289    /// # Errors
290    /// - `InvalidValue` if `indices.len() != ndim()`
291    /// - `IndexOutOfBounds` if any index is out of range
292    pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
293        let ndim = self.ndim();
294        if indices.len() != ndim {
295            return Err(FerrayError::invalid_value(format!(
296                "expected {} indices, got {}",
297                ndim,
298                indices.len()
299            )));
300        }
301
302        // Compute the flat offset manually
303        let shape = self.shape();
304        let strides = self.inner.strides();
305        let base_ptr = self.inner.as_ptr();
306
307        let mut offset: isize = 0;
308        for (ax, &idx) in indices.iter().enumerate() {
309            let pos = normalize_index(idx, shape[ax], ax)?;
310            offset += pos as isize * strides[ax];
311        }
312
313        // SAFETY: all indices are validated in-bounds, so the computed
314        // offset is within the array's data allocation.
315        Ok(unsafe { &*base_ptr.offset(offset) })
316    }
317
318    /// Get a mutable reference to a single element by multi-dimensional index.
319    ///
320    /// # Errors
321    /// Same as [`get`](Self::get).
322    pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
323        let ndim = self.ndim();
324        if indices.len() != ndim {
325            return Err(FerrayError::invalid_value(format!(
326                "expected {} indices, got {}",
327                ndim,
328                indices.len()
329            )));
330        }
331
332        let shape = self.shape().to_vec();
333        let strides: Vec<isize> = self.inner.strides().to_vec();
334        let base_ptr = self.inner.as_mut_ptr();
335
336        let mut offset: isize = 0;
337        for (ax, &idx) in indices.iter().enumerate() {
338            let pos = normalize_index(idx, shape[ax], ax)?;
339            offset += pos as isize * strides[ax];
340        }
341
342        // SAFETY: we have &mut self so exclusive access is guaranteed,
343        // and all indices are validated in-bounds.
344        Ok(unsafe { &mut *base_ptr.offset(offset) })
345    }
346}
347
348// ---------------------------------------------------------------------------
349// ArrayView methods — basic indexing
350// ---------------------------------------------------------------------------
351
352impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
353    /// Index into the view along a given axis, removing that axis.
354    /// Returns a view with rank one less than `Self` (#349).
355    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, D::Smaller>>
356    where
357        D::NdarrayDim:
358            ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
359    {
360        let ndim = self.ndim();
361        let ax = axis.index();
362        if ax >= ndim {
363            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
364        }
365        let size = self.shape()[ax];
366        let idx = normalize_index(index, size, ax)?;
367
368        let nd_axis = ndarray::Axis(ax);
369        // clone() on ArrayView is cheap (it's Copy-like)
370        let sub = self.inner.clone().index_axis_move(nd_axis, idx);
371        Ok(ArrayView::from_ndarray(sub))
372    }
373
374    /// Slice the view along a given axis.
375    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
376        let ndim = self.ndim();
377        let ax = axis.index();
378        if ax >= ndim {
379            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
380        }
381        spec.validate()?;
382
383        let nd_axis = ndarray::Axis(ax);
384        let nd_slice = spec.to_ndarray_slice();
385        // slice_axis on a cloned view preserves the 'a lifetime
386        let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
387        let dyn_view = sliced.into_dyn();
388        Ok(ArrayView::from_ndarray(dyn_view))
389    }
390
391    /// Insert a new axis of length 1 at the given position.
392    /// Returns a view with rank one greater than `Self` (#349).
393    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, D::Larger>>
394    where
395        D::NdarrayDim:
396            ndarray::Dimension<Larger = <D::Larger as crate::dimension::Dimension>::NdarrayDim>,
397    {
398        let ndim = self.ndim();
399        let ax = axis.index();
400        if ax > ndim {
401            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
402        }
403        let expanded = self.inner.clone().insert_axis(ndarray::Axis(ax));
404        Ok(ArrayView::from_ndarray(expanded))
405    }
406
407    /// Remove an axis of length 1.
408    /// Returns a view with rank one less than `Self` (#349).
409    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, D::Smaller>>
410    where
411        D::NdarrayDim:
412            ndarray::RemoveAxis<Smaller = <D::Smaller as crate::dimension::Dimension>::NdarrayDim>,
413    {
414        let ndim = self.ndim();
415        let ax = axis.index();
416        if ax >= ndim {
417            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
418        }
419        if self.shape()[ax] != 1 {
420            return Err(FerrayError::invalid_value(format!(
421                "cannot remove axis {} with size {} (must be 1)",
422                ax,
423                self.shape()[ax]
424            )));
425        }
426
427        let squeezed = self.inner.clone().index_axis_move(ndarray::Axis(ax), 0);
428        Ok(ArrayView::from_ndarray(squeezed))
429    }
430
431    /// Get a reference to a single element by multi-dimensional index.
432    pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
433        let ndim = self.ndim();
434        if indices.len() != ndim {
435            return Err(FerrayError::invalid_value(format!(
436                "expected {} indices, got {}",
437                ndim,
438                indices.len()
439            )));
440        }
441
442        let shape = self.shape();
443        let strides = self.inner.strides();
444        let base_ptr = self.inner.as_ptr();
445
446        let mut offset: isize = 0;
447        for (ax, &idx) in indices.iter().enumerate() {
448            let pos = normalize_index(idx, shape[ax], ax)?;
449            offset += pos as isize * strides[ax];
450        }
451
452        // SAFETY: indices validated in-bounds; the pointer is valid for 'a.
453        Ok(unsafe { &*base_ptr.offset(offset) })
454    }
455}
456
457// ---------------------------------------------------------------------------
458// ArrayViewMut methods — basic indexing
459// ---------------------------------------------------------------------------
460
461impl<T: Element, D: Dimension> ArrayViewMut<'_, T, D> {
462    /// Slice the mutable view along a given axis.
463    pub fn slice_axis_mut(
464        &mut self,
465        axis: Axis,
466        spec: SliceSpec,
467    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
468        let ndim = self.ndim();
469        let ax = axis.index();
470        if ax >= ndim {
471            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
472        }
473        spec.validate()?;
474
475        let nd_axis = ndarray::Axis(ax);
476        let nd_slice = spec.to_ndarray_slice();
477        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
478        let dyn_view = sliced.into_dyn();
479        Ok(ArrayViewMut::from_ndarray(dyn_view))
480    }
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use crate::dimension::{Ix1, Ix2, Ix3};
487
488    // -----------------------------------------------------------------------
489    // Normalization
490    // -----------------------------------------------------------------------
491
492    #[test]
493    fn normalize_positive_in_bounds() {
494        assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
495    }
496
497    #[test]
498    fn normalize_negative() {
499        assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
500        assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
501    }
502
503    #[test]
504    fn normalize_out_of_bounds() {
505        assert!(normalize_index(5, 5, 0).is_err());
506        assert!(normalize_index(-6, 5, 0).is_err());
507    }
508
509    // -----------------------------------------------------------------------
510    // index_axis
511    // -----------------------------------------------------------------------
512
513    #[test]
514    fn index_axis_row() {
515        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
516        let row = arr.index_axis(Axis(0), 1).unwrap();
517        assert_eq!(row.shape(), &[4]);
518        let data: Vec<i32> = row.iter().copied().collect();
519        assert_eq!(data, vec![4, 5, 6, 7]);
520    }
521
522    #[test]
523    fn index_axis_column() {
524        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
525        let col = arr.index_axis(Axis(1), 2).unwrap();
526        assert_eq!(col.shape(), &[3]);
527        let data: Vec<i32> = col.iter().copied().collect();
528        assert_eq!(data, vec![2, 6, 10]);
529    }
530
531    #[test]
532    fn index_axis_negative() {
533        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
534        let row = arr.index_axis(Axis(0), -1).unwrap();
535        let data: Vec<i32> = row.iter().copied().collect();
536        assert_eq!(data, vec![8, 9, 10, 11]);
537    }
538
539    #[test]
540    fn index_axis_out_of_bounds() {
541        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
542        assert!(arr.index_axis(Axis(0), 3).is_err());
543        assert!(arr.index_axis(Axis(2), 0).is_err());
544    }
545
546    #[test]
547    fn index_axis_is_zero_copy() {
548        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
549            .unwrap();
550        let row = arr.index_axis(Axis(0), 0).unwrap();
551        assert_eq!(row.as_ptr(), arr.as_ptr());
552    }
553
554    // -----------------------------------------------------------------------
555    // slice_axis
556    // -----------------------------------------------------------------------
557
558    #[test]
559    fn slice_axis_basic() {
560        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
561        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
562        assert_eq!(sliced.shape(), &[3]);
563        let data: Vec<i32> = sliced.iter().copied().collect();
564        assert_eq!(data, vec![20, 30, 40]);
565    }
566
567    #[test]
568    fn slice_axis_step() {
569        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
570        let sliced = arr
571            .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
572            .unwrap();
573        assert_eq!(sliced.shape(), &[3]);
574        let data: Vec<i32> = sliced.iter().copied().collect();
575        assert_eq!(data, vec![0, 2, 4]);
576    }
577
578    #[test]
579    fn slice_axis_negative_step() {
580        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
581        // Reverse the entire array: range [0, len) traversed backwards
582        // ndarray interprets Slice::new(start, end, step) where start/end define
583        // a forward range and negative step reverses traversal within it.
584        let spec = SliceSpec {
585            start: None,
586            stop: None,
587            step: Some(-1),
588        };
589        let sliced = arr.slice_axis(Axis(0), spec).unwrap();
590        let data: Vec<i32> = sliced.iter().copied().collect();
591        assert_eq!(data, vec![4, 3, 2, 1, 0]);
592    }
593
594    #[test]
595    fn slice_axis_negative_step_partial() {
596        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
597        // Range [1, 4) traversed backwards with step -1: [3, 2, 1]
598        let sliced = arr
599            .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
600            .unwrap();
601        let data: Vec<i32> = sliced.iter().copied().collect();
602        assert_eq!(data, vec![3, 2, 1]);
603    }
604
605    #[test]
606    fn slice_axis_full() {
607        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
608        let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
609        let data: Vec<i32> = sliced.iter().copied().collect();
610        assert_eq!(data, vec![1, 2, 3]);
611    }
612
613    #[test]
614    fn slice_axis_2d_rows() {
615        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
616        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
617        assert_eq!(sliced.shape(), &[2, 3]);
618        let data: Vec<i32> = sliced.iter().copied().collect();
619        assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
620    }
621
622    #[test]
623    fn slice_axis_is_zero_copy() {
624        let arr =
625            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
626        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
627        unsafe {
628            assert_eq!(*sliced.as_ptr(), 2.0);
629        }
630    }
631
632    #[test]
633    fn slice_axis_zero_step_error() {
634        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
635        assert!(
636            arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
637                .is_err()
638        );
639    }
640
641    // -----------------------------------------------------------------------
642    // slice_multi
643    // -----------------------------------------------------------------------
644
645    #[test]
646    fn slice_multi_2d() {
647        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
648        let sliced = arr
649            .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
650            .unwrap();
651        assert_eq!(sliced.shape(), &[2, 4]);
652    }
653
654    #[test]
655    fn slice_multi_wrong_count() {
656        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
657        assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
658    }
659
660    // -----------------------------------------------------------------------
661    // insert_axis / remove_axis
662    // -----------------------------------------------------------------------
663
664    #[test]
665    fn insert_axis_at_front() {
666        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
667        let expanded = arr.insert_axis(Axis(0)).unwrap();
668        assert_eq!(expanded.shape(), &[1, 3]);
669    }
670
671    #[test]
672    fn insert_axis_at_end() {
673        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
674        let expanded = arr.insert_axis(Axis(1)).unwrap();
675        assert_eq!(expanded.shape(), &[3, 1]);
676    }
677
678    #[test]
679    fn insert_axis_out_of_bounds() {
680        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
681        assert!(arr.insert_axis(Axis(3)).is_err());
682    }
683
684    #[test]
685    fn remove_axis_single() {
686        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
687        let squeezed = arr.remove_axis(Axis(0)).unwrap();
688        assert_eq!(squeezed.shape(), &[3]);
689        let data: Vec<f64> = squeezed.iter().copied().collect();
690        assert_eq!(data, vec![1.0, 2.0, 3.0]);
691    }
692
693    #[test]
694    fn remove_axis_not_one() {
695        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
696        assert!(arr.remove_axis(Axis(0)).is_err());
697    }
698
699    #[test]
700    fn remove_axis_out_of_bounds() {
701        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
702        assert!(arr.remove_axis(Axis(1)).is_err());
703    }
704
705    // -----------------------------------------------------------------------
706    // flat_index
707    // -----------------------------------------------------------------------
708
709    #[test]
710    fn flat_index_positive() {
711        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
712        assert_eq!(*arr.flat_index(0).unwrap(), 1);
713        assert_eq!(*arr.flat_index(3).unwrap(), 4);
714        assert_eq!(*arr.flat_index(5).unwrap(), 6);
715    }
716
717    #[test]
718    fn flat_index_negative() {
719        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
720        assert_eq!(*arr.flat_index(-1).unwrap(), 50);
721        assert_eq!(*arr.flat_index(-5).unwrap(), 10);
722    }
723
724    #[test]
725    fn flat_index_out_of_bounds() {
726        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
727        assert!(arr.flat_index(3).is_err());
728        assert!(arr.flat_index(-4).is_err());
729    }
730
731    // -----------------------------------------------------------------------
732    // get / get_mut
733    // -----------------------------------------------------------------------
734
735    #[test]
736    fn get_2d() {
737        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
738        assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
739        assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
740        assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
741    }
742
743    #[test]
744    fn get_negative_indices() {
745        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
746        assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
747        assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
748    }
749
750    #[test]
751    fn get_wrong_ndim() {
752        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
753        assert!(arr.get(&[0]).is_err());
754        assert!(arr.get(&[0, 0, 0]).is_err());
755    }
756
757    #[test]
758    fn get_mut_modify() {
759        let mut arr =
760            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
761        *arr.get_mut(&[1, 2]).unwrap() = 99;
762        assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
763    }
764
765    // -----------------------------------------------------------------------
766    // ArrayView basic indexing
767    // -----------------------------------------------------------------------
768
769    #[test]
770    fn view_index_axis() {
771        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
772        let v = arr.view();
773        let row = v.index_axis(Axis(0), 1).unwrap();
774        let data: Vec<i32> = row.iter().copied().collect();
775        assert_eq!(data, vec![4, 5, 6, 7]);
776    }
777
778    #[test]
779    fn view_slice_axis() {
780        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
781        let v = arr.view();
782        let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
783        let data: Vec<i32> = sliced.iter().copied().collect();
784        assert_eq!(data, vec![20, 30, 40]);
785    }
786
787    #[test]
788    fn view_insert_remove_axis() {
789        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
790        let v = arr.view();
791        let expanded = v.insert_axis(Axis(0)).unwrap();
792        assert_eq!(expanded.shape(), &[1, 4]);
793        let squeezed = expanded.remove_axis(Axis(0)).unwrap();
794        assert_eq!(squeezed.shape(), &[4]);
795    }
796
797    #[test]
798    fn view_get() {
799        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
800        let v = arr.view();
801        assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
802    }
803
804    // -----------------------------------------------------------------------
805    // ArrayViewMut slice
806    // -----------------------------------------------------------------------
807
808    #[test]
809    fn view_mut_slice_axis() {
810        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
811        {
812            let mut vm = arr.view_mut();
813            let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
814            if let Some(s) = sliced.as_slice_mut() {
815                s[0] = 20;
816                s[1] = 30;
817            }
818        }
819        assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
820    }
821
822    // -----------------------------------------------------------------------
823    // 3D indexing
824    // -----------------------------------------------------------------------
825
826    #[test]
827    fn index_axis_3d() {
828        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
829        let plane = arr.index_axis(Axis(0), 1).unwrap();
830        assert_eq!(plane.shape(), &[3, 4]);
831        assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
832    }
833
834    // ---- Rank-preserving indexing (#349) ------------------------------
835    //
836    // index_axis on a fixed-rank Array now returns a fixed-rank
837    // ArrayView with one fewer axis; insert_axis adds one axis. These
838    // tests pin the *static type* of the result via explicit type
839    // annotations (compile-time check, not just runtime shape).
840
841    #[test]
842    fn index_axis_2d_returns_ix1() {
843        let arr = Array::<f64, crate::dimension::Ix2>::from_vec(
844            crate::dimension::Ix2::new([3, 4]),
845            (0..12).map(|i| i as f64).collect(),
846        )
847        .unwrap();
848        // The variable's type is explicit — must compile to Ix1, not IxDyn.
849        let row: ArrayView<'_, f64, crate::dimension::Ix1> = arr.index_axis(Axis(0), 1).unwrap();
850        assert_eq!(row.shape(), &[4]);
851    }
852
853    #[test]
854    fn index_axis_3d_returns_ix2() {
855        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
856        let plane: ArrayView<'_, i32, crate::dimension::Ix2> = arr.index_axis(Axis(0), 0).unwrap();
857        assert_eq!(plane.shape(), &[3, 4]);
858    }
859
860    #[test]
861    fn insert_axis_ix2_returns_ix3() {
862        let arr = Array::<f64, crate::dimension::Ix2>::from_vec(
863            crate::dimension::Ix2::new([2, 3]),
864            (0..6).map(|i| i as f64).collect(),
865        )
866        .unwrap();
867        let expanded: ArrayView<'_, f64, Ix3> = arr.insert_axis(Axis(0)).unwrap();
868        assert_eq!(expanded.shape(), &[1, 2, 3]);
869    }
870
871    #[test]
872    fn remove_axis_ix3_returns_ix2() {
873        let arr =
874            Array::<f64, Ix3>::from_vec(Ix3::new([1, 3, 4]), (0..12).map(|i| i as f64).collect())
875                .unwrap();
876        let squeezed: ArrayView<'_, f64, crate::dimension::Ix2> = arr.remove_axis(Axis(0)).unwrap();
877        assert_eq!(squeezed.shape(), &[3, 4]);
878    }
879
880    #[test]
881    fn index_axis_chains_preserve_rank_at_each_step() {
882        // Ix3 → Ix2 → Ix1 → Ix0 chain.
883        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
884        let plane: ArrayView<'_, i32, crate::dimension::Ix2> = arr.index_axis(Axis(0), 1).unwrap();
885        let row: ArrayView<'_, i32, crate::dimension::Ix1> = plane.index_axis(Axis(0), 1).unwrap();
886        let scalar: ArrayView<'_, i32, crate::dimension::Ix0> = row.index_axis(Axis(0), 2).unwrap();
887        assert_eq!(scalar.shape(), &[] as &[usize]);
888    }
889}