1use std::ops::{Index, IndexMut};
7
8use crate::dtype::Element;
9
10use super::Array;
11
12#[inline]
17fn flat_offset(indices: &[usize], shape: &[usize]) -> usize {
18 debug_assert_eq!(indices.len(), shape.len());
19 let mut offset = 0;
20 for (i, (&idx, &dim)) in indices.iter().zip(shape.iter()).enumerate() {
21 assert!(
22 idx < dim,
23 "index out of bounds: axis {i} index {idx} >= dimension {dim}"
24 );
25 offset = offset * dim + idx;
26 }
27 offset
28}
29
30macro_rules! impl_index {
31 ($ix:ident, $n:expr) => {
32 impl<T: Element> Index<[usize; $n]> for Array<T, crate::dimension::$ix> {
33 type Output = T;
34
35 #[inline]
36 fn index(&self, idx: [usize; $n]) -> &T {
37 let offset = flat_offset(&idx, self.shape());
38 if let Some(slice) = self.as_slice() {
42 &slice[offset]
43 } else {
44 let strides = self.strides();
46 let mut raw_offset: isize = 0;
47 for (&i, &s) in idx.iter().zip(strides.iter()) {
48 raw_offset += i as isize * s;
49 }
50 unsafe { &*self.as_ptr().offset(raw_offset) }
51 }
52 }
53 }
54
55 impl<T: Element> IndexMut<[usize; $n]> for Array<T, crate::dimension::$ix> {
56 #[inline]
57 fn index_mut(&mut self, idx: [usize; $n]) -> &mut T {
58 let strides = self.strides().to_vec();
60 let shape = self.shape().to_vec();
61 let _ = flat_offset(&idx, &shape); let mut raw_offset: isize = 0;
63 for (&i, &s) in idx.iter().zip(strides.iter()) {
64 raw_offset += i as isize * s;
65 }
66 unsafe { &mut *self.as_mut_ptr().offset(raw_offset) }
68 }
69 }
70 };
71}
72
73impl_index!(Ix1, 1);
74impl_index!(Ix2, 2);
75impl_index!(Ix3, 3);
76impl_index!(Ix4, 4);
77impl_index!(Ix5, 5);
78impl_index!(Ix6, 6);
79
80impl<T: Element> Index<&[usize]> for Array<T, crate::dimension::IxDyn> {
82 type Output = T;
83
84 #[inline]
85 fn index(&self, idx: &[usize]) -> &T {
86 assert_eq!(
87 idx.len(),
88 self.ndim(),
89 "index dimension mismatch: got {} indices for {}D array",
90 idx.len(),
91 self.ndim()
92 );
93 let offset = flat_offset(idx, self.shape());
94 if let Some(slice) = self.as_slice() {
95 &slice[offset]
96 } else {
97 let strides = self.strides();
98 let mut raw_offset: isize = 0;
99 for (&i, &s) in idx.iter().zip(strides.iter()) {
100 raw_offset += i as isize * s;
101 }
102 unsafe { &*self.as_ptr().offset(raw_offset) }
103 }
104 }
105}
106
107impl<T: Element> IndexMut<&[usize]> for Array<T, crate::dimension::IxDyn> {
108 #[inline]
109 fn index_mut(&mut self, idx: &[usize]) -> &mut T {
110 assert_eq!(
111 idx.len(),
112 self.ndim(),
113 "index dimension mismatch: got {} indices for {}D array",
114 idx.len(),
115 self.ndim()
116 );
117 let strides = self.strides().to_vec();
118 let shape = self.shape().to_vec();
119 let _ = flat_offset(idx, &shape); let mut raw_offset: isize = 0;
121 for (&i, &s) in idx.iter().zip(strides.iter()) {
122 raw_offset += i as isize * s;
123 }
124 unsafe { &mut *self.as_mut_ptr().offset(raw_offset) }
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use crate::dimension::{Ix1, Ix2, Ix3, Ix4, IxDyn};
132
133 use super::*;
134
135 #[test]
136 fn index_1d() {
137 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([4]), vec![10.0, 20.0, 30.0, 40.0]).unwrap();
138 assert_eq!(arr[[0]], 10.0);
139 assert_eq!(arr[[3]], 40.0);
140 }
141
142 #[test]
143 fn index_2d() {
144 let arr =
145 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![1, 2, 3, 4, 5, 6]).unwrap();
146 assert_eq!(arr[[0, 0]], 1);
147 assert_eq!(arr[[0, 2]], 3);
148 assert_eq!(arr[[1, 0]], 4);
149 assert_eq!(arr[[1, 2]], 6);
150 }
151
152 #[test]
153 fn index_3d() {
154 let arr =
156 Array::<f32, Ix3>::from_vec(Ix3::new([2, 2, 2]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
157 .unwrap();
158 assert_eq!(arr[[0, 0, 0]], 1.0);
159 assert_eq!(arr[[0, 0, 1]], 2.0);
160 assert_eq!(arr[[1, 1, 1]], 8.0);
161 }
162
163 #[test]
164 fn index_4d() {
165 let data: Vec<i32> = (0..16).collect();
167 let arr = Array::<i32, Ix4>::from_vec(Ix4::new([2, 2, 2, 2]), data).unwrap();
168 assert_eq!(arr[[0, 0, 0, 0]], 0);
169 assert_eq!(arr[[0, 0, 0, 1]], 1);
170 assert_eq!(arr[[1, 1, 1, 1]], 15);
171 }
172
173 #[test]
174 fn index_mut_2d() {
175 let mut arr =
176 Array::<i32, Ix2>::from_vec(Ix2::new([2, 3]), vec![0; 6]).unwrap();
177 arr[[0, 1]] = 42;
178 arr[[1, 2]] = 99;
179 assert_eq!(arr[[0, 1]], 42);
180 assert_eq!(arr[[1, 2]], 99);
181 assert_eq!(arr[[0, 0]], 0);
182 }
183
184 #[test]
185 fn index_dyn() {
186 let arr =
187 Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
188 .unwrap();
189 assert_eq!(arr[&[0, 0][..]], 1.0);
190 assert_eq!(arr[&[1, 2][..]], 6.0);
191 }
192
193 #[test]
194 fn index_mut_dyn() {
195 let mut arr =
196 Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![0, 0, 0]).unwrap();
197 arr[&[1][..]] = 77;
198 assert_eq!(arr[&[1][..]], 77);
199 }
200
201 #[test]
202 #[should_panic(expected = "index out of bounds")]
203 fn index_out_of_bounds() {
204 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![0.0; 6]).unwrap();
205 let _ = arr[[2, 0]]; }
207
208 #[test]
209 #[should_panic(expected = "index dimension mismatch")]
210 fn index_dyn_wrong_ndim() {
211 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2, 3]), vec![0.0; 6]).unwrap();
212 let _ = arr[&[0][..]]; }
214}