use ark_ec::{AffineRepr, CurveGroup, PrimeGroup, VariableBaseMSM};
use ark_ff::Field;
use ark_std::log2;
use rand::rngs::OsRng;
use spongefish::codecs::arkworks_algebra::{
CommonGroupToUnit, DomainSeparator, FieldDomainSeparator, FieldToUnitDeserialize,
FieldToUnitSerialize, GroupDomainSeparator, GroupToUnitDeserialize, GroupToUnitSerialize,
ProofError, ProofResult, ProverState, UnitToField, VerifierState,
};
trait BulletproofDomainSeparator<G: CurveGroup> {
fn bulletproof_statement(self) -> Self;
fn add_bulletproof(self, len: usize) -> Self;
}
impl<G> BulletproofDomainSeparator<G> for DomainSeparator
where
G: CurveGroup,
Self: GroupDomainSeparator<G> + FieldDomainSeparator<G::ScalarField>,
{
fn bulletproof_statement(self) -> Self {
self.add_points(1, "Pedersen commitment")
}
fn add_bulletproof(mut self, len: usize) -> Self {
for _ in 0..log2(len) {
self = self
.add_points(2, "round-message")
.challenge_scalars(1, "challenge");
}
self.add_scalars(2, "final-message")
}
}
fn prove<'a, G: CurveGroup>(
prover_state: &'a mut ProverState,
generators: (&[G::Affine], &[G::Affine], &G::Affine),
statement: &G, witness: (&[G::ScalarField], &[G::ScalarField]),
) -> ProofResult<&'a [u8]>
where
ProverState: GroupToUnitSerialize<G> + UnitToField<G::ScalarField>,
{
assert_eq!(witness.0.len(), witness.1.len());
if witness.0.len() == 1 {
assert_eq!(generators.0.len(), 1);
prover_state.add_scalars(&[witness.0[0], witness.1[0]])?;
return Ok(prover_state.narg_string());
}
let n = witness.0.len() / 2;
let (a_left, a_right) = witness.0.split_at(n);
let (b_left, b_right) = witness.1.split_at(n);
let (g_left, g_right) = generators.0.split_at(n);
let (h_left, h_right) = generators.1.split_at(n);
let u = *generators.2;
let left = u * dot_prod(a_left, b_right)
+ G::msm_unchecked(g_right, a_left)
+ G::msm_unchecked(h_left, b_right);
let right = u * dot_prod(a_right, b_left)
+ G::msm_unchecked(g_left, a_right)
+ G::msm_unchecked(h_right, b_left);
prover_state.add_points(&[left, right])?;
let [x]: [G::ScalarField; 1] = prover_state.challenge_scalars()?;
let x_inv = x.inverse().expect("You just won the lottery!");
let new_g = fold_generators(g_left, g_right, &x_inv, &x);
let new_h = fold_generators(h_left, h_right, &x, &x_inv);
let new_generators = (&new_g[..], &new_h[..], generators.2);
let new_a = fold(a_left, a_right, &x, &x_inv);
let new_b = fold(b_left, b_right, &x_inv, &x);
let new_witness = (&new_a[..], &new_b[..]);
let new_statement = *statement + left * x.square() + right * x_inv.square();
let bulletproof = prove(prover_state, new_generators, &new_statement, new_witness)?;
Ok(bulletproof)
}
fn verify<G: CurveGroup>(
verifier_state: &mut VerifierState,
generators: (&[G::Affine], &[G::Affine], &G::Affine),
mut n: usize,
statement: &G,
) -> ProofResult<()>
where
for<'a> VerifierState<'a>: GroupToUnitDeserialize<G> + UnitToField<G::ScalarField>,
{
let mut g = generators.0.to_vec();
let mut h = generators.1.to_vec();
let u = *generators.2;
let mut statement = *statement;
while n != 1 {
let [left, right]: [G; 2] = verifier_state.next_points().unwrap();
n /= 2;
let (g_left, g_right) = g.split_at(n);
let (h_left, h_right) = h.split_at(n);
let [x]: [G::ScalarField; 1] = verifier_state.challenge_scalars().unwrap();
let x_inv = x.inverse().expect("You just won the lottery!");
g = fold_generators(g_left, g_right, &x_inv, &x);
h = fold_generators(h_left, h_right, &x, &x_inv);
statement = statement + left * x.square() + right * x_inv.square();
}
let [a, b]: [G::ScalarField; 2] = verifier_state.next_scalars().unwrap();
let c = a * b;
if (g[0] * a + h[0] * b + u * c - statement).is_zero() {
Ok(())
} else {
Err(ProofError::InvalidProof)
}
}
fn fold_generators<A: AffineRepr>(
a: &[A],
b: &[A],
x: &A::ScalarField,
y: &A::ScalarField,
) -> Vec<A> {
a.iter()
.zip(b.iter())
.map(|(&a, &b)| (a * x + b * y).into_affine())
.collect()
}
fn dot_prod<F: Field>(a: &[F], b: &[F]) -> F {
a.iter().zip(b.iter()).map(|(&a, &b)| a * b).sum()
}
fn fold<F: Field>(a: &[F], b: &[F], x: &F, y: &F) -> Vec<F> {
a.iter()
.zip(b.iter())
.map(|(&a, &b)| a * x + b * y)
.collect()
}
fn main() {
use ark_curve25519::EdwardsProjective as G;
use ark_std::UniformRand;
type F = <G as PrimeGroup>::ScalarField;
type GAffine = <G as CurveGroup>::Affine;
let size = 8;
let domain_separator = DomainSeparator::new("example.com");
let domain_separator =
BulletproofDomainSeparator::<G>::bulletproof_statement(domain_separator).ratchet();
let domain_separator = BulletproofDomainSeparator::<G>::add_bulletproof(domain_separator, size);
let a = (0..size).map(|x| F::from(x as u32)).collect::<Vec<_>>();
let b = (0..size)
.map(|x| F::from(x as u32 + 42))
.collect::<Vec<_>>();
let ab = dot_prod(&a, &b);
let g = (0..a.len())
.map(|_| GAffine::rand(&mut OsRng))
.collect::<Vec<_>>();
let h = (0..b.len())
.map(|_| GAffine::rand(&mut OsRng))
.collect::<Vec<_>>();
let u = GAffine::rand(&mut OsRng);
let generators = (&g[..], &h[..], &u);
let statement = G::msm_unchecked(&g, &a) + G::msm_unchecked(&h, &b) + u * ab;
let witness = (&a[..], &b[..]);
let mut prover_state = domain_separator.to_prover_state();
prover_state.public_points(&[statement]).unwrap();
prover_state.ratchet().unwrap();
let proof = prove(&mut prover_state, generators, &statement, witness).expect("Error proving");
println!(
"Here's a bulletproof for {} elements:\n{}",
size,
hex::encode(proof)
);
let mut verifier_state = domain_separator.to_verifier_state(proof);
verifier_state.public_points(&[statement]).unwrap();
verifier_state.ratchet().unwrap();
verify(&mut verifier_state, generators, size, &statement).expect("Invalid proof");
}