use super::{inner_product_proof::*, utils::*};
use crate::random_oracle::TranscriptProtocol;
use crate::{
common::*,
curve_arithmetic::{multiexp, Curve, Field, MultiExp},
id::id_proof_types::ProofVersion,
pedersen_commitment::*,
};
use rand::*;
use std::{convert::TryInto, iter::once};
#[derive(Clone, Eq, PartialEq, Serialize, SerdeBase16Serialize, Debug)]
#[allow(non_snake_case)]
pub struct SetMembershipProof<C: Curve> {
A: C,
S: C,
T_1: C,
T_2: C,
tx: C::Scalar,
tx_tilde: C::Scalar,
e_tilde: C::Scalar,
ip_proof: InnerProductProof<C>,
}
#[derive(Debug, PartialEq, Eq)]
pub enum ProverError {
NotEnoughGenerators,
CouldNotFindValueInSet,
InnerProductProofFailure,
DivisionError,
}
#[allow(non_snake_case)]
fn a_L_a_R<F: Field>(v: &F, set_slice: &[F]) -> Option<(Vec<F>, Vec<F>)> {
let n = set_slice.len();
let mut a_L = Vec::with_capacity(n);
let mut a_R = Vec::with_capacity(n);
let mut found_element = false;
for si in set_slice {
let mut bit = F::zero();
if (!found_element) && (v == si) {
bit = F::one();
found_element = true;
}
a_L.push(bit);
bit.sub_assign(&F::one());
a_R.push(bit);
}
if found_element {
Some((a_L, a_R))
} else {
None
}
}
#[allow(non_snake_case, clippy::too_many_arguments)]
pub fn prove<C: Curve, R: Rng>(
version: ProofVersion,
transcript: &mut impl TranscriptProtocol,
csprng: &mut R,
the_set: &[C::Scalar],
v: C::Scalar,
gens: &Generators<C>,
v_keys: &CommitmentKey<C>,
v_rand: &Randomness<C>,
) -> Result<SetMembershipProof<C>, ProverError> {
transcript.append_label(b"SetMembershipProof");
let mut set_vec = the_set.to_vec();
pad_vector_to_power_of_two(&mut set_vec);
let n = set_vec.len();
if gens.G_H.len() < n {
return Err(ProverError::NotEnoughGenerators);
}
let (G, H): (Vec<_>, Vec<_>) = gens.G_H.iter().take(n).cloned().unzip();
if version >= ProofVersion::Version2 {
transcript.append_message(b"G", &G);
transcript.append_message(b"H", &H);
transcript.append_message(b"v_keys", &v_keys);
}
let v_value = Value::<C>::new(v);
let V = v_keys.hide(&v_value, v_rand);
transcript.append_message(b"V", &V.0);
transcript.append_message(b"theSet", &set_vec);
let B = v_keys.g;
let B_tilde = v_keys.h;
let (a_L, a_R) = a_L_a_R(&v, &set_vec).ok_or(ProverError::CouldNotFindValueInSet)?;
let mut s_L = Vec::with_capacity(n);
let mut s_R = Vec::with_capacity(n);
for _ in 0..n {
s_L.push(C::generate_scalar(csprng));
s_R.push(C::generate_scalar(csprng));
}
let a_tilde = C::generate_scalar(csprng); let s_tilde = C::generate_scalar(csprng);
let A_scalars: Vec<C::Scalar> = a_L
.iter()
.chain(a_R.iter())
.copied()
.chain(once(a_tilde))
.collect();
let S_scalars: Vec<C::Scalar> = s_L
.iter()
.chain(s_R.iter())
.copied()
.chain(once(s_tilde))
.collect();
let GH_B_tilde: Vec<C> = G
.iter()
.chain(H.iter())
.copied()
.chain(once(B_tilde))
.collect();
let mexp = C::new_multiexp(&GH_B_tilde);
let A = mexp.multiexp(&A_scalars);
let S = mexp.multiexp(&S_scalars);
transcript.append_message(b"A", &A);
transcript.append_message(b"S", &S);
let y: C::Scalar = transcript.extract_challenge_scalar::<C>(b"y");
let z: C::Scalar = transcript.extract_challenge_scalar::<C>(b"z");
let y_n = z_vec(y, 0, n);
let z_sq = {
let mut z_sq = z;
z_sq.mul_assign(&z);
z_sq
};
let z_cb = {
let mut z_cb = z_sq;
z_cb.mul_assign(&z);
z_cb
};
let mut l_0 = Vec::with_capacity(n);
let mut l_1 = Vec::with_capacity(n);
for i in 0..n {
let mut l_0_i = a_L[i];
l_0_i.sub_assign(&z);
l_0.push(l_0_i);
l_1.push(s_L[i]);
}
let mut r_0 = Vec::with_capacity(n);
let mut r_1 = Vec::with_capacity(n);
for i in 0..n {
let mut r_0_i = a_R[i];
r_0_i.add_assign(&z);
r_0_i.mul_assign(&y_n[i]);
r_0_i.add_assign(&z_cb);
let mut z_cb_si = z_sq;
z_cb_si.mul_assign(&set_vec[i]);
r_0_i.add_assign(&z_cb_si);
r_0.push(r_0_i);
let mut r_1_i = y_n[i];
r_1_i.mul_assign(&s_R[i]);
r_1.push(r_1_i);
}
let t_0 = inner_product(&l_0, &r_0);
let t_2 = inner_product(&l_1, &r_1);
let mut t_1 = C::Scalar::zero();
for i in 0..n {
let mut l_side = l_0[i];
l_side.add_assign(&l_1[i]);
let mut r_side = r_0[i];
r_side.add_assign(&r_1[i]);
let mut prod = l_side;
prod.mul_assign(&r_side);
t_1.add_assign(&prod);
}
t_1.sub_assign(&t_0);
t_1.sub_assign(&t_2);
let t_1_tilde = C::generate_scalar(csprng);
let t_2_tilde = C::generate_scalar(csprng);
let T_1 = B
.mul_by_scalar(&t_1)
.plus_point(&B_tilde.mul_by_scalar(&t_1_tilde));
let T_2 = B
.mul_by_scalar(&t_2)
.plus_point(&B_tilde.mul_by_scalar(&t_2_tilde));
transcript.append_message(b"T1", &T_1);
transcript.append_message(b"T2", &T_2);
let x: C::Scalar = transcript.extract_challenge_scalar::<C>(b"x");
let mut x_sq = x;
x_sq.mul_assign(&x);
let mut lx = Vec::with_capacity(n);
let mut rx = Vec::with_capacity(n);
for i in 0..n {
let mut lx_i = l_1[i];
lx_i.mul_assign(&x);
lx_i.add_assign(&l_0[i]);
lx.push(lx_i);
let mut rx_i = r_1[i];
rx_i.mul_assign(&x);
rx_i.add_assign(&r_0[i]);
rx.push(rx_i);
}
let mut tx = t_0;
let mut tx_1 = t_1;
tx_1.mul_assign(&x);
tx.add_assign(&tx_1);
let mut tx_2 = t_2;
tx_2.mul_assign(&x_sq);
tx.add_assign(&tx_2);
let mut tx_tilde = z_sq;
tx_tilde.mul_assign(v_rand);
let mut tx_s1 = t_1_tilde;
tx_s1.mul_assign(&x);
tx_tilde.add_assign(&tx_s1);
let mut tx_s2 = t_2_tilde;
tx_s2.mul_assign(&x_sq);
tx_tilde.add_assign(&tx_s2);
let mut e_tilde = s_tilde;
e_tilde.mul_assign(&x);
e_tilde.add_assign(&a_tilde);
transcript.append_message(b"tx", &tx);
transcript.append_message(b"tx_tilde", &tx_tilde);
transcript.append_message(b"e_tilde", &e_tilde);
let w: C::Scalar = transcript.extract_challenge_scalar::<C>(b"w");
let Q = B.mul_by_scalar(&w);
let y_inv = match y.inverse() {
Some(inv) => inv,
None => return Err(ProverError::DivisionError),
};
let H_prime_scalars = z_vec(y_inv, 0, n);
let proof =
prove_inner_product_with_scalars(transcript, &G, &H, &H_prime_scalars, &Q, &lx, &rx);
if let Some(ip_proof) = proof {
Ok(SetMembershipProof {
A,
S,
T_1,
T_2,
tx,
tx_tilde,
e_tilde,
ip_proof,
})
} else {
Err(ProverError::InnerProductProofFailure)
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum VerificationError {
SetTooLarge,
NotEnoughGenerators,
InconsistentT0,
DivisionError,
IPVerificationError,
}
#[allow(non_snake_case)]
pub fn verify<C: Curve>(
version: ProofVersion,
transcript: &mut impl TranscriptProtocol,
the_set: &[C::Scalar],
V: &Commitment<C>,
proof: &SetMembershipProof<C>,
gens: &Generators<C>,
v_keys: &CommitmentKey<C>,
) -> Result<(), VerificationError> {
let mut set_vec = the_set.to_vec();
pad_vector_to_power_of_two(&mut set_vec);
let n = set_vec.len();
if gens.G_H.len() < n {
return Err(VerificationError::NotEnoughGenerators);
}
let (G, H): (Vec<_>, Vec<_>) = gens.G_H.iter().take(n).cloned().unzip();
transcript.append_label(b"SetMembershipProof");
if version >= ProofVersion::Version2 {
transcript.append_message(b"G", &G);
transcript.append_message(b"H", &H);
transcript.append_message(b"v_keys", &v_keys);
}
transcript.append_message(b"V", &V.0);
transcript.append_message(b"theSet", &set_vec);
let A = proof.A;
let S = proof.S;
transcript.append_message(b"A", &A);
transcript.append_message(b"S", &S);
let y: C::Scalar = transcript.extract_challenge_scalar::<C>(b"y");
let z: C::Scalar = transcript.extract_challenge_scalar::<C>(b"z");
let T_1 = proof.T_1;
let T_2 = proof.T_2;
transcript.append_message(b"T1", &T_1);
transcript.append_message(b"T2", &T_2);
let x: C::Scalar = transcript.extract_challenge_scalar::<C>(b"x");
let tx = proof.tx;
let tx_tilde = proof.tx_tilde;
let e_tilde = proof.e_tilde;
transcript.append_message(b"tx", &tx);
transcript.append_message(b"tx_tilde", &tx_tilde);
transcript.append_message(b"e_tilde", &e_tilde);
let w: C::Scalar = transcript.extract_challenge_scalar::<C>(b"w");
let mut z2 = z; z2.mul_assign(&z);
let mut z3 = z2; z3.mul_assign(&z);
let n64: u64 = n.try_into().map_err(|_| VerificationError::SetTooLarge)?;
let ns = C::scalar_from_u64(n64);
let mut yi = C::Scalar::one(); let mut ip_1_yn = C::Scalar::zero();
for _ in 0..n {
ip_1_yn.add_assign(&yi);
yi.mul_assign(&y);
}
let mut delta_yz = z; delta_yz.sub_assign(&z2); delta_yz.mul_assign(&ip_1_yn);
let mut ip_1_s = C::Scalar::zero();
for si in &set_vec {
ip_1_s.add_assign(si);
}
let mut zn = ns;
zn.mul_assign(&z);
let mut z3_term = C::Scalar::one();
z3_term.sub_assign(&zn);
z3_term.sub_assign(&ip_1_s);
z3_term.mul_assign(&z3);
delta_yz.add_assign(&z3_term);
let mut delta_minus_tx = delta_yz;
delta_minus_tx.sub_assign(&tx);
let mut x2 = x; x2.mul_assign(&x);
let mut minus_tx_tilde = tx_tilde;
minus_tx_tilde.negate();
let t0_check_base_points = vec![V.0, v_keys.g, T_1, T_2, v_keys.h];
let t0_check_exponents = vec![z2, delta_minus_tx, x, x2, minus_tx_tilde];
let rhs = multiexp(&t0_check_base_points, &t0_check_exponents);
if !rhs.is_zero_point() {
return Err(VerificationError::InconsistentT0);
}
let g_hat = v_keys.g.mul_by_scalar(&w);
let y_inv = match y.inverse() {
Some(inv) => inv,
None => return Err(VerificationError::DivisionError),
};
let y_inv_n = z_vec(y_inv, 0, n);
let mut minus_e_tilde = e_tilde;
minus_e_tilde.negate();
let mut minus_z = z;
minus_z.negate();
let mut minus_z_vec = vec![minus_z; n];
let mut P_prime_exps = Vec::with_capacity(2 * n + 4);
P_prime_exps.append(&mut minus_z_vec);
for i in 0..n {
let mut hexp = z;
let mut z2ynisi = z2;
z2ynisi.mul_assign(&y_inv_n[i]);
z2ynisi.mul_assign(&set_vec[i]);
hexp.add_assign(&z2ynisi);
let mut z3yni = z3;
z3yni.mul_assign(&y_inv_n[i]);
hexp.add_assign(&z3yni);
P_prime_exps.push(hexp);
}
P_prime_exps.push(tx);
P_prime_exps.push(minus_e_tilde);
P_prime_exps.push(C::Scalar::one());
P_prime_exps.push(x);
let mut P_prime_bases = Vec::with_capacity(2 * n + 4);
P_prime_bases.extend(G);
P_prime_bases.extend(H);
P_prime_bases.push(g_hat);
P_prime_bases.push(v_keys.h);
P_prime_bases.push(A);
P_prime_bases.push(S);
let ip_verification = verify_inner_product_with_scalars(
transcript,
&y_inv_n,
&P_prime_bases,
&P_prime_exps,
&proof.ip_proof,
);
if !ip_verification {
return Err(VerificationError::IPVerificationError);
}
Ok(())
}
#[cfg(test)]
mod tests {
use crate::curve_arithmetic::arkworks_instances::ArkGroup;
use super::*;
use crate::random_oracle::RandomOracle;
use ark_bls12_381::G1Projective;
type SomeCurve = ArkGroup<G1Projective>;
fn get_set_vector<C: Curve>(the_set: &[u64]) -> Vec<C::Scalar> {
the_set.iter().copied().map(C::scalar_from_u64).collect()
}
fn generate_helper_values(
n: usize,
) -> (
Generators<SomeCurve>,
CommitmentKey<SomeCurve>,
Randomness<SomeCurve>,
) {
let rng = &mut thread_rng();
let gens = Generators::generate(n, rng);
let b = SomeCurve::generate(rng);
let b_tilde = SomeCurve::generate(rng);
let v_keys = CommitmentKey { g: b, h: b_tilde };
let v_rand = Randomness::generate(rng);
(gens, v_keys, v_rand)
}
fn get_v_com(
v: &<SomeCurve as Curve>::Scalar,
v_keys: &CommitmentKey<SomeCurve>,
v_rand: &Randomness<SomeCurve>,
) -> Commitment<SomeCurve> {
let v_value = Value::<SomeCurve>::new(*v);
v_keys.hide(&v_value, &v_rand)
}
#[test]
fn test_smp_prove_verify() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(3);
let n = the_set.len();
let (gens, v_keys, v_rand) = generate_helper_values(n);
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 1 proof should verify.");
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version2,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version2,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 2 proof should verify.");
}
#[test]
fn test_smp_prove_not_power_of_two() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5, 6]);
let v = SomeCurve::scalar_from_u64(3);
let n = the_set.len();
let k = n.next_power_of_two();
let (gens, v_keys, v_rand) = generate_helper_values(k);
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 1 proof should verify.");
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version2,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version2,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 2 proof should verify.");
}
#[test]
fn test_smp_prove_not_in_set() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(4);
let n = the_set.len();
let (gens, v_keys, v_rand) = generate_helper_values(n);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(matches!(proof, Err(ProverError::CouldNotFindValueInSet)));
}
#[test]
fn test_smp_verify_different_value() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(3);
let n = the_set.len();
let (gens, v_keys, v_rand) = generate_helper_values(n);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let v = SomeCurve::scalar_from_u64(5); let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(matches!(result, Err(VerificationError::InconsistentT0)));
}
#[test]
fn test_smp_verify_different_set() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(3);
let n = the_set.len();
let (gens, v_keys, v_rand) = generate_helper_values(n);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let new_set = get_set_vector::<SomeCurve>(&[2, 7, 3, 5]);
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&new_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(matches!(result, Err(VerificationError::InconsistentT0)));
}
#[test]
fn test_smp_verify_invalid_inner_product() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(3);
let n = the_set.len();
let (gens, v_keys, v_rand) = generate_helper_values(n);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let mut proof = proof.unwrap();
proof.ip_proof.a.negate();
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(matches!(
result,
Err(VerificationError::IPVerificationError)
));
}
#[test]
fn test_smp_prove_many_generators() {
let rng = &mut thread_rng();
let the_set = get_set_vector::<SomeCurve>(&[1, 7, 3, 5]);
let v = SomeCurve::scalar_from_u64(3);
let num_gens = 2112;
let (gens, v_keys, v_rand) = generate_helper_values(num_gens);
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version1,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version1,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 1 proof should verify");
let mut transcript = RandomOracle::empty();
let proof = prove(
ProofVersion::Version2,
&mut transcript,
rng,
&the_set,
v,
&gens,
&v_keys,
&v_rand,
);
assert!(proof.is_ok());
let proof = proof.unwrap();
let v_com = get_v_com(&v, &v_keys, &v_rand);
let mut transcript = RandomOracle::empty();
let result = verify(
ProofVersion::Version2,
&mut transcript,
&the_set,
&v_com,
&proof,
&gens,
&v_keys,
);
assert!(result.is_ok(), "Version 2 proof should verify");
}
}