Skip to main content

provekit_common/
r1cs.rs

1use {
2    crate::{FieldElement, HydratedSparseMatrix, Interner, SparseMatrix},
3    ark_std::Zero,
4    serde::{Deserialize, Serialize},
5    sha3::{Digest, Sha3_256},
6};
7
8fn has_duplicate_witnesses(terms: &[(FieldElement, usize)]) -> bool {
9    for i in 0..terms.len() {
10        for j in (i + 1)..terms.len() {
11            if terms[i].1 == terms[j].1 {
12                return true;
13            }
14        }
15    }
16    false
17}
18
19/// Merge duplicate witness indices and drop zero-coefficient entries.
20fn canonicalize_terms(terms: &[(FieldElement, usize)]) -> Vec<(FieldElement, usize)> {
21    if !has_duplicate_witnesses(terms) {
22        return terms
23            .iter()
24            .filter(|(c, _)| !c.is_zero())
25            .copied()
26            .collect();
27    }
28
29    let mut sorted: Vec<(FieldElement, usize)> = terms.to_vec();
30    sorted.sort_unstable_by_key(|&(_c, w)| w);
31
32    let mut result: Vec<(FieldElement, usize)> = Vec::with_capacity(sorted.len());
33    let mut acc_coeff = sorted[0].0;
34    let mut acc_witness = sorted[0].1;
35
36    for &(coeff, witness) in &sorted[1..] {
37        if witness == acc_witness {
38            acc_coeff += coeff;
39        } else {
40            if !acc_coeff.is_zero() {
41                result.push((acc_coeff, acc_witness));
42            }
43            acc_coeff = coeff;
44            acc_witness = witness;
45        }
46    }
47
48    if !acc_coeff.is_zero() {
49        result.push((acc_coeff, acc_witness));
50    }
51
52    result
53}
54
55/// Represents a R1CS constraint system.
56#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct R1CS {
58    pub num_public_inputs: usize,
59    pub interner:          Interner,
60    pub a:                 SparseMatrix,
61    pub b:                 SparseMatrix,
62    pub c:                 SparseMatrix,
63}
64
65impl Default for R1CS {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl R1CS {
72    #[must_use]
73    pub fn new() -> Self {
74        Self {
75            num_public_inputs: 0,
76            interner:          Interner::new(),
77            a:                 SparseMatrix::new(0, 0),
78            b:                 SparseMatrix::new(0, 0),
79            c:                 SparseMatrix::new(0, 0),
80        }
81    }
82
83    #[must_use]
84    pub const fn a(&self) -> HydratedSparseMatrix<'_> {
85        self.a.hydrate(&self.interner)
86    }
87
88    #[must_use]
89    pub const fn b(&self) -> HydratedSparseMatrix<'_> {
90        self.b.hydrate(&self.interner)
91    }
92
93    #[must_use]
94    pub const fn c(&self) -> HydratedSparseMatrix<'_> {
95        self.c.hydrate(&self.interner)
96    }
97
98    /// The number of constraints in the R1CS instance.
99    pub const fn num_constraints(&self) -> usize {
100        self.a.num_rows
101    }
102
103    /// The number of witnesses in the R1CS instance (including the constant one
104    /// witness).
105    pub const fn num_witnesses(&self) -> usize {
106        self.a.num_cols
107    }
108
109    /// Hash of the R1CS
110    #[must_use]
111    pub fn hash(&self) -> [u8; 32] {
112        let bytes = postcard::to_stdvec(self).expect("R1CS serialization failed");
113        Sha3_256::digest(&bytes).into()
114    }
115
116    // Increase the size of the R1CS matrices to the specified dimensions.
117    pub fn grow_matrices(&mut self, num_rows: usize, num_cols: usize) {
118        self.a.grow(num_rows, num_cols);
119        self.b.grow(num_rows, num_cols);
120        self.c.grow(num_rows, num_cols);
121    }
122
123    /// Add a new witnesses to the R1CS instance.
124    pub fn add_witnesses(&mut self, count: usize) {
125        self.grow_matrices(self.num_constraints(), self.num_witnesses() + count);
126    }
127
128    /// Add an R1CS constraint. Duplicate witness indices within each linear
129    /// combination are merged (coefficients summed) and zeros are dropped.
130    pub fn add_constraint(
131        &mut self,
132        a: &[(FieldElement, usize)],
133        b: &[(FieldElement, usize)],
134        c: &[(FieldElement, usize)],
135    ) {
136        let a = canonicalize_terms(a);
137        let b = canonicalize_terms(b);
138        let c = canonicalize_terms(c);
139
140        let next_constraint_idx = self.num_constraints();
141        self.grow_matrices(self.num_constraints() + 1, self.num_witnesses());
142
143        for (coeff, witness_idx) in &a {
144            self.a.set(
145                next_constraint_idx,
146                *witness_idx,
147                self.interner.intern(*coeff),
148            );
149        }
150        for (coeff, witness_idx) in &b {
151            self.b.set(
152                next_constraint_idx,
153                *witness_idx,
154                self.interner.intern(*coeff),
155            );
156        }
157        for (coeff, witness_idx) in &c {
158            self.c.set(
159                next_constraint_idx,
160                *witness_idx,
161                self.interner.intern(*coeff),
162            );
163        }
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use {super::*, ark_std::One};
170
171    /// Duplicate witness coefficients are summed, not overwritten.
172    #[test]
173    fn duplicate_witnesses_are_merged() {
174        let mut r1cs = R1CS::new();
175        r1cs.add_witnesses(3);
176
177        let a = vec![(FieldElement::from(3u64), 1), (FieldElement::from(5u64), 1)];
178        let b = vec![(FieldElement::one(), 0)];
179        let c = vec![(FieldElement::from(8u64), 1)];
180
181        r1cs.add_constraint(&a, &b, &c);
182
183        let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
184        assert_eq!(a_entries.len(), 1);
185        assert_eq!(a_entries[0], (1, FieldElement::from(8u64)));
186    }
187
188    /// Opposite-sign duplicates cancel to zero and produce no entry.
189    #[test]
190    fn cancelling_duplicates_produce_no_entry() {
191        let mut r1cs = R1CS::new();
192        r1cs.add_witnesses(3);
193
194        let five = FieldElement::from(5u64);
195        let neg_five = FieldElement::zero() - five;
196        let a = vec![(five, 1), (neg_five, 1)];
197        let b = vec![(FieldElement::one(), 0)];
198        let c: Vec<(FieldElement, usize)> = vec![];
199
200        r1cs.add_constraint(&a, &b, &c);
201
202        let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
203        assert!(a_entries.is_empty());
204    }
205
206    /// Only duplicate witnesses are merged; distinct witnesses are preserved.
207    #[test]
208    fn mixed_unique_and_duplicate_witnesses() {
209        let mut r1cs = R1CS::new();
210        r1cs.add_witnesses(4);
211
212        let a = vec![
213            (FieldElement::from(2u64), 1),
214            (FieldElement::from(7u64), 2),
215            (FieldElement::from(3u64), 1),
216            (FieldElement::from(11u64), 3),
217        ];
218        let b = vec![(FieldElement::one(), 0)];
219        let c = vec![];
220
221        r1cs.add_constraint(&a, &b, &c);
222
223        let mut a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
224        a_entries.sort_by_key(|(col, _)| *col);
225        assert_eq!(a_entries.len(), 3);
226        assert_eq!(a_entries[0], (1, FieldElement::from(5u64)));
227        assert_eq!(a_entries[1], (2, FieldElement::from(7u64)));
228        assert_eq!(a_entries[2], (3, FieldElement::from(11u64)));
229    }
230
231    /// Duplicates are merged independently in all three matrices.
232    #[test]
233    fn duplicates_in_all_matrices() {
234        let mut r1cs = R1CS::new();
235        r1cs.add_witnesses(3);
236
237        let a = vec![(FieldElement::from(1u64), 1), (FieldElement::from(2u64), 1)];
238        let b = vec![(FieldElement::from(3u64), 2), (FieldElement::from(4u64), 2)];
239        let c = vec![(FieldElement::from(5u64), 1), (FieldElement::from(6u64), 1)];
240
241        r1cs.add_constraint(&a, &b, &c);
242
243        let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
244        assert_eq!(a_entries, vec![(1, FieldElement::from(3u64))]);
245
246        let b_entries: Vec<_> = r1cs.b().iter_row(0).collect();
247        assert_eq!(b_entries, vec![(2, FieldElement::from(7u64))]);
248
249        let c_entries: Vec<_> = r1cs.c().iter_row(0).collect();
250        assert_eq!(c_entries, vec![(1, FieldElement::from(11u64))]);
251    }
252
253    #[test]
254    fn canonicalize_terms_basics() {
255        assert!(canonicalize_terms(&[]).is_empty());
256        assert!(canonicalize_terms(&[(FieldElement::zero(), 0)]).is_empty());
257
258        let result = canonicalize_terms(&[(FieldElement::from(42u64), 5)]);
259        assert_eq!(result, vec![(FieldElement::from(42u64), 5)]);
260
261        // 1 + 2 + 3 = 6
262        let result = canonicalize_terms(&[
263            (FieldElement::from(1u64), 7),
264            (FieldElement::from(2u64), 7),
265            (FieldElement::from(3u64), 7),
266        ]);
267        assert_eq!(result, vec![(FieldElement::from(6u64), 7)]);
268    }
269}