Skip to main content

ferray_core/indexing/
basic.rs

1// ferray-core: Basic indexing (REQ-12, REQ-14)
2//
3// Integer + slice indexing returning views (zero-copy).
4// insert_axis / remove_axis for dimension manipulation.
5//
6// These are implemented as methods on Array, ArrayView, ArrayViewMut.
7// The s![] macro is out of scope (Agent 1d).
8
9use crate::array::owned::Array;
10use crate::array::view::ArrayView;
11use crate::array::view_mut::ArrayViewMut;
12use crate::dimension::{Axis, Dimension, IxDyn};
13use crate::dtype::Element;
14use crate::error::{FerrayError, FerrayResult};
15
16use super::normalize_index;
17
18/// A slice specification for one axis, mirroring Python's `start:stop:step`.
19///
20/// All fields are optional (represented by `None`), matching NumPy behaviour:
21/// - `start`: defaults to 0 (or end-1 if step < 0)
22/// - `stop`: defaults to size (or before-start if step < 0)
23/// - `step`: defaults to 1; must not be 0
24#[derive(Debug, Clone, Copy)]
25pub struct SliceSpec {
26    /// Start index (inclusive). Negative counts from end.
27    pub start: Option<isize>,
28    /// Stop index (exclusive). Negative counts from end.
29    pub stop: Option<isize>,
30    /// Step size. Must not be zero.
31    pub step: Option<isize>,
32}
33
34impl SliceSpec {
35    /// Create a new full-range slice (equivalent to `:`).
36    pub fn full() -> Self {
37        Self {
38            start: None,
39            stop: None,
40            step: None,
41        }
42    }
43
44    /// Create a slice `start:stop` with step 1.
45    pub fn new(start: isize, stop: isize) -> Self {
46        Self {
47            start: Some(start),
48            stop: Some(stop),
49            step: None,
50        }
51    }
52
53    /// Create a slice `start:stop:step`.
54    pub fn with_step(start: isize, stop: isize, step: isize) -> Self {
55        Self {
56            start: Some(start),
57            stop: Some(stop),
58            step: Some(step),
59        }
60    }
61
62    /// Validate that the step is not zero.
63    fn validate(&self) -> FerrayResult<()> {
64        if let Some(0) = self.step {
65            return Err(FerrayError::invalid_value("slice step cannot be zero"));
66        }
67        Ok(())
68    }
69
70    /// Convert to an ndarray Slice.
71    #[allow(clippy::wrong_self_convention)]
72    fn to_ndarray_slice(&self) -> ndarray::Slice {
73        ndarray::Slice::new(self.start.unwrap_or(0), self.stop, self.step.unwrap_or(1))
74    }
75
76    /// Convert to an ndarray SliceInfoElem (used by s![] macro integration).
77    #[allow(dead_code, clippy::wrong_self_convention)]
78    pub(crate) fn to_ndarray_elem(&self) -> ndarray::SliceInfoElem {
79        ndarray::SliceInfoElem::Slice {
80            start: self.start.unwrap_or(0),
81            end: self.stop,
82            step: self.step.unwrap_or(1),
83        }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Array methods — basic indexing
89// ---------------------------------------------------------------------------
90
91impl<T: Element, D: Dimension> Array<T, D> {
92    /// Index into the array along a given axis, removing that axis.
93    ///
94    /// Equivalent to `a[i]` for axis 0, or `a[:, i]` for axis 1, etc.
95    /// Returns a view with one fewer dimension (dynamic-rank).
96    ///
97    /// # Errors
98    /// - `AxisOutOfBounds` if `axis >= ndim`
99    /// - `IndexOutOfBounds` if `index` is out of range (supports negative)
100    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'_, T, IxDyn>>
101    where
102        D::NdarrayDim: ndarray::RemoveAxis,
103    {
104        let ndim = self.ndim();
105        let ax = axis.index();
106        if ax >= ndim {
107            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
108        }
109        let size = self.shape()[ax];
110        let idx = normalize_index(index, size, ax)?;
111
112        let nd_axis = ndarray::Axis(ax);
113        let sub = self.inner.index_axis(nd_axis, idx);
114        let dyn_view = sub.into_dyn();
115        Ok(ArrayView::from_ndarray(dyn_view))
116    }
117
118    /// Slice the array along a given axis, returning a view.
119    ///
120    /// The returned view shares data with the source (zero-copy).
121    ///
122    /// # Errors
123    /// - `AxisOutOfBounds` if `axis >= ndim`
124    /// - `InvalidValue` if step is zero
125    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
126        let ndim = self.ndim();
127        let ax = axis.index();
128        if ax >= ndim {
129            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
130        }
131        spec.validate()?;
132
133        let nd_axis = ndarray::Axis(ax);
134        let nd_slice = spec.to_ndarray_slice();
135        let sliced = self.inner.slice_axis(nd_axis, nd_slice);
136        let dyn_view = sliced.into_dyn();
137        Ok(ArrayView::from_ndarray(dyn_view))
138    }
139
140    /// Slice the array along a given axis, returning a mutable view.
141    ///
142    /// # Errors
143    /// Same as [`slice_axis`](Self::slice_axis).
144    pub fn slice_axis_mut(
145        &mut self,
146        axis: Axis,
147        spec: SliceSpec,
148    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
149        let ndim = self.ndim();
150        let ax = axis.index();
151        if ax >= ndim {
152            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
153        }
154        spec.validate()?;
155
156        let nd_axis = ndarray::Axis(ax);
157        let nd_slice = spec.to_ndarray_slice();
158        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
159        let dyn_view = sliced.into_dyn();
160        Ok(ArrayViewMut::from_ndarray(dyn_view))
161    }
162
163    /// Multi-axis slicing: apply a slice specification to each axis.
164    ///
165    /// `specs` must have length equal to `ndim()`. For axes you don't
166    /// want to slice, pass `SliceSpec::full()`.
167    ///
168    /// # Errors
169    /// - `InvalidValue` if `specs.len() != ndim()`
170    /// - Any errors from individual axis slicing
171    pub fn slice_multi(&self, specs: &[SliceSpec]) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
172        let ndim = self.ndim();
173        if specs.len() != ndim {
174            return Err(FerrayError::invalid_value(format!(
175                "expected {} slice specs, got {}",
176                ndim,
177                specs.len()
178            )));
179        }
180
181        for spec in specs {
182            spec.validate()?;
183        }
184
185        // Apply axis-by-axis slicing using move variants to preserve lifetimes
186        let mut result = self.inner.view().into_dyn();
187        for (ax, spec) in specs.iter().enumerate() {
188            let nd_axis = ndarray::Axis(ax);
189            let nd_slice = spec.to_ndarray_slice();
190            result = result.slice_axis_move(nd_axis, nd_slice).into_dyn();
191        }
192        Ok(ArrayView::from_ndarray(result))
193    }
194
195    /// Insert a new axis of length 1 at the given position.
196    ///
197    /// This is equivalent to `np.expand_dims` or `np.newaxis`.
198    /// Returns a dynamic-rank view with one more dimension.
199    ///
200    /// # Errors
201    /// - `AxisOutOfBounds` if `axis > ndim`
202    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
203        let ndim = self.ndim();
204        let ax = axis.index();
205        if ax > ndim {
206            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
207        }
208
209        let dyn_view = self.inner.view().into_dyn();
210        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
211        Ok(ArrayView::from_ndarray(expanded))
212    }
213
214    /// Remove an axis of length 1.
215    ///
216    /// This is equivalent to `np.squeeze` for a single axis.
217    /// Returns a dynamic-rank view with one fewer dimension.
218    ///
219    /// # Errors
220    /// - `AxisOutOfBounds` if `axis >= ndim`
221    /// - `InvalidValue` if the axis has size != 1
222    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'_, T, IxDyn>> {
223        let ndim = self.ndim();
224        let ax = axis.index();
225        if ax >= ndim {
226            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
227        }
228        if self.shape()[ax] != 1 {
229            return Err(FerrayError::invalid_value(format!(
230                "cannot remove axis {} with size {} (must be 1)",
231                ax,
232                self.shape()[ax]
233            )));
234        }
235
236        // index_axis_move at 0 removes the axis (consumes the view, preserving lifetime)
237        let dyn_view = self.inner.view().into_dyn();
238        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
239        Ok(ArrayView::from_ndarray(squeezed))
240    }
241
242    /// Index into the array with a flat (linear) index.
243    ///
244    /// Elements are ordered in row-major (C) order. The flat index is
245    /// converted to a multi-dimensional index via unravel, then stride
246    /// arithmetic computes the physical offset in O(ndim) time.
247    ///
248    /// # Errors
249    /// Returns `IndexOutOfBounds` if the index is out of range.
250    pub fn flat_index(&self, index: isize) -> FerrayResult<&T> {
251        let size = self.size();
252        let idx = normalize_index(index, size, 0)?;
253
254        // Unravel flat index to multi-dim index in row-major order,
255        // then compute the physical offset via strides.
256        let shape = self.shape();
257        let strides = self.inner.strides();
258        let base_ptr = self.inner.as_ptr();
259        let ndim = shape.len();
260
261        let mut remaining = idx;
262        let mut offset: isize = 0;
263        for d in 0..ndim {
264            let dim_stride: usize = shape[d + 1..].iter().product::<usize>().max(1);
265            let coord = remaining / dim_stride;
266            remaining %= dim_stride;
267            offset += coord as isize * strides[d];
268        }
269
270        // SAFETY: idx is validated in-bounds via normalize_index, and the
271        // unravel produces coordinates within each axis, so the computed
272        // offset is within the array's data allocation.
273        Ok(unsafe { &*base_ptr.offset(offset) })
274    }
275
276    /// Get a reference to a single element by multi-dimensional index.
277    ///
278    /// Supports negative indices (counting from end).
279    ///
280    /// # Errors
281    /// - `InvalidValue` if `indices.len() != ndim()`
282    /// - `IndexOutOfBounds` if any index is out of range
283    pub fn get(&self, indices: &[isize]) -> FerrayResult<&T> {
284        let ndim = self.ndim();
285        if indices.len() != ndim {
286            return Err(FerrayError::invalid_value(format!(
287                "expected {} indices, got {}",
288                ndim,
289                indices.len()
290            )));
291        }
292
293        // Compute the flat offset manually
294        let shape = self.shape();
295        let strides = self.inner.strides();
296        let base_ptr = self.inner.as_ptr();
297
298        let mut offset: isize = 0;
299        for (ax, &idx) in indices.iter().enumerate() {
300            let pos = normalize_index(idx, shape[ax], ax)?;
301            offset += pos as isize * strides[ax];
302        }
303
304        // SAFETY: all indices are validated in-bounds, so the computed
305        // offset is within the array's data allocation.
306        Ok(unsafe { &*base_ptr.offset(offset) })
307    }
308
309    /// Get a mutable reference to a single element by multi-dimensional index.
310    ///
311    /// # Errors
312    /// Same as [`get`](Self::get).
313    pub fn get_mut(&mut self, indices: &[isize]) -> FerrayResult<&mut T> {
314        let ndim = self.ndim();
315        if indices.len() != ndim {
316            return Err(FerrayError::invalid_value(format!(
317                "expected {} indices, got {}",
318                ndim,
319                indices.len()
320            )));
321        }
322
323        let shape = self.shape().to_vec();
324        let strides: Vec<isize> = self.inner.strides().to_vec();
325        let base_ptr = self.inner.as_mut_ptr();
326
327        let mut offset: isize = 0;
328        for (ax, &idx) in indices.iter().enumerate() {
329            let pos = normalize_index(idx, shape[ax], ax)?;
330            offset += pos as isize * strides[ax];
331        }
332
333        // SAFETY: we have &mut self so exclusive access is guaranteed,
334        // and all indices are validated in-bounds.
335        Ok(unsafe { &mut *base_ptr.offset(offset) })
336    }
337}
338
339// ---------------------------------------------------------------------------
340// ArrayView methods — basic indexing
341// ---------------------------------------------------------------------------
342
343impl<'a, T: Element, D: Dimension> ArrayView<'a, T, D> {
344    /// Index into the view along a given axis, removing that axis.
345    pub fn index_axis(&self, axis: Axis, index: isize) -> FerrayResult<ArrayView<'a, T, IxDyn>>
346    where
347        D::NdarrayDim: ndarray::RemoveAxis,
348    {
349        let ndim = self.ndim();
350        let ax = axis.index();
351        if ax >= ndim {
352            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
353        }
354        let size = self.shape()[ax];
355        let idx = normalize_index(index, size, ax)?;
356
357        let nd_axis = ndarray::Axis(ax);
358        // clone() on ArrayView is cheap (it's Copy-like)
359        let sub = self.inner.clone().index_axis_move(nd_axis, idx);
360        let dyn_view = sub.into_dyn();
361        Ok(ArrayView::from_ndarray(dyn_view))
362    }
363
364    /// Slice the view along a given axis.
365    pub fn slice_axis(&self, axis: Axis, spec: SliceSpec) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
366        let ndim = self.ndim();
367        let ax = axis.index();
368        if ax >= ndim {
369            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
370        }
371        spec.validate()?;
372
373        let nd_axis = ndarray::Axis(ax);
374        let nd_slice = spec.to_ndarray_slice();
375        // slice_axis on a cloned view preserves the 'a lifetime
376        let sliced = self.inner.clone().slice_axis_move(nd_axis, nd_slice);
377        let dyn_view = sliced.into_dyn();
378        Ok(ArrayView::from_ndarray(dyn_view))
379    }
380
381    /// Insert a new axis of length 1 at the given position.
382    pub fn insert_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
383        let ndim = self.ndim();
384        let ax = axis.index();
385        if ax > ndim {
386            return Err(FerrayError::axis_out_of_bounds(ax, ndim + 1));
387        }
388
389        let dyn_view = self.inner.clone().into_dyn();
390        let expanded = dyn_view.insert_axis(ndarray::Axis(ax));
391        Ok(ArrayView::from_ndarray(expanded))
392    }
393
394    /// Remove an axis of length 1.
395    pub fn remove_axis(&self, axis: Axis) -> FerrayResult<ArrayView<'a, T, IxDyn>> {
396        let ndim = self.ndim();
397        let ax = axis.index();
398        if ax >= ndim {
399            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
400        }
401        if self.shape()[ax] != 1 {
402            return Err(FerrayError::invalid_value(format!(
403                "cannot remove axis {} with size {} (must be 1)",
404                ax,
405                self.shape()[ax]
406            )));
407        }
408
409        let dyn_view = self.inner.clone().into_dyn();
410        let squeezed = dyn_view.index_axis_move(ndarray::Axis(ax), 0);
411        Ok(ArrayView::from_ndarray(squeezed))
412    }
413
414    /// Get a reference to a single element by multi-dimensional index.
415    pub fn get(&self, indices: &[isize]) -> FerrayResult<&'a T> {
416        let ndim = self.ndim();
417        if indices.len() != ndim {
418            return Err(FerrayError::invalid_value(format!(
419                "expected {} indices, got {}",
420                ndim,
421                indices.len()
422            )));
423        }
424
425        let shape = self.shape();
426        let strides = self.inner.strides();
427        let base_ptr = self.inner.as_ptr();
428
429        let mut offset: isize = 0;
430        for (ax, &idx) in indices.iter().enumerate() {
431            let pos = normalize_index(idx, shape[ax], ax)?;
432            offset += pos as isize * strides[ax];
433        }
434
435        // SAFETY: indices validated in-bounds; the pointer is valid for 'a.
436        Ok(unsafe { &*base_ptr.offset(offset) })
437    }
438}
439
440// ---------------------------------------------------------------------------
441// ArrayViewMut methods — basic indexing
442// ---------------------------------------------------------------------------
443
444impl<'a, T: Element, D: Dimension> ArrayViewMut<'a, T, D> {
445    /// Slice the mutable view along a given axis.
446    pub fn slice_axis_mut(
447        &mut self,
448        axis: Axis,
449        spec: SliceSpec,
450    ) -> FerrayResult<ArrayViewMut<'_, T, IxDyn>> {
451        let ndim = self.ndim();
452        let ax = axis.index();
453        if ax >= ndim {
454            return Err(FerrayError::axis_out_of_bounds(ax, ndim));
455        }
456        spec.validate()?;
457
458        let nd_axis = ndarray::Axis(ax);
459        let nd_slice = spec.to_ndarray_slice();
460        let sliced = self.inner.slice_axis_mut(nd_axis, nd_slice);
461        let dyn_view = sliced.into_dyn();
462        Ok(ArrayViewMut::from_ndarray(dyn_view))
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::dimension::{Ix1, Ix2, Ix3};
470
471    // -----------------------------------------------------------------------
472    // Normalization
473    // -----------------------------------------------------------------------
474
475    #[test]
476    fn normalize_positive_in_bounds() {
477        assert_eq!(normalize_index(2, 5, 0).unwrap(), 2);
478    }
479
480    #[test]
481    fn normalize_negative() {
482        assert_eq!(normalize_index(-1, 5, 0).unwrap(), 4);
483        assert_eq!(normalize_index(-5, 5, 0).unwrap(), 0);
484    }
485
486    #[test]
487    fn normalize_out_of_bounds() {
488        assert!(normalize_index(5, 5, 0).is_err());
489        assert!(normalize_index(-6, 5, 0).is_err());
490    }
491
492    // -----------------------------------------------------------------------
493    // index_axis
494    // -----------------------------------------------------------------------
495
496    #[test]
497    fn index_axis_row() {
498        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
499        let row = arr.index_axis(Axis(0), 1).unwrap();
500        assert_eq!(row.shape(), &[4]);
501        let data: Vec<i32> = row.iter().copied().collect();
502        assert_eq!(data, vec![4, 5, 6, 7]);
503    }
504
505    #[test]
506    fn index_axis_column() {
507        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
508        let col = arr.index_axis(Axis(1), 2).unwrap();
509        assert_eq!(col.shape(), &[3]);
510        let data: Vec<i32> = col.iter().copied().collect();
511        assert_eq!(data, vec![2, 6, 10]);
512    }
513
514    #[test]
515    fn index_axis_negative() {
516        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
517        let row = arr.index_axis(Axis(0), -1).unwrap();
518        let data: Vec<i32> = row.iter().copied().collect();
519        assert_eq!(data, vec![8, 9, 10, 11]);
520    }
521
522    #[test]
523    fn index_axis_out_of_bounds() {
524        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
525        assert!(arr.index_axis(Axis(0), 3).is_err());
526        assert!(arr.index_axis(Axis(2), 0).is_err());
527    }
528
529    #[test]
530    fn index_axis_is_zero_copy() {
531        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
532            .unwrap();
533        let row = arr.index_axis(Axis(0), 0).unwrap();
534        assert_eq!(row.as_ptr(), arr.as_ptr());
535    }
536
537    // -----------------------------------------------------------------------
538    // slice_axis
539    // -----------------------------------------------------------------------
540
541    #[test]
542    fn slice_axis_basic() {
543        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
544        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
545        assert_eq!(sliced.shape(), &[3]);
546        let data: Vec<i32> = sliced.iter().copied().collect();
547        assert_eq!(data, vec![20, 30, 40]);
548    }
549
550    #[test]
551    fn slice_axis_step() {
552        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([6]), vec![0, 1, 2, 3, 4, 5]).unwrap();
553        let sliced = arr
554            .slice_axis(Axis(0), SliceSpec::with_step(0, 6, 2))
555            .unwrap();
556        assert_eq!(sliced.shape(), &[3]);
557        let data: Vec<i32> = sliced.iter().copied().collect();
558        assert_eq!(data, vec![0, 2, 4]);
559    }
560
561    #[test]
562    fn slice_axis_negative_step() {
563        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
564        // Reverse the entire array: range [0, len) traversed backwards
565        // ndarray interprets Slice::new(start, end, step) where start/end define
566        // a forward range and negative step reverses traversal within it.
567        let spec = SliceSpec {
568            start: None,
569            stop: None,
570            step: Some(-1),
571        };
572        let sliced = arr.slice_axis(Axis(0), spec).unwrap();
573        let data: Vec<i32> = sliced.iter().copied().collect();
574        assert_eq!(data, vec![4, 3, 2, 1, 0]);
575    }
576
577    #[test]
578    fn slice_axis_negative_step_partial() {
579        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![0, 1, 2, 3, 4]).unwrap();
580        // Range [1, 4) traversed backwards with step -1: [3, 2, 1]
581        let sliced = arr
582            .slice_axis(Axis(0), SliceSpec::with_step(1, 4, -1))
583            .unwrap();
584        let data: Vec<i32> = sliced.iter().copied().collect();
585        assert_eq!(data, vec![3, 2, 1]);
586    }
587
588    #[test]
589    fn slice_axis_full() {
590        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
591        let sliced = arr.slice_axis(Axis(0), SliceSpec::full()).unwrap();
592        let data: Vec<i32> = sliced.iter().copied().collect();
593        assert_eq!(data, vec![1, 2, 3]);
594    }
595
596    #[test]
597    fn slice_axis_2d_rows() {
598        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 3]), (0..12).collect()).unwrap();
599        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 3)).unwrap();
600        assert_eq!(sliced.shape(), &[2, 3]);
601        let data: Vec<i32> = sliced.iter().copied().collect();
602        assert_eq!(data, vec![3, 4, 5, 6, 7, 8]);
603    }
604
605    #[test]
606    fn slice_axis_is_zero_copy() {
607        let arr =
608            Array::<f64, Ix1>::from_vec(Ix1::new([5]), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
609        let sliced = arr.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
610        unsafe {
611            assert_eq!(*sliced.as_ptr(), 2.0);
612        }
613    }
614
615    #[test]
616    fn slice_axis_zero_step_error() {
617        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
618        assert!(
619            arr.slice_axis(Axis(0), SliceSpec::with_step(0, 3, 0))
620                .is_err()
621        );
622    }
623
624    // -----------------------------------------------------------------------
625    // slice_multi
626    // -----------------------------------------------------------------------
627
628    #[test]
629    fn slice_multi_2d() {
630        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([4, 5]), (0..20).collect()).unwrap();
631        let sliced = arr
632            .slice_multi(&[SliceSpec::new(1, 3), SliceSpec::new(0, 4)])
633            .unwrap();
634        assert_eq!(sliced.shape(), &[2, 4]);
635    }
636
637    #[test]
638    fn slice_multi_wrong_count() {
639        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
640        assert!(arr.slice_multi(&[SliceSpec::full()]).is_err());
641    }
642
643    // -----------------------------------------------------------------------
644    // insert_axis / remove_axis
645    // -----------------------------------------------------------------------
646
647    #[test]
648    fn insert_axis_at_front() {
649        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
650        let expanded = arr.insert_axis(Axis(0)).unwrap();
651        assert_eq!(expanded.shape(), &[1, 3]);
652    }
653
654    #[test]
655    fn insert_axis_at_end() {
656        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
657        let expanded = arr.insert_axis(Axis(1)).unwrap();
658        assert_eq!(expanded.shape(), &[3, 1]);
659    }
660
661    #[test]
662    fn insert_axis_out_of_bounds() {
663        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
664        assert!(arr.insert_axis(Axis(3)).is_err());
665    }
666
667    #[test]
668    fn remove_axis_single() {
669        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([1, 3]), vec![1.0, 2.0, 3.0]).unwrap();
670        let squeezed = arr.remove_axis(Axis(0)).unwrap();
671        assert_eq!(squeezed.shape(), &[3]);
672        let data: Vec<f64> = squeezed.iter().copied().collect();
673        assert_eq!(data, vec![1.0, 2.0, 3.0]);
674    }
675
676    #[test]
677    fn remove_axis_not_one() {
678        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
679        assert!(arr.remove_axis(Axis(0)).is_err());
680    }
681
682    #[test]
683    fn remove_axis_out_of_bounds() {
684        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), vec![1.0, 2.0, 3.0]).unwrap();
685        assert!(arr.remove_axis(Axis(1)).is_err());
686    }
687
688    // -----------------------------------------------------------------------
689    // flat_index
690    // -----------------------------------------------------------------------
691
692    #[test]
693    fn flat_index_positive() {
694        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
695        assert_eq!(*arr.flat_index(0).unwrap(), 1);
696        assert_eq!(*arr.flat_index(3).unwrap(), 4);
697        assert_eq!(*arr.flat_index(5).unwrap(), 6);
698    }
699
700    #[test]
701    fn flat_index_negative() {
702        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
703        assert_eq!(*arr.flat_index(-1).unwrap(), 50);
704        assert_eq!(*arr.flat_index(-5).unwrap(), 10);
705    }
706
707    #[test]
708    fn flat_index_out_of_bounds() {
709        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
710        assert!(arr.flat_index(3).is_err());
711        assert!(arr.flat_index(-4).is_err());
712    }
713
714    // -----------------------------------------------------------------------
715    // get / get_mut
716    // -----------------------------------------------------------------------
717
718    #[test]
719    fn get_2d() {
720        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
721        assert_eq!(*arr.get(&[0, 0]).unwrap(), 0);
722        assert_eq!(*arr.get(&[1, 2]).unwrap(), 6);
723        assert_eq!(*arr.get(&[2, 3]).unwrap(), 11);
724    }
725
726    #[test]
727    fn get_negative_indices() {
728        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
729        assert_eq!(*arr.get(&[-1, -1]).unwrap(), 11);
730        assert_eq!(*arr.get(&[-3, 0]).unwrap(), 0);
731    }
732
733    #[test]
734    fn get_wrong_ndim() {
735        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), (0..6).collect()).unwrap();
736        assert!(arr.get(&[0]).is_err());
737        assert!(arr.get(&[0, 0, 0]).is_err());
738    }
739
740    #[test]
741    fn get_mut_modify() {
742        let mut arr =
743            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
744        *arr.get_mut(&[1, 2]).unwrap() = 99;
745        assert_eq!(*arr.get(&[1, 2]).unwrap(), 99);
746    }
747
748    // -----------------------------------------------------------------------
749    // ArrayView basic indexing
750    // -----------------------------------------------------------------------
751
752    #[test]
753    fn view_index_axis() {
754        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([3, 4]), (0..12).collect()).unwrap();
755        let v = arr.view();
756        let row = v.index_axis(Axis(0), 1).unwrap();
757        let data: Vec<i32> = row.iter().copied().collect();
758        assert_eq!(data, vec![4, 5, 6, 7]);
759    }
760
761    #[test]
762    fn view_slice_axis() {
763        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![10, 20, 30, 40, 50]).unwrap();
764        let v = arr.view();
765        let sliced = v.slice_axis(Axis(0), SliceSpec::new(1, 4)).unwrap();
766        let data: Vec<i32> = sliced.iter().copied().collect();
767        assert_eq!(data, vec![20, 30, 40]);
768    }
769
770    #[test]
771    fn view_insert_remove_axis() {
772        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
773        let v = arr.view();
774        let expanded = v.insert_axis(Axis(0)).unwrap();
775        assert_eq!(expanded.shape(), &[1, 4]);
776        let squeezed = expanded.remove_axis(Axis(0)).unwrap();
777        assert_eq!(squeezed.shape(), &[4]);
778    }
779
780    #[test]
781    fn view_get() {
782        let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
783        let v = arr.view();
784        assert_eq!(*v.get(&[1, 2]).unwrap(), 6);
785    }
786
787    // -----------------------------------------------------------------------
788    // ArrayViewMut slice
789    // -----------------------------------------------------------------------
790
791    #[test]
792    fn view_mut_slice_axis() {
793        let mut arr = Array::<i32, Ix1>::from_vec(Ix1::new([5]), vec![1, 2, 3, 4, 5]).unwrap();
794        {
795            let mut vm = arr.view_mut();
796            let mut sliced = vm.slice_axis_mut(Axis(0), SliceSpec::new(1, 3)).unwrap();
797            if let Some(s) = sliced.as_slice_mut() {
798                s[0] = 20;
799                s[1] = 30;
800            }
801        }
802        assert_eq!(arr.as_slice().unwrap(), &[1, 20, 30, 4, 5]);
803    }
804
805    // -----------------------------------------------------------------------
806    // 3D indexing
807    // -----------------------------------------------------------------------
808
809    #[test]
810    fn index_axis_3d() {
811        let arr = Array::<i32, Ix3>::from_vec(Ix3::new([2, 3, 4]), (0..24).collect()).unwrap();
812        let plane = arr.index_axis(Axis(0), 1).unwrap();
813        assert_eq!(plane.shape(), &[3, 4]);
814        assert_eq!(*plane.get(&[0, 0]).unwrap(), 12);
815    }
816}