1use crate::Tensor;
2use std::ops::Index;
3
4#[inline]
5fn flat_index<T, const N: usize>(tensor: &Tensor<T>, index: &[usize; N]) -> usize {
6 debug_assert!((0..N).all(|i| index[i] < tensor.shape().dim_at(i)));
7
8 let strides = tensor.strides().as_slice();
9 let mut flat = 0;
10
11 for i in 0..N {
13 flat += index[i] * strides[i];
14 }
15
16 flat
17}
18
19impl<T> Index<usize> for Tensor<T> {
20 type Output = T;
21
22 fn index(&self, index: usize) -> &Self::Output {
23 &self.data[index]
24 }
25}
26
27impl<T> Index<(usize, usize)> for Tensor<T> {
28 type Output = T;
29
30 fn index(&self, index: (usize, usize)) -> &Self::Output {
31 &self.data[index.0 * self.strides().stride_at(0) + index.1 * self.strides().stride_at(1)]
32 }
33}
34
35impl<T> Index<(usize, usize, usize)> for Tensor<T> {
36 type Output = T;
37
38 fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
39 let flat_index = index.0 * self.strides().stride_at(0)
40 + index.1 * self.strides().stride_at(1)
41 + index.2 * self.strides().stride_at(2);
42 &self.data[flat_index]
43 }
44}
45
46impl<T> Index<(usize, usize, usize, usize)> for Tensor<T> {
47 type Output = T;
48
49 fn index(&self, index: (usize, usize, usize, usize)) -> &Self::Output {
50 let flat_index = index.0 * self.strides().stride_at(0)
51 + index.1 * self.strides().stride_at(1)
52 + index.2 * self.strides().stride_at(2)
53 + index.3 * self.strides().stride_at(3);
54 &self.data[flat_index]
55 }
56}
57
58impl<T, const N: usize> Index<[usize; N]> for Tensor<T> {
59 type Output = T;
60
61 fn index(&self, index: [usize; N]) -> &Self::Output {
62 let flat_index = flat_index(self, &index);
63 &self.data[flat_index]
64 }
65}
66
67#[cfg(test)]
68mod tests {
69 use super::*;
70
71 #[test]
72 fn test_flat_index() {
73 let tensor = Tensor::new(vec![0; 60], (3, 4, 5));
74 let index = [2, 1, 3];
75 let flat_idx = flat_index(&tensor, &index);
76
77 let expected_idx = 2 * 20 + 1 * 5 + 3;
78 assert_eq!(flat_idx, expected_idx);
79
80 let value = &tensor[index];
81 assert_eq!(value, &0);
82 }
83
84 #[test]
85 fn test_index_1d() {
86 let t = Tensor::new(vec![5, 6, 7, 8], 4);
87 assert_eq!(t[[0]], 5);
88 assert_eq!(t[[3]], 8);
89 }
90
91 #[test]
92 fn test_index_2d_matches_row_major() {
93 let t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
98 assert_eq!(t[[0, 0]], 0);
99 assert_eq!(t[[0, 2]], 2);
100 assert_eq!(t[[1, 0]], 3);
101 assert_eq!(t[[1, 2]], 5);
102 }
103
104 #[test]
105 fn test_index_3d_matches_strides() {
106 let t = Tensor::new((0..24).collect::<Vec<i32>>(), (2, 3, 4));
108 assert_eq!(t.strides().as_slice(), &[12, 4, 1]);
109
110 assert_eq!(t[[0, 0, 0]], 0);
112 assert_eq!(t[[0, 0, 3]], 3);
113 assert_eq!(t[[0, 2, 1]], 0 * 12 + 2 * 4 + 1);
114 assert_eq!(t[[1, 0, 0]], 12);
115 assert_eq!(t[[1, 2, 3]], 1 * 12 + 2 * 4 + 3);
116 }
117
118 #[test]
119 fn test_index_4d() {
120 let t = Tensor::new((0..16).collect::<Vec<i32>>(), (2, 2, 2, 2));
122 assert_eq!(t.strides().as_slice(), &[8, 4, 2, 1]);
123
124 assert_eq!(t[[0, 0, 0, 0]], 0);
126 assert_eq!(t[[0, 0, 0, 1]], 1);
127 assert_eq!(t[[0, 1, 0, 0]], 4);
128 assert_eq!(t[[1, 0, 0, 0]], 8);
129 assert_eq!(t[[1, 1, 1, 1]], 15);
130 }
131
132 #[test]
133 fn test_index_rank_gt_5() {
134 let t = Tensor::new((0..4).collect::<Vec<i32>>(), (1, 1, 1, 1, 1, 1, 4));
137 assert_eq!(t.strides().as_slice(), &[4, 4, 4, 4, 4, 4, 1]);
138
139 assert_eq!(t[[0, 0, 0, 0, 0, 0, 0]], 0);
140 assert_eq!(t[[0, 0, 0, 0, 0, 0, 3]], 3);
141 }
142
143 #[cfg(debug_assertions)]
144 #[test]
145 #[should_panic]
146 fn test_index_rank_mismatch_panics_in_debug() {
147 let t = Tensor::new(vec![0; 6], (2, 3)); let _ = t[[0, 0, 0]]; }
150
151 #[test]
152 fn test_index_tuples() {
153 let t = Tensor::new((0..(2 * 3 * 4)).collect::<Vec<i32>>(), (2, 3, 4));
154
155 assert_eq!(t[(0, 0, 0)], 0);
156 assert_eq!(t[(0, 1, 2)], 6);
157 assert_eq!(t[(1, 0, 0)], 12);
158 assert_eq!(t[(1, 2, 3)], 23);
159 }
160}