use alloc::{sync::Arc, vec, vec::Vec};
use curve25519_dalek::{traits::Identity, RistrettoPoint};
use snafu::prelude::*;
use crate::{parallel::TriptychParameters, Transcript, TRANSCRIPT_HASH_BYTES};
#[allow(non_snake_case)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TriptychInputSet {
M: Vec<RistrettoPoint>,
M1: Vec<RistrettoPoint>,
hash: Vec<u8>,
}
impl TriptychInputSet {
const DOMAIN: &'static str = "Parallel Triptych input set";
const VERSION: u64 = 0;
#[allow(non_snake_case)]
pub fn new(M: &[RistrettoPoint], M1: &[RistrettoPoint]) -> Result<Self, StatementError> {
if M.len() != M1.len() {
return Err(StatementError::InvalidParameter);
}
Self::new_internal(M, M1, M.len())
}
#[allow(non_snake_case)]
pub fn new_with_padding(
M: &[RistrettoPoint],
M1: &[RistrettoPoint],
params: &TriptychParameters,
) -> Result<Self, StatementError> {
if M.len() != M1.len() {
return Err(StatementError::InvalidParameter);
}
let unpadded_size = M.len();
if unpadded_size > params.get_N() as usize {
return Err(StatementError::InvalidParameter);
}
let last = M.last().ok_or(StatementError::InvalidParameter)?;
let last1 = M1.last().ok_or(StatementError::InvalidParameter)?;
let mut M_padded = M.to_vec();
M_padded.resize(params.get_N() as usize, *last);
let mut M1_padded = M1.to_vec();
M1_padded.resize(params.get_N() as usize, *last1);
Self::new_internal(&M_padded, &M1_padded, unpadded_size)
}
#[allow(non_snake_case)]
fn new_internal(M: &[RistrettoPoint], M1: &[RistrettoPoint], unpadded_size: usize) -> Result<Self, StatementError> {
let unpadded_size = u32::try_from(unpadded_size).map_err(|_| StatementError::InvalidParameter)?;
let mut transcript = Transcript::new(Self::DOMAIN.as_bytes());
transcript.append_u64(b"version", Self::VERSION);
transcript.append_message(b"unpadded_size", &unpadded_size.to_le_bytes());
for item in M {
transcript.append_message(b"M", item.compress().as_bytes());
}
for item in M1 {
transcript.append_message(b"M1", item.compress().as_bytes());
}
let mut hash = vec![0u8; TRANSCRIPT_HASH_BYTES];
transcript.challenge_bytes(b"hash", &mut hash);
Ok(Self {
M: M.to_vec(),
M1: M1.to_vec(),
hash,
})
}
pub fn get_keys(&self) -> &[RistrettoPoint] {
&self.M
}
pub fn get_auxiliary_keys(&self) -> &[RistrettoPoint] {
&self.M1
}
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
}
}
#[allow(non_snake_case)]
#[derive(Clone, Eq, PartialEq)]
pub struct TriptychStatement {
params: Arc<TriptychParameters>,
input_set: Arc<TriptychInputSet>,
offset: RistrettoPoint,
J: RistrettoPoint,
hash: Vec<u8>,
}
#[derive(Debug, Snafu)]
pub enum StatementError {
#[snafu(display("An invalid parameter was provided"))]
InvalidParameter,
}
impl TriptychStatement {
const DOMAIN: &'static str = "Parallel Triptych statement";
const VERSION: u64 = 0;
#[allow(non_snake_case)]
pub fn new(
params: &Arc<TriptychParameters>,
input_set: &Arc<TriptychInputSet>,
offset: &RistrettoPoint,
J: &RistrettoPoint,
) -> Result<Self, StatementError> {
if input_set.get_keys().len() != params.get_N() as usize {
return Err(StatementError::InvalidParameter);
}
if input_set.get_keys().contains(&RistrettoPoint::identity()) {
return Err(StatementError::InvalidParameter);
}
if input_set
.get_auxiliary_keys()
.iter()
.map(|p| p - offset)
.collect::<Vec<RistrettoPoint>>()
.contains(&RistrettoPoint::identity())
{
return Err(StatementError::InvalidParameter);
}
let mut transcript = Transcript::new(Self::DOMAIN.as_bytes());
transcript.append_u64(b"version", Self::VERSION);
transcript.append_message(b"params", params.get_hash());
transcript.append_message(b"input_set", input_set.get_hash());
transcript.append_message(b"offset", offset.compress().as_bytes());
transcript.append_message(b"J", J.compress().as_bytes());
let mut hash = vec![0u8; TRANSCRIPT_HASH_BYTES];
transcript.challenge_bytes(b"hash", &mut hash);
Ok(Self {
params: params.clone(),
input_set: input_set.clone(),
offset: *offset,
J: *J,
hash,
})
}
pub fn get_params(&self) -> &Arc<TriptychParameters> {
&self.params
}
pub fn get_input_set(&self) -> &Arc<TriptychInputSet> {
&self.input_set
}
pub fn get_offset(&self) -> &RistrettoPoint {
&self.offset
}
#[allow(non_snake_case)]
pub fn get_J(&self) -> &RistrettoPoint {
&self.J
}
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
}
}
#[cfg(test)]
mod test {
use alloc::{borrow::ToOwned, vec::Vec};
use curve25519_dalek::RistrettoPoint;
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
use crate::parallel::{TriptychInputSet, TriptychParameters};
fn random_vector(size: usize) -> Vec<RistrettoPoint> {
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
(0..size)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<RistrettoPoint>>()
}
#[test]
#[allow(non_snake_case)]
fn test_padding() {
let params = TriptychParameters::new(2, 4).unwrap();
let N = params.get_N() as usize;
assert!(TriptychInputSet::new_with_padding(&[], &[], ¶ms).is_err());
let M = random_vector(N + 1);
let M1 = random_vector(N + 1);
assert!(TriptychInputSet::new_with_padding(&M, &M1, ¶ms).is_err());
let M = random_vector(N);
let M1 = random_vector(N);
assert_eq!(
TriptychInputSet::new_with_padding(&M, &M1, ¶ms).unwrap(),
TriptychInputSet::new(&M, &M1).unwrap()
);
let M = random_vector(N - 1);
let mut M_padded = M.clone();
M_padded.push(M.last().unwrap().to_owned());
let M1 = random_vector(N - 1);
let mut M1_padded = M1.clone();
M1_padded.push(M1.last().unwrap().to_owned());
assert_eq!(
TriptychInputSet::new_with_padding(&M, &M1, ¶ms).unwrap().get_keys(),
TriptychInputSet::new(&M_padded, &M1_padded).unwrap().get_keys()
);
assert_eq!(
TriptychInputSet::new_with_padding(&M, &M1, ¶ms)
.unwrap()
.get_auxiliary_keys(),
TriptychInputSet::new(&M_padded, &M1_padded)
.unwrap()
.get_auxiliary_keys()
);
assert_ne!(
TriptychInputSet::new_with_padding(&M, &M1, ¶ms).unwrap().get_hash(),
TriptychInputSet::new(&M_padded, &M1_padded).unwrap().get_hash()
)
}
}