Skip to main content

nova_snark/r1cs/
sparse.rs

1//! # Sparse Matrices
2//!
3//! This module defines a custom implementation of CSR/CSC sparse matrices.
4//! Specifically, we implement sparse matrix / dense vector multiplication
5//! to compute the `A z`, `B z`, and `C z` in Nova.
6use ff::PrimeField;
7use rayon::prelude::*;
8use serde::{Deserialize, Serialize};
9
10/// CSR format sparse matrix, We follow the names used by scipy.
11/// Detailed explanation here: <https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr>
12#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
13pub struct SparseMatrix<F: PrimeField> {
14  /// all non-zero values in the matrix
15  pub data: Vec<F>,
16  /// column indices
17  pub indices: Vec<usize>,
18  /// row information
19  pub indptr: Vec<usize>,
20  /// number of columns
21  pub cols: usize,
22}
23
24impl<F: PrimeField> SparseMatrix<F> {
25  /// 0x0 empty matrix
26  pub fn empty() -> Self {
27    SparseMatrix {
28      data: vec![],
29      indices: vec![],
30      indptr: vec![0],
31      cols: 0,
32    }
33  }
34
35  /// Construct from the COO representation; Vec<usize(row), usize(col), F>.
36  /// We assume that the rows are sorted during construction.
37  pub fn new(matrix: &[(usize, usize, F)], rows: usize, cols: usize) -> Self {
38    let mut new_matrix = vec![vec![]; rows];
39    for (row, col, val) in matrix {
40      new_matrix[*row].push((*col, *val));
41    }
42
43    for row in new_matrix.iter() {
44      assert!(row.windows(2).all(|w| w[0].0 < w[1].0));
45    }
46
47    let mut indptr = vec![0; rows + 1];
48    for (i, col) in new_matrix.iter().enumerate() {
49      indptr[i + 1] = indptr[i] + col.len();
50    }
51
52    let mut indices = vec![];
53    let mut data = vec![];
54    for col in new_matrix {
55      let (idx, val): (Vec<_>, Vec<_>) = col.into_iter().unzip();
56      indices.extend(idx);
57      data.extend(val);
58    }
59
60    SparseMatrix {
61      data,
62      indices,
63      indptr,
64      cols,
65    }
66  }
67
68  /// Retrieves the data for row slice [i..j] from `ptrs`.
69  /// We assume that `ptrs` is indexed from `indptrs` and do not check if the
70  /// returned slice is actually a valid row.
71  pub fn get_row_unchecked(&self, ptrs: &[usize; 2]) -> impl Iterator<Item = (&F, &usize)> {
72    self.data[ptrs[0]..ptrs[1]]
73      .iter()
74      .zip(&self.indices[ptrs[0]..ptrs[1]])
75  }
76
77  /// Multiply by a dense vector; uses rayon/gpu.
78  pub fn multiply_vec(&self, vector: &[F]) -> Vec<F> {
79    assert_eq!(self.cols, vector.len(), "invalid shape");
80
81    self.multiply_vec_unchecked(vector)
82  }
83
84  /// Multiply by a dense vector; uses rayon/gpu.
85  /// This does not check that the shape of the matrix/vector are compatible.
86  pub fn multiply_vec_unchecked(&self, vector: &[F]) -> Vec<F> {
87    self
88      .indptr
89      .par_windows(2)
90      .map(|ptrs| {
91        self
92          .get_row_unchecked(ptrs.try_into().unwrap())
93          .map(|(val, col_idx)| *val * vector[*col_idx])
94          .sum()
95      })
96      .collect()
97  }
98
99  /// number of non-zero entries
100  pub fn len(&self) -> usize {
101    *self.indptr.last().unwrap()
102  }
103
104  /// empty matrix
105  pub fn is_empty(&self) -> bool {
106    self.len() == 0
107  }
108
109  /// returns a custom iterator
110  pub fn iter(&self) -> Iter<'_, F> {
111    let nnz = *self.indptr.last().unwrap();
112    if nnz == 0 {
113      return Iter {
114        matrix: self,
115        row: 0,
116        i: 0,
117        nnz,
118      };
119    }
120
121    let mut row = 0;
122    while row + 1 < self.indptr.len() && self.indptr[row + 1] == 0 {
123      row += 1;
124    }
125    Iter {
126      matrix: self,
127      row,
128      i: 0,
129      nnz,
130    }
131  }
132}
133
134/// Iterator for sparse matrix
135pub struct Iter<'a, F: PrimeField> {
136  matrix: &'a SparseMatrix<F>,
137  row: usize,
138  i: usize,
139  nnz: usize,
140}
141
142impl<F: PrimeField> Iterator for Iter<'_, F> {
143  type Item = (usize, usize, F);
144
145  fn next(&mut self) -> Option<Self::Item> {
146    // are we at the end?
147    if self.i == self.nnz {
148      return None;
149    }
150
151    // compute current item
152    let curr_item = (
153      self.row,
154      self.matrix.indices[self.i],
155      self.matrix.data[self.i],
156    );
157
158    // advance the iterator
159    self.i += 1;
160    // edge case at the end
161    if self.i == self.nnz {
162      return Some(curr_item);
163    }
164    // if `i` has moved to next row
165    while self.i >= self.matrix.indptr[self.row + 1] {
166      self.row += 1;
167    }
168
169    Some(curr_item)
170  }
171}
172
173#[cfg(test)]
174mod tests {
175  use super::*;
176  use crate::{
177    provider::PallasEngine,
178    traits::{Engine, Group},
179  };
180  use ff::PrimeField;
181  use proptest::{
182    prelude::*,
183    strategy::{BoxedStrategy, Just, Strategy},
184  };
185
186  type G = <PallasEngine as Engine>::GE;
187  type Fr = <G as Group>::Scalar;
188
189  /// Wrapper struct around a field element that implements additional traits
190  #[derive(Clone, Debug, PartialEq, Eq)]
191  pub struct FWrap<F: PrimeField>(pub F);
192
193  impl<F: PrimeField> Copy for FWrap<F> {}
194
195  #[cfg(not(target_arch = "wasm32"))]
196  /// Trait implementation for generating `FWrap<F>` instances with proptest
197  impl<F: PrimeField> Arbitrary for FWrap<F> {
198    type Parameters = ();
199    type Strategy = BoxedStrategy<Self>;
200
201    fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
202      use rand::rngs::StdRng;
203      use rand_core::SeedableRng;
204
205      let strategy = any::<[u8; 32]>()
206        .prop_map(|seed| FWrap(F::random(StdRng::from_seed(seed))))
207        .no_shrink();
208      strategy.boxed()
209    }
210  }
211
212  #[test]
213  fn test_matrix_creation() {
214    let matrix_data = vec![
215      (0, 1, Fr::from(2)),
216      (1, 2, Fr::from(3)),
217      (2, 0, Fr::from(4)),
218    ];
219    let sparse_matrix = SparseMatrix::<Fr>::new(&matrix_data, 3, 3);
220
221    assert_eq!(
222      sparse_matrix.data,
223      vec![Fr::from(2), Fr::from(3), Fr::from(4)]
224    );
225    assert_eq!(sparse_matrix.indices, vec![1, 2, 0]);
226    assert_eq!(sparse_matrix.indptr, vec![0, 1, 2, 3]);
227  }
228
229  #[test]
230  fn test_matrix_vector_multiplication() {
231    let matrix_data = vec![
232      (0, 1, Fr::from(2)),
233      (0, 2, Fr::from(7)),
234      (1, 2, Fr::from(3)),
235      (2, 0, Fr::from(4)),
236    ];
237    let sparse_matrix = SparseMatrix::<Fr>::new(&matrix_data, 3, 3);
238    let vector = vec![Fr::from(1), Fr::from(2), Fr::from(3)];
239
240    let result = sparse_matrix.multiply_vec(&vector);
241
242    assert_eq!(result, vec![Fr::from(25), Fr::from(9), Fr::from(4)]);
243  }
244
245  fn coo_strategy() -> BoxedStrategy<Vec<(usize, usize, FWrap<Fr>)>> {
246    let coo_strategy = any::<FWrap<Fr>>().prop_flat_map(|f| (0usize..100, 0usize..100, Just(f)));
247    proptest::collection::vec(coo_strategy, 10).boxed()
248  }
249
250  proptest! {
251      #[test]
252      fn test_matrix_iter(mut coo_matrix in coo_strategy()) {
253        // process the randomly generated coo matrix
254        coo_matrix.sort_by_key(|(row, col, _val)| (*row, *col));
255        coo_matrix.dedup_by_key(|(row, col, _val)| (*row, *col));
256        let coo_matrix = coo_matrix.into_iter().map(|(row, col, val)| { (row, col, val.0) }).collect::<Vec<_>>();
257
258        let matrix = SparseMatrix::new(&coo_matrix, 100, 100);
259
260        prop_assert_eq!(coo_matrix, matrix.iter().collect::<Vec<_>>());
261    }
262  }
263}