use crate::Tensor;
use std::ops::Index;
#[inline]
fn flat_index<T, const N: usize>(tensor: &Tensor<T>, index: &[usize; N]) -> usize {
debug_assert!((0..N).all(|i| index[i] < tensor.shape().dim_at(i)));
let strides = tensor.strides().as_slice();
let mut flat = 0;
for i in 0..N {
flat += index[i] * strides[i];
}
flat
}
impl<T> Index<usize> for Tensor<T> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl<T> Index<(usize, usize)> for Tensor<T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
&self.data[index.0 * self.strides().stride_at(0) + index.1 * self.strides().stride_at(1)]
}
}
impl<T> Index<(usize, usize, usize)> for Tensor<T> {
type Output = T;
fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
let flat_index = index.0 * self.strides().stride_at(0)
+ index.1 * self.strides().stride_at(1)
+ index.2 * self.strides().stride_at(2);
&self.data[flat_index]
}
}
impl<T> Index<(usize, usize, usize, usize)> for Tensor<T> {
type Output = T;
fn index(&self, index: (usize, usize, usize, usize)) -> &Self::Output {
let flat_index = index.0 * self.strides().stride_at(0)
+ index.1 * self.strides().stride_at(1)
+ index.2 * self.strides().stride_at(2)
+ index.3 * self.strides().stride_at(3);
&self.data[flat_index]
}
}
impl<T, const N: usize> Index<[usize; N]> for Tensor<T> {
type Output = T;
fn index(&self, index: [usize; N]) -> &Self::Output {
let flat_index = flat_index(self, &index);
&self.data[flat_index]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flat_index() {
let tensor = Tensor::new(vec![0; 60], (3, 4, 5));
let index = [2, 1, 3];
let flat_idx = flat_index(&tensor, &index);
let expected_idx = 2 * 20 + 1 * 5 + 3;
assert_eq!(flat_idx, expected_idx);
let value = &tensor[index];
assert_eq!(value, &0);
}
#[test]
fn test_index_1d() {
let t = Tensor::new(vec![5, 6, 7, 8], 4);
assert_eq!(t[[0]], 5);
assert_eq!(t[[3]], 8);
}
#[test]
fn test_index_2d_matches_row_major() {
let t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
assert_eq!(t[[0, 0]], 0);
assert_eq!(t[[0, 2]], 2);
assert_eq!(t[[1, 0]], 3);
assert_eq!(t[[1, 2]], 5);
}
#[test]
fn test_index_3d_matches_strides() {
let t = Tensor::new((0..24).collect::<Vec<i32>>(), (2, 3, 4));
assert_eq!(t.strides().as_slice(), &[12, 4, 1]);
assert_eq!(t[[0, 0, 0]], 0);
assert_eq!(t[[0, 0, 3]], 3);
assert_eq!(t[[0, 2, 1]], 0 * 12 + 2 * 4 + 1);
assert_eq!(t[[1, 0, 0]], 12);
assert_eq!(t[[1, 2, 3]], 1 * 12 + 2 * 4 + 3);
}
#[test]
fn test_index_4d() {
let t = Tensor::new((0..16).collect::<Vec<i32>>(), (2, 2, 2, 2));
assert_eq!(t.strides().as_slice(), &[8, 4, 2, 1]);
assert_eq!(t[[0, 0, 0, 0]], 0);
assert_eq!(t[[0, 0, 0, 1]], 1);
assert_eq!(t[[0, 1, 0, 0]], 4);
assert_eq!(t[[1, 0, 0, 0]], 8);
assert_eq!(t[[1, 1, 1, 1]], 15);
}
#[test]
fn test_index_rank_gt_5() {
let t = Tensor::new((0..4).collect::<Vec<i32>>(), (1, 1, 1, 1, 1, 1, 4));
assert_eq!(t.strides().as_slice(), &[4, 4, 4, 4, 4, 4, 1]);
assert_eq!(t[[0, 0, 0, 0, 0, 0, 0]], 0);
assert_eq!(t[[0, 0, 0, 0, 0, 0, 3]], 3);
}
#[cfg(debug_assertions)]
#[test]
#[should_panic]
fn test_index_rank_mismatch_panics_in_debug() {
let t = Tensor::new(vec![0; 6], (2, 3)); let _ = t[[0, 0, 0]]; }
#[test]
fn test_index_tuples() {
let t = Tensor::new((0..(2 * 3 * 4)).collect::<Vec<i32>>(), (2, 3, 4));
assert_eq!(t[(0, 0, 0)], 0);
assert_eq!(t[(0, 1, 2)], 6);
assert_eq!(t[(1, 0, 0)], 12);
assert_eq!(t[(1, 2, 3)], 23);
}
}