use ark_ec::{AffineRepr, CurveGroup, Group, VariableBaseMSM};
use ark_ff::{Field, One, PrimeField, Zero};
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{cfg_iter, ops::MulAssign, rand::RngCore, vec, vec::Vec, UniformRand};
use digest::Digest;
use crate::{error::CompSigmaError, transforms::Homomorphism};
use dock_crypto_utils::hashing_utils::field_elem_from_try_and_incr;
use crate::utils::{elements_to_element_products, get_g_multiples_for_verifying_compression};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
pub struct RandomCommitment<G: AffineRepr> {
pub r: Vec<G::ScalarField>,
pub A_hat: G,
pub t: G,
}
#[derive(Clone, Debug, PartialEq, Eq, CanonicalSerialize, CanonicalDeserialize)]
pub struct Response<G: AffineRepr> {
pub z_prime_0: G::ScalarField,
pub z_prime_1: G::ScalarField,
pub A: Vec<G>,
pub B: Vec<G>,
pub a: Vec<G>,
pub b: Vec<G>,
}
impl<G> RandomCommitment<G>
where
G: AffineRepr,
{
pub fn new<R: RngCore, F: Homomorphism<G::ScalarField, Output = G>>(
rng: &mut R,
g: &[G],
homomorphism: &F,
blindings: Option<Vec<G::ScalarField>>,
) -> Result<Self, CompSigmaError> {
if !g.len().is_power_of_two() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
let r = if let Some(blindings) = blindings {
if blindings.len() != g.len() {
return Err(CompSigmaError::VectorLenMismatch);
}
blindings
} else {
(0..g.len()).map(|_| G::ScalarField::rand(rng)).collect()
};
let t = homomorphism.eval(&r).unwrap();
let A_hat = G::Group::msm_unchecked(g, &r);
Ok(Self {
r,
A_hat: A_hat.into_affine(),
t,
})
}
pub fn response<D: Digest, F: Homomorphism<G::ScalarField, Output = G> + Clone>(
&self,
g: &[G],
f: &F,
x: &[G::ScalarField],
challenge: &G::ScalarField,
) -> Result<Response<G>, CompSigmaError> {
if !g.len().is_power_of_two() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
if g.len() != x.len() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
if !f.size().is_power_of_two() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
if f.size() != x.len() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
let z = x
.iter()
.zip(self.r.iter())
.map(|(x_, r)| *x_ * challenge + r)
.collect::<Vec<_>>();
Ok(Self::compressed_response::<D, F>(z, g.to_vec(), f.clone()))
}
pub fn compressed_response<D: Digest, F: Homomorphism<G::ScalarField, Output = G> + Clone>(
mut z: Vec<G::ScalarField>,
mut g: Vec<G>,
mut f: F,
) -> Response<G> {
let mut bytes = vec![];
let mut As = vec![];
let mut Bs = vec![];
let mut as_ = vec![];
let mut bs = vec![];
while z.len() > 2 {
let m = g.len();
let g_r = g.split_off(m / 2);
let z_r = z.split_off(m / 2);
let (f_l, f_r) = f.split_in_half();
let A = G::Group::msm_unchecked(&g_r, &z);
let B = G::Group::msm_unchecked(&g, &z_r);
let a = f_r.eval(&z).unwrap();
let b = f_l.eval(&z_r).unwrap();
A.serialize_compressed(&mut bytes).unwrap();
B.serialize_compressed(&mut bytes).unwrap();
a.serialize_compressed(&mut bytes).unwrap();
b.serialize_compressed(&mut bytes).unwrap();
let c = field_elem_from_try_and_incr::<G::ScalarField, D>(&bytes);
let c_repr = c.into_bigint();
g = g
.iter()
.zip(g_r.iter())
.map(|(l, r)| (l.mul_bigint(c_repr) + r).into_affine())
.collect::<Vec<_>>();
f = f_l.scale(&c).add(&f_r).unwrap();
z = z
.iter()
.zip(z_r.iter())
.map(|(l, r)| *l + *r * c)
.collect::<Vec<_>>();
As.push(A);
Bs.push(B);
as_.push(a);
bs.push(b);
}
Response {
z_prime_0: z[0],
z_prime_1: z[1],
A: G::Group::normalize_batch(&As),
B: G::Group::normalize_batch(&Bs),
a: as_,
b: bs,
}
}
}
impl<G> Response<G>
where
G: AffineRepr,
{
pub fn is_valid_recursive<D: Digest, F: Homomorphism<G::ScalarField, Output = G> + Clone>(
&self,
g: &[G],
P: &G,
y: &G,
f: &F,
A_hat: &G,
t: &G,
challenge: &G::ScalarField,
) -> Result<(), CompSigmaError> {
self.check_sizes(g, f)?;
let (Q, Y) = calculate_Q_and_Y(P, y, A_hat, t, challenge);
self.recursively_validate_compressed::<D, F>(Q, Y, g.to_vec(), f.clone())
}
pub fn is_valid<D: Digest, F: Homomorphism<G::ScalarField, Output = G> + Clone>(
&self,
g: &[G],
P: &G,
y: &G,
f: &F,
A_hat: &G,
t: &G,
challenge: &G::ScalarField,
) -> Result<(), CompSigmaError> {
self.check_sizes(g, f)?;
let (Q, Y) = calculate_Q_and_Y(P, y, A_hat, t, challenge);
self.validate_compressed::<D, F>(Q, Y, g.to_vec(), f.clone())
}
pub fn recursively_validate_compressed<
D: Digest,
F: Homomorphism<G::ScalarField, Output = G> + Clone,
>(
&self,
mut Q: G::Group,
mut Y: G::Group,
mut g: Vec<G>,
mut f: F,
) -> Result<(), CompSigmaError> {
let mut bytes = vec![];
for i in 0..self.A.len() {
let A = &self.A[i];
let B = &self.B[i];
let a = &self.a[i];
let b = &self.b[i];
A.serialize_compressed(&mut bytes).unwrap();
B.serialize_compressed(&mut bytes).unwrap();
a.serialize_compressed(&mut bytes).unwrap();
b.serialize_compressed(&mut bytes).unwrap();
let c = field_elem_from_try_and_incr::<G::ScalarField, D>(&bytes);
let c_repr = c.into_bigint();
let m = g.len();
let g_r = g.split_off(m / 2);
g = g
.iter()
.zip(g_r.iter())
.map(|(l, r)| (l.mul_bigint(c_repr) + r).into_affine())
.collect::<Vec<_>>();
let (f_l, f_r) = f.split_in_half();
f = f_l.scale(&c).add(&f_r).unwrap();
let c_sq = c.square().into_bigint();
Q = A.into_group() + Q.mul_bigint(c_repr) + B.mul_bigint(c_sq);
Y = a.into_group() + Y.mul_bigint(c_repr) + b.mul_bigint(c_sq);
}
if (g.len() != 2) || (f.size() != 2) {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
if G::Group::msm_unchecked(&g, &[self.z_prime_0, self.z_prime_1]) != Q {
return Err(CompSigmaError::InvalidResponse);
}
let f_prime_z_prime = f
.eval(&[self.z_prime_0, self.z_prime_1])
.unwrap()
.into_group();
if Y != f_prime_z_prime {
return Err(CompSigmaError::InvalidResponse);
}
Ok(())
}
pub fn validate_compressed<D: Digest, F: Homomorphism<G::ScalarField, Output = G> + Clone>(
&self,
mut Q: G::Group,
mut Y: G::Group,
g: Vec<G>,
f: F,
) -> Result<(), CompSigmaError> {
let mut challenges = vec![];
let mut challenge_squares = vec![];
let mut bytes = vec![];
for i in 0..self.A.len() {
let A = &self.A[i];
let B = &self.B[i];
let a = &self.a[i];
let b = &self.b[i];
A.serialize_compressed(&mut bytes).unwrap();
B.serialize_compressed(&mut bytes).unwrap();
a.serialize_compressed(&mut bytes).unwrap();
b.serialize_compressed(&mut bytes).unwrap();
let c = field_elem_from_try_and_incr::<G::ScalarField, D>(&bytes);
challenge_squares.push(c.square());
challenges.push(c);
}
let g_len = g.len();
let g_multiples = get_g_multiples_for_verifying_compression(
g_len,
&challenges,
&self.z_prime_0,
&self.z_prime_1,
);
let mut challenge_products = elements_to_element_products(challenges);
let all_challenges_product = challenge_products.remove(0);
let B_multiples = cfg_iter!(challenge_products)
.zip(cfg_iter!(challenge_squares))
.map(|(c, c_sqr)| (*c * c_sqr).into_bigint())
.collect::<Vec<_>>();
let challenges_repr = cfg_iter!(challenge_products)
.map(|c| c.into_bigint())
.collect::<Vec<_>>();
Q.mul_assign(all_challenges_product);
let Q_prime = G::Group::msm_bigint(&self.A, &challenges_repr)
+ G::Group::msm_bigint(&self.B, &B_multiples)
+ Q;
if G::Group::msm_unchecked(&g, &g_multiples) != Q_prime {
return Err(CompSigmaError::InvalidResponse);
}
Y.mul_assign(all_challenges_product);
let Y_prime = G::Group::msm_bigint(&self.a, &challenges_repr)
+ G::Group::msm_bigint(&self.b, &B_multiples)
+ Y;
let f_prime_z_prime = f.eval(&g_multiples).unwrap().into_group();
if Y_prime != f_prime_z_prime {
return Err(CompSigmaError::InvalidResponse);
}
Ok(())
}
fn check_sizes<F: Homomorphism<G::ScalarField, Output = G> + Clone>(
&self,
g: &[G],
f: &F,
) -> Result<(), CompSigmaError> {
if !g.len().is_power_of_two() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
if self.A.len() != self.B.len() {
return Err(CompSigmaError::VectorLenMismatch);
}
if self.a.len() != self.b.len() {
return Err(CompSigmaError::VectorLenMismatch);
}
if self.A.len() != self.a.len() {
return Err(CompSigmaError::VectorLenMismatch);
}
if g.len() != 1 << (self.A.len() + 1) {
return Err(CompSigmaError::WrongRecursionLevel);
}
if !f.size().is_power_of_two() {
return Err(CompSigmaError::UncompressedNotPowerOf2);
}
Ok(())
}
}
pub fn calculate_Q_and_Y<G: AffineRepr>(
P: &G,
Y: &G,
A: &G,
t: &G,
challenge: &G::ScalarField,
) -> (G::Group, G::Group) {
let challenge_repr = challenge.into_bigint();
(
P.mul_bigint(challenge_repr) + A,
Y.mul_bigint(challenge_repr) + t,
)
}
#[cfg(test)]
mod tests {
use super::*;
use ark_bls12_381::Bls12_381;
use ark_ec::pairing::Pairing;
use ark_std::{
rand::{rngs::StdRng, SeedableRng},
UniformRand,
};
use blake2::Blake2b512;
use std::time::Instant;
type Fr = <Bls12_381 as Pairing>::ScalarField;
type G1 = <Bls12_381 as Pairing>::G1Affine;
#[derive(Clone)]
struct TestHom<G: AffineRepr> {
pub constants: Vec<G>,
}
impl_simple_homomorphism!(TestHom, Fr, G1);
#[test]
fn compression() {
fn check_compression(size: usize) {
let mut rng = StdRng::seed_from_u64(0u64);
let mut homomorphism = TestHom {
constants: (0..size)
.map(|_| <Bls12_381 as Pairing>::G1::rand(&mut rng).into_affine())
.collect::<Vec<_>>(),
};
let mut x = (0..size).map(|_| Fr::rand(&mut rng)).collect::<Vec<_>>();
let mut g = (0..size)
.map(|_| <Bls12_381 as Pairing>::G1::rand(&mut rng).into_affine())
.collect::<Vec<_>>();
if !size.is_power_of_two() {
let new_size = size.next_power_of_two();
let pod_size = new_size - size;
homomorphism = homomorphism.pad(new_size);
for _ in 0..pod_size {
x.push(Fr::zero());
g.push(<Bls12_381 as Pairing>::G1::rand(&mut rng).into_affine());
}
}
let P = <Bls12_381 as Pairing>::G1::msm_unchecked(&g, &x).into_affine();
let y = homomorphism.eval(&x).unwrap();
let rand_comm = RandomCommitment::new(&mut rng, &g, &homomorphism, None).unwrap();
let challenge = Fr::rand(&mut rng);
let response = rand_comm
.response::<Blake2b512, _>(&g, &homomorphism, &x, &challenge)
.unwrap();
let start = Instant::now();
response
.is_valid_recursive::<Blake2b512, _>(
&g,
&P,
&y,
&homomorphism,
&rand_comm.A_hat,
&rand_comm.t,
&challenge,
)
.unwrap();
println!(
"Recursive verification for compressed homomorphism form of size {} takes: {:?}",
size,
start.elapsed()
);
let start = Instant::now();
response
.is_valid::<Blake2b512, _>(
&g,
&P,
&y,
&homomorphism,
&rand_comm.A_hat,
&rand_comm.t,
&challenge,
)
.unwrap();
println!(
"Verification for compressed homomorphism form of size {} takes: {:?}",
size,
start.elapsed()
);
}
check_compression(4);
check_compression(5);
check_compression(6);
check_compression(7);
check_compression(8);
check_compression(9);
check_compression(11);
check_compression(15);
check_compression(16);
check_compression(17);
check_compression(18);
check_compression(20);
check_compression(25);
check_compression(31);
check_compression(32);
check_compression(48);
check_compression(63);
check_compression(64);
}
}