use super::*;
use crate::compute::get_curve25519_generators;
use core::{mem, slice};
use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar};
use merlin::Transcript;
use rand_core::SeedableRng;
fn as_byte_slice<T>(point: &T) -> &[u8] {
let len = mem::size_of::<T>();
unsafe { slice::from_raw_parts(point as *const T as *const u8, len) }
}
fn test_prove_and_verify_with_given_n_and_generators_offset(n: u64, generators_offset: u64) {
assert!(n > 0);
let mut rng = rand::rngs::StdRng::seed_from_u64(n);
let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let g = {
let mut temp_g = vec![RistrettoPoint::default(); n as usize];
get_curve25519_generators(&mut temp_g, generators_offset);
temp_g
};
let mut transcript = Transcript::new(b"innerproducttest");
let proof = InnerProductProof::create(&mut transcript, &a, &b, generators_offset);
let product = a.iter().zip(&b).map(|(a_i, b_i)| a_i * b_i).sum::<Scalar>();
let a_commit = a
.iter()
.zip(&g)
.map(|(a_i, g_i)| a_i * g_i)
.sum::<RistrettoPoint>();
{
let mut transcript = Transcript::new(b"innerproducttest");
assert!(proof
.verify(&mut transcript, &a_commit, &product, &b, generators_offset)
.is_ok());
}
{
if n > 1 {
let mut transcript = Transcript::new(b"invalid");
assert!(proof
.verify(&mut transcript, &a_commit, &product, &b, generators_offset)
.is_err());
}
}
{
let mut transcript = Transcript::new(b"innerproducttest");
let a_commit_p = Scalar::from(123_u64) * g[0];
assert!(proof
.verify(
&mut transcript,
&a_commit_p,
&product,
&b,
generators_offset
)
.is_err());
}
{
let mut transcript = Transcript::new(b"innerproducttest");
let product_p = product + Scalar::from(123_u64);
assert!(proof
.verify(
&mut transcript,
&a_commit,
&product_p,
&b,
generators_offset
)
.is_err());
}
{
let mut transcript = Transcript::new(b"innerproducttest");
assert!(proof
.verify(&mut transcript, &a_commit, &product, &a, generators_offset)
.is_err());
}
{
let mut transcript = Transcript::new(b"innerproducttest");
assert!(proof
.verify(&mut transcript, &a_commit, &product, &b, generators_offset)
.is_ok());
let mut expected_transcript = Transcript::new(b"innerproducttest");
expected_transcript.append_message(b"domain-sep", b"inner product proof v1");
expected_transcript.append_u64(b"n", n);
let num_rounds = n.next_power_of_two().trailing_zeros() as usize;
for i in 0..num_rounds {
expected_transcript.append_message(b"L", as_byte_slice(&proof.l_vector[i]));
expected_transcript.append_message(b"R", as_byte_slice(&proof.r_vector[i]));
let mut buf = [0u8; 32];
expected_transcript.challenge_bytes(b"x", &mut buf);
}
for _i in 0..16 {
let mut buf = [0u8; 128];
let mut expected_buf = [0u8; 128];
transcript.challenge_bytes(b"test", &mut buf);
expected_transcript.challenge_bytes(b"test", &mut expected_buf);
assert_eq!(buf, expected_buf);
}
transcript.append_message(b"tampering with transcript", b"should fail");
let mut buf = [0u8; 128];
let mut expected_buf = [0u8; 128];
transcript.challenge_bytes(b"test", &mut buf);
expected_transcript.challenge_bytes(b"test", &mut expected_buf);
assert_ne!(buf, expected_buf);
}
{
if n > 1 {
let mut transcript = Transcript::new(b"innerproducttest");
let mut tampered_proof = proof;
tampered_proof.l_vector = Vec::new();
assert!(tampered_proof
.verify(&mut transcript, &a_commit, &product, &b, generators_offset)
.is_err());
}
}
}
#[test]
fn test_prove_and_verify_with_a_single_element() {
test_prove_and_verify_with_given_n_and_generators_offset(1, 0);
test_prove_and_verify_with_given_n_and_generators_offset(1, 1);
}
#[test]
fn test_prove_and_verify_with_two_elements() {
test_prove_and_verify_with_given_n_and_generators_offset(2, 0);
test_prove_and_verify_with_given_n_and_generators_offset(2, 3);
}
#[test]
fn test_prove_and_verify_random_proofs_of_varying_size() {
for i in 3_u64..16_u64 {
test_prove_and_verify_with_given_n_and_generators_offset(i, 0);
test_prove_and_verify_with_given_n_and_generators_offset(i, i);
}
}