Skip to main content

nova_snark/frontend/
test_shape_cs.rs

1//! Support for generating R1CS shape using bellpepper.
2//! `TestShapeCS` implements a superset of `ShapeCS`, adding non-trivial namespace support for use in testing.
3
4use std::{
5  cmp::Ordering,
6  collections::{BTreeMap, HashMap},
7};
8
9use crate::{
10  frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable},
11  traits::Engine,
12};
13use core::fmt::Write;
14use ff::{Field, PrimeField};
15
16#[derive(Clone, Copy)]
17struct OrderedVariable(Variable);
18
19#[allow(unused)]
20#[derive(Debug)]
21enum NamedObject {
22  Constraint(usize),
23  Var(Variable),
24  Namespace,
25}
26
27impl Eq for OrderedVariable {}
28impl PartialEq for OrderedVariable {
29  fn eq(&self, other: &OrderedVariable) -> bool {
30    match (self.0.get_unchecked(), other.0.get_unchecked()) {
31      (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => a == b,
32      _ => false,
33    }
34  }
35}
36impl PartialOrd for OrderedVariable {
37  fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38    Some(self.cmp(other))
39  }
40}
41impl Ord for OrderedVariable {
42  fn cmp(&self, other: &Self) -> Ordering {
43    match (self.0.get_unchecked(), other.0.get_unchecked()) {
44      (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => {
45        a.cmp(b)
46      }
47      (Index::Input(_), Index::Aux(_)) => Ordering::Less,
48      (Index::Aux(_), Index::Input(_)) => Ordering::Greater,
49    }
50  }
51}
52
53/// `TestShapeCS` is a `ConstraintSystem` for creating `R1CSShape`s for a circuit.
54pub struct TestShapeCS<E: Engine> {
55  named_objects: HashMap<String, NamedObject>,
56  current_namespace: Vec<String>,
57  /// All constraints added to the `TestShapeCS`.
58  pub constraints: Vec<(
59    LinearCombination<E::Scalar>,
60    LinearCombination<E::Scalar>,
61    LinearCombination<E::Scalar>,
62    String,
63  )>,
64  inputs: Vec<String>,
65  aux: Vec<String>,
66}
67
68fn proc_lc<Scalar: PrimeField>(
69  terms: &LinearCombination<Scalar>,
70) -> BTreeMap<OrderedVariable, Scalar> {
71  let mut map = BTreeMap::new();
72  for (var, &coeff) in terms.iter() {
73    map
74      .entry(OrderedVariable(var))
75      .or_insert_with(|| Scalar::ZERO)
76      .add_assign(&coeff);
77  }
78
79  // Remove terms that have a zero coefficient to normalize
80  let mut to_remove = vec![];
81  for (var, coeff) in map.iter() {
82    if coeff.is_zero().into() {
83      to_remove.push(*var)
84    }
85  }
86
87  for var in to_remove {
88    map.remove(&var);
89  }
90
91  map
92}
93
94impl<E: Engine> TestShapeCS<E>
95where
96  E::Scalar: PrimeField,
97{
98  #[allow(unused)]
99  /// Create a new, default `TestShapeCS`,
100  pub fn new() -> Self {
101    TestShapeCS::default()
102  }
103
104  /// Returns the number of constraints defined for this `TestShapeCS`.
105  pub fn num_constraints(&self) -> usize {
106    self.constraints.len()
107  }
108
109  /// Returns the number of inputs defined for this `TestShapeCS`.
110  pub fn num_inputs(&self) -> usize {
111    self.inputs.len()
112  }
113
114  /// Returns the number of aux inputs defined for this `TestShapeCS`.
115  pub fn num_aux(&self) -> usize {
116    self.aux.len()
117  }
118
119  /// Print all public inputs, aux inputs, and constraint names.
120  #[allow(dead_code)]
121  pub fn pretty_print_list(&self) -> Vec<String> {
122    let mut result = Vec::new();
123
124    for input in &self.inputs {
125      result.push(format!("INPUT {input}"));
126    }
127    for aux in &self.aux {
128      result.push(format!("AUX {aux}"));
129    }
130
131    for (_a, _b, _c, name) in &self.constraints {
132      result.push(name.to_string());
133    }
134
135    result
136  }
137
138  /// Print all inputs and a detailed representation of each constraint.
139  #[allow(dead_code)]
140  pub fn pretty_print(&self) -> String {
141    let mut s = String::new();
142
143    for input in &self.inputs {
144      writeln!(s, "INPUT {}", &input).unwrap()
145    }
146
147    let negone = -<E::Scalar>::ONE;
148
149    let powers_of_two = (0..E::Scalar::NUM_BITS)
150      .map(|i| E::Scalar::from(2u64).pow_vartime([u64::from(i)]))
151      .collect::<Vec<_>>();
152
153    let pp = |s: &mut String, lc: &LinearCombination<E::Scalar>| {
154      s.push('(');
155      let mut is_first = true;
156      for (var, coeff) in proc_lc::<E::Scalar>(lc) {
157        if coeff == negone {
158          s.push_str(" - ")
159        } else if !is_first {
160          s.push_str(" + ")
161        }
162        is_first = false;
163
164        if coeff != <E::Scalar>::ONE && coeff != negone {
165          for (i, x) in powers_of_two.iter().enumerate() {
166            if x == &coeff {
167              write!(s, "2^{i} . ").unwrap();
168              break;
169            }
170          }
171
172          write!(s, "{coeff:?} . ").unwrap()
173        }
174
175        match var.0.get_unchecked() {
176          Index::Input(i) => {
177            write!(s, "`I{}`", &self.inputs[i]).unwrap();
178          }
179          Index::Aux(i) => {
180            write!(s, "`A{}`", &self.aux[i]).unwrap();
181          }
182        }
183      }
184      if is_first {
185        // Nothing was visited, print 0.
186        s.push('0');
187      }
188      s.push(')');
189    };
190
191    for (a, b, c, name) in &self.constraints {
192      s.push('\n');
193
194      write!(s, "{name}: ").unwrap();
195      pp(&mut s, a);
196      write!(s, " * ").unwrap();
197      pp(&mut s, b);
198      s.push_str(" = ");
199      pp(&mut s, c);
200    }
201
202    s.push('\n');
203
204    s
205  }
206
207  /// Associate `NamedObject` with `path`.
208  /// `path` must not already have an associated object.
209  fn set_named_obj(&mut self, path: String, to: NamedObject) {
210    assert!(
211      !self.named_objects.contains_key(&path),
212      "tried to create object at existing path: {path}"
213    );
214
215    self.named_objects.insert(path, to);
216  }
217}
218
219impl<E: Engine> Default for TestShapeCS<E> {
220  fn default() -> Self {
221    let mut map = HashMap::new();
222    map.insert("ONE".into(), NamedObject::Var(TestShapeCS::<E>::one()));
223    TestShapeCS {
224      named_objects: map,
225      current_namespace: vec![],
226      constraints: vec![],
227      inputs: vec![String::from("ONE")],
228      aux: vec![],
229    }
230  }
231}
232
233impl<E: Engine> ConstraintSystem<E::Scalar> for TestShapeCS<E>
234where
235  E::Scalar: PrimeField,
236{
237  type Root = Self;
238
239  fn alloc<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
240  where
241    F: FnOnce() -> Result<E::Scalar, SynthesisError>,
242    A: FnOnce() -> AR,
243    AR: Into<String>,
244  {
245    let path = compute_path(&self.current_namespace, &annotation().into());
246    self.aux.push(path);
247
248    Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1)))
249  }
250
251  fn alloc_input<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
252  where
253    F: FnOnce() -> Result<E::Scalar, SynthesisError>,
254    A: FnOnce() -> AR,
255    AR: Into<String>,
256  {
257    let path = compute_path(&self.current_namespace, &annotation().into());
258    self.inputs.push(path);
259
260    Ok(Variable::new_unchecked(Index::Input(self.inputs.len() - 1)))
261  }
262
263  fn enforce<A, AR, LA, LB, LC>(&mut self, annotation: A, a: LA, b: LB, c: LC)
264  where
265    A: FnOnce() -> AR,
266    AR: Into<String>,
267    LA: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
268    LB: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
269    LC: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
270  {
271    let path = compute_path(&self.current_namespace, &annotation().into());
272    let index = self.constraints.len();
273    self.set_named_obj(path.clone(), NamedObject::Constraint(index));
274
275    let a = a(LinearCombination::zero());
276    let b = b(LinearCombination::zero());
277    let c = c(LinearCombination::zero());
278
279    self.constraints.push((a, b, c, path));
280  }
281
282  fn push_namespace<NR, N>(&mut self, name_fn: N)
283  where
284    NR: Into<String>,
285    N: FnOnce() -> NR,
286  {
287    let name = name_fn().into();
288    let path = compute_path(&self.current_namespace, &name);
289    self.set_named_obj(path, NamedObject::Namespace);
290    self.current_namespace.push(name);
291  }
292
293  fn pop_namespace(&mut self) {
294    assert!(self.current_namespace.pop().is_some());
295  }
296
297  fn get_root(&mut self) -> &mut Self::Root {
298    self
299  }
300}
301
302fn compute_path(ns: &[String], this: &str) -> String {
303  assert!(
304    !this.chars().any(|a| a == '/'),
305    "'/' is not allowed in names"
306  );
307
308  let mut name = String::new();
309
310  let mut needs_separation = false;
311  for ns in ns.iter().chain(Some(this.to_string()).iter()) {
312    if needs_separation {
313      name += "/";
314    }
315
316    name += ns;
317    needs_separation = true;
318  }
319
320  name
321}