Skip to main content

radiate_utils/array/
indices.rs

1use crate::Tensor;
2use std::ops::Index;
3
4#[inline]
5fn flat_index<T, const N: usize>(tensor: &Tensor<T>, index: &[usize; N]) -> usize {
6    debug_assert!((0..N).all(|i| index[i] < tensor.shape().dim_at(i)));
7
8    let strides = tensor.strides().as_slice();
9    let mut flat = 0;
10
11    // Const-generic N: LLVM typically unrolls for small N.
12    for i in 0..N {
13        flat += index[i] * strides[i];
14    }
15
16    flat
17}
18
19impl<T> Index<usize> for Tensor<T> {
20    type Output = T;
21
22    fn index(&self, index: usize) -> &Self::Output {
23        &self.data[index]
24    }
25}
26
27impl<T> Index<(usize, usize)> for Tensor<T> {
28    type Output = T;
29
30    fn index(&self, index: (usize, usize)) -> &Self::Output {
31        &self.data[index.0 * self.strides().stride_at(0) + index.1 * self.strides().stride_at(1)]
32    }
33}
34
35impl<T> Index<(usize, usize, usize)> for Tensor<T> {
36    type Output = T;
37
38    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
39        let flat_index = index.0 * self.strides().stride_at(0)
40            + index.1 * self.strides().stride_at(1)
41            + index.2 * self.strides().stride_at(2);
42        &self.data[flat_index]
43    }
44}
45
46impl<T> Index<(usize, usize, usize, usize)> for Tensor<T> {
47    type Output = T;
48
49    fn index(&self, index: (usize, usize, usize, usize)) -> &Self::Output {
50        let flat_index = index.0 * self.strides().stride_at(0)
51            + index.1 * self.strides().stride_at(1)
52            + index.2 * self.strides().stride_at(2)
53            + index.3 * self.strides().stride_at(3);
54        &self.data[flat_index]
55    }
56}
57
58impl<T, const N: usize> Index<[usize; N]> for Tensor<T> {
59    type Output = T;
60
61    fn index(&self, index: [usize; N]) -> &Self::Output {
62        let flat_index = flat_index(self, &index);
63        &self.data[flat_index]
64    }
65}
66
67#[cfg(test)]
68mod tests {
69    use super::*;
70
71    #[test]
72    fn test_flat_index() {
73        let tensor = Tensor::new(vec![0; 60], (3, 4, 5));
74        let index = [2, 1, 3];
75        let flat_idx = flat_index(&tensor, &index);
76
77        let expected_idx = 2 * 20 + 1 * 5 + 3;
78        assert_eq!(flat_idx, expected_idx);
79
80        let value = &tensor[index];
81        assert_eq!(value, &0);
82    }
83
84    #[test]
85    fn test_index_1d() {
86        let t = Tensor::new(vec![5, 6, 7, 8], 4);
87        assert_eq!(t[[0]], 5);
88        assert_eq!(t[[3]], 8);
89    }
90
91    #[test]
92    fn test_index_2d_matches_row_major() {
93        // shape (2, 3) row-major => strides [3, 1]
94        // data layout:
95        // [[0,1,2],
96        //  [3,4,5]]
97        let t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
98        assert_eq!(t[[0, 0]], 0);
99        assert_eq!(t[[0, 2]], 2);
100        assert_eq!(t[[1, 0]], 3);
101        assert_eq!(t[[1, 2]], 5);
102    }
103
104    #[test]
105    fn test_index_3d_matches_strides() {
106        // shape (2, 3, 4) => strides [12, 4, 1]
107        let t = Tensor::new((0..24).collect::<Vec<i32>>(), (2, 3, 4));
108        assert_eq!(t.strides().as_slice(), &[12, 4, 1]);
109
110        // flat = a*12 + b*4 + c
111        assert_eq!(t[[0, 0, 0]], 0);
112        assert_eq!(t[[0, 0, 3]], 3);
113        assert_eq!(t[[0, 2, 1]], 0 * 12 + 2 * 4 + 1);
114        assert_eq!(t[[1, 0, 0]], 12);
115        assert_eq!(t[[1, 2, 3]], 1 * 12 + 2 * 4 + 3);
116    }
117
118    #[test]
119    fn test_index_4d() {
120        // shape (2, 2, 2, 2) => strides [8, 4, 2, 1]
121        let t = Tensor::new((0..16).collect::<Vec<i32>>(), (2, 2, 2, 2));
122        assert_eq!(t.strides().as_slice(), &[8, 4, 2, 1]);
123
124        // pick a few spots
125        assert_eq!(t[[0, 0, 0, 0]], 0);
126        assert_eq!(t[[0, 0, 0, 1]], 1);
127        assert_eq!(t[[0, 1, 0, 0]], 4);
128        assert_eq!(t[[1, 0, 0, 0]], 8);
129        assert_eq!(t[[1, 1, 1, 1]], 15);
130    }
131
132    #[test]
133    fn test_index_rank_gt_5() {
134        // shape (1,1,1,1,1,1,4) => rank 7
135        // strides should be [4,4,4,4,4,4,1]
136        let t = Tensor::new((0..4).collect::<Vec<i32>>(), (1, 1, 1, 1, 1, 1, 4));
137        assert_eq!(t.strides().as_slice(), &[4, 4, 4, 4, 4, 4, 1]);
138
139        assert_eq!(t[[0, 0, 0, 0, 0, 0, 0]], 0);
140        assert_eq!(t[[0, 0, 0, 0, 0, 0, 3]], 3);
141    }
142
143    #[cfg(debug_assertions)]
144    #[test]
145    #[should_panic]
146    fn test_index_rank_mismatch_panics_in_debug() {
147        let t = Tensor::new(vec![0; 6], (2, 3)); // rank 2
148        let _ = t[[0, 0, 0]]; // rank 3 index => should trip debug_assert
149    }
150
151    #[test]
152    fn test_index_tuples() {
153        let t = Tensor::new((0..(2 * 3 * 4)).collect::<Vec<i32>>(), (2, 3, 4));
154
155        assert_eq!(t[(0, 0, 0)], 0);
156        assert_eq!(t[(0, 1, 2)], 6);
157        assert_eq!(t[(1, 0, 0)], 12);
158        assert_eq!(t[(1, 2, 3)], 23);
159    }
160}