use std::cell::RefCell;
use std::iter::Iterator;
use std::ops::Deref;
use std::ops::DerefMut;
use std::rc::Rc;
use std::sync::Arc;
use bellpepper_core::num::AllocatedNum;
use bellpepper_core::ConstraintSystem;
use bellpepper_core::LinearCombination;
use bellpepper_core::SynthesisError;
use ff::PrimeField;
use nova_snark::traits::circuit::StepCircuit;
use serde::Deserialize;
use serde::Serialize;
use crate::error::Result;
use crate::r1cs::R1CS;
use crate::witness::calculator::WitnessCalculator;
pub mod bellman;
pub mod bellpepper;
#[derive(Serialize, Deserialize, Clone)]
pub struct Input<F: PrimeField> {
pub input: Vec<(String, Vec<F>)>,
}
impl<F: PrimeField> AsRef<Input<F>> for Input<F> {
fn as_ref(&self) -> &Self {
self
}
}
impl<F: PrimeField> Deref for Input<F> {
type Target = Vec<(String, Vec<F>)>;
fn deref(&self) -> &Self::Target {
&self.input
}
}
impl<F: PrimeField> DerefMut for Input<F> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.input
}
}
impl<F: PrimeField> Input<F> {
pub fn flat(&self) -> Vec<F> {
self.input
.clone()
.into_iter()
.flat_map(|(_, v)| v)
.collect()
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.input
.iter()
.flat_map(|(_, v)| v)
.collect::<Vec<&F>>()
.len()
}
}
impl<F: PrimeField> IntoIterator for Input<F> {
type Item = (String, Vec<F>);
type IntoIter = <Vec<Self::Item> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.input.into_iter()
}
}
impl<'a, F: PrimeField> IntoIterator for &'a Input<F> {
type Item = <&'a Vec<(String, Vec<F>)> as IntoIterator>::Item;
type IntoIter = <&'a Vec<(String, Vec<F>)> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.input.iter()
}
}
impl<F: PrimeField> From<Vec<(String, Vec<F>)>> for Input<F> {
fn from(input: Vec<(String, Vec<F>)>) -> Self {
Self { input }
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Circuit<F: PrimeField> {
r1cs: Arc<R1CS<F>>,
witness: Vec<F>,
}
impl<F: PrimeField> AsRef<Circuit<F>> for &Circuit<F> {
fn as_ref(&self) -> &Circuit<F> {
self
}
}
pub struct WasmCircuitGenerator<F: PrimeField> {
r1cs: Arc<R1CS<F>>,
calculator: Rc<RefCell<WitnessCalculator>>,
}
impl<F: PrimeField> WasmCircuitGenerator<F> {
pub fn new(r1cs: R1CS<F>, calculator: WitnessCalculator) -> Self {
Self {
r1cs: Arc::new(r1cs),
calculator: Rc::new(RefCell::new(calculator)),
}
}
pub fn gen_circuit(&self, input: Input<F>, sanity_check: bool) -> Result<Circuit<F>>
where F: PrimeField {
let mut calc = self.calculator.borrow_mut();
let witness: Vec<F> = calc.calculate_witness::<F>(input.to_vec(), sanity_check)?;
let circom = Circuit::<F> {
r1cs: self.r1cs.clone(),
witness,
};
Ok(circom)
}
pub fn gen_recursive_circuit(
&self,
public_input: Input<F>,
private_inputs: Vec<Input<F>>,
times: usize,
sanity_check: bool,
) -> Result<Vec<Circuit<F>>>
where
F: PrimeField,
{
fn reshape<F: PrimeField>(input: &[(String, Vec<F>)], output: &[F]) -> Input<F> {
let mut ret = vec![];
let mut iter = output.iter();
for (val, vec) in input.iter() {
let size = vec.len();
let mut new_vec: Vec<F> = Vec::with_capacity(size);
for _ in 0..size {
if let Some(item) = iter.next() {
new_vec.push(*item);
} else {
panic!(
"Failed on reshape output {:?} as input format {:?}",
output, input
)
}
}
ret.push((val.clone(), new_vec));
}
ret.into()
}
let mut ret = vec![];
let mut calc = self.calculator.borrow_mut();
let mut latest_output: Input<F> = vec![].into();
for i in 0..times {
let witness: Vec<F> = if latest_output.is_empty() {
let mut input = public_input.clone();
if let Some(p) = private_inputs.get(i) {
input.input.extend(p.to_owned());
}
calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
} else {
let mut input = latest_output.clone();
if let Some(p) = private_inputs.get(i) {
input.input.extend(p.to_owned());
}
calc.calculate_witness::<F>(input.to_vec(), sanity_check)?
};
let circom = Circuit::<F> {
r1cs: self.r1cs.clone(),
witness: witness.clone(),
};
log::trace!("witness: {:?}, r1cs: {:?}", witness, self.r1cs);
latest_output = reshape(&public_input, &circom.get_public_outputs());
ret.push(circom);
}
Ok(ret)
}
}
impl<F: PrimeField> Circuit<F> {
pub fn new(r1cs: Arc<R1CS<F>>, witness: Vec<F>) -> Self {
Self { r1cs, witness }
}
pub fn get_public_outputs(&self) -> Vec<F> {
let output_count = (self.r1cs.num_inputs - 1) / 2;
self.witness[1..output_count + 1].to_vec()
}
pub fn get_public_inputs(&self) -> Vec<F> {
let output_count = (self.r1cs.num_inputs - 1) / 2;
self.witness[1 + output_count..self.r1cs.num_inputs].to_vec()
}
}
impl<F: PrimeField> StepCircuit<F> for Circuit<F> {
fn arity(&self) -> usize {
(self.r1cs.num_inputs - 1) / 2
}
fn synthesize<CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
z: &[AllocatedNum<F>],
) -> core::result::Result<Vec<AllocatedNum<F>>, SynthesisError> {
let mut vars: Vec<AllocatedNum<F>> = vec![];
let mut z_out: Vec<AllocatedNum<F>> = vec![];
let pub_output_count = (self.r1cs.num_inputs - 1) / 2;
for i in 1..self.r1cs.num_inputs {
let f: F = self.witness[i];
let v = AllocatedNum::alloc(cs.namespace(|| format!("public_{}", i)), || Ok(f))?;
vars.push(v.clone());
if i <= pub_output_count {
z_out.push(v);
}
}
for i in 0..self.r1cs.num_aux {
let f: F = self.witness[i + self.r1cs.num_inputs];
let v = AllocatedNum::alloc(cs.namespace(|| format!("aux_{}", i)), || Ok(f))?;
vars.push(v);
}
let make_lc = |lc_data: Vec<(usize, F)>| {
let res = lc_data.iter().fold(
LinearCombination::<F>::zero(),
|lc: LinearCombination<F>, (index, coeff)| {
lc + if *index > 0_usize {
(*coeff, vars[*index - 1].get_variable())
} else {
(*coeff, CS::one())
}
},
);
res
};
for (i, constraint) in self.r1cs.constraints.iter().enumerate() {
cs.enforce(
|| format!("constraint {}", i),
|_| make_lc(constraint.0.clone()),
|_| make_lc(constraint.1.clone()),
|_| make_lc(constraint.2.clone()),
);
}
for i in (pub_output_count + 1)..self.r1cs.num_inputs {
cs.enforce(
|| format!("pub input enforce {}", i),
|lc| lc + z[i - 1 - pub_output_count].get_variable(),
|lc| lc + CS::one(),
|lc| lc + vars[i - 1].get_variable(),
);
}
Ok(z_out)
}
}