Skip to main content

ferray_core/indexing/
basic.rs

1// ferray-core: Basic indexing (REQ-12, REQ-14)
2//
3// Integer + slice indexing returning views (zero-copy).
4// insert_axis / remove_axis for dimension manipulation.
5//
6// These are implemented as methods on Array, ArrayView, ArrayViewMut.
7// The s![] macro is out of scope (Agent 1d).
8
9use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16use super::normalize_index;
17
18/// A slice specification for one axis, mirroring Python's `start:stop:step`.
19///
20/// All fields are optional (represented by `None`), matching `NumPy` behaviour:
21/// - `start`: defaults to 0 (or end-1 if step < 0)
22/// - `stop`: defaults to size (or before-start if step < 0)
23/// - `step`: defaults to 1; must not be 0
24#[derive(Debug, Clone, Copy)]
25pub struct SliceSpec {
26    /// Start index (inclusive). Negative counts from end.
27    pub start: Option<isize>,
28    /// Stop index (exclusive). Negative counts from end.
29    pub stop: Option<isize>,
30    /// Step size. Must not be zero.
31    pub step: Option<isize>,
32}
33
34impl SliceSpec {
35    /// Create a new full-range slice (equivalent to `:`).
36    #[must_use]
37    pub const fn full() -> Self {
38        Self {
39            start: None,
40            stop: None,
41            step: None,
42        }
43    }
44
45    /// Create a slice `start:stop` with step 1.
46    #[must_use]
47    pub const fn new(start: isize, stop: isize) -> Self {
48        Self {
49            start: Some(start),
50            stop: Some(stop),
51            step: None,
52        }
53    }
54
55    /// Create a slice `start:stop:step`.
56    #[must_use]
57    pub const fn with_step(start: isize, stop: isize, step: isize) -> Self {
58        Self {
59            start: Some(start),
60            stop: Some(stop),
61            step: Some(step),
62        }
63    }
64
65    /// Validate that the step is not zero.
66    fn validate(&self) -> FerrayResult<()> {
67        if self.step == Some(0) {
68            return Err(FerrayError::invalid_value("slice step cannot be zero"));
69        }
70        Ok(())
71    }
72
73    /// Convert to an ndarray Slice.
74    #[allow(clippy::wrong_self_convention)]
75    fn to_ndarray_slice(&self) -> ndarray::Slice {
76        ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
77    }
78
79    /// Convert to an ndarray `SliceInfoElem` (used by s![] macro integration).
80    #[allow(dead_code, clippy::wrong_self_convention)]
81    pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
82        ndarray::SliceInfoElem::Slice {
83            start: self.start.unwrap_or(0),
84            end: self.stop,
85            step: self.step.unwrap_or(1),
86        }
87    }
88}
89
90// ---------------------------------------------------------------------------
91// Array methods — basic indexing
92// ---------------------------------------------------------------------------
93
94impl<T: Element, D: Dimension> Array<T, D> {
95    /// Index into the array along a given axis, removing that axis.
96    ///
97    /// Equivalent to `a[i]` for axis 0, or `a[:, i]` for axis 1, etc.
98    /// Returns a view with one fewer dimension (dynamic-rank).
99    ///
100    /// # Errors
101    /// - `AxisOutOfBounds` if `axis >= ndim`
102    /// - `IndexOutOfBounds` if `index` is out of range (supports negative)
103    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, IxDyn>>
104    where
105        D::NdarrayDim: ndarray::RemoveAxis,
106    {
107        let ndim = self.ndim();
108        let ax = axis.index();
109        if ax >= ndim {
110            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
111        }
112        let size = self.shape()[ax];
113        let idx = normalize_index(index, size, ax)?;
114
115        let nd_axis = ndarray::Axis(ax);
116        let sub = self.inner.index_axis(nd_axis, idx);
117        let dyn_view = sub.into_dyn();
118        Ok(ArrayView::from_ndarray(dyn_view))
119    }
120
121    /// Slice the array along a given axis, returning a view.
122    ///
123    /// The returned view shares data with the source (zero-copy).
124    ///
125    /// # Errors
126    /// - `AxisOutOfBounds` if `axis >= ndim`
127    /// - `InvalidValue` if step is zero
128    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
129        let ndim = self.ndim();
130        let ax = axis.index();
131        if ax >= ndim {
132            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
133        }
134        spec.validate()?;
135
136        let nd_axis = ndarray::Axis(ax);
137        let nd_slice = spec.to_ndarray_slice();
138        let sliced = self.inner.slice_axis(nd_axis, nd_slice);
139        let dyn_view = sliced.into_dyn();
140        Ok(ArrayView::from_ndarray(dyn_view))
141    }
142
143    /// Slice the array along a given axis, returning a mutable view.
144    ///
145    /// # Errors
146    /// Same as [`slice_axis`](Self::slice_axis).
147    pub fn slice_axis_mut(
148        &mut self,
149        axis: Axis,
150        spec: SliceSpec,
151    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
152        let ndim = self.ndim();
153        let ax = axis.index();
154        if ax >= ndim {
155            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
156        }
157        spec.validate()?;
158
159        let nd_axis = ndarray::Axis(ax);
160        let nd_slice = spec.to_ndarray_slice();
161        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
162        let dyn_view = sliced.into_dyn();
163        Ok(ArrayViewMut::from_ndarray(dyn_view))
164    }
165
166    /// Multi-axis slicing: apply a slice specification to each axis.
167    ///
168    /// `specs` must have length equal to `ndim()`. For axes you don't
169    /// want to slice, pass `SliceSpec::full()`.
170    ///
171    /// # Errors
172    /// - `InvalidValue` if `specs.len() != ndim()`
173    /// - Any errors from individual axis slicing
174    pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
175        let ndim = self.ndim();
176        if specs.len() != ndim {
177            return Err(FerrayError::invalid_value(format!(
178                "expected {} slice specs, got {}",
179                ndim,
180                specs.len()
181            )));
182        }
183
184        for spec in specs {
185            spec.validate()?;
186        }
187
188        // Apply axis-by-axis slicing using move variants to preserve lifetimes
189        let mut result = self.inner.view().into_dyn();
190        for (ax, spec) in specs.iter().enumerate() {
191            let nd_axis = ndarray::Axis(ax);
192            let nd_slice = spec.to_ndarray_slice();
193            result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
194        }
195        Ok(ArrayView::from_ndarray(result))
196    }
197
198    /// Insert a new axis of length 1 at the given position.
199    ///
200    /// This is equivalent to `np.expand_dims` or `np.newaxis`.
201    /// Returns a dynamic-rank view with one more dimension.
202    ///
203    /// # Errors
204    /// - `AxisOutOfBounds` if `axis > ndim`
205    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
206        let ndim = self.ndim();
207        let ax = axis.index();
208        if ax > ndim {
209            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
210        }
211
212        let dyn_view = self.inner.view().into_dyn();
213        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
214        Ok(ArrayView::from_ndarray(expanded))
215    }
216
217    /// Remove an axis of length 1.
218    ///
219    /// This is equivalent to `np.squeeze` for a single axis.
220    /// Returns a dynamic-rank view with one fewer dimension.
221    ///
222    /// # Errors
223    /// - `AxisOutOfBounds` if `axis >= ndim`
224    /// - `InvalidValue` if the axis has size != 1
225    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
226        let ndim = self.ndim();
227        let ax = axis.index();
228        if ax >= ndim {
229            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
230        }
231        if self.shape()[ax] != 1 {
232            return Err(FerrayError::invalid_value(format!(
233                "cannot remove axis {} with size {} (must be 1)",
234                ax,
235                self.shape()[ax]
236            )));
237        }
238
239        // index_axis_move at 0 removes the axis (consumes the view, preserving lifetime)
240        let dyn_view = self.inner.view().into_dyn();
241        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
242        Ok(ArrayView::from_ndarray(squeezed))
243    }
244
245    /// Index into the array with a flat (linear) index.
246    ///
247    /// Elements are ordered in row-major (C) order. The flat index is
248    /// converted to a multi-dimensional index via unravel, then stride
249    /// arithmetic computes the physical offset in O(ndim) time.
250    ///
251    /// # Errors
252    /// Returns `IndexOutOfBounds` if the index is out of range.
253    pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
254        let size = self.size();
255        let idx = normalize_index(index, size, 0)?;
256
257        // Unravel flat index to multi-dim index in row-major order,
258        // then compute the physical offset via strides.
259        let shape = self.shape();
260        let strides = self.inner.strides();
261        let base_ptr = self.inner.as_ptr();
262        let ndim = shape.len();
263
264        let mut remaining = idx;
265        let mut offset: isize = 0;
266        for d in 0..ndim {
267            let dim_stride: usize = shape[d + 1..].iter().product::<usize>().max(1);
268            let coord = remaining / dim_stride;
269            remaining %= dim_stride;
270            offset += coord as isize * strides[d];
271        }
272
273        // SAFETY: idx is validated in-bounds via normalize_index, and the
274        // unravel produces coordinates within each axis, so the computed
275        // offset is within the array's data allocation.
276        Ok(unsafe { &*base_ptr.offset(offset) })
277    }
278
279    /// Get a reference to a single element by multi-dimensional index.
280    ///
281    /// Supports negative indices (counting from end).
282    ///
283    /// # Errors
284    /// - `InvalidValue` if `indices.len() != ndim()`
285    /// - `IndexOutOfBounds` if any index is out of range
286    pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
287        let ndim = self.ndim();
288        if indices.len() != ndim {
289            return Err(FerrayError::invalid_value(format!(
290                "expected {} indices, got {}",
291                ndim,
292                indices.len()
293            )));
294        }
295
296        // Compute the flat offset manually
297        let shape = self.shape();
298        let strides = self.inner.strides();
299        let base_ptr = self.inner.as_ptr();
300
301        let mut offset: isize = 0;
302        for (ax, &idx) in indices.iter().enumerate() {
303            let pos = normalize_index(idx, shape[ax], ax)?;
304            offset += pos as isize * strides[ax];
305        }
306
307        // SAFETY: all indices are validated in-bounds, so the computed
308        // offset is within the array's data allocation.
309        Ok(unsafe { &*base_ptr.offset(offset) })
310    }
311
312    /// Get a mutable reference to a single element by multi-dimensional index.
313    ///
314    /// # Errors
315    /// Same as [`get`](Self::get).
316    pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
317        let ndim = self.ndim();
318        if indices.len() != ndim {
319            return Err(FerrayError::invalid_value(format!(
320                "expected {} indices, got {}",
321                ndim,
322                indices.len()
323            )));
324        }
325
326        let shape = self.shape().to_vec();
327        let strides: Vec<isize> = self.inner.strides().to_vec();
328        let base_ptr = self.inner.as_mut_ptr();
329
330        let mut offset: isize = 0;
331        for (ax, &idx) in indices.iter().enumerate() {
332            let pos = normalize_index(idx, shape[ax], ax)?;
333            offset += pos as isize * strides[ax];
334        }
335
336        // SAFETY: we have &mut self so exclusive access is guaranteed,
337        // and all indices are validated in-bounds.
338        Ok(unsafe { &mut *base_ptr.offset(offset) })
339    }
340}
341
342// ---------------------------------------------------------------------------
343// ArrayView methods — basic indexing
344// ---------------------------------------------------------------------------
345
346impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
347    /// Index into the view along a given axis, removing that axis.
348    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, IxDyn>>
349    where
350        D::NdarrayDim: ndarray::RemoveAxis,
351    {
352        let ndim = self.ndim();
353        let ax = axis.index();
354        if ax >= ndim {
355            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
356        }
357        let size = self.shape()[ax];
358        let idx = normalize_index(index, size, ax)?;
359
360        let nd_axis = ndarray::Axis(ax);
361        // clone() on ArrayView is cheap (it's Copy-like)
362        let sub = self.inner.clone().index_axis_move(nd_axis, idx);
363        let dyn_view = sub.into_dyn();
364        Ok(ArrayView::from_ndarray(dyn_view))
365    }
366
367    /// Slice the view along a given axis.
368    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
369        let ndim = self.ndim();
370        let ax = axis.index();
371        if ax >= ndim {
372            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
373        }
374        spec.validate()?;
375
376        let nd_axis = ndarray::Axis(ax);
377        let nd_slice = spec.to_ndarray_slice();
378        // slice_axis on a cloned view preserves the 'a lifetime
379        let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
380        let dyn_view = sliced.into_dyn();
381        Ok(ArrayView::from_ndarray(dyn_view))
382    }
383
384    /// Insert a new axis of length 1 at the given position.
385    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
386        let ndim = self.ndim();
387        let ax = axis.index();
388        if ax > ndim {
389            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
390        }
391
392        let dyn_view = self.inner.clone().into_dyn();
393        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
394        Ok(ArrayView::from_ndarray(expanded))
395    }
396
397    /// Remove an axis of length 1.
398    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
399        let ndim = self.ndim();
400        let ax = axis.index();
401        if ax >= ndim {
402            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
403        }
404        if self.shape()[ax] != 1 {
405            return Err(FerrayError::invalid_value(format!(
406                "cannot remove axis {} with size {} (must be 1)",
407                ax,
408                self.shape()[ax]
409            )));
410        }
411
412        let dyn_view = self.inner.clone().into_dyn();
413        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
414        Ok(ArrayView::from_ndarray(squeezed))
415    }
416
417    /// Get a reference to a single element by multi-dimensional index.
418    pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
419        let ndim = self.ndim();
420        if indices.len() != ndim {
421            return Err(FerrayError::invalid_value(format!(
422                "expected {} indices, got {}",
423                ndim,
424                indices.len()
425            )));
426        }
427
428        let shape = self.shape();
429        let strides = self.inner.strides();
430        let base_ptr = self.inner.as_ptr();
431
432        let mut offset: isize = 0;
433        for (ax, &idx) in indices.iter().enumerate() {
434            let pos = normalize_index(idx, shape[ax], ax)?;
435            offset += pos as isize * strides[ax];
436        }
437
438        // SAFETY: indices validated in-bounds; the pointer is valid for 'a.
439        Ok(unsafe { &*base_ptr.offset(offset) })
440    }
441}
442
443// ---------------------------------------------------------------------------
444// ArrayViewMut methods — basic indexing
445// ---------------------------------------------------------------------------
446
447impl<T: Element, D: Dimension> ArrayViewMut<'_, T, D> {
448    /// Slice the mutable view along a given axis.
449    pub fn slice_axis_mut(
450        &mut self,
451        axis: Axis,
452        spec: SliceSpec,
453    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
454        let ndim = self.ndim();
455        let ax = axis.index();
456        if ax >= ndim {
457            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
458        }
459        spec.validate()?;
460
461        let nd_axis = ndarray::Axis(ax);
462        let nd_slice = spec.to_ndarray_slice();
463        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
464        let dyn_view = sliced.into_dyn();
465        Ok(ArrayViewMut::from_ndarray(dyn_view))
466    }
467}
468
469#[cfg(test)]
470mod tests {
471    use super::*;
472    use crate::dimension::{Ix1, Ix2, Ix3};
473
474    // -----------------------------------------------------------------------
475    // Normalization
476    // -----------------------------------------------------------------------
477
478    #[test]
479    fn normalize_positive_in_bounds() {
480        assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
481    }
482
483    #[test]
484    fn normalize_negative() {
485        assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
486        assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
487    }
488
489    #[test]
490    fn normalize_out_of_bounds() {
491        assert!(normalize_index(5, 5, 0).is_err());
492        assert!(normalize_index(-6, 5, 0).is_err());
493    }
494
495    // -----------------------------------------------------------------------
496    // index_axis
497    // -----------------------------------------------------------------------
498
499    #[test]
500    fn index_axis_row() {
501        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
502        let row = arr.index_axis(Axis(0), 1).unwrap();
503        assert_eq!(row.shape(), &[4]);
504        let data: Vec<i32> = row.iter().copied().collect();
505        assert_eq!(data, vec![4, 5, 6, 7]);
506    }
507
508    #[test]
509    fn index_axis_column() {
510        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
511        let col = arr.index_axis(Axis(1), 2).unwrap();
512        assert_eq!(col.shape(), &[3]);
513        let data: Vec<i32> = col.iter().copied().collect();
514        assert_eq!(data, vec![2, 6, 10]);
515    }
516
517    #[test]
518    fn index_axis_negative() {
519        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
520        let row = arr.index_axis(Axis(0), -1).unwrap();
521        let data: Vec<i32> = row.iter().copied().collect();
522        assert_eq!(data, vec![8, 9, 10, 11]);
523    }
524
525    #[test]
526    fn index_axis_out_of_bounds() {
527        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
528        assert!(arr.index_axis(Axis(0), 3).is_err());
529        assert!(arr.index_axis(Axis(2), 0).is_err());
530    }
531
532    #[test]
533    fn index_axis_is_zero_copy() {
534        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
535            .unwrap();
536        let row = arr.index_axis(Axis(0), 0).unwrap();
537        assert_eq!(row.as_ptr(), arr.as_ptr());
538    }
539
540    // -----------------------------------------------------------------------
541    // slice_axis
542    // -----------------------------------------------------------------------
543
544    #[test]
545    fn slice_axis_basic() {
546        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
547        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
548        assert_eq!(sliced.shape(), &[3]);
549        let data: Vec<i32> = sliced.iter().copied().collect();
550        assert_eq!(data, vec![20, 30, 40]);
551    }
552
553    #[test]
554    fn slice_axis_step() {
555        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
556        let sliced = arr
557            .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
558            .unwrap();
559        assert_eq!(sliced.shape(), &[3]);
560        let data: Vec<i32> = sliced.iter().copied().collect();
561        assert_eq!(data, vec![0, 2, 4]);
562    }
563
564    #[test]
565    fn slice_axis_negative_step() {
566        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
567        // Reverse the entire array: range [0, len) traversed backwards
568        // ndarray interprets Slice::new(start, end, step) where start/end define
569        // a forward range and negative step reverses traversal within it.
570        let spec = SliceSpec {
571            start: None,
572            stop: None,
573            step: Some(-1),
574        };
575        let sliced = arr.slice_axis(Axis(0), spec).unwrap();
576        let data: Vec<i32> = sliced.iter().copied().collect();
577        assert_eq!(data, vec![4, 3, 2, 1, 0]);
578    }
579
580    #[test]
581    fn slice_axis_negative_step_partial() {
582        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
583        // Range [1, 4) traversed backwards with step -1: [3, 2, 1]
584        let sliced = arr
585            .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
586            .unwrap();
587        let data: Vec<i32> = sliced.iter().copied().collect();
588        assert_eq!(data, vec![3, 2, 1]);
589    }
590
591    #[test]
592    fn slice_axis_full() {
593        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
594        let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
595        let data: Vec<i32> = sliced.iter().copied().collect();
596        assert_eq!(data, vec![1, 2, 3]);
597    }
598
599    #[test]
600    fn slice_axis_2d_rows() {
601        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
602        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
603        assert_eq!(sliced.shape(), &[2, 3]);
604        let data: Vec<i32> = sliced.iter().copied().collect();
605        assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
606    }
607
608    #[test]
609    fn slice_axis_is_zero_copy() {
610        let arr =
611            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
612        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
613        unsafe {
614            assert_eq!(*sliced.as_ptr(), 2.0);
615        }
616    }
617
618    #[test]
619    fn slice_axis_zero_step_error() {
620        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
621        assert!(
622            arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
623                .is_err()
624        );
625    }
626
627    // -----------------------------------------------------------------------
628    // slice_multi
629    // -----------------------------------------------------------------------
630
631    #[test]
632    fn slice_multi_2d() {
633        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
634        let sliced = arr
635            .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
636            .unwrap();
637        assert_eq!(sliced.shape(), &[2, 4]);
638    }
639
640    #[test]
641    fn slice_multi_wrong_count() {
642        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
643        assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
644    }
645
646    // -----------------------------------------------------------------------
647    // insert_axis / remove_axis
648    // -----------------------------------------------------------------------
649
650    #[test]
651    fn insert_axis_at_front() {
652        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
653        let expanded = arr.insert_axis(Axis(0)).unwrap();
654        assert_eq!(expanded.shape(), &[1, 3]);
655    }
656
657    #[test]
658    fn insert_axis_at_end() {
659        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
660        let expanded = arr.insert_axis(Axis(1)).unwrap();
661        assert_eq!(expanded.shape(), &[3, 1]);
662    }
663
664    #[test]
665    fn insert_axis_out_of_bounds() {
666        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
667        assert!(arr.insert_axis(Axis(3)).is_err());
668    }
669
670    #[test]
671    fn remove_axis_single() {
672        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
673        let squeezed = arr.remove_axis(Axis(0)).unwrap();
674        assert_eq!(squeezed.shape(), &[3]);
675        let data: Vec<f64> = squeezed.iter().copied().collect();
676        assert_eq!(data, vec![1.0, 2.0, 3.0]);
677    }
678
679    #[test]
680    fn remove_axis_not_one() {
681        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
682        assert!(arr.remove_axis(Axis(0)).is_err());
683    }
684
685    #[test]
686    fn remove_axis_out_of_bounds() {
687        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
688        assert!(arr.remove_axis(Axis(1)).is_err());
689    }
690
691    // -----------------------------------------------------------------------
692    // flat_index
693    // -----------------------------------------------------------------------
694
695    #[test]
696    fn flat_index_positive() {
697        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
698        assert_eq!(*arr.flat_index(0).unwrap(), 1);
699        assert_eq!(*arr.flat_index(3).unwrap(), 4);
700        assert_eq!(*arr.flat_index(5).unwrap(), 6);
701    }
702
703    #[test]
704    fn flat_index_negative() {
705        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
706        assert_eq!(*arr.flat_index(-1).unwrap(), 50);
707        assert_eq!(*arr.flat_index(-5).unwrap(), 10);
708    }
709
710    #[test]
711    fn flat_index_out_of_bounds() {
712        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
713        assert!(arr.flat_index(3).is_err());
714        assert!(arr.flat_index(-4).is_err());
715    }
716
717    // -----------------------------------------------------------------------
718    // get / get_mut
719    // -----------------------------------------------------------------------
720
721    #[test]
722    fn get_2d() {
723        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
724        assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
725        assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
726        assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
727    }
728
729    #[test]
730    fn get_negative_indices() {
731        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
732        assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
733        assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
734    }
735
736    #[test]
737    fn get_wrong_ndim() {
738        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
739        assert!(arr.get(&[0]).is_err());
740        assert!(arr.get(&[0, 0, 0]).is_err());
741    }
742
743    #[test]
744    fn get_mut_modify() {
745        let mut arr =
746            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
747        *arr.get_mut(&[1, 2]).unwrap() = 99;
748        assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
749    }
750
751    // -----------------------------------------------------------------------
752    // ArrayView basic indexing
753    // -----------------------------------------------------------------------
754
755    #[test]
756    fn view_index_axis() {
757        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
758        let v = arr.view();
759        let row = v.index_axis(Axis(0), 1).unwrap();
760        let data: Vec<i32> = row.iter().copied().collect();
761        assert_eq!(data, vec![4, 5, 6, 7]);
762    }
763
764    #[test]
765    fn view_slice_axis() {
766        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
767        let v = arr.view();
768        let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
769        let data: Vec<i32> = sliced.iter().copied().collect();
770        assert_eq!(data, vec![20, 30, 40]);
771    }
772
773    #[test]
774    fn view_insert_remove_axis() {
775        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
776        let v = arr.view();
777        let expanded = v.insert_axis(Axis(0)).unwrap();
778        assert_eq!(expanded.shape(), &[1, 4]);
779        let squeezed = expanded.remove_axis(Axis(0)).unwrap();
780        assert_eq!(squeezed.shape(), &[4]);
781    }
782
783    #[test]
784    fn view_get() {
785        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
786        let v = arr.view();
787        assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
788    }
789
790    // -----------------------------------------------------------------------
791    // ArrayViewMut slice
792    // -----------------------------------------------------------------------
793
794    #[test]
795    fn view_mut_slice_axis() {
796        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
797        {
798            let mut vm = arr.view_mut();
799            let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
800            if let Some(s) = sliced.as_slice_mut() {
801                s[0] = 20;
802                s[1] = 30;
803            }
804        }
805        assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
806    }
807
808    // -----------------------------------------------------------------------
809    // 3D indexing
810    // -----------------------------------------------------------------------
811
812    #[test]
813    fn index_axis_3d() {
814        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
815        let plane = arr.index_axis(Axis(0), 1).unwrap();
816        assert_eq!(plane.shape(), &[3, 4]);
817        assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
818    }
819}