1use ff::PrimeField;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9
10#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
13pub struct SparseMatrix<F: PrimeField> {
14 pub data: Vec<F>,
16 pub indices: Vec<usize>,
18 pub indptr: Vec<usize>,
20 pub cols: usize,
22}
23
24impl<F: PrimeField> SparseMatrix<F> {
25 pub fn empty() -> Self {
27 SparseMatrix {
28 data: vec![],
29 indices: vec![],
30 indptr: vec![0],
31 cols: 0,
32 }
33 }
34
35 pub fn new(matrix: &[(usize, usize, F)], rows: usize, cols: usize) -> Self {
38 let mut new_matrix = vec![vec![]; rows];
39 for (row, col, val) in matrix {
40 new_matrix[*row].push((*col, *val));
41 }
42
43 for row in new_matrix.iter() {
44 assert!(row.windows(2).all(|w| w[0].0 < w[1].0));
45 }
46
47 let mut indptr = vec![0; rows + 1];
48 for (i, col) in new_matrix.iter().enumerate() {
49 indptr[i + 1] = indptr[i] + col.len();
50 }
51
52 let mut indices = vec![];
53 let mut data = vec![];
54 for col in new_matrix {
55 let (idx, val): (Vec<_>, Vec<_>) = col.into_iter().unzip();
56 indices.extend(idx);
57 data.extend(val);
58 }
59
60 SparseMatrix {
61 data,
62 indices,
63 indptr,
64 cols,
65 }
66 }
67
68 pub fn get_row_unchecked(&self, ptrs: &[usize; 2]) -> impl Iterator<Item = (&F, &usize)> {
72 self.data[ptrs[0]..ptrs[1]]
73 .iter()
74 .zip(&self.indices[ptrs[0]..ptrs[1]])
75 }
76
77 pub fn multiply_vec(&self, vector: &[F]) -> Vec<F> {
79 assert_eq!(self.cols, vector.len(), "invalid shape");
80
81 self.multiply_vec_unchecked(vector)
82 }
83
84 pub fn multiply_vec_unchecked(&self, vector: &[F]) -> Vec<F> {
87 self
88 .indptr
89 .par_windows(2)
90 .map(|ptrs| {
91 self
92 .get_row_unchecked(ptrs.try_into().unwrap())
93 .map(|(val, col_idx)| *val * vector[*col_idx])
94 .sum()
95 })
96 .collect()
97 }
98
99 pub fn len(&self) -> usize {
101 *self.indptr.last().unwrap()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.len() == 0
107 }
108
109 pub fn iter(&self) -> Iter<'_, F> {
111 let nnz = *self.indptr.last().unwrap();
112 if nnz == 0 {
113 return Iter {
114 matrix: self,
115 row: 0,
116 i: 0,
117 nnz,
118 };
119 }
120
121 let mut row = 0;
122 while row + 1 < self.indptr.len() && self.indptr[row + 1] == 0 {
123 row += 1;
124 }
125 Iter {
126 matrix: self,
127 row,
128 i: 0,
129 nnz,
130 }
131 }
132}
133
134pub struct Iter<'a, F: PrimeField> {
136 matrix: &'a SparseMatrix<F>,
137 row: usize,
138 i: usize,
139 nnz: usize,
140}
141
142impl<F: PrimeField> Iterator for Iter<'_, F> {
143 type Item = (usize, usize, F);
144
145 fn next(&mut self) -> Option<Self::Item> {
146 if self.i == self.nnz {
148 return None;
149 }
150
151 let curr_item = (
153 self.row,
154 self.matrix.indices[self.i],
155 self.matrix.data[self.i],
156 );
157
158 self.i += 1;
160 if self.i == self.nnz {
162 return Some(curr_item);
163 }
164 while self.i >= self.matrix.indptr[self.row + 1] {
166 self.row += 1;
167 }
168
169 Some(curr_item)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176 use crate::{
177 provider::PallasEngine,
178 traits::{Engine, Group},
179 };
180 use ff::PrimeField;
181 use proptest::{
182 prelude::*,
183 strategy::{BoxedStrategy, Just, Strategy},
184 };
185
186 type G = <PallasEngine as Engine>::GE;
187 type Fr = <G as Group>::Scalar;
188
189 #[derive(Clone, Debug, PartialEq, Eq)]
191 pub struct FWrap<F: PrimeField>(pub F);
192
193 impl<F: PrimeField> Copy for FWrap<F> {}
194
195 #[cfg(not(target_arch = "wasm32"))]
196 impl<F: PrimeField> Arbitrary for FWrap<F> {
198 type Parameters = ();
199 type Strategy = BoxedStrategy<Self>;
200
201 fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
202 use rand::rngs::StdRng;
203 use rand_core::SeedableRng;
204
205 let strategy = any::<[u8; 32]>()
206 .prop_map(|seed| FWrap(F::random(StdRng::from_seed(seed))))
207 .no_shrink();
208 strategy.boxed()
209 }
210 }
211
212 #[test]
213 fn test_matrix_creation() {
214 let matrix_data = vec![
215 (0, 1, Fr::from(2)),
216 (1, 2, Fr::from(3)),
217 (2, 0, Fr::from(4)),
218 ];
219 let sparse_matrix = SparseMatrix::<Fr>::new(&matrix_data, 3, 3);
220
221 assert_eq!(
222 sparse_matrix.data,
223 vec![Fr::from(2), Fr::from(3), Fr::from(4)]
224 );
225 assert_eq!(sparse_matrix.indices, vec![1, 2, 0]);
226 assert_eq!(sparse_matrix.indptr, vec![0, 1, 2, 3]);
227 }
228
229 #[test]
230 fn test_matrix_vector_multiplication() {
231 let matrix_data = vec![
232 (0, 1, Fr::from(2)),
233 (0, 2, Fr::from(7)),
234 (1, 2, Fr::from(3)),
235 (2, 0, Fr::from(4)),
236 ];
237 let sparse_matrix = SparseMatrix::<Fr>::new(&matrix_data, 3, 3);
238 let vector = vec![Fr::from(1), Fr::from(2), Fr::from(3)];
239
240 let result = sparse_matrix.multiply_vec(&vector);
241
242 assert_eq!(result, vec![Fr::from(25), Fr::from(9), Fr::from(4)]);
243 }
244
245 fn coo_strategy() -> BoxedStrategy<Vec<(usize, usize, FWrap<Fr>)>> {
246 let coo_strategy = any::<FWrap<Fr>>().prop_flat_map(|f| (0usize..100, 0usize..100, Just(f)));
247 proptest::collection::vec(coo_strategy, 10).boxed()
248 }
249
250 proptest! {
251 #[test]
252 fn test_matrix_iter(mut coo_matrix in coo_strategy()) {
253 coo_matrix.sort_by_key(|(row, col, _val)| (*row, *col));
255 coo_matrix.dedup_by_key(|(row, col, _val)| (*row, *col));
256 let coo_matrix = coo_matrix.into_iter().map(|(row, col, val)| { (row, col, val.0) }).collect::<Vec<_>>();
257
258 let matrix = SparseMatrix::new(&coo_matrix, 100, 100);
259
260 prop_assert_eq!(coo_matrix, matrix.iter().collect::<Vec<_>>());
261 }
262 }
263}