opensrdk_linear_algebra/tensor/sparse/
mod.rs1pub mod operations;
2pub mod operators;
3
4pub use operations::*;
5
6use crate::{Matrix, Number, RankIndex, Tensor, TensorError};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct SparseTensor<T = f64>
12where
13 T: Number,
14{
15 sizes: Vec<usize>,
16 elems: HashMap<Vec<usize>, T>,
17 default: T,
18}
19
20impl<T> SparseTensor<T>
21where
22 T: Number,
23{
24 pub fn new(sizes: Vec<usize>) -> Self {
25 Self {
26 sizes,
27 elems: HashMap::new(),
28 default: T::default(),
29 }
30 }
31
32 pub fn from(sizes: Vec<usize>, elems: HashMap<Vec<usize>, T>) -> Result<Self, TensorError> {
33 for (index, _) in elems.iter() {
34 if index.len() != sizes.len() {
35 return Err(TensorError::RankMismatch);
36 }
37 for (rank, &d) in index.iter().enumerate() {
38 if sizes[rank] <= d {
39 return Err(TensorError::OutOfRange);
40 }
41 }
42 }
43 Ok(Self {
44 sizes,
45 elems,
46 default: T::default(),
47 })
48 }
49
50 pub fn is_same_size(&self, other: &SparseTensor<T>) -> bool {
51 self.sizes == other.sizes
52 }
53
54 pub fn total_size(&self) -> usize {
55 self.sizes.iter().product()
56 }
57
58 pub fn not_1dimension_ranks(&self) -> usize {
59 self.sizes.iter().filter(|&d| *d != 1).count()
60 }
61
62 pub fn reduce_1dimension_rank(&self) -> Self {
63 let mut new_dims = vec![];
64 for d in self.sizes.iter() {
65 if *d != 1 {
66 new_dims.push(*d);
67 }
68 }
69
70 let mut new_elems = HashMap::new();
72 for (k, v) in self.elems.iter() {
73 let mut new_k = vec![];
74 for (i, d) in k.iter().enumerate() {
75 if self.sizes[i] != 1 {
76 new_k.push(*d);
77 }
78 }
79 new_elems.insert(new_k, *v);
80 }
81
82 Self {
83 sizes: new_dims,
84 elems: new_elems,
85 default: self.default,
86 }
87 }
88
89 pub fn to_vec(&self) -> Vec<T> {
90 if self.rank() != 1 {
91 panic!("SparseTensor::to_vec() is only available for rank 1 tensor.");
92 }
93
94 let mut vec = vec![T::default(); self.sizes[0]];
95 for (k, v) in self.elems.iter() {
96 vec[k[0]] = *v;
97 }
98
99 vec
100 }
101
102 pub fn to_mat(&self) -> Matrix<T> {
103 if self.rank() != 2 {
104 panic!("SparseTensor::to_mat() is only available for rank 2 tensor.");
105 }
106
107 let mut mat = Matrix::new(self.sizes[0], self.sizes[1]);
108 for (k, v) in self.elems.iter() {
109 mat[(k[0], k[1])] = *v;
110 }
111 mat
112 }
113
114 pub fn elems(&self) -> &HashMap<Vec<usize>, T> {
115 &self.elems
116 }
117
118 pub fn elems_mut(&mut self) -> &mut HashMap<Vec<usize>, T> {
119 &mut self.elems
120 }
121
122 pub fn eject(self) -> (Vec<usize>, HashMap<Vec<usize>, T>) {
123 (self.sizes, self.elems)
124 }
125}
126
127impl<T> Tensor<T> for SparseTensor<T>
128where
129 T: Number,
130{
131 fn rank(&self) -> usize {
132 self.sizes.len()
133 }
134
135 fn size(&self, rank: RankIndex) -> usize {
136 self.sizes[rank]
137 }
138
139 fn elem(&self, indices: &[usize]) -> T {
140 self[indices]
141 }
142
143 fn elem_mut(&mut self, indices: &[usize]) -> &mut T {
144 &mut self[indices]
145 }
146}
147
148impl<T> From<Vec<T>> for SparseTensor<T>
149where
150 T: Number,
151{
152 fn from(vec: Vec<T>) -> Self {
153 let sizes = vec![vec.len()];
154 let elems = vec
155 .into_iter()
156 .enumerate()
157 .map(|(i, v)| (vec![i], v))
158 .collect();
159
160 Self {
161 sizes,
162 elems,
163 default: T::default(),
164 }
165 }
166}