use alloc::{vec, vec::Vec};
use core::{iter::once, slice, slice::ChunksExact};
use curve25519_dalek::{
ristretto::CompressedRistretto,
traits::{Identity, MultiscalarMul, VartimeMultiscalarMul},
RistrettoPoint,
Scalar,
};
use itertools::{izip, Itertools};
use rand_core::CryptoRngCore;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use snafu::prelude::*;
use subtle::{ConditionallySelectable, ConstantTimeEq};
use zeroize::Zeroizing;
use crate::{
gray::GrayIterator,
transcript::ProofTranscript,
util::{delta, NullRng, OperationTiming},
Transcript,
TriptychStatement,
TriptychWitness,
};
const SERIALIZED_BYTES: usize = 32;
#[allow(non_snake_case)]
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TriptychProof {
A: RistrettoPoint,
B: RistrettoPoint,
C: RistrettoPoint,
D: RistrettoPoint,
X: Vec<RistrettoPoint>,
Y: Vec<RistrettoPoint>,
f: Vec<Vec<Scalar>>,
z_A: Scalar,
z_C: Scalar,
z: Scalar,
}
#[derive(Debug, Snafu)]
pub enum ProofError {
#[snafu(display("An invalid parameter was provided"))]
InvalidParameter,
#[snafu(display("A transcript challenge was invalid"))]
InvalidChallenge,
#[snafu(display("Proof deserialization failed"))]
FailedDeserialization,
#[snafu[display("Single proof verification failed")]]
FailedVerification,
#[snafu[display("Batch proof verification failed")]]
FailedBatchVerification,
#[snafu[display("Batch proof verification failed")]]
FailedBatchVerificationWithSingleBlame {
index: Option<usize>,
},
#[snafu[display("Batch proof verification failed")]]
FailedBatchVerificationWithFullBlame {
indexes: Vec<usize>,
},
}
impl TriptychProof {
#[cfg(feature = "rand")]
pub fn prove_vartime(
witness: &TriptychWitness,
statement: &TriptychStatement,
transcript: &mut Transcript,
) -> Result<Self, ProofError> {
use rand_core::OsRng;
Self::prove_internal(witness, statement, &mut OsRng, transcript, OperationTiming::Variable)
}
pub fn prove_with_rng_vartime<R: CryptoRngCore>(
witness: &TriptychWitness,
statement: &TriptychStatement,
rng: &mut R,
transcript: &mut Transcript,
) -> Result<Self, ProofError> {
Self::prove_internal(witness, statement, rng, transcript, OperationTiming::Variable)
}
#[cfg(feature = "rand")]
pub fn prove(
witness: &TriptychWitness,
statement: &TriptychStatement,
transcript: &mut Transcript,
) -> Result<Self, ProofError> {
use rand_core::OsRng;
Self::prove_internal(witness, statement, &mut OsRng, transcript, OperationTiming::Constant)
}
pub fn prove_with_rng<R: CryptoRngCore>(
witness: &TriptychWitness,
statement: &TriptychStatement,
rng: &mut R,
transcript: &mut Transcript,
) -> Result<Self, ProofError> {
Self::prove_internal(witness, statement, rng, transcript, OperationTiming::Constant)
}
#[allow(clippy::too_many_lines, non_snake_case)]
fn prove_internal<R: CryptoRngCore>(
witness: &TriptychWitness,
statement: &TriptychStatement,
rng: &mut R,
transcript: &mut Transcript,
timing: OperationTiming,
) -> Result<Self, ProofError> {
if witness.get_params() != statement.get_params() {
return Err(ProofError::InvalidParameter);
}
let r = witness.get_r();
let l = witness.get_l();
let M = statement.get_input_set().get_keys();
let params = statement.get_params();
let J = statement.get_J();
let mut M_l = RistrettoPoint::identity();
match timing {
OperationTiming::Constant => {
for (index, item) in M.iter().enumerate() {
M_l.conditional_assign(item, index.ct_eq(&(l as usize)));
}
},
OperationTiming::Variable => {
M_l = M[l as usize];
},
}
if M_l != r * params.get_G() {
return Err(ProofError::InvalidParameter);
}
if &(r * J) != params.get_U() {
return Err(ProofError::InvalidParameter);
}
let mut transcript = ProofTranscript::new(transcript, statement, rng, Some(witness));
let r_A = Scalar::random(transcript.as_mut_rng());
let mut a = (0..params.get_m())
.map(|_| {
(0..params.get_n())
.map(|_| Scalar::random(transcript.as_mut_rng()))
.collect::<Vec<Scalar>>()
})
.collect::<Vec<Vec<Scalar>>>();
for j in (0..params.get_m()).map(|j| j as usize) {
a[j][0] = -a[j][1..].iter().sum::<Scalar>();
}
let A = params
.commit_matrix(&a, &r_A, timing)
.map_err(|_| ProofError::InvalidParameter)?;
let r_B = Scalar::random(transcript.as_mut_rng());
let l_decomposed = match timing {
OperationTiming::Constant => {
GrayIterator::decompose(params.get_n(), params.get_m(), l).ok_or(ProofError::InvalidParameter)?
},
OperationTiming::Variable => GrayIterator::decompose_vartime(params.get_n(), params.get_m(), l)
.ok_or(ProofError::InvalidParameter)?,
};
let sigma = (0..params.get_m())
.map(|j| {
(0..params.get_n())
.map(|i| delta(l_decomposed[j as usize], i, timing))
.collect::<Vec<Scalar>>()
})
.collect::<Vec<Vec<Scalar>>>();
let B = params
.commit_matrix(&sigma, &r_B, timing)
.map_err(|_| ProofError::InvalidParameter)?;
let two = Scalar::from(2u32);
let r_C = Scalar::random(transcript.as_mut_rng());
let a_sigma = (0..params.get_m())
.map(|j| {
(0..params.get_n())
.map(|i| a[j as usize][i as usize] * (Scalar::ONE - two * sigma[j as usize][i as usize]))
.collect::<Vec<Scalar>>()
})
.collect::<Vec<Vec<Scalar>>>();
let C = params
.commit_matrix(&a_sigma, &r_C, timing)
.map_err(|_| ProofError::InvalidParameter)?;
let r_D = Scalar::random(transcript.as_mut_rng());
let a_square = (0..params.get_m())
.map(|j| {
(0..params.get_n())
.map(|i| -a[j as usize][i as usize] * a[j as usize][i as usize])
.collect::<Vec<Scalar>>()
})
.collect::<Vec<Vec<Scalar>>>();
let D = params
.commit_matrix(&a_square, &r_D, timing)
.map_err(|_| ProofError::InvalidParameter)?;
let rho = Zeroizing::new(
(0..params.get_m())
.map(|_| Scalar::random(transcript.as_mut_rng()))
.collect::<Vec<Scalar>>(),
);
let mut p = Vec::<Vec<Scalar>>::with_capacity(params.get_N() as usize);
let mut k_decomposed = vec![0; params.get_m() as usize];
for (gray_index, _, gray_new) in
GrayIterator::new(params.get_n(), params.get_m()).ok_or(ProofError::InvalidParameter)?
{
k_decomposed[gray_index] = gray_new;
let mut coefficients = Vec::new();
coefficients.resize(
(params.get_m() as usize)
.checked_add(1)
.ok_or(ProofError::InvalidParameter)?,
Scalar::ZERO,
);
coefficients[0] = a[0][k_decomposed[0] as usize];
coefficients[1] = sigma[0][k_decomposed[0] as usize];
for j in 1..params.get_m() {
let degree_0_portion = coefficients
.iter()
.map(|c| a[j as usize][k_decomposed[j as usize] as usize] * c)
.collect::<Vec<Scalar>>();
let mut shifted_coefficients = coefficients.clone();
shifted_coefficients.rotate_right(1);
let degree_1_portion = shifted_coefficients
.iter()
.map(|c| sigma[j as usize][k_decomposed[j as usize] as usize] * c)
.collect::<Vec<Scalar>>();
coefficients = degree_0_portion
.iter()
.zip(degree_1_portion.iter())
.map(|(x, y)| x + y)
.collect::<Vec<Scalar>>();
}
p.push(coefficients);
}
let X = rho
.iter()
.enumerate()
.map(|(j, rho)| {
let X_points = M.iter().chain(once(params.get_G()));
let X_scalars = p.iter().map(|p| &p[j]).chain(once(rho));
match timing {
OperationTiming::Constant => RistrettoPoint::multiscalar_mul(X_scalars, X_points),
OperationTiming::Variable => RistrettoPoint::vartime_multiscalar_mul(X_scalars, X_points),
}
})
.collect::<Vec<RistrettoPoint>>();
let Y = rho.iter().map(|rho| rho * J).collect::<Vec<RistrettoPoint>>();
let xi_powers = transcript.commit(params, &A, &B, &C, &D, &X, &Y)?;
let f = (0..params.get_m())
.map(|j| {
(1..params.get_n())
.map(|i| sigma[j as usize][i as usize] * xi_powers[1] + a[j as usize][i as usize])
.collect::<Vec<Scalar>>()
})
.collect::<Vec<Vec<Scalar>>>();
let z_A = r_A + xi_powers[1] * r_B;
let z_C = xi_powers[1] * r_C + r_D;
let z = r * xi_powers[params.get_m() as usize] -
rho.iter()
.zip(xi_powers.iter())
.map(|(rho, xi_power)| rho * xi_power)
.sum::<Scalar>();
Ok(Self {
A,
B,
C,
D,
X,
Y,
f,
z_A,
z_C,
z,
})
}
pub fn verify(&self, statement: &TriptychStatement, transcript: &mut Transcript) -> Result<(), ProofError> {
Self::verify_batch(
slice::from_ref(statement),
slice::from_ref(self),
slice::from_mut(transcript),
)
}
pub fn verify_batch_with_single_blame(
statements: &[TriptychStatement],
proofs: &[TriptychProof],
transcripts: &mut [Transcript],
) -> Result<(), ProofError> {
if Self::verify_batch(statements, proofs, &mut transcripts.to_vec()).is_ok() {
return Ok(());
}
let mut left = 0;
let mut right = proofs.len();
while left < right {
#[allow(clippy::arithmetic_side_effects)]
let average = left
.checked_add(
(right - left) / 2,
)
.ok_or(ProofError::FailedBatchVerificationWithSingleBlame { index: None })?;
#[allow(clippy::arithmetic_side_effects)]
let mid = if (right - left) % 2 == 0 {
average
} else {
average
.checked_add(1)
.ok_or(ProofError::FailedBatchVerificationWithSingleBlame { index: None })?
};
let failure_on_left = Self::verify_batch(
&statements[left..mid],
&proofs[left..mid],
&mut transcripts.to_vec()[left..mid],
)
.is_err();
if failure_on_left {
let left_check = mid
.checked_sub(1)
.ok_or(ProofError::FailedBatchVerificationWithSingleBlame { index: None })?;
if left == left_check {
return Err(ProofError::FailedBatchVerificationWithSingleBlame { index: Some(left) });
}
right = mid;
} else {
let right_check = mid
.checked_add(1)
.ok_or(ProofError::FailedBatchVerificationWithSingleBlame { index: None })?;
if right == right_check {
let right_result = right
.checked_sub(1)
.ok_or(ProofError::FailedBatchVerificationWithSingleBlame { index: None })?;
return Err(ProofError::FailedBatchVerificationWithSingleBlame {
index: Some(right_result),
});
}
left = mid
}
}
Err(ProofError::FailedBatchVerificationWithSingleBlame { index: None })
}
pub fn verify_batch_with_full_blame(
statements: &[TriptychStatement],
proofs: &[TriptychProof],
transcripts: &mut [Transcript],
) -> Result<(), ProofError> {
if Self::verify_batch(statements, proofs, &mut transcripts.to_vec()).is_ok() {
return Ok(());
}
let mut failures = Vec::with_capacity(proofs.len());
for (index, (statement, proof, transcript)) in izip!(statements, proofs, transcripts.iter_mut()).enumerate() {
if proof.verify(statement, transcript).is_err() {
failures.push(index);
}
}
Err(ProofError::FailedBatchVerificationWithFullBlame { indexes: failures })
}
#[allow(clippy::too_many_lines, non_snake_case)]
pub fn verify_batch(
statements: &[TriptychStatement],
proofs: &[TriptychProof],
transcripts: &mut [Transcript],
) -> Result<(), ProofError> {
if statements.len() != proofs.len() {
return Err(ProofError::InvalidParameter);
}
if statements.len() != transcripts.len() {
return Err(ProofError::InvalidParameter);
}
let first_statement = match statements.first() {
Some(statement) => statement,
None => return Ok(()),
};
if !statements.iter().map(|s| s.get_input_set().get_hash()).all_equal() {
return Err(ProofError::InvalidParameter);
}
if !statements.iter().map(|s| s.get_params().get_hash()).all_equal() {
return Err(ProofError::InvalidParameter);
}
let M = first_statement.get_input_set().get_keys();
let params = first_statement.get_params();
for proof in proofs {
if proof.X.len() != params.get_m() as usize {
return Err(ProofError::InvalidParameter);
}
if proof.Y.len() != params.get_m() as usize {
return Err(ProofError::InvalidParameter);
}
if proof.f.len() != params.get_m() as usize {
return Err(ProofError::InvalidParameter);
}
for f_row in &proof.f {
if f_row.len() != params.get_n().checked_sub(1).ok_or(ProofError::InvalidParameter)? as usize {
return Err(ProofError::InvalidParameter);
}
}
}
let batch_size = u32::try_from(proofs.len()).map_err(|_| ProofError::InvalidParameter)?;
#[allow(clippy::arithmetic_side_effects)]
let final_size = usize::try_from(
1 + params.get_n() * params.get_m() + 1 + params.get_N() + 1 + batch_size * (
4 + 1 + 2 * params.get_m() ),
)
.map_err(|_| ProofError::InvalidParameter)?;
let points = proofs
.iter()
.zip(statements.iter())
.flat_map(|(p, s)| {
once(&p.A)
.chain(once(&p.B))
.chain(once(&p.C))
.chain(once(&p.D))
.chain(once(s.get_J()))
.chain(p.X.iter())
.chain(p.Y.iter())
})
.chain(once(params.get_G()))
.chain(params.get_CommitmentG().iter())
.chain(once(params.get_CommitmentH()))
.chain(M.iter())
.chain(once(params.get_U()))
.collect::<Vec<&RistrettoPoint>>();
let mut scalars = Vec::with_capacity(final_size);
let mut G_scalar = Scalar::ZERO;
let mut CommitmentG_scalars = vec![Scalar::ZERO; params.get_CommitmentG().len()];
let mut CommitmentH_scalar = Scalar::ZERO;
let mut M_scalars = vec![Scalar::ZERO; M.len()];
let mut U_scalar = Scalar::ZERO;
let mut transcript_weights = Transcript::new(b"Triptych verifier weights");
let mut null_rng = NullRng;
let mut xi_powers_all = Vec::with_capacity(proofs.len());
for (statement, proof, transcript) in izip!(statements.iter(), proofs.iter(), transcripts.iter_mut()) {
let mut transcript = ProofTranscript::new(transcript, statement, &mut null_rng, None);
xi_powers_all.push(transcript.commit(params, &proof.A, &proof.B, &proof.C, &proof.D, &proof.X, &proof.Y)?);
let mut transcript_rng = transcript.response(&proof.f, &proof.z_A, &proof.z_C, &proof.z);
transcript_weights.append_u64(b"proof", transcript_rng.as_rngcore().next_u64());
}
let mut transcript_weights_rng = transcript_weights.build_rng().finalize(&mut null_rng);
for (proof, xi_powers) in proofs.iter().zip(xi_powers_all.iter()) {
let f = (0..params.get_m())
.map(|j| {
let mut f_j = Vec::with_capacity(params.get_n() as usize);
f_j.push(xi_powers[1] - proof.f[j as usize].iter().sum::<Scalar>());
f_j.extend(proof.f[j as usize].iter());
f_j
})
.collect::<Vec<Vec<Scalar>>>();
for f_row in &f {
if f_row.contains(&Scalar::ZERO) {
return Err(ProofError::InvalidParameter);
}
}
let mut w1 = Scalar::ZERO;
let mut w2 = Scalar::ZERO;
let mut w3 = Scalar::ZERO;
let mut w4 = Scalar::ZERO;
while w1 == Scalar::ZERO || w2 == Scalar::ZERO || w3 == Scalar::ZERO || w4 == Scalar::ZERO {
w1 = Scalar::random(&mut transcript_weights_rng);
w2 = Scalar::random(&mut transcript_weights_rng);
w3 = Scalar::random(&mut transcript_weights_rng);
w4 = Scalar::random(&mut transcript_weights_rng);
}
let xi = xi_powers[1];
G_scalar -= w3 * proof.z;
for (CommitmentG_scalar, f_item) in CommitmentG_scalars
.iter_mut()
.zip(f.iter().flatten().map(|f| w1 * f + w2 * f * (xi - f)))
{
*CommitmentG_scalar += f_item;
}
CommitmentH_scalar += w1 * proof.z_A + w2 * proof.z_C;
scalars.push(-w1);
scalars.push(-w1 * xi_powers[1]);
scalars.push(-w2 * xi_powers[1]);
scalars.push(-w2);
scalars.push(-w4 * proof.z);
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-w3 * xi_power);
}
for xi_power in &xi_powers[0..(params.get_m() as usize)] {
scalars.push(-w4 * xi_power);
}
let mut f_product = f.iter().map(|f_row| f_row[0]).product::<Scalar>();
let gray_iterator =
GrayIterator::new(params.get_n(), params.get_m()).ok_or(ProofError::InvalidParameter)?;
let mut f_inverse_flat = f.iter().flatten().copied().collect::<Vec<Scalar>>();
Scalar::batch_invert(&mut f_inverse_flat);
let f_inverse = f_inverse_flat
.chunks_exact(params.get_n() as usize)
.collect::<Vec<&[Scalar]>>();
let mut U_scalar_proof = Scalar::ZERO;
for (M_scalar, (gray_index, gray_old, gray_new)) in M_scalars.iter_mut().zip(gray_iterator) {
f_product *= f_inverse[gray_index][gray_old as usize] * f[gray_index][gray_new as usize];
*M_scalar += w3 * f_product;
U_scalar_proof += f_product;
}
U_scalar += w4 * U_scalar_proof;
}
scalars.push(G_scalar);
scalars.extend(CommitmentG_scalars);
scalars.push(CommitmentH_scalar);
scalars.extend(M_scalars);
scalars.push(U_scalar);
if RistrettoPoint::vartime_multiscalar_mul(scalars.iter(), points) == RistrettoPoint::identity() {
Ok(())
} else {
Err(ProofError::FailedVerification)
}
}
#[allow(non_snake_case)]
pub fn to_bytes(&self) -> Vec<u8> {
#[allow(clippy::arithmetic_side_effects)]
let mut result = Vec::with_capacity(
8 + SERIALIZED_BYTES * (
4 + self.X.len()
+ self.Y.len()
+ 3 + self.f.len() * self.f[0].len()
),
);
#[allow(clippy::cast_possible_truncation)]
let n_minus_1 = self.f[0].len() as u32;
#[allow(clippy::cast_possible_truncation)]
let m = self.f.len() as u32;
result.extend(n_minus_1.to_le_bytes());
result.extend(m.to_le_bytes());
result.extend_from_slice(self.A.compress().as_bytes());
result.extend_from_slice(self.B.compress().as_bytes());
result.extend_from_slice(self.C.compress().as_bytes());
result.extend_from_slice(self.D.compress().as_bytes());
result.extend_from_slice(self.z_A.as_bytes());
result.extend_from_slice(self.z_C.as_bytes());
result.extend_from_slice(self.z.as_bytes());
for X in &self.X {
result.extend_from_slice(X.compress().as_bytes());
}
for Y in &self.Y {
result.extend_from_slice(Y.compress().as_bytes());
}
for f_row in &self.f {
for f in f_row {
result.extend_from_slice(f.as_bytes());
}
}
result
}
#[allow(non_snake_case)]
pub fn from_bytes(bytes: &[u8]) -> Result<Self, ProofError> {
let parse_u32 = |iter: &mut dyn Iterator<Item = &u8>| {
let bytes = iter.take(4).copied().collect::<Vec<u8>>();
if bytes.len() != 4 {
return Err(ProofError::FailedDeserialization);
}
let array: [u8; 4] = bytes.try_into().map_err(|_| ProofError::FailedDeserialization)?;
Ok(u32::from_le_bytes(array))
};
let parse_scalar = |chunks: &mut ChunksExact<'_, u8>| -> Result<Scalar, ProofError> {
chunks
.next()
.ok_or(ProofError::FailedDeserialization)
.and_then(|slice| {
let bytes: [u8; SERIALIZED_BYTES] =
slice.try_into().map_err(|_| ProofError::FailedDeserialization)?;
Option::<Scalar>::from(Scalar::from_canonical_bytes(bytes)).ok_or(ProofError::FailedDeserialization)
})
};
let parse_point = |chunks: &mut ChunksExact<'_, u8>| -> Result<RistrettoPoint, ProofError> {
chunks
.next()
.ok_or(ProofError::FailedDeserialization)
.and_then(|slice| {
let bytes: [u8; SERIALIZED_BYTES] =
slice.try_into().map_err(|_| ProofError::FailedDeserialization)?;
CompressedRistretto::from_slice(&bytes)
.map_err(|_| ProofError::FailedDeserialization)?
.decompress()
.ok_or(ProofError::FailedDeserialization)
})
};
let mut iter = bytes.iter();
let n_minus_1 = parse_u32(&mut iter)?;
if n_minus_1.checked_add(1).ok_or(ProofError::FailedDeserialization)? < 2 {
return Err(ProofError::FailedDeserialization);
}
let m = parse_u32(&mut iter)?;
if m < 2 {
return Err(ProofError::FailedDeserialization);
}
let mut chunks = iter.as_slice().chunks_exact(SERIALIZED_BYTES);
let A = parse_point(&mut chunks)?;
let B = parse_point(&mut chunks)?;
let C = parse_point(&mut chunks)?;
let D = parse_point(&mut chunks)?;
let z_A = parse_scalar(&mut chunks)?;
let z_C = parse_scalar(&mut chunks)?;
let z = parse_scalar(&mut chunks)?;
let X = (0..m)
.map(|_| parse_point(&mut chunks))
.collect::<Result<Vec<RistrettoPoint>, ProofError>>()?;
let Y = (0..m)
.map(|_| parse_point(&mut chunks))
.collect::<Result<Vec<RistrettoPoint>, ProofError>>()?;
let f = (0..m)
.map(|_| {
(0..n_minus_1)
.map(|_| parse_scalar(&mut chunks))
.collect::<Result<Vec<Scalar>, ProofError>>()
})
.collect::<Result<Vec<Vec<Scalar>>, ProofError>>()?;
if !chunks.remainder().is_empty() {
return Err(ProofError::FailedDeserialization);
}
if chunks.next().is_some() {
return Err(ProofError::FailedDeserialization);
}
if X.len() != m as usize || Y.len() != m as usize {
return Err(ProofError::FailedDeserialization);
}
if f.len() != m as usize {
return Err(ProofError::FailedDeserialization);
}
for f_row in &f {
if f_row.len() != n_minus_1 as usize {
return Err(ProofError::FailedDeserialization);
}
}
Ok(TriptychProof {
A,
B,
C,
D,
X,
Y,
f,
z_A,
z_C,
z,
})
}
}
#[cfg(test)]
mod test {
use alloc::{sync::Arc, vec::Vec};
use curve25519_dalek::{traits::Identity, RistrettoPoint, Scalar};
use itertools::izip;
use rand_chacha::ChaCha12Rng;
use rand_core::{CryptoRngCore, SeedableRng};
use crate::{
proof::{ProofError, SERIALIZED_BYTES},
Transcript,
TriptychInputSet,
TriptychParameters,
TriptychProof,
TriptychStatement,
TriptychWitness,
};
#[test]
fn test_serialized_bytes() {
assert_eq!(Scalar::ZERO.as_bytes().len(), SERIALIZED_BYTES);
assert_eq!(RistrettoPoint::identity().compress().as_bytes().len(), SERIALIZED_BYTES);
}
#[allow(non_snake_case)]
#[allow(clippy::arithmetic_side_effects)]
fn generate_data<R: CryptoRngCore>(
n: u32,
m: u32,
b: usize,
rng: &mut R,
) -> (Vec<TriptychWitness>, Vec<TriptychStatement>, Vec<Transcript>) {
let params = Arc::new(TriptychParameters::new(n, m).unwrap());
assert!(b <= params.get_N() as usize);
let mut witnesses = Vec::with_capacity(b);
witnesses.push(TriptychWitness::random(¶ms, rng));
for _ in 1..b {
let r = Scalar::random(rng);
let l = (witnesses.last().unwrap().get_l() + 1) % params.get_N();
witnesses.push(TriptychWitness::new(¶ms, l, &r).unwrap());
}
let mut M = (0..params.get_N())
.map(|_| RistrettoPoint::random(rng))
.collect::<Vec<RistrettoPoint>>();
for witness in &witnesses {
M[witness.get_l() as usize] = witness.compute_verification_key();
}
let input_set = Arc::new(TriptychInputSet::new(&M).unwrap());
let mut statements = Vec::with_capacity(b);
for witness in &witnesses {
let J = witness.compute_linking_tag();
statements.push(TriptychStatement::new(¶ms, &input_set, &J).unwrap());
}
let transcripts = (0..b)
.map(|i| {
let mut transcript = Transcript::new(b"Test transcript");
transcript.append_u64(b"index", i as u64);
transcript
})
.collect::<Vec<Transcript>>();
(witnesses, statements, transcripts)
}
#[test]
#[cfg(feature = "rand")]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof = TriptychProof::prove(&witnesses[0], &statements[0], &mut transcripts[0].clone()).unwrap();
assert!(proof.verify(&statements[0], &mut transcripts[0]).is_ok());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_with_rng() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof = TriptychProof::prove_with_rng(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0].clone())
.unwrap();
assert!(proof.verify(&statements[0], &mut transcripts[0]).is_ok());
}
#[test]
#[cfg(feature = "rand")]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_vartime() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof = TriptychProof::prove_vartime(&witnesses[0], &statements[0], &mut transcripts[0].clone()).unwrap();
assert!(proof.verify(&statements[0], &mut transcripts[0]).is_ok());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_vartime_with_rng() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof =
TriptychProof::prove_with_rng_vartime(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0].clone())
.unwrap();
assert!(proof.verify(&statements[0], &mut transcripts[0]).is_ok());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_serialize_deserialize() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof =
TriptychProof::prove_with_rng_vartime(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0].clone())
.unwrap();
assert!(proof.verify(&statements[0], &mut transcripts[0]).is_ok());
let serialized = proof.to_bytes();
let deserialized = TriptychProof::from_bytes(&serialized).unwrap();
assert_eq!(deserialized, proof);
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_batch() {
const n: u32 = 2;
const m: u32 = 4;
const batch: usize = 3; let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, batch, &mut rng);
let proofs = izip!(witnesses.iter(), statements.iter(), transcripts.clone().iter_mut())
.map(|(w, s, t)| TriptychProof::prove_with_rng_vartime(w, s, &mut rng, t).unwrap())
.collect::<Vec<TriptychProof>>();
assert!(TriptychProof::verify_batch(&statements, &proofs, &mut transcripts.clone()).is_ok());
assert!(TriptychProof::verify_batch_with_single_blame(&statements, &proofs, &mut transcripts.clone()).is_ok());
assert!(TriptychProof::verify_batch_with_full_blame(&statements, &proofs, &mut transcripts).is_ok());
}
#[test]
fn test_prove_verify_empty_batch() {
assert!(TriptychProof::verify_batch(&[], &[], &mut []).is_ok());
assert!(TriptychProof::verify_batch_with_single_blame(&[], &[], &mut []).is_ok());
assert!(TriptychProof::verify_batch_with_full_blame(&[], &[], &mut []).is_ok());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_invalid_batch() {
const n: u32 = 2;
const m: u32 = 4;
const batch: usize = 3; let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, batch, &mut rng);
let proofs = izip!(witnesses.iter(), statements.iter(), transcripts.clone().iter_mut())
.map(|(w, s, t)| TriptychProof::prove_with_rng_vartime(w, s, &mut rng, t).unwrap())
.collect::<Vec<TriptychProof>>();
transcripts[0] = Transcript::new(b"Evil transcript");
assert!(TriptychProof::verify_batch(&statements, &proofs, &mut transcripts).is_err());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_invalid_batch_single_blame() {
const n: u32 = 2;
const m: u32 = 4;
for batch in [4, 5] {
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, transcripts) = generate_data(n, m, batch, &mut rng);
let proofs = izip!(witnesses.iter(), statements.iter(), transcripts.clone().iter_mut())
.map(|(w, s, t)| TriptychProof::prove_with_rng_vartime(w, s, &mut rng, t).unwrap())
.collect::<Vec<TriptychProof>>();
for i in 0..proofs.len() {
let mut evil_transcripts = transcripts.clone();
evil_transcripts[i] = Transcript::new(b"Evil transcript");
let error = TriptychProof::verify_batch_with_single_blame(&statements, &proofs, &mut evil_transcripts)
.unwrap_err();
if let ProofError::FailedBatchVerificationWithSingleBlame { index: Some(index) } = error {
assert_eq!(index, i);
} else {
panic!();
}
}
}
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_prove_verify_invalid_batch_full_blame() {
const n: u32 = 2;
const m: u32 = 4;
const batch: usize = 4;
const failures: [usize; 2] = [1, 3];
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, batch, &mut rng);
let proofs = izip!(witnesses.iter(), statements.iter(), transcripts.clone().iter_mut())
.map(|(w, s, t)| TriptychProof::prove_with_rng_vartime(w, s, &mut rng, t).unwrap())
.collect::<Vec<TriptychProof>>();
for i in failures {
transcripts[i] = Transcript::new(b"Evil transcript");
}
let error = TriptychProof::verify_batch_with_full_blame(&statements, &proofs, &mut transcripts).unwrap_err();
if let ProofError::FailedBatchVerificationWithFullBlame { indexes } = error {
assert_eq!(indexes, failures);
} else {
panic!();
}
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_evil_message() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof = TriptychProof::prove_with_rng_vartime(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0])
.unwrap();
let mut evil_transcript = Transcript::new(b"Evil transcript");
assert!(proof.verify(&statements[0], &mut evil_transcript).is_err());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_evil_input_set() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof =
TriptychProof::prove_with_rng_vartime(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0].clone())
.unwrap();
let mut M = statements[0].get_input_set().get_keys().to_vec();
let index = ((witnesses[0].get_l() + 1) % witnesses[0].get_params().get_N()) as usize;
M[index] = RistrettoPoint::random(&mut rng);
let evil_input_set = Arc::new(TriptychInputSet::new(&M).unwrap());
let evil_statement =
TriptychStatement::new(statements[0].get_params(), &evil_input_set, statements[0].get_J()).unwrap();
assert!(proof.verify(&evil_statement, &mut transcripts[0]).is_err());
}
#[test]
#[allow(non_snake_case, non_upper_case_globals)]
fn test_evil_linking_tag() {
const n: u32 = 2;
const m: u32 = 4;
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
let (witnesses, statements, mut transcripts) = generate_data(n, m, 1, &mut rng);
let proof =
TriptychProof::prove_with_rng_vartime(&witnesses[0], &statements[0], &mut rng, &mut transcripts[0].clone())
.unwrap();
let evil_statement = TriptychStatement::new(
statements[0].get_params(),
statements[0].get_input_set(),
&RistrettoPoint::random(&mut rng),
)
.unwrap();
assert!(proof.verify(&evil_statement, &mut transcripts[0]).is_err());
}
}