rings_snark/circuit/
mod.rs1use std::cell::RefCell;
4use std::iter::Iterator;
5use std::ops::Deref;
6use std::ops::DerefMut;
7use std::rc::Rc;
8use std::sync::Arc;
9
10use bellpepper_core::num::AllocatedNum;
11use bellpepper_core::ConstraintSystem;
12use bellpepper_core::LinearCombination;
13use bellpepper_core::SynthesisError;
14use ff::PrimeField;
15use nova_snark::traits::circuit::StepCircuit;
16use serde::Deserialize;
17use serde::Serialize;
18
19use crate::error::Result;
20use crate::r1cs::R1CS;
21use crate::witness::calculator::WitnessCalculator;
22
23pub mod bellman;
24pub mod bellpepper;
25
26#[derive(Serialize, Deserialize, Clone)]
28pub struct Input<F: PrimeField> {
29 pub input: Vec<(String, Vec<F>)>,
31}
32
33impl<F: PrimeField> AsRef<Input<F>> for Input<F> {
34 fn as_ref(&self) -> &Self {
35 self
36 }
37}
38
39impl<F: PrimeField> Deref for Input<F> {
40 type Target = Vec<(String, Vec<F>)>;
41 fn deref(&self) -> &Self::Target {
42 &self.input
43 }
44}
45
46impl<F: PrimeField> DerefMut for Input<F> {
47 fn deref_mut(&mut self) -> &mut Self::Target {
48 &mut self.input
49 }
50}
51
52impl<F: PrimeField> Input<F> {
53 pub fn flat(&self) -> Vec<F> {
55 self.input
56 .clone()
57 .into_iter()
58 .flat_map(|(_, v)| v)
59 .collect()
60 }
61
62 #[allow(clippy::len_without_is_empty)]
64 pub fn len(&self) -> usize {
65 self.input
66 .iter()
67 .flat_map(|(_, v)| v)
68 .collect::<Vec<&F>>()
69 .len()
70 }
71}
72
73impl<F: PrimeField> IntoIterator for Input<F> {
74 type Item = (String, Vec<F>);
75 type IntoIter = <Vec<Self::Item> as IntoIterator>::IntoIter;
76 fn into_iter(self) -> Self::IntoIter {
77 self.input.into_iter()
78 }
79}
80
81impl<'a, F: PrimeField> IntoIterator for &'a Input<F> {
82 type Item = <&'a Vec<(String, Vec<F>)> as IntoIterator>::Item;
83 type IntoIter = <&'a Vec<(String, Vec<F>)> as IntoIterator>::IntoIter;
84
85 fn into_iter(self) -> Self::IntoIter {
86 self.input.iter()
87 }
88}
89
90impl<F: PrimeField> From<Vec<(String, Vec<F>)>> for Input<F> {
91 fn from(input: Vec<(String, Vec<F>)>) -> Self {
92 Self { input }
93 }
94}
95
96#[derive(Serialize, Deserialize, Clone, Debug)]
98pub struct Circuit<F: PrimeField> {
99 r1cs: Arc<R1CS<F>>,
100 witness: Vec<F>,
101}
102
103impl<F: PrimeField> AsRef<Circuit<F>> for &Circuit<F> {
104 fn as_ref(&self) -> &Circuit<F> {
105 self
106 }
107}
108
109pub struct WasmCircuitGenerator<F: PrimeField> {
111 r1cs: Arc<R1CS<F>>,
112 calculator: Rc<RefCell<WitnessCalculator>>,
113}
114
115impl<F: PrimeField> WasmCircuitGenerator<F> {
116 pub fn new(r1cs: R1CS<F>, calculator: WitnessCalculator) -> Self {
118 Self {
119 r1cs: Arc::new(r1cs),
120 calculator: Rc::new(RefCell::new(calculator)),
121 }
122 }
123
124 pub fn gen_circuit(&self, input: Input<F>, sanity_check: bool) -> Result<Circuit<F>>
127 where F: PrimeField {
128 let mut calc = self.calculator.borrow_mut();
129 let witness: Vec<F> = calc.calculate_witness::<F>(input.to_vec(), sanity_check)?;
130 let circom = Circuit::<F> {
131 r1cs: self.r1cs.clone(),
132 witness,
133 };
134 Ok(circom)
135 }
136
137 pub fn gen_recursive_circuit(
140 &self,
141 public_input: Input<F>,
142 private_inputs: Vec<Input<F>>,
143 times: usize,
144 sanity_check: bool,
145 ) -> Result<Vec<Circuit<F>>>
146 where
147 F: PrimeField,
148 {
149 fn reshape<F: PrimeField>(input: &[(String, Vec<F>)], output: &[F]) -> Input<F> {
150 let mut ret = vec![];
151 let mut iter = output.iter();
152
153 for (val, vec) in input.iter() {
154 let size = vec.len();
155 let mut new_vec: Vec<F> = Vec::with_capacity(size);
156 for _ in 0..size {
157 if let Some(item) = iter.next() {
158 new_vec.push(*item);
159 } else {
160 panic!(
161 "Failed on reshape output {:?} as input format {:?}",
162 output, input
163 )
164 }
165 }
166 ret.push((val.clone(), new_vec));
167 }
168 ret.into()
169 }
170
171 let mut ret = vec![];
172 let mut calc = self.calculator.borrow_mut();
173 let mut latest_output: Input<F> = vec![].into();
174 for i in 0..times {
175 let witness: Vec<F> = if latest_output.is_empty() {
176 let mut input = public_input.clone();
177 if let Some(p) = private_inputs.get(i) {
178 input.input.extend(p.to_owned());
179 }
180 calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
181 } else {
182 let mut input = latest_output.clone();
183 if let Some(p) = private_inputs.get(i) {
184 input.input.extend(p.to_owned());
185 }
186 calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
187 };
188 let circom = Circuit::<F> {
189 r1cs: self.r1cs.clone(),
190 witness: witness.clone(),
191 };
192 log::trace!("witness: {:?}, r1cs: {:?}", witness, self.r1cs);
193 latest_output = reshape(&public_input, &circom.get_public_outputs());
194 ret.push(circom);
195 }
196 Ok(ret)
197 }
198}
199
200impl<F: PrimeField> Circuit<F> {
201 pub fn new(r1cs: Arc<R1CS<F>>, witness: Vec<F>) -> Self {
203 Self { r1cs, witness }
204 }
205
206 pub fn get_public_outputs(&self) -> Vec<F> {
208 let output_count = (self.r1cs.num_inputs - 1) / 2;
211 self.witness[1..output_count + 1].to_vec()
212 }
213
214 pub fn get_public_inputs(&self) -> Vec<F> {
216 let output_count = (self.r1cs.num_inputs - 1) / 2;
219 self.witness[1 + output_count..self.r1cs.num_inputs].to_vec()
220 }
221}
222
223impl<F: PrimeField> StepCircuit<F> for Circuit<F> {
228 fn arity(&self) -> usize {
229 (self.r1cs.num_inputs - 1) / 2
230 }
231
232 fn synthesize<CS: ConstraintSystem<F>>(
234 &self,
235 cs: &mut CS,
236 z: &[AllocatedNum<F>],
237 ) -> core::result::Result<Vec<AllocatedNum<F>>, SynthesisError> {
238 let mut vars: Vec<AllocatedNum<F>> = vec![];
239 let mut z_out: Vec<AllocatedNum<F>> = vec![];
240 let pub_output_count = (self.r1cs.num_inputs - 1) / 2;
241
242 for i in 1..self.r1cs.num_inputs {
243 let f: F = self.witness[i];
245 let v = AllocatedNum::alloc(cs.namespace(|| format!("public_{}", i)), || Ok(f))?;
246
247 vars.push(v.clone());
248 if i <= pub_output_count {
249 z_out.push(v);
251 }
252 }
253 for i in 0..self.r1cs.num_aux {
254 let f: F = self.witness[i + self.r1cs.num_inputs];
256 let v = AllocatedNum::alloc(cs.namespace(|| format!("aux_{}", i)), || Ok(f))?;
257 vars.push(v);
258 }
259
260 let make_lc = |lc_data: Vec<(usize, F)>| {
261 let res = lc_data.iter().fold(
262 LinearCombination::<F>::zero(),
263 |lc: LinearCombination<F>, (index, coeff)| {
264 lc + if *index > 0_usize {
265 (*coeff, vars[*index - 1].get_variable())
266 } else {
267 (*coeff, CS::one())
268 }
269 },
270 );
271 res
272 };
273 for (i, constraint) in self.r1cs.constraints.iter().enumerate() {
274 cs.enforce(
275 || format!("constraint {}", i),
276 |_| make_lc(constraint.0.clone()),
277 |_| make_lc(constraint.1.clone()),
278 |_| make_lc(constraint.2.clone()),
279 );
280 }
281
282 for i in (pub_output_count + 1)..self.r1cs.num_inputs {
283 cs.enforce(
284 || format!("pub input enforce {}", i),
285 |lc| lc + z[i - 1 - pub_output_count].get_variable(),
286 |lc| lc + CS::one(),
287 |lc| lc + vars[i - 1].get_variable(),
288 );
289 }
290
291 Ok(z_out)
292 }
293}