1use {
2 crate::{FieldElement, HydratedSparseMatrix, Interner, SparseMatrix},
3 ark_std::Zero,
4 serde::{Deserialize, Serialize},
5 sha3::{Digest, Sha3_256},
6};
7
8fn has_duplicate_witnesses(terms: &[(FieldElement, usize)]) -> bool {
9 for i in 0..terms.len() {
10 for j in (i + 1)..terms.len() {
11 if terms[i].1 == terms[j].1 {
12 return true;
13 }
14 }
15 }
16 false
17}
18
19fn canonicalize_terms(terms: &[(FieldElement, usize)]) -> Vec<(FieldElement, usize)> {
21 if !has_duplicate_witnesses(terms) {
22 return terms
23 .iter()
24 .filter(|(c, _)| !c.is_zero())
25 .copied()
26 .collect();
27 }
28
29 let mut sorted: Vec<(FieldElement, usize)> = terms.to_vec();
30 sorted.sort_unstable_by_key(|&(_c, w)| w);
31
32 let mut result: Vec<(FieldElement, usize)> = Vec::with_capacity(sorted.len());
33 let mut acc_coeff = sorted[0].0;
34 let mut acc_witness = sorted[0].1;
35
36 for &(coeff, witness) in &sorted[1..] {
37 if witness == acc_witness {
38 acc_coeff += coeff;
39 } else {
40 if !acc_coeff.is_zero() {
41 result.push((acc_coeff, acc_witness));
42 }
43 acc_coeff = coeff;
44 acc_witness = witness;
45 }
46 }
47
48 if !acc_coeff.is_zero() {
49 result.push((acc_coeff, acc_witness));
50 }
51
52 result
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
57pub struct R1CS {
58 pub num_public_inputs: usize,
59 pub interner: Interner,
60 pub a: SparseMatrix,
61 pub b: SparseMatrix,
62 pub c: SparseMatrix,
63}
64
65impl Default for R1CS {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl R1CS {
72 #[must_use]
73 pub fn new() -> Self {
74 Self {
75 num_public_inputs: 0,
76 interner: Interner::new(),
77 a: SparseMatrix::new(0, 0),
78 b: SparseMatrix::new(0, 0),
79 c: SparseMatrix::new(0, 0),
80 }
81 }
82
83 #[must_use]
84 pub const fn a(&self) -> HydratedSparseMatrix<'_> {
85 self.a.hydrate(&self.interner)
86 }
87
88 #[must_use]
89 pub const fn b(&self) -> HydratedSparseMatrix<'_> {
90 self.b.hydrate(&self.interner)
91 }
92
93 #[must_use]
94 pub const fn c(&self) -> HydratedSparseMatrix<'_> {
95 self.c.hydrate(&self.interner)
96 }
97
98 pub const fn num_constraints(&self) -> usize {
100 self.a.num_rows
101 }
102
103 pub const fn num_witnesses(&self) -> usize {
106 self.a.num_cols
107 }
108
109 #[must_use]
111 pub fn hash(&self) -> [u8; 32] {
112 let bytes = postcard::to_stdvec(self).expect("R1CS serialization failed");
113 Sha3_256::digest(&bytes).into()
114 }
115
116 pub fn grow_matrices(&mut self, num_rows: usize, num_cols: usize) {
118 self.a.grow(num_rows, num_cols);
119 self.b.grow(num_rows, num_cols);
120 self.c.grow(num_rows, num_cols);
121 }
122
123 pub fn add_witnesses(&mut self, count: usize) {
125 self.grow_matrices(self.num_constraints(), self.num_witnesses() + count);
126 }
127
128 pub fn add_constraint(
131 &mut self,
132 a: &[(FieldElement, usize)],
133 b: &[(FieldElement, usize)],
134 c: &[(FieldElement, usize)],
135 ) {
136 let a = canonicalize_terms(a);
137 let b = canonicalize_terms(b);
138 let c = canonicalize_terms(c);
139
140 let next_constraint_idx = self.num_constraints();
141 self.grow_matrices(self.num_constraints() + 1, self.num_witnesses());
142
143 for (coeff, witness_idx) in &a {
144 self.a.set(
145 next_constraint_idx,
146 *witness_idx,
147 self.interner.intern(*coeff),
148 );
149 }
150 for (coeff, witness_idx) in &b {
151 self.b.set(
152 next_constraint_idx,
153 *witness_idx,
154 self.interner.intern(*coeff),
155 );
156 }
157 for (coeff, witness_idx) in &c {
158 self.c.set(
159 next_constraint_idx,
160 *witness_idx,
161 self.interner.intern(*coeff),
162 );
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use {super::*, ark_std::One};
170
171 #[test]
173 fn duplicate_witnesses_are_merged() {
174 let mut r1cs = R1CS::new();
175 r1cs.add_witnesses(3);
176
177 let a = vec![(FieldElement::from(3u64), 1), (FieldElement::from(5u64), 1)];
178 let b = vec![(FieldElement::one(), 0)];
179 let c = vec![(FieldElement::from(8u64), 1)];
180
181 r1cs.add_constraint(&a, &b, &c);
182
183 let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
184 assert_eq!(a_entries.len(), 1);
185 assert_eq!(a_entries[0], (1, FieldElement::from(8u64)));
186 }
187
188 #[test]
190 fn cancelling_duplicates_produce_no_entry() {
191 let mut r1cs = R1CS::new();
192 r1cs.add_witnesses(3);
193
194 let five = FieldElement::from(5u64);
195 let neg_five = FieldElement::zero() - five;
196 let a = vec![(five, 1), (neg_five, 1)];
197 let b = vec![(FieldElement::one(), 0)];
198 let c: Vec<(FieldElement, usize)> = vec![];
199
200 r1cs.add_constraint(&a, &b, &c);
201
202 let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
203 assert!(a_entries.is_empty());
204 }
205
206 #[test]
208 fn mixed_unique_and_duplicate_witnesses() {
209 let mut r1cs = R1CS::new();
210 r1cs.add_witnesses(4);
211
212 let a = vec![
213 (FieldElement::from(2u64), 1),
214 (FieldElement::from(7u64), 2),
215 (FieldElement::from(3u64), 1),
216 (FieldElement::from(11u64), 3),
217 ];
218 let b = vec![(FieldElement::one(), 0)];
219 let c = vec![];
220
221 r1cs.add_constraint(&a, &b, &c);
222
223 let mut a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
224 a_entries.sort_by_key(|(col, _)| *col);
225 assert_eq!(a_entries.len(), 3);
226 assert_eq!(a_entries[0], (1, FieldElement::from(5u64)));
227 assert_eq!(a_entries[1], (2, FieldElement::from(7u64)));
228 assert_eq!(a_entries[2], (3, FieldElement::from(11u64)));
229 }
230
231 #[test]
233 fn duplicates_in_all_matrices() {
234 let mut r1cs = R1CS::new();
235 r1cs.add_witnesses(3);
236
237 let a = vec![(FieldElement::from(1u64), 1), (FieldElement::from(2u64), 1)];
238 let b = vec![(FieldElement::from(3u64), 2), (FieldElement::from(4u64), 2)];
239 let c = vec![(FieldElement::from(5u64), 1), (FieldElement::from(6u64), 1)];
240
241 r1cs.add_constraint(&a, &b, &c);
242
243 let a_entries: Vec<_> = r1cs.a().iter_row(0).collect();
244 assert_eq!(a_entries, vec![(1, FieldElement::from(3u64))]);
245
246 let b_entries: Vec<_> = r1cs.b().iter_row(0).collect();
247 assert_eq!(b_entries, vec![(2, FieldElement::from(7u64))]);
248
249 let c_entries: Vec<_> = r1cs.c().iter_row(0).collect();
250 assert_eq!(c_entries, vec![(1, FieldElement::from(11u64))]);
251 }
252
253 #[test]
254 fn canonicalize_terms_basics() {
255 assert!(canonicalize_terms(&[]).is_empty());
256 assert!(canonicalize_terms(&[(FieldElement::zero(), 0)]).is_empty());
257
258 let result = canonicalize_terms(&[(FieldElement::from(42u64), 5)]);
259 assert_eq!(result, vec![(FieldElement::from(42u64), 5)]);
260
261 let result = canonicalize_terms(&[
263 (FieldElement::from(1u64), 7),
264 (FieldElement::from(2u64), 7),
265 (FieldElement::from(3u64), 7),
266 ]);
267 assert_eq!(result, vec![(FieldElement::from(6u64), 7)]);
268 }
269}