use alloc::{vec, vec::Vec};
use core::iter::once;
use blake3::Hasher;
use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_POINT,
traits::{MultiscalarMul, VartimeMultiscalarMul},
RistrettoPoint,
Scalar,
};
use snafu::prelude::*;
use crate::{util::OperationTiming, Transcript, TRANSCRIPT_HASH_BYTES};
#[allow(non_snake_case)]
#[derive(Clone, Eq, PartialEq)]
pub struct TriptychParameters {
n: u32,
m: u32,
G: RistrettoPoint,
U: RistrettoPoint,
CommitmentG: Vec<RistrettoPoint>,
CommitmentH: RistrettoPoint,
hash: Vec<u8>,
}
#[derive(Debug, Snafu)]
pub enum ParameterError {
#[snafu(display("An invalid parameter was provided"))]
InvalidParameter,
}
impl TriptychParameters {
const DOMAIN: &'static str = "Triptych parameters";
const VERSION: u64 = 0;
#[allow(non_snake_case)]
pub fn new(n: u32, m: u32) -> Result<Self, ParameterError> {
let G = RISTRETTO_BASEPOINT_POINT;
let mut U_bytes = [0u8; 64];
let mut hasher = Hasher::new();
hasher.update(b"Triptych U");
hasher.finalize_xof().fill(&mut U_bytes);
let U = RistrettoPoint::from_uniform_bytes(&U_bytes);
Self::new_with_generators(n, m, &G, &U)
}
#[allow(non_snake_case)]
pub fn new_with_generators(n: u32, m: u32, G: &RistrettoPoint, U: &RistrettoPoint) -> Result<Self, ParameterError> {
if n < 2 || m < 2 {
return Err(ParameterError::InvalidParameter);
}
if n.checked_pow(m).is_none() {
return Err(ParameterError::InvalidParameter);
}
let mut CommitmentH_bytes = [0u8; 64];
let mut hasher = Hasher::new();
hasher.update(b"Triptych CommitmentH");
hasher.finalize_xof().fill(&mut CommitmentH_bytes);
let CommitmentH = RistrettoPoint::from_uniform_bytes(&CommitmentH_bytes);
let mut hasher = Hasher::new();
hasher.update(b"Triptych CommitmentG");
hasher.update(&n.to_le_bytes());
hasher.update(&m.to_le_bytes());
let mut hasher_xof = hasher.finalize_xof();
let mut CommitmentG_bytes = [0u8; 64];
let CommitmentG = (0..n.checked_mul(m).ok_or(ParameterError::InvalidParameter)?)
.map(|_| {
hasher_xof.fill(&mut CommitmentG_bytes);
RistrettoPoint::from_uniform_bytes(&CommitmentG_bytes)
})
.collect::<Vec<RistrettoPoint>>();
let mut transcript = Transcript::new(Self::DOMAIN.as_bytes());
transcript.append_u64(b"version", Self::VERSION);
transcript.append_message(b"n", &n.to_le_bytes());
transcript.append_message(b"m", &m.to_le_bytes());
transcript.append_message(b"G", G.compress().as_bytes());
transcript.append_message(b"U", U.compress().as_bytes());
for item in &CommitmentG {
transcript.append_message(b"CommitmentG", item.compress().as_bytes());
}
transcript.append_message(b"CommitmentH", CommitmentH.compress().as_bytes());
let mut hash = vec![0u8; TRANSCRIPT_HASH_BYTES];
transcript.challenge_bytes(b"hash", &mut hash);
Ok(TriptychParameters {
n,
m,
G: *G,
U: *U,
CommitmentG,
CommitmentH,
hash,
})
}
pub(crate) fn commit_matrix(
&self,
matrix: &[Vec<Scalar>],
mask: &Scalar,
timing: OperationTiming,
) -> Result<RistrettoPoint, ParameterError> {
if matrix.len() != (self.m as usize) || matrix.iter().any(|m| m.len() != (self.n as usize)) {
return Err(ParameterError::InvalidParameter);
}
let scalars = matrix.iter().flatten().chain(once(mask)).collect::<Vec<&Scalar>>();
let points = self.get_CommitmentG().iter().chain(once(self.get_CommitmentH()));
match timing {
OperationTiming::Constant => Ok(RistrettoPoint::multiscalar_mul(scalars, points)),
OperationTiming::Variable => Ok(RistrettoPoint::vartime_multiscalar_mul(scalars, points)),
}
}
#[allow(non_snake_case)]
pub fn get_G(&self) -> &RistrettoPoint {
&self.G
}
#[allow(non_snake_case)]
pub fn get_U(&self) -> &RistrettoPoint {
&self.U
}
pub fn get_n(&self) -> u32 {
self.n
}
pub fn get_m(&self) -> u32 {
self.m
}
#[allow(non_snake_case)]
pub fn get_N(&self) -> u32 {
self.n.pow(self.m)
}
#[allow(non_snake_case)]
pub(crate) fn get_CommitmentG(&self) -> &Vec<RistrettoPoint> {
&self.CommitmentG
}
#[allow(non_snake_case)]
pub(crate) fn get_CommitmentH(&self) -> &RistrettoPoint {
&self.CommitmentH
}
pub(crate) fn get_hash(&self) -> &[u8] {
&self.hash
}
}