1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
use crate::{
    shapes::{Shape, Unit},
    tensor::{Cpu, Tensor},
};
use std::sync::Arc;

pub(crate) fn index_to_i<S: Shape>(shape: &S, strides: &S::Concrete, index: S::Concrete) -> usize {
    let sizes = shape.concrete();
    for (i, idx) in index.into_iter().enumerate() {
        if idx >= sizes[i] {
            panic!("Index out of bounds: index={index:?} shape={shape:?}");
        }
    }
    strides.into_iter().zip(index).map(|(a, b)| a * b).sum()
}

impl<S: Shape, E: Unit, T> std::ops::Index<S::Concrete> for Tensor<S, E, Cpu, T> {
    type Output = E;
    #[inline(always)]
    fn index(&self, index: S::Concrete) -> &Self::Output {
        let i = index_to_i(&self.shape, &self.strides, index);
        &self.data[i]
    }
}

impl<S: Shape, E: Unit, T> std::ops::IndexMut<S::Concrete> for Tensor<S, E, Cpu, T> {
    #[inline(always)]
    fn index_mut(&mut self, index: S::Concrete) -> &mut Self::Output {
        let i = index_to_i(&self.shape, &self.strides, index);
        let data = Arc::make_mut(&mut self.data);
        &mut data[i]
    }
}