opensrdk_linear_algebra/tensor/sparse/
mod.rs

1pub mod operations;
2pub mod operators;
3
4pub use operations::*;
5
6use crate::{Matrix, Number, RankIndex, Tensor, TensorError};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
11pub struct SparseTensor<T = f64>
12where
13    T: Number,
14{
15    sizes: Vec<usize>,
16    elems: HashMap<Vec<usize>, T>,
17    default: T,
18}
19
20impl<T> SparseTensor<T>
21where
22    T: Number,
23{
24    pub fn new(sizes: Vec<usize>) -> Self {
25        Self {
26            sizes,
27            elems: HashMap::new(),
28            default: T::default(),
29        }
30    }
31
32    pub fn from(sizes: Vec<usize>, elems: HashMap<Vec<usize>, T>) -> Result<Self, TensorError> {
33        for (index, _) in elems.iter() {
34            if index.len() != sizes.len() {
35                return Err(TensorError::RankMismatch);
36            }
37            for (rank, &d) in index.iter().enumerate() {
38                if sizes[rank] <= d {
39                    return Err(TensorError::OutOfRange);
40                }
41            }
42        }
43        Ok(Self {
44            sizes,
45            elems,
46            default: T::default(),
47        })
48    }
49
50    pub fn is_same_size(&self, other: &SparseTensor<T>) -> bool {
51        self.sizes == other.sizes
52    }
53
54    pub fn total_size(&self) -> usize {
55        self.sizes.iter().product()
56    }
57
58    pub fn not_1dimension_ranks(&self) -> usize {
59        self.sizes.iter().filter(|&d| *d != 1).count()
60    }
61
62    pub fn reduce_1dimension_rank(&self) -> Self {
63        let mut new_dims = vec![];
64        for d in self.sizes.iter() {
65            if *d != 1 {
66                new_dims.push(*d);
67            }
68        }
69
70        // TODO: parallelize
71        let mut new_elems = HashMap::new();
72        for (k, v) in self.elems.iter() {
73            let mut new_k = vec![];
74            for (i, d) in k.iter().enumerate() {
75                if self.sizes[i] != 1 {
76                    new_k.push(*d);
77                }
78            }
79            new_elems.insert(new_k, *v);
80        }
81
82        Self {
83            sizes: new_dims,
84            elems: new_elems,
85            default: self.default,
86        }
87    }
88
89    pub fn to_vec(&self) -> Vec<T> {
90        if self.rank() != 1 {
91            panic!("SparseTensor::to_vec() is only available for rank 1 tensor.");
92        }
93
94        let mut vec = vec![T::default(); self.sizes[0]];
95        for (k, v) in self.elems.iter() {
96            vec[k[0]] = *v;
97        }
98
99        vec
100    }
101
102    pub fn to_mat(&self) -> Matrix<T> {
103        if self.rank() != 2 {
104            panic!("SparseTensor::to_mat() is only available for rank 2 tensor.");
105        }
106
107        let mut mat = Matrix::new(self.sizes[0], self.sizes[1]);
108        for (k, v) in self.elems.iter() {
109            mat[(k[0], k[1])] = *v;
110        }
111        mat
112    }
113
114    pub fn elems(&self) -> &HashMap<Vec<usize>, T> {
115        &self.elems
116    }
117
118    pub fn elems_mut(&mut self) -> &mut HashMap<Vec<usize>, T> {
119        &mut self.elems
120    }
121
122    pub fn eject(self) -> (Vec<usize>, HashMap<Vec<usize>, T>) {
123        (self.sizes, self.elems)
124    }
125}
126
127impl<T> Tensor<T> for SparseTensor<T>
128where
129    T: Number,
130{
131    fn rank(&self) -> usize {
132        self.sizes.len()
133    }
134
135    fn size(&self, rank: RankIndex) -> usize {
136        self.sizes[rank]
137    }
138
139    fn elem(&self, indices: &[usize]) -> T {
140        self[indices]
141    }
142
143    fn elem_mut(&mut self, indices: &[usize]) -> &mut T {
144        &mut self[indices]
145    }
146}
147
148impl<T> From<Vec<T>> for SparseTensor<T>
149where
150    T: Number,
151{
152    fn from(vec: Vec<T>) -> Self {
153        let sizes = vec![vec.len()];
154        let elems = vec
155            .into_iter()
156            .enumerate()
157            .map(|(i, v)| (vec![i], v))
158            .collect();
159
160        Self {
161            sizes,
162            elems,
163            default: T::default(),
164        }
165    }
166}