1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
use std::collections::HashSet;
use std::iter::FromIterator;

use bellman::{Circuit, ConstraintSystem, LinearCombination, SynthesisError, Variable};
use bellman::Index::{Aux, Input};
use pairing::Engine;
use r1cs::{Constraint, Element, Expression, Field, Gadget, Wire};

pub struct WrappedCircuit<F: Field, E: Engine> {
    gadget: Gadget<F>,
    public_inputs: Vec<Wire>,
    convert_field: fn(&Element<F>) -> E::Fr,
}

impl<F: Field, E: Engine> Circuit<E> for WrappedCircuit<F, E> {
    fn synthesize<CS: ConstraintSystem<E>>(self, cs: &mut CS) -> Result<(), SynthesisError> {
        let WrappedCircuit { gadget, public_inputs, convert_field } = self;
        let public_inputs = HashSet::from_iter(public_inputs);
        for constraint in gadget.constraints {
            let Constraint { a, b, c } = constraint;
            cs.enforce(
                || "generated by r1cs-bellman",
                |_lc| convert_lc(a, convert_field, &public_inputs),
                |_lc| convert_lc(b, convert_field, &public_inputs),
                |_lc| convert_lc(c, convert_field, &public_inputs),
            );
        }
        Ok(())
    }
}

fn convert_lc<F: Field, E: Engine>(
    exp: Expression<F>,
    convert_field: fn(&Element<F>) -> E::Fr,
    public_inputs: &HashSet<Wire>,
) -> LinearCombination<E> {
    // This is inefficient, but bellman doesn't expose a LinearCombination constructor taking an
    // entire variable/coefficient map, so we have to build one up with repeated addition.
    let mut sum = LinearCombination::zero();
    for (wire, coeff) in exp.coefficients() {
        let fr = convert_field(coeff);
        let var = convert_wire(*wire, public_inputs);
        sum = sum + (fr, var);
    }
    sum
}

fn convert_wire(wire: Wire, public_inputs: &HashSet<Wire>) -> Variable {
    let wire_index = wire.index as usize;
    let index = if public_inputs.contains(&wire) {
        Input(wire_index)
    } else {
        Aux(wire_index)
    };
    Variable::new_unchecked(index)
}

#[cfg(test)]
mod tests {
    use bellman::groth16::{create_random_proof, generate_random_parameters, prepare_verifying_key, Proof, verify_proof};
    use num::{BigUint, Integer, One, ToPrimitive};
    use pairing::bls12_381::{Bls12, FrRepr};
    use pairing::PrimeField;
    use r1cs::{Bls12_381, Element, Gadget, GadgetBuilder};
    use rand::thread_rng;

    use crate::WrappedCircuit;

    #[test]
    fn random_proof() {
        let rng = &mut thread_rng();

        // Generate random parameters.
        let circuit = build_circuit();
        let params = generate_random_parameters::<Bls12, _, _>(circuit, rng).unwrap();
        let pvk = prepare_verifying_key(&params.vk);

        // Generate a random proof.
        let circuit = build_circuit();
        let proof = create_random_proof(circuit, &params, rng).unwrap();

        // Serialize and deserialize the proof.
        let mut proof_out = vec![];
        proof.write(&mut proof_out).unwrap();
        let proof = Proof::read(&proof_out[..]).unwrap();

        // Verify the proof.
        let public_inputs = &[convert_bls12_381(&Element::from(42u8))];
        assert!(verify_proof(&pvk, &proof, public_inputs).unwrap());
    }

    fn build_circuit() -> WrappedCircuit<r1cs::Bls12_381, pairing::bls12_381::Bls12> {
        let mut builder = GadgetBuilder::<Bls12_381>::new();
        let x = builder.wire();
        let gadget = builder.build();

        WrappedCircuit {
            gadget,
            public_inputs: vec![x],
            convert_field: convert_bls12_381,
        }
    }

    fn convert_bls12_381(n: &Element<r1cs::Bls12_381>) -> pairing::bls12_381::Fr {
        let n = n.to_biguint();
        // Bls12::Fr::FrRepr's chunks are little endian.
        let u64_size = BigUint::one() << 64;
        let chunks = [
            n.mod_floor(&u64_size).to_u64().unwrap(),
            (n >> 64).mod_floor(&u64_size).to_u64().unwrap(),
            (n >> 64 * 2).mod_floor(&u64_size).to_u64().unwrap(),
            (n >> 64 * 3).mod_floor(&u64_size).to_u64().unwrap(),
        ];
        pairing::bls12_381::Fr::from_repr(FrRepr(chunks)).unwrap()
    }
}