use alloc::{sync::Arc, vec, vec::Vec};
use curve25519_dalek::{traits::Identity, RistrettoPoint};
use snafu::prelude::*;
use crate::{Transcript, TriptychParameters, TRANSCRIPT_HASH_BYTES};
#[allow(non_snake_case)]
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct TriptychInputSet {
M: Vec<RistrettoPoint>,
hash: Vec<u8>,
}
impl TriptychInputSet {
const DOMAIN: &'static str = "Triptych input set";
const VERSION: u64 = 0;
#[allow(non_snake_case)]
pub fn new(M: &[RistrettoPoint]) -> Result<Self, StatementError> {
Self::new_internal(M, M.len())
}
#[allow(non_snake_case)]
pub fn new_with_padding(M: &[RistrettoPoint], params: &TriptychParameters) -> Result<Self, StatementError> {
let unpadded_size = M.len();
if unpadded_size > params.get_N() as usize {
return Err(StatementError::InvalidParameter);
}
let last = M.last().ok_or(StatementError::InvalidParameter)?;
let mut M_padded = M.to_vec();
M_padded.resize(params.get_N() as usize, *last);
Self::new_internal(&M_padded, unpadded_size)
}
#[allow(non_snake_case)]
fn new_internal(M: &[RistrettoPoint], unpadded_size: usize) -> Result<Self, StatementError> {
let unpadded_size = u32::try_from(unpadded_size).map_err(|_| StatementError::InvalidParameter)?;
let mut transcript = Transcript::new(Self::DOMAIN.as_bytes());
transcript.append_u64(b"version", Self::VERSION);
transcript.append_message(b"unpadded_size", &unpadded_size.to_le_bytes());
for item in M {
transcript.append_message(b"M", item.compress().as_bytes());
}
let mut hash = vec![0u8; TRANSCRIPT_HASH_BYTES];
transcript.challenge_bytes(b"hash", &mut hash);
Ok(Self { M: M.to_vec(), hash })
}
pub fn get_keys(&self) -> &[RistrettoPoint] {
&self.M
}
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
}
}
#[allow(non_snake_case)]
#[derive(Clone, Eq, PartialEq)]
pub struct TriptychStatement {
params: Arc<TriptychParameters>,
input_set: Arc<TriptychInputSet>,
J: RistrettoPoint,
hash: Vec<u8>,
}
#[derive(Debug, Snafu)]
pub enum StatementError {
#[snafu(display("An invalid parameter was provided"))]
InvalidParameter,
}
impl TriptychStatement {
const DOMAIN: &'static str = "Triptych statement";
const VERSION: u64 = 0;
#[allow(non_snake_case)]
pub fn new(
params: &Arc<TriptychParameters>,
input_set: &Arc<TriptychInputSet>,
J: &RistrettoPoint,
) -> Result<Self, StatementError> {
if input_set.get_keys().len() != params.get_N() as usize {
return Err(StatementError::InvalidParameter);
}
if input_set.get_keys().contains(&RistrettoPoint::identity()) {
return Err(StatementError::InvalidParameter);
}
let mut transcript = Transcript::new(Self::DOMAIN.as_bytes());
transcript.append_u64(b"version", Self::VERSION);
transcript.append_message(b"params", params.get_hash());
transcript.append_message(b"input_set", input_set.get_hash());
transcript.append_message(b"J", J.compress().as_bytes());
let mut hash = vec![0u8; TRANSCRIPT_HASH_BYTES];
transcript.challenge_bytes(b"hash", &mut hash);
Ok(Self {
params: params.clone(),
input_set: input_set.clone(),
J: *J,
hash,
})
}
pub fn get_params(&self) -> &Arc<TriptychParameters> {
&self.params
}
pub fn get_input_set(&self) -> &Arc<TriptychInputSet> {
&self.input_set
}
#[allow(non_snake_case)]
pub fn get_J(&self) -> &RistrettoPoint {
&self.J
}
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
}
}
#[cfg(test)]
mod test {
use alloc::{borrow::ToOwned, vec::Vec};
use curve25519_dalek::RistrettoPoint;
use rand_chacha::ChaCha12Rng;
use rand_core::SeedableRng;
use crate::{TriptychInputSet, TriptychParameters};
fn random_vector(size: usize) -> Vec<RistrettoPoint> {
let mut rng = ChaCha12Rng::seed_from_u64(8675309);
(0..size)
.map(|_| RistrettoPoint::random(&mut rng))
.collect::<Vec<RistrettoPoint>>()
}
#[test]
#[allow(non_snake_case)]
fn test_padding() {
let params = TriptychParameters::new(2, 4).unwrap();
let N = params.get_N() as usize;
assert!(TriptychInputSet::new_with_padding(&[], ¶ms).is_err());
let M = random_vector(N + 1);
assert!(TriptychInputSet::new_with_padding(&M, ¶ms).is_err());
let M = random_vector(N);
assert_eq!(
TriptychInputSet::new_with_padding(&M, ¶ms).unwrap(),
TriptychInputSet::new(&M).unwrap()
);
let M = random_vector(N - 1);
let mut M_padded = M.clone();
M_padded.push(M.last().unwrap().to_owned());
assert_eq!(
TriptychInputSet::new_with_padding(&M, ¶ms).unwrap().get_keys(),
TriptychInputSet::new(&M_padded).unwrap().get_keys()
);
assert_ne!(
TriptychInputSet::new_with_padding(&M, ¶ms).unwrap().get_hash(),
TriptychInputSet::new(&M_padded).unwrap().get_hash()
)
}
}