opensrdk_linear_algebra/tensor/sparse/operators/
index.rs

1use std::ops::{Index, IndexMut};
2
3use crate::{sparse::SparseTensor, Number, TensorError};
4
5impl<T> Index<&[usize]> for SparseTensor<T>
6where
7    T: Number,
8{
9    type Output = T;
10
11    fn index(&self, index: &[usize]) -> &Self::Output {
12        if index.len() != self.sizes.len() {
13            panic!("{}", TensorError::RankMismatch);
14        }
15        for (rank, &d) in index.iter().enumerate() {
16            if self.sizes[rank] <= d {
17                panic!("{}", TensorError::OutOfRange);
18            }
19        }
20
21        self.elems.get(index).unwrap_or(&self.default)
22    }
23}
24
25impl<T> IndexMut<&[usize]> for SparseTensor<T>
26where
27    T: Number,
28{
29    fn index_mut(&mut self, index: &[usize]) -> &mut Self::Output {
30        if index.len() != self.sizes.len() {
31            panic!("{}", TensorError::RankMismatch);
32        }
33        for (rank, &d) in index.iter().enumerate() {
34            if self.sizes[rank] <= d {
35                panic!("{}", TensorError::OutOfRange);
36            }
37        }
38
39        self.elems.entry(index.to_vec()).or_default()
40    }
41}
42
43#[cfg(test)]
44mod tests {
45    use crate::{sparse::SparseTensor, *};
46
47    #[test]
48    fn index() {
49        let mut tensor = SparseTensor::new(vec![2, 3]);
50        tensor[&[0, 0]] = 1.0;
51        tensor[&[1, 1]] = 2.0;
52        tensor[&[1, 2]] = 3.0;
53
54        assert_eq!(tensor[&[0, 0]], 1.0);
55        assert_eq!(tensor[&[0, 1]], 0.0);
56        assert_eq!(tensor[&[0, 2]], 0.0);
57        assert_eq!(tensor[&[1, 0]], 0.0);
58        assert_eq!(tensor[&[1, 1]], 2.0);
59        assert_eq!(tensor[&[1, 2]], 3.0);
60    }
61
62    #[test]
63    fn index_mut() {
64        let mut tensor = SparseTensor::new(vec![2, 3]);
65        tensor[&[0, 0]] = 1.0;
66        tensor[&[1, 1]] = 2.0;
67        tensor[&[1, 2]] = 3.0;
68
69        assert_eq!(tensor[&[0, 0]], 1.0);
70        assert_eq!(tensor[&[0, 1]], 0.0);
71        assert_eq!(tensor[&[0, 2]], 0.0);
72        assert_eq!(tensor[&[1, 0]], 0.0);
73        assert_eq!(tensor[&[1, 1]], 2.0);
74        assert_eq!(tensor[&[1, 2]], 3.0);
75
76        tensor[&[0, 0]] = 0.0;
77        tensor[&[1, 1]] = 0.0;
78        tensor[&[1, 2]] = 0.0;
79
80        assert_eq!(tensor[&[0, 0]], 0.0);
81        assert_eq!(tensor[&[0, 1]], 0.0);
82        assert_eq!(tensor[&[0, 2]], 0.0);
83        assert_eq!(tensor[&[1, 0]], 0.0);
84        assert_eq!(tensor[&[1, 1]], 0.0);
85        assert_eq!(tensor[&[1, 2]], 0.0);
86    }
87}