rings_snark/circuit/
mod.rs

1//! Implementation of Circuit
2//! ==========================
3use std::cell::RefCell;
4use std::iter::Iterator;
5use std::ops::Deref;
6use std::ops::DerefMut;
7use std::rc::Rc;
8use std::sync::Arc;
9
10use bellpepper_core::num::AllocatedNum;
11use bellpepper_core::ConstraintSystem;
12use bellpepper_core::LinearCombination;
13use bellpepper_core::SynthesisError;
14use ff::PrimeField;
15use nova_snark::traits::circuit::StepCircuit;
16use serde::Deserialize;
17use serde::Serialize;
18
19use crate::error::Result;
20use crate::r1cs::R1CS;
21use crate::witness::calculator::WitnessCalculator;
22
23pub mod bellman;
24pub mod bellpepper;
25
26/// Input of witness
27#[derive(Serialize, Deserialize, Clone)]
28pub struct Input<F: PrimeField> {
29    /// inner input
30    pub input: Vec<(String, Vec<F>)>,
31}
32
33impl<F: PrimeField> AsRef<Input<F>> for Input<F> {
34    fn as_ref(&self) -> &Self {
35        self
36    }
37}
38
39impl<F: PrimeField> Deref for Input<F> {
40    type Target = Vec<(String, Vec<F>)>;
41    fn deref(&self) -> &Self::Target {
42        &self.input
43    }
44}
45
46impl<F: PrimeField> DerefMut for Input<F> {
47    fn deref_mut(&mut self) -> &mut Self::Target {
48        &mut self.input
49    }
50}
51
52impl<F: PrimeField> Input<F> {
53    /// flat input
54    pub fn flat(&self) -> Vec<F> {
55        self.input
56            .clone()
57            .into_iter()
58            .flat_map(|(_, v)| v)
59            .collect()
60    }
61
62    /// Get flat length of input
63    #[allow(clippy::len_without_is_empty)]
64    pub fn len(&self) -> usize {
65        self.input
66            .iter()
67            .flat_map(|(_, v)| v)
68            .collect::<Vec<&F>>()
69            .len()
70    }
71}
72
73impl<F: PrimeField> IntoIterator for Input<F> {
74    type Item = (String, Vec<F>);
75    type IntoIter = <Vec<Self::Item> as IntoIterator>::IntoIter;
76    fn into_iter(self) -> Self::IntoIter {
77        self.input.into_iter()
78    }
79}
80
81impl<'a, F: PrimeField> IntoIterator for &'a Input<F> {
82    type Item = <&'a Vec<(String, Vec<F>)> as IntoIterator>::Item;
83    type IntoIter = <&'a Vec<(String, Vec<F>)> as IntoIterator>::IntoIter;
84
85    fn into_iter(self) -> Self::IntoIter {
86        self.input.iter()
87    }
88}
89
90impl<F: PrimeField> From<Vec<(String, Vec<F>)>> for Input<F> {
91    fn from(input: Vec<(String, Vec<F>)>) -> Self {
92        Self { input }
93    }
94}
95
96/// Circuit
97#[derive(Serialize, Deserialize, Clone, Debug)]
98pub struct Circuit<F: PrimeField> {
99    r1cs: Arc<R1CS<F>>,
100    witness: Vec<F>,
101}
102
103impl<F: PrimeField> AsRef<Circuit<F>> for &Circuit<F> {
104    fn as_ref(&self) -> &Circuit<F> {
105        self
106    }
107}
108
109/// Wasm based circuit generator
110pub struct WasmCircuitGenerator<F: PrimeField> {
111    r1cs: Arc<R1CS<F>>,
112    calculator: Rc<RefCell<WitnessCalculator>>,
113}
114
115impl<F: PrimeField> WasmCircuitGenerator<F> {
116    /// Crate new instance
117    pub fn new(r1cs: R1CS<F>, calculator: WitnessCalculator) -> Self {
118        Self {
119            r1cs: Arc::new(r1cs),
120            calculator: Rc::new(RefCell::new(calculator)),
121        }
122    }
123
124    /// Generate iterator circuit list
125    /// Which iterate inputs and generate circuit
126    pub fn gen_circuit(&self, input: Input<F>, sanity_check: bool) -> Result<Circuit<F>>
127    where F: PrimeField {
128        let mut calc = self.calculator.borrow_mut();
129        let witness: Vec<F> = calc.calculate_witness::<F>(input.to_vec(), sanity_check)?;
130        let circom = Circuit::<F> {
131            r1cs: self.r1cs.clone(),
132            witness,
133        };
134        Ok(circom)
135    }
136
137    /// Generate recursive circuit list
138    /// Which use $output_{i-1}$ as $input_i$
139    pub fn gen_recursive_circuit(
140        &self,
141        public_input: Input<F>,
142        private_inputs: Vec<Input<F>>,
143        times: usize,
144        sanity_check: bool,
145    ) -> Result<Vec<Circuit<F>>>
146    where
147        F: PrimeField,
148    {
149        fn reshape<F: PrimeField>(input: &[(String, Vec<F>)], output: &[F]) -> Input<F> {
150            let mut ret = vec![];
151            let mut iter = output.iter();
152
153            for (val, vec) in input.iter() {
154                let size = vec.len();
155                let mut new_vec: Vec<F> = Vec::with_capacity(size);
156                for _ in 0..size {
157                    if let Some(item) = iter.next() {
158                        new_vec.push(*item);
159                    } else {
160                        panic!(
161                            "Failed on reshape output {:?} as input format {:?}",
162                            output, input
163                        )
164                    }
165                }
166                ret.push((val.clone(), new_vec));
167            }
168            ret.into()
169        }
170
171        let mut ret = vec![];
172        let mut calc = self.calculator.borrow_mut();
173        let mut latest_output: Input<F> = vec![].into();
174        for i in 0..times {
175            let witness: Vec<F> = if latest_output.is_empty() {
176                let mut input = public_input.clone();
177                if let Some(p) = private_inputs.get(i) {
178                    input.input.extend(p.to_owned());
179                }
180                calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
181            } else {
182                let mut input = latest_output.clone();
183                if let Some(p) = private_inputs.get(i) {
184                    input.input.extend(p.to_owned());
185                }
186                calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
187            };
188            let circom = Circuit::<F> {
189                r1cs: self.r1cs.clone(),
190                witness: witness.clone(),
191            };
192            log::trace!("witness: {:?}, r1cs: {:?}", witness, self.r1cs);
193            latest_output = reshape(&public_input, &circom.get_public_outputs());
194            ret.push(circom);
195        }
196        Ok(ret)
197    }
198}
199
200impl<F: PrimeField> Circuit<F> {
201    /// Create a new instance
202    pub fn new(r1cs: Arc<R1CS<F>>, witness: Vec<F>) -> Self {
203        Self { r1cs, witness }
204    }
205
206    /// get public outputs from witness
207    pub fn get_public_outputs(&self) -> Vec<F> {
208        // witness: <1> <Outputs> <Inputs> <Auxs>
209        // NOTE: assumes exactly half of the (public inputs + outputs) are outputs
210        let output_count = (self.r1cs.num_inputs - 1) / 2;
211        self.witness[1..output_count + 1].to_vec()
212    }
213
214    /// get public inputs from witness
215    pub fn get_public_inputs(&self) -> Vec<F> {
216        // witness: <1> <Outputs> <Inputs> <Auxs>
217        // NOTE: assumes exactly half of the (public inputs + outputs) are outputs
218        let output_count = (self.r1cs.num_inputs - 1) / 2;
219        self.witness[1 + output_count..self.r1cs.num_inputs].to_vec()
220    }
221}
222
223/// Implement StepCircuit for our Circuit
224/// Reference work: Nota-Scotia :: CircomCircuit
225/// `<https://github.com/nalinbhardwaj/Nova-Scotia/blob/main/src/circom/circuit.rs>`
226/// NOTE: assumes exactly half of the (public inputs + outputs) are outputs
227impl<F: PrimeField> StepCircuit<F> for Circuit<F> {
228    fn arity(&self) -> usize {
229        (self.r1cs.num_inputs - 1) / 2
230    }
231
232    /// Simple synthesize
233    fn synthesize<CS: ConstraintSystem<F>>(
234        &self,
235        cs: &mut CS,
236        z: &[AllocatedNum<F>],
237    ) -> core::result::Result<Vec<AllocatedNum<F>>, SynthesisError> {
238        let mut vars: Vec<AllocatedNum<F>> = vec![];
239        let mut z_out: Vec<AllocatedNum<F>> = vec![];
240        let pub_output_count = (self.r1cs.num_inputs - 1) / 2;
241
242        for i in 1..self.r1cs.num_inputs {
243            // Public inputs do not exist, so we alloc, and later enforce equality from z values
244            let f: F = self.witness[i];
245            let v = AllocatedNum::alloc(cs.namespace(|| format!("public_{}", i)), || Ok(f))?;
246
247            vars.push(v.clone());
248            if i <= pub_output_count {
249                // public output
250                z_out.push(v);
251            }
252        }
253        for i in 0..self.r1cs.num_aux {
254            // Private witness trace
255            let f: F = self.witness[i + self.r1cs.num_inputs];
256            let v = AllocatedNum::alloc(cs.namespace(|| format!("aux_{}", i)), || Ok(f))?;
257            vars.push(v);
258        }
259
260        let make_lc = |lc_data: Vec<(usize, F)>| {
261            let res = lc_data.iter().fold(
262                LinearCombination::<F>::zero(),
263                |lc: LinearCombination<F>, (index, coeff)| {
264                    lc + if *index > 0_usize {
265                        (*coeff, vars[*index - 1].get_variable())
266                    } else {
267                        (*coeff, CS::one())
268                    }
269                },
270            );
271            res
272        };
273        for (i, constraint) in self.r1cs.constraints.iter().enumerate() {
274            cs.enforce(
275                || format!("constraint {}", i),
276                |_| make_lc(constraint.0.clone()),
277                |_| make_lc(constraint.1.clone()),
278                |_| make_lc(constraint.2.clone()),
279            );
280        }
281
282        for i in (pub_output_count + 1)..self.r1cs.num_inputs {
283            cs.enforce(
284                || format!("pub input enforce {}", i),
285                |lc| lc + z[i - 1 - pub_output_count].get_variable(),
286                |lc| lc + CS::one(),
287                |lc| lc + vars[i - 1].get_variable(),
288            );
289        }
290
291        Ok(z_out)
292    }
293}