math/tensor/
indexable_tensor.rs

1use crate::tensor::{has_tensor_shape_data::HasTensorShapeData, Unitless};
2
3pub trait IndexableTensor<Dtype>: HasTensorShapeData<Dtype>
4where
5    Dtype: Copy, {
6    fn at<T: AsRef<[Unitless]>>(&self, coord: T) -> Dtype {
7        self.data()[self.coord_to_index(coord.as_ref()) as usize]
8    }
9}
10
11impl<Dtype, T> IndexableTensor<Dtype> for T
12where
13    T: HasTensorShapeData<Dtype>,
14    Dtype: Copy,
15{
16}
17
18#[cfg(test)]
19mod tests {
20    use crate::tensor::{
21        ephemeral_view::ToEphemeralView, indexable_tensor::IndexableTensor,
22        tensor_storage::IntoTensorStorage,
23    };
24
25    #[test]
26    fn test_indexing() {
27        let storage = vec![1, 2, 3, 4, 5, 6].into_tensor_storage();
28        let view = storage.as_shape([2, 3]);
29        assert_eq!(view.at([0, 0]), 1);
30        assert_eq!(view.at(&[0, 1]), 2);
31        assert_eq!(view.at([0, 2]), 3);
32
33        assert_eq!(view.at(vec![1, 0]), 4);
34
35        let coord = vec![1, 1];
36        assert_eq!(view.at(coord), 5);
37
38        let coord = vec![1, 2];
39        assert_eq!(view.at(&coord), 6);
40    }
41}