Skip to main content

ferray_core/array/
index_impl.rs

1//! `Index` and `IndexMut` impls for `Array<T, IxN>` with `[usize; N]` indices.
2//!
3//! Enables `arr[[i, j]]` syntax for element access, matching ndarray's ergonomics.
4//! Panics on out-of-bounds (standard `Index` trait semantics).
5
6use std::ops::{Index, IndexMut};
7
8use crate::dtype::Element;
9
10use super::Array;
11
12/// Compute the flat row-major offset for the given indices and shape.
13///
14/// # Panics
15/// Panics if any index is out of bounds.
16#[inline]
17fn flat_offset(indices: &[usize], shape: &[usize]) -> usize {
18    debug_assert_eq!(indices.len(), shape.len());
19    let mut offset = 0;
20    for (i, (&idx, &dim)) in indices.iter().zip(shape.iter()).enumerate() {
21        assert!(
22            idx < dim,
23            "index out of bounds: axis {i} index {idx} >= dimension {dim}"
24        );
25        offset = offset * dim + idx;
26    }
27    offset
28}
29
30macro_rules! impl_index {
31    ($ix:ident, $n:expr) => {
32        impl<T: Element> Index<[usize; $n]> for Array<T, crate::dimension::$ix> {
33            type Output = T;
34
35            #[inline]
36            fn index(&self, idx: [usize; $n]) -> &T {
37                let offset = flat_offset(&idx, self.shape());
38                // to_vec_flat iterates in row-major order, and as_slice returns
39                // the contiguous buffer when the array is C-contiguous (the default).
40                // For non-contiguous arrays, fall back to the strides.
41                if let Some(slice) = self.as_slice() {
42                    &slice[offset]
43                } else {
44                    // Non-contiguous: compute via strides
45                    let strides = self.strides();
46                    let mut raw_offset: isize = 0;
47                    for (&i, &s) in idx.iter().zip(strides.iter()) {
48                        raw_offset += i as isize * s;
49                    }
50                    unsafe { &*self.as_ptr().offset(raw_offset) }
51                }
52            }
53        }
54
55        impl<T: Element> IndexMut<[usize; $n]> for Array<T, crate::dimension::$ix> {
56            #[inline]
57            fn index_mut(&mut self, idx: [usize; $n]) -> &mut T {
58                // Bounds check first, then use raw pointer to avoid borrow conflicts
59                let strides = self.strides().to_vec();
60                let shape = self.shape().to_vec();
61                let _ = flat_offset(&idx, &shape); // panics if out of bounds
62                let mut raw_offset: isize = 0;
63                for (&i, &s) in idx.iter().zip(strides.iter()) {
64                    raw_offset += i as isize * s;
65                }
66                // SAFETY: all indices validated in-bounds by flat_offset
67                unsafe { &mut *self.as_mut_ptr().offset(raw_offset) }
68            }
69        }
70    };
71}
72
73impl_index!(Ix1, 1);
74impl_index!(Ix2, 2);
75impl_index!(Ix3, 3);
76impl_index!(Ix4, 4);
77impl_index!(Ix5, 5);
78impl_index!(Ix6, 6);
79
80// IxDyn uses &[usize] slice indexing instead of fixed-size arrays
81impl<T: Element> Index<&[usize]> for Array<T, crate::dimension::IxDyn> {
82    type Output = T;
83
84    #[inline]
85    fn index(&self, idx: &[usize]) -> &T {
86        assert_eq!(
87            idx.len(),
88            self.ndim(),
89            "index dimension mismatch: got {} indices for {}D array",
90            idx.len(),
91            self.ndim()
92        );
93        let offset = flat_offset(idx, self.shape());
94        if let Some(slice) = self.as_slice() {
95            &slice[offset]
96        } else {
97            let strides = self.strides();
98            let mut raw_offset: isize = 0;
99            for (&i, &s) in idx.iter().zip(strides.iter()) {
100                raw_offset += i as isize * s;
101            }
102            unsafe { &*self.as_ptr().offset(raw_offset) }
103        }
104    }
105}
106
107impl<T: Element> IndexMut<&[usize]> for Array<T, crate::dimension::IxDyn> {
108    #[inline]
109    fn index_mut(&mut self, idx: &[usize]) -> &mut T {
110        assert_eq!(
111            idx.len(),
112            self.ndim(),
113            "index dimension mismatch: got {} indices for {}D array",
114            idx.len(),
115            self.ndim()
116        );
117        let strides = self.strides().to_vec();
118        let shape = self.shape().to_vec();
119        let _ = flat_offset(idx, &shape); // panics if out of bounds
120        let mut raw_offset: isize = 0;
121        for (&i, &s) in idx.iter().zip(strides.iter()) {
122            raw_offset += i as isize * s;
123        }
124        // SAFETY: all indices validated in-bounds by flat_offset
125        unsafe { &mut *self.as_mut_ptr().offset(raw_offset) }
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use crate::dimension::{Ix1, Ix2, Ix3, Ix4, IxDyn};
132
133    use super::*;
134
135    #[test]
136    fn index_1d() {
137        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
138        assert_eq!(arr[[0]], 10.0);
139        assert_eq!(arr[[3]], 40.0);
140    }
141
142    #[test]
143    fn index_2d() {
144        let arr =
145            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
146        assert_eq!(arr[[0, 0]], 1);
147        assert_eq!(arr[[0, 2]], 3);
148        assert_eq!(arr[[1, 0]], 4);
149        assert_eq!(arr[[1, 2]], 6);
150    }
151
152    #[test]
153    fn index_3d() {
154        // 2x2x2 = 8 elements
155        let arr =
156            Array::<f32, Ix3>::from_vec(Ix3::new([2, 2, 2]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
157                .unwrap();
158        assert_eq!(arr[[0, 0, 0]], 1.0);
159        assert_eq!(arr[[0, 0, 1]], 2.0);
160        assert_eq!(arr[[1, 1, 1]], 8.0);
161    }
162
163    #[test]
164    fn index_4d() {
165        // 2x2x2x2 = 16 elements
166        let data: Vec<i32> = (0..16).collect();
167        let arr = Array::<i32, Ix4>::from_vec(Ix4::new([2, 2, 2, 2]), data).unwrap();
168        assert_eq!(arr[[0, 0, 0, 0]], 0);
169        assert_eq!(arr[[0, 0, 0, 1]], 1);
170        assert_eq!(arr[[1, 1, 1, 1]], 15);
171    }
172
173    #[test]
174    fn index_mut_2d() {
175        let mut arr =
176            Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
177        arr[[0, 1]] = 42;
178        arr[[1, 2]] = 99;
179        assert_eq!(arr[[0, 1]], 42);
180        assert_eq!(arr[[1, 2]], 99);
181        assert_eq!(arr[[0, 0]], 0);
182    }
183
184    #[test]
185    fn index_dyn() {
186        let arr =
187            Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
188                .unwrap();
189        assert_eq!(arr[&[0, 0][..]], 1.0);
190        assert_eq!(arr[&[1, 2][..]], 6.0);
191    }
192
193    #[test]
194    fn index_mut_dyn() {
195        let mut arr =
196            Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0, 0, 0]).unwrap();
197        arr[&[1][..]] = 77;
198        assert_eq!(arr[&[1][..]], 77);
199    }
200
201    #[test]
202    #[should_panic(expected = "index out of bounds")]
203    fn index_out_of_bounds() {
204        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
205        let _ = arr[[2, 0]]; // row 2 doesn't exist
206    }
207
208    #[test]
209    #[should_panic(expected = "index dimension mismatch")]
210    fn index_dyn_wrong_ndim() {
211        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
212        let _ = arr[&[0][..]]; // 1 index for 2D array
213    }
214}