1use std::cmp::Ordering;
2use std::collections::BTreeMap;
3use std::collections::HashMap;
4use std::fmt::Write;
5
6use super::Comparable;
7use crate::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable};
8use blake2s_simd::State as Blake2s;
9use byteorder::{BigEndian, ByteOrder};
10use ff::PrimeField;
11
12#[derive(Debug)]
13enum NamedObject {
14 Constraint(usize),
15 Var(Variable),
16 Namespace,
17}
18
19#[derive(Debug)]
21pub struct TestConstraintSystem<Scalar: PrimeField> {
22 named_objects: HashMap<String, NamedObject>,
23 current_namespace: Vec<String>,
24 #[allow(clippy::type_complexity)]
25 constraints: Vec<(
26 LinearCombination<Scalar>,
27 LinearCombination<Scalar>,
28 LinearCombination<Scalar>,
29 String,
30 )>,
31 inputs: Vec<(Scalar, String)>,
32 aux: Vec<(Scalar, String)>,
33}
34
35#[derive(Clone, Copy)]
36struct OrderedVariable(Variable);
37
38impl Eq for OrderedVariable {}
39impl PartialEq for OrderedVariable {
40 fn eq(&self, other: &OrderedVariable) -> bool {
41 match (self.0.get_unchecked(), other.0.get_unchecked()) {
42 (Index::Input(ref a), Index::Input(ref b)) => a == b,
43 (Index::Aux(ref a), Index::Aux(ref b)) => a == b,
44 _ => false,
45 }
46 }
47}
48impl PartialOrd for OrderedVariable {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 Some(self.cmp(other))
51 }
52}
53impl Ord for OrderedVariable {
54 fn cmp(&self, other: &Self) -> Ordering {
55 match (self.0.get_unchecked(), other.0.get_unchecked()) {
56 (Index::Input(ref a), Index::Input(ref b)) => a.cmp(b),
57 (Index::Aux(ref a), Index::Aux(ref b)) => a.cmp(b),
58 (Index::Input(_), Index::Aux(_)) => Ordering::Less,
59 (Index::Aux(_), Index::Input(_)) => Ordering::Greater,
60 }
61 }
62}
63
64fn proc_lc<Scalar: PrimeField>(
65 terms: &LinearCombination<Scalar>,
66) -> BTreeMap<OrderedVariable, Scalar> {
67 let mut map = BTreeMap::new();
68 for (var, &coeff) in terms.iter() {
69 map.entry(OrderedVariable(var))
70 .or_insert_with(|| Scalar::ZERO)
71 .add_assign(&coeff);
72 }
73
74 let mut to_remove = vec![];
76 for (var, coeff) in map.iter() {
77 if coeff.is_zero().into() {
78 to_remove.push(*var)
79 }
80 }
81
82 for var in to_remove {
83 map.remove(&var);
84 }
85
86 map
87}
88
89fn hash_lc<Scalar: PrimeField>(terms: &LinearCombination<Scalar>, h: &mut Blake2s) {
90 let map = proc_lc::<Scalar>(terms);
91
92 let mut buf = [0u8; 9 + 32];
93 BigEndian::write_u64(&mut buf[0..8], map.len() as u64);
94 h.update(&buf[0..8]);
95
96 for (var, coeff) in map {
97 match var.0.get_unchecked() {
98 Index::Input(i) => {
99 buf[0] = b'I';
100 BigEndian::write_u64(&mut buf[1..9], i as u64);
101 }
102 Index::Aux(i) => {
103 buf[0] = b'A';
104 BigEndian::write_u64(&mut buf[1..9], i as u64);
105 }
106 }
107
108 let mut bytes = coeff.to_repr();
110 bytes.as_mut().reverse();
111 buf[9..].copy_from_slice(bytes.as_ref());
112
113 h.update(&buf[..]);
114 }
115}
116
117fn _eval_lc2<Scalar: PrimeField>(
118 terms: &LinearCombination<Scalar>,
119 inputs: &[Scalar],
120 aux: &[Scalar],
121) -> Scalar {
122 let mut acc = Scalar::ZERO;
123
124 for (var, coeff) in terms.iter() {
125 let mut tmp = match var.get_unchecked() {
126 Index::Input(index) => inputs[index],
127 Index::Aux(index) => aux[index],
128 };
129
130 tmp.mul_assign(coeff);
131 acc.add_assign(&tmp);
132 }
133
134 acc
135}
136
137fn eval_lc<Scalar: PrimeField>(
138 terms: &LinearCombination<Scalar>,
139 inputs: &[(Scalar, String)],
140 aux: &[(Scalar, String)],
141) -> Scalar {
142 let mut acc = Scalar::ZERO;
143
144 for (var, coeff) in terms.iter() {
145 let mut tmp = match var.get_unchecked() {
146 Index::Input(index) => inputs[index].0,
147 Index::Aux(index) => aux[index].0,
148 };
149
150 tmp.mul_assign(coeff);
151 acc.add_assign(&tmp);
152 }
153
154 acc
155}
156
157impl<Scalar: PrimeField> Default for TestConstraintSystem<Scalar> {
158 fn default() -> Self {
159 let mut map = HashMap::new();
160 map.insert(
161 "ONE".into(),
162 NamedObject::Var(TestConstraintSystem::<Scalar>::one()),
163 );
164
165 TestConstraintSystem {
166 named_objects: map,
167 current_namespace: vec![],
168 constraints: vec![],
169 inputs: vec![(Scalar::ONE, "ONE".into())],
170 aux: vec![],
171 }
172 }
173}
174
175impl<Scalar: PrimeField> TestConstraintSystem<Scalar> {
176 pub fn new() -> Self {
177 Default::default()
178 }
179
180 pub fn scalar_inputs(&self) -> Vec<Scalar> {
181 self.inputs
182 .iter()
183 .map(|(scalar, _string)| *scalar)
184 .collect()
185 }
186
187 pub fn scalar_aux(&self) -> Vec<Scalar> {
188 self.aux.iter().map(|(scalar, _string)| *scalar).collect()
189 }
190
191 pub fn pretty_print_list(&self) -> Vec<String> {
192 let mut result = Vec::new();
193
194 for input in &self.inputs {
195 result.push(format!("INPUT {}", input.1));
196 }
197 for aux in &self.aux {
198 result.push(format!("AUX {}", aux.1));
199 }
200
201 for (_a, _b, _c, name) in &self.constraints {
202 result.push(name.to_string());
203 }
204
205 result
206 }
207
208 pub fn pretty_print(&self) -> String {
209 let res = self.pretty_print_list();
210
211 res.join("\n")
212 }
213
214 pub fn hash(&self) -> String {
215 let mut h = Blake2s::new();
216 {
217 let mut buf = [0u8; 24];
218
219 BigEndian::write_u64(&mut buf[0..8], self.inputs.len() as u64);
220 BigEndian::write_u64(&mut buf[8..16], self.aux.len() as u64);
221 BigEndian::write_u64(&mut buf[16..24], self.constraints.len() as u64);
222 h.update(&buf);
223 }
224
225 for constraint in &self.constraints {
226 hash_lc::<Scalar>(&constraint.0, &mut h);
227 hash_lc::<Scalar>(&constraint.1, &mut h);
228 hash_lc::<Scalar>(&constraint.2, &mut h);
229 }
230
231 let mut s = String::new();
232 for b in h.finalize().as_ref() {
233 write!(s, "{:02x}", b).expect("writing to string never fails");
234 }
235
236 s
237 }
238
239 pub fn which_is_unsatisfied(&self) -> Option<&str> {
240 for (a, b, c, path) in &self.constraints {
241 let mut a = eval_lc::<Scalar>(a, &self.inputs, &self.aux);
242 let b = eval_lc::<Scalar>(b, &self.inputs, &self.aux);
243 let c = eval_lc::<Scalar>(c, &self.inputs, &self.aux);
244
245 a.mul_assign(&b);
246
247 if a != c {
248 return Some(path);
249 }
250 }
251
252 None
253 }
254
255 pub fn is_satisfied(&self) -> bool {
256 match self.which_is_unsatisfied() {
257 Some(b) => {
258 println!("fail: {:?}", b);
259 false
260 }
261 None => true,
262 }
263 }
265
266 pub fn num_constraints(&self) -> usize {
267 self.constraints.len()
268 }
269
270 pub fn set(&mut self, path: &str, to: Scalar) {
271 match self.named_objects.get(path) {
272 Some(NamedObject::Var(v)) => match v.get_unchecked() {
273 Index::Input(index) => self.inputs[index].0 = to,
274 Index::Aux(index) => self.aux[index].0 = to,
275 },
276 Some(e) => panic!(
277 "tried to set path `{}` to value, but `{:?}` already exists there.",
278 path, e
279 ),
280 _ => panic!("no variable exists at path: {}", path),
281 }
282 }
283
284 pub fn verify(&self, expected: &[Scalar]) -> bool {
285 assert_eq!(expected.len() + 1, self.inputs.len());
286 for (a, b) in self.inputs.iter().skip(1).zip(expected.iter()) {
287 if &a.0 != b {
288 return false;
289 }
290 }
291
292 true
293 }
294
295 pub fn num_inputs(&self) -> usize {
296 self.inputs.len()
297 }
298
299 pub fn get_input(&mut self, index: usize, path: &str) -> Scalar {
300 let (assignment, name) = self.inputs[index].clone();
301
302 assert_eq!(path, name);
303
304 assignment
305 }
306
307 pub fn get_inputs(&self) -> &[(Scalar, String)] {
308 &self.inputs[..]
309 }
310
311 pub fn get(&mut self, path: &str) -> Scalar {
312 match self.named_objects.get(path) {
313 Some(NamedObject::Var(v)) => match v.get_unchecked() {
314 Index::Input(index) => self.inputs[index].0,
315 Index::Aux(index) => self.aux[index].0,
316 },
317 Some(e) => panic!(
318 "tried to get value of path `{}`, but `{:?}` exists there (not a variable)",
319 path, e
320 ),
321 _ => panic!("no variable exists at path: {}", path),
322 }
323 }
324
325 fn set_named_obj(&mut self, path: String, to: NamedObject) {
326 assert!(
327 !self.named_objects.contains_key(&path),
328 "tried to create object at existing path: {}",
329 path
330 );
331
332 self.named_objects.insert(path, to);
333 }
334}
335
336impl<Scalar: PrimeField> Comparable<Scalar> for TestConstraintSystem<Scalar> {
337 fn num_inputs(&self) -> usize {
338 self.num_inputs()
339 }
340 fn num_constraints(&self) -> usize {
341 self.num_constraints()
342 }
343
344 fn aux(&self) -> Vec<String> {
345 self.aux
346 .iter()
347 .map(|(_scalar, string)| string.to_string())
348 .collect()
349 }
350
351 fn inputs(&self) -> Vec<String> {
352 self.inputs
353 .iter()
354 .map(|(_scalar, string)| string.to_string())
355 .collect()
356 }
357
358 fn constraints(&self) -> &[crate::util_cs::Constraint<Scalar>] {
359 &self.constraints
360 }
361}
362
363fn compute_path(ns: &[String], this: &str) -> String {
364 assert!(
365 !this.chars().any(|a| a == '/'),
366 "'/' is not allowed in names"
367 );
368
369 if ns.is_empty() {
370 return this.to_string();
371 }
372
373 let name = ns.join("/");
374 format!("{}/{}", name, this)
375}
376
377impl<Scalar: PrimeField> ConstraintSystem<Scalar> for TestConstraintSystem<Scalar> {
378 type Root = Self;
379
380 fn alloc<F, A, AR>(&mut self, annotation: A, f: F) -> Result<Variable, SynthesisError>
381 where
382 F: FnOnce() -> Result<Scalar, SynthesisError>,
383 A: FnOnce() -> AR,
384 AR: Into<String>,
385 {
386 let index = self.aux.len();
387 let path = compute_path(&self.current_namespace, &annotation().into());
388 self.aux.push((f()?, path.clone()));
389 let var = Variable::new_unchecked(Index::Aux(index));
390 self.set_named_obj(path, NamedObject::Var(var));
391
392 Ok(var)
393 }
394
395 fn alloc_input<F, A, AR>(&mut self, annotation: A, f: F) -> Result<Variable, SynthesisError>
396 where
397 F: FnOnce() -> Result<Scalar, SynthesisError>,
398 A: FnOnce() -> AR,
399 AR: Into<String>,
400 {
401 let index = self.inputs.len();
402 let path = compute_path(&self.current_namespace, &annotation().into());
403 self.inputs.push((f()?, path.clone()));
404 let var = Variable::new_unchecked(Index::Input(index));
405 self.set_named_obj(path, NamedObject::Var(var));
406
407 Ok(var)
408 }
409
410 fn enforce<A, AR, LA, LB, LC>(&mut self, annotation: A, a: LA, b: LB, c: LC)
411 where
412 A: FnOnce() -> AR,
413 AR: Into<String>,
414 LA: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
415 LB: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
416 LC: FnOnce(LinearCombination<Scalar>) -> LinearCombination<Scalar>,
417 {
418 let path = compute_path(&self.current_namespace, &annotation().into());
419 let index = self.constraints.len();
420 self.set_named_obj(path.clone(), NamedObject::Constraint(index));
421
422 let a = a(LinearCombination::zero());
423 let b = b(LinearCombination::zero());
424 let c = c(LinearCombination::zero());
425
426 self.constraints.push((a, b, c, path));
427 }
428
429 fn push_namespace<NR, N>(&mut self, name_fn: N)
430 where
431 NR: Into<String>,
432 N: FnOnce() -> NR,
433 {
434 let name = name_fn().into();
435 let path = compute_path(&self.current_namespace, &name);
436 self.set_named_obj(path, NamedObject::Namespace);
437 self.current_namespace.push(name);
438 }
439
440 fn pop_namespace(&mut self) {
441 assert!(self.current_namespace.pop().is_some());
442 }
443
444 fn get_root(&mut self) -> &mut Self::Root {
445 self
446 }
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 use blstrs::Scalar as Fr;
454 use ff::Field;
455
456 #[test]
457 fn test_compute_path() {
458 assert_eq!(
459 compute_path(
460 &[
461 "hello".to_string(),
462 "world".to_string(),
463 "things".to_string()
464 ],
465 "thing"
466 ),
467 "hello/world/things/thing"
468 );
469 }
470
471 #[test]
472 fn test_cs() {
473 let mut cs = TestConstraintSystem::<Fr>::new();
474 assert!(cs.is_satisfied());
475 assert_eq!(cs.num_constraints(), 0);
476 let a = cs
477 .namespace(|| "a")
478 .alloc(|| "var", || Ok(Fr::from(10u64)))
479 .unwrap();
480 let b = cs
481 .namespace(|| "b")
482 .alloc(|| "var", || Ok(Fr::from(4u64)))
483 .unwrap();
484 let c = cs.alloc(|| "product", || Ok(Fr::from(40u64))).unwrap();
485
486 cs.enforce(|| "mult", |lc| lc + a, |lc| lc + b, |lc| lc + c);
487 assert!(cs.is_satisfied());
488 assert_eq!(cs.num_constraints(), 1);
489
490 cs.set("a/var", Fr::from(4u64));
491
492 let one = TestConstraintSystem::<Fr>::one();
493 cs.enforce(|| "eq", |lc| lc + a, |lc| lc + one, |lc| lc + b);
494
495 assert!(!cs.is_satisfied());
496 assert!(cs.which_is_unsatisfied() == Some("mult"));
497
498 assert!(cs.get("product") == Fr::from(40u64));
499
500 cs.set("product", Fr::from(16u64));
501 assert!(cs.is_satisfied());
502
503 {
504 let mut cs = cs.namespace(|| "test1");
505 let mut cs = cs.namespace(|| "test2");
506 cs.alloc(|| "hehe", || Ok(Fr::ONE)).unwrap();
507 }
508
509 assert!(cs.get("test1/test2/hehe") == Fr::ONE);
510 }
511}