opensrdk_linear_algebra/tensor/sparse/operators/
index.rs1use std::ops::{Index, IndexMut};
2
3use crate::{sparse::SparseTensor, Number, TensorError};
4
5impl<T> Index<&[usize]> for SparseTensor<T>
6where
7 T: Number,
8{
9 type Output = T;
10
11 fn index(&self, index: &[usize]) -> &Self::Output {
12 if index.len() != self.sizes.len() {
13 panic!("{}", TensorError::RankMismatch);
14 }
15 for (rank, &d) in index.iter().enumerate() {
16 if self.sizes[rank] <= d {
17 panic!("{}", TensorError::OutOfRange);
18 }
19 }
20
21 self.elems.get(index).unwrap_or(&self.default)
22 }
23}
24
25impl<T> IndexMut<&[usize]> for SparseTensor<T>
26where
27 T: Number,
28{
29 fn index_mut(&mut self, index: &[usize]) -> &mut Self::Output {
30 if index.len() != self.sizes.len() {
31 panic!("{}", TensorError::RankMismatch);
32 }
33 for (rank, &d) in index.iter().enumerate() {
34 if self.sizes[rank] <= d {
35 panic!("{}", TensorError::OutOfRange);
36 }
37 }
38
39 self.elems.entry(index.to_vec()).or_default()
40 }
41}
42
43#[cfg(test)]
44mod tests {
45 use crate::{sparse::SparseTensor, *};
46
47 #[test]
48 fn index() {
49 let mut tensor = SparseTensor::new(vec![2, 3]);
50 tensor[&[0, 0]] = 1.0;
51 tensor[&[1, 1]] = 2.0;
52 tensor[&[1, 2]] = 3.0;
53
54 assert_eq!(tensor[&[0, 0]], 1.0);
55 assert_eq!(tensor[&[0, 1]], 0.0);
56 assert_eq!(tensor[&[0, 2]], 0.0);
57 assert_eq!(tensor[&[1, 0]], 0.0);
58 assert_eq!(tensor[&[1, 1]], 2.0);
59 assert_eq!(tensor[&[1, 2]], 3.0);
60 }
61
62 #[test]
63 fn index_mut() {
64 let mut tensor = SparseTensor::new(vec![2, 3]);
65 tensor[&[0, 0]] = 1.0;
66 tensor[&[1, 1]] = 2.0;
67 tensor[&[1, 2]] = 3.0;
68
69 assert_eq!(tensor[&[0, 0]], 1.0);
70 assert_eq!(tensor[&[0, 1]], 0.0);
71 assert_eq!(tensor[&[0, 2]], 0.0);
72 assert_eq!(tensor[&[1, 0]], 0.0);
73 assert_eq!(tensor[&[1, 1]], 2.0);
74 assert_eq!(tensor[&[1, 2]], 3.0);
75
76 tensor[&[0, 0]] = 0.0;
77 tensor[&[1, 1]] = 0.0;
78 tensor[&[1, 2]] = 0.0;
79
80 assert_eq!(tensor[&[0, 0]], 0.0);
81 assert_eq!(tensor[&[0, 1]], 0.0);
82 assert_eq!(tensor[&[0, 2]], 0.0);
83 assert_eq!(tensor[&[1, 0]], 0.0);
84 assert_eq!(tensor[&[1, 1]], 0.0);
85 assert_eq!(tensor[&[1, 2]], 0.0);
86 }
87}