1use std::{
5 cmp::Ordering,
6 collections::{BTreeMap, HashMap},
7};
8
9use crate::{
10 frontend::{ConstraintSystem, Index, LinearCombination, SynthesisError, Variable},
11 traits::Engine,
12};
13use core::fmt::Write;
14use ff::{Field, PrimeField};
15
16#[derive(Clone, Copy)]
17struct OrderedVariable(Variable);
18
19#[allow(unused)]
20#[derive(Debug)]
21enum NamedObject {
22 Constraint(usize),
23 Var(Variable),
24 Namespace,
25}
26
27impl Eq for OrderedVariable {}
28impl PartialEq for OrderedVariable {
29 fn eq(&self, other: &OrderedVariable) -> bool {
30 match (self.0.get_unchecked(), other.0.get_unchecked()) {
31 (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => a == b,
32 _ => false,
33 }
34 }
35}
36impl PartialOrd for OrderedVariable {
37 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
38 Some(self.cmp(other))
39 }
40}
41impl Ord for OrderedVariable {
42 fn cmp(&self, other: &Self) -> Ordering {
43 match (self.0.get_unchecked(), other.0.get_unchecked()) {
44 (Index::Input(ref a), Index::Input(ref b)) | (Index::Aux(ref a), Index::Aux(ref b)) => {
45 a.cmp(b)
46 }
47 (Index::Input(_), Index::Aux(_)) => Ordering::Less,
48 (Index::Aux(_), Index::Input(_)) => Ordering::Greater,
49 }
50 }
51}
52
53pub struct TestShapeCS<E: Engine> {
55 named_objects: HashMap<String, NamedObject>,
56 current_namespace: Vec<String>,
57 pub constraints: Vec<(
59 LinearCombination<E::Scalar>,
60 LinearCombination<E::Scalar>,
61 LinearCombination<E::Scalar>,
62 String,
63 )>,
64 inputs: Vec<String>,
65 aux: Vec<String>,
66}
67
68fn proc_lc<Scalar: PrimeField>(
69 terms: &LinearCombination<Scalar>,
70) -> BTreeMap<OrderedVariable, Scalar> {
71 let mut map = BTreeMap::new();
72 for (var, &coeff) in terms.iter() {
73 map
74 .entry(OrderedVariable(var))
75 .or_insert_with(|| Scalar::ZERO)
76 .add_assign(&coeff);
77 }
78
79 let mut to_remove = vec![];
81 for (var, coeff) in map.iter() {
82 if coeff.is_zero().into() {
83 to_remove.push(*var)
84 }
85 }
86
87 for var in to_remove {
88 map.remove(&var);
89 }
90
91 map
92}
93
94impl<E: Engine> TestShapeCS<E>
95where
96 E::Scalar: PrimeField,
97{
98 #[allow(unused)]
99 pub fn new() -> Self {
101 TestShapeCS::default()
102 }
103
104 pub fn num_constraints(&self) -> usize {
106 self.constraints.len()
107 }
108
109 pub fn num_inputs(&self) -> usize {
111 self.inputs.len()
112 }
113
114 pub fn num_aux(&self) -> usize {
116 self.aux.len()
117 }
118
119 #[allow(dead_code)]
121 pub fn pretty_print_list(&self) -> Vec<String> {
122 let mut result = Vec::new();
123
124 for input in &self.inputs {
125 result.push(format!("INPUT {input}"));
126 }
127 for aux in &self.aux {
128 result.push(format!("AUX {aux}"));
129 }
130
131 for (_a, _b, _c, name) in &self.constraints {
132 result.push(name.to_string());
133 }
134
135 result
136 }
137
138 #[allow(dead_code)]
140 pub fn pretty_print(&self) -> String {
141 let mut s = String::new();
142
143 for input in &self.inputs {
144 writeln!(s, "INPUT {}", &input).unwrap()
145 }
146
147 let negone = -<E::Scalar>::ONE;
148
149 let powers_of_two = (0..E::Scalar::NUM_BITS)
150 .map(|i| E::Scalar::from(2u64).pow_vartime([u64::from(i)]))
151 .collect::<Vec<_>>();
152
153 let pp = |s: &mut String, lc: &LinearCombination<E::Scalar>| {
154 s.push('(');
155 let mut is_first = true;
156 for (var, coeff) in proc_lc::<E::Scalar>(lc) {
157 if coeff == negone {
158 s.push_str(" - ")
159 } else if !is_first {
160 s.push_str(" + ")
161 }
162 is_first = false;
163
164 if coeff != <E::Scalar>::ONE && coeff != negone {
165 for (i, x) in powers_of_two.iter().enumerate() {
166 if x == &coeff {
167 write!(s, "2^{i} . ").unwrap();
168 break;
169 }
170 }
171
172 write!(s, "{coeff:?} . ").unwrap()
173 }
174
175 match var.0.get_unchecked() {
176 Index::Input(i) => {
177 write!(s, "`I{}`", &self.inputs[i]).unwrap();
178 }
179 Index::Aux(i) => {
180 write!(s, "`A{}`", &self.aux[i]).unwrap();
181 }
182 }
183 }
184 if is_first {
185 s.push('0');
187 }
188 s.push(')');
189 };
190
191 for (a, b, c, name) in &self.constraints {
192 s.push('\n');
193
194 write!(s, "{name}: ").unwrap();
195 pp(&mut s, a);
196 write!(s, " * ").unwrap();
197 pp(&mut s, b);
198 s.push_str(" = ");
199 pp(&mut s, c);
200 }
201
202 s.push('\n');
203
204 s
205 }
206
207 fn set_named_obj(&mut self, path: String, to: NamedObject) {
210 assert!(
211 !self.named_objects.contains_key(&path),
212 "tried to create object at existing path: {path}"
213 );
214
215 self.named_objects.insert(path, to);
216 }
217}
218
219impl<E: Engine> Default for TestShapeCS<E> {
220 fn default() -> Self {
221 let mut map = HashMap::new();
222 map.insert("ONE".into(), NamedObject::Var(TestShapeCS::<E>::one()));
223 TestShapeCS {
224 named_objects: map,
225 current_namespace: vec![],
226 constraints: vec![],
227 inputs: vec![String::from("ONE")],
228 aux: vec![],
229 }
230 }
231}
232
233impl<E: Engine> ConstraintSystem<E::Scalar> for TestShapeCS<E>
234where
235 E::Scalar: PrimeField,
236{
237 type Root = Self;
238
239 fn alloc<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
240 where
241 F: FnOnce() -> Result<E::Scalar, SynthesisError>,
242 A: FnOnce() -> AR,
243 AR: Into<String>,
244 {
245 let path = compute_path(&self.current_namespace, &annotation().into());
246 self.aux.push(path);
247
248 Ok(Variable::new_unchecked(Index::Aux(self.aux.len() - 1)))
249 }
250
251 fn alloc_input<F, A, AR>(&mut self, annotation: A, _f: F) -> Result<Variable, SynthesisError>
252 where
253 F: FnOnce() -> Result<E::Scalar, SynthesisError>,
254 A: FnOnce() -> AR,
255 AR: Into<String>,
256 {
257 let path = compute_path(&self.current_namespace, &annotation().into());
258 self.inputs.push(path);
259
260 Ok(Variable::new_unchecked(Index::Input(self.inputs.len() - 1)))
261 }
262
263 fn enforce<A, AR, LA, LB, LC>(&mut self, annotation: A, a: LA, b: LB, c: LC)
264 where
265 A: FnOnce() -> AR,
266 AR: Into<String>,
267 LA: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
268 LB: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
269 LC: FnOnce(LinearCombination<E::Scalar>) -> LinearCombination<E::Scalar>,
270 {
271 let path = compute_path(&self.current_namespace, &annotation().into());
272 let index = self.constraints.len();
273 self.set_named_obj(path.clone(), NamedObject::Constraint(index));
274
275 let a = a(LinearCombination::zero());
276 let b = b(LinearCombination::zero());
277 let c = c(LinearCombination::zero());
278
279 self.constraints.push((a, b, c, path));
280 }
281
282 fn push_namespace<NR, N>(&mut self, name_fn: N)
283 where
284 NR: Into<String>,
285 N: FnOnce() -> NR,
286 {
287 let name = name_fn().into();
288 let path = compute_path(&self.current_namespace, &name);
289 self.set_named_obj(path, NamedObject::Namespace);
290 self.current_namespace.push(name);
291 }
292
293 fn pop_namespace(&mut self) {
294 assert!(self.current_namespace.pop().is_some());
295 }
296
297 fn get_root(&mut self) -> &mut Self::Root {
298 self
299 }
300}
301
302fn compute_path(ns: &[String], this: &str) -> String {
303 assert!(
304 !this.chars().any(|a| a == '/'),
305 "'/' is not allowed in names"
306 );
307
308 let mut name = String::new();
309
310 let mut needs_separation = false;
311 for ns in ns.iter().chain(Some(this.to_string()).iter()) {
312 if needs_separation {
313 name += "/";
314 }
315
316 name += ns;
317 needs_separation = true;
318 }
319
320 name
321}