math/tensor/
indexable_tensor.rs1use crate::tensor::{has_tensor_shape_data::HasTensorShapeData, Unitless};
2
3pub trait IndexableTensor<Dtype>: HasTensorShapeData<Dtype>
4where
5 Dtype: Copy, {
6 fn at<T: AsRef<[Unitless]>>(&self, coord: T) -> Dtype {
7 self.data()[self.coord_to_index(coord.as_ref()) as usize]
8 }
9}
10
11impl<Dtype, T> IndexableTensor<Dtype> for T
12where
13 T: HasTensorShapeData<Dtype>,
14 Dtype: Copy,
15{
16}
17
18#[cfg(test)]
19mod tests {
20 use crate::tensor::{
21 ephemeral_view::ToEphemeralView, indexable_tensor::IndexableTensor,
22 tensor_storage::IntoTensorStorage,
23 };
24
25 #[test]
26 fn test_indexing() {
27 let storage = vec![1, 2, 3, 4, 5, 6].into_tensor_storage();
28 let view = storage.as_shape([2, 3]);
29 assert_eq!(view.at([0, 0]), 1);
30 assert_eq!(view.at(&[0, 1]), 2);
31 assert_eq!(view.at([0, 2]), 3);
32
33 assert_eq!(view.at(vec![1, 0]), 4);
34
35 let coord = vec![1, 1];
36 assert_eq!(view.at(coord), 5);
37
38 let coord = vec![1, 2];
39 assert_eq!(view.at(&coord), 6);
40 }
41}