use ark_ff::PrimeField;
use ark_serialize::{CanonicalDeserialize, CanonicalSerialize};
use ark_std::{cfg_into_iter, collections::BTreeMap, rand::RngCore, vec, vec::Vec};
use digest::Digest;
use super::ParticipantId;
use crate::error::OTError;
use dock_crypto_utils::expect_equality;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct Commitments(pub Vec<Vec<u8>>);
#[derive(Clone, Debug, PartialEq, CanonicalSerialize, CanonicalDeserialize)]
pub struct Party<F: PrimeField, const SALT_SIZE: usize> {
pub id: ParticipantId,
pub protocol_id: Vec<u8>,
pub own_shares_and_salts: Vec<(F, [u8; SALT_SIZE])>,
pub other_commitments: BTreeMap<ParticipantId, Commitments>,
pub other_shares: BTreeMap<ParticipantId, Vec<F>>,
}
impl<F: PrimeField, const SALT_SIZE: usize> Party<F, SALT_SIZE> {
pub fn commit<R: RngCore, D: Digest>(
rng: &mut R,
id: ParticipantId,
batch_size: u32,
protocol_id: Vec<u8>,
) -> (Self, Commitments) {
let shares_and_salts = (0..batch_size)
.map(|_| {
let mut salt = [0; SALT_SIZE];
rng.fill_bytes(&mut salt);
(F::rand(rng), salt)
})
.collect::<Vec<_>>();
let commitments = Self::compute_commitments::<D>(&shares_and_salts, &protocol_id);
(
Self {
id,
protocol_id,
own_shares_and_salts: shares_and_salts,
other_commitments: Default::default(),
other_shares: Default::default(),
},
Commitments(commitments),
)
}
pub fn receive_commitment(
&mut self,
sender_id: ParticipantId,
commitments: Commitments,
) -> Result<(), OTError> {
if self.id == sender_id {
return Err(OTError::SenderIdCannotBeSameAsSelf(sender_id, self.id));
}
if self.other_commitments.contains_key(&sender_id) {
return Err(OTError::AlreadyHaveCommitmentFromParticipant(sender_id));
}
expect_equality!(
self.own_shares_and_salts.len(),
commitments.0.len(),
OTError::IncorrectNoOfCommitments
);
self.other_commitments.insert(sender_id, commitments);
Ok(())
}
pub fn receive_shares<D: Digest>(
&mut self,
sender_id: ParticipantId,
shares_and_salts: Vec<(F, [u8; SALT_SIZE])>,
) -> Result<(), OTError> {
if self.id == sender_id {
return Err(OTError::SenderIdCannotBeSameAsSelf(sender_id, self.id));
}
if !self.other_commitments.contains_key(&sender_id) {
return Err(OTError::MissingCommitmentFromParticipant(sender_id));
}
if self.other_shares.contains_key(&sender_id) {
return Err(OTError::AlreadyHaveSharesFromParticipant(sender_id));
}
expect_equality!(
self.own_shares_and_salts.len(),
shares_and_salts.len(),
OTError::IncorrectNoOfShares
);
let expected_commitments =
Self::compute_commitments::<D>(&shares_and_salts, &self.protocol_id);
if expected_commitments != self.other_commitments.get(&sender_id).unwrap().0 {
return Err(OTError::IncorrectCommitment);
}
self.other_shares.insert(
sender_id,
shares_and_salts.into_iter().map(|(s, _)| s).collect(),
);
Ok(())
}
pub fn compute_joint_randomness(self) -> Vec<F> {
cfg_into_iter!(0..self.own_shares_and_salts.len())
.map(|i| {
let mut sum = self.own_shares_and_salts[i].0;
for v in self.other_shares.values() {
sum += v[i];
}
sum
})
.collect()
}
pub fn has_commitment_from(&self, id: &ParticipantId) -> bool {
self.other_commitments.contains_key(id)
}
pub fn has_shares_from(&self, id: &ParticipantId) -> bool {
self.other_shares.contains_key(id)
}
pub fn has_shares_from_all_who_committed(&self) -> bool {
self.other_shares.len() == self.other_commitments.len()
}
fn compute_commitments<D: Digest>(
shares_and_salts: &[(F, [u8; SALT_SIZE])],
label: &[u8],
) -> Vec<Vec<u8>> {
cfg_into_iter!(0..shares_and_salts.len())
.map(|i| hash::<F, D>(label, &shares_and_salts[i].0, &shares_and_salts[i].1))
.collect()
}
}
fn hash<F: PrimeField, D: Digest>(label: &[u8], share: &F, salt: &[u8]) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(label);
share.serialize_compressed(&mut bytes).unwrap();
bytes.extend_from_slice(salt);
let mut hasher = D::new();
hasher.update(bytes);
hasher.finalize().to_vec()
}
#[cfg(test)]
pub mod tests {
use super::*;
use ark_bls12_381::Bls12_381;
use ark_ec::pairing::Pairing;
use std::time::Instant;
use ark_std::rand::{rngs::StdRng, SeedableRng};
use sha3::Sha3_256;
type Fr = <Bls12_381 as Pairing>::ScalarField;
#[test]
fn cointoss() {
let mut rng = StdRng::seed_from_u64(0u64);
fn check<const SALT_SIZE: usize>(rng: &mut StdRng, batch_size: u32, num_parties: u16) {
let label = b"test".to_vec();
let mut parties = vec![];
let mut commitments = vec![];
let start = Instant::now();
for i in 1..=num_parties {
let (party, comm) = Party::<Fr, SALT_SIZE>::commit::<_, Sha3_256>(
rng,
i,
batch_size,
label.clone(),
);
parties.push(party);
commitments.push(comm);
}
let commit_time = start.elapsed();
let start = Instant::now();
for i in 1..=num_parties {
for j in 1..=num_parties {
if i != j {
parties[i as usize - 1]
.receive_commitment(j, commitments[j as usize - 1].clone())
.unwrap();
}
}
}
let process_commit_time = start.elapsed();
let start = Instant::now();
for receiver_id in 1..=num_parties {
for sender_id in 1..=num_parties {
if receiver_id != sender_id {
assert!(
!parties[receiver_id as usize - 1].has_shares_from_all_who_committed()
);
let share = parties[sender_id as usize - 1].own_shares_and_salts.clone();
parties[receiver_id as usize - 1]
.receive_shares::<Sha3_256>(sender_id, share)
.unwrap();
}
}
assert!(parties[receiver_id as usize - 1].has_shares_from_all_who_committed());
}
let process_shares_time = start.elapsed();
for i in 1..=num_parties {
for j in 1..=num_parties {
if i != j {
assert_eq!(
parties[j as usize - 1].other_shares.get(&i).unwrap(),
&parties[i as usize - 1]
.own_shares_and_salts
.clone()
.into_iter()
.map(|s| s.0)
.collect::<Vec<_>>()
)
}
}
}
for i in 0..num_parties as usize {
assert!(parties[i].has_shares_from_all_who_committed());
}
let start = Instant::now();
let mut joint_randomness = vec![];
for party in parties {
joint_randomness.push(party.compute_joint_randomness());
}
let compute_randomness_time = start.elapsed();
for i in 1..num_parties as usize {
assert_eq!(joint_randomness[0], joint_randomness[i]);
}
println!("For a batch size of {} and {} parties, below is the total time taken by all parties", batch_size, num_parties);
println!("Commitment time {:?}", commit_time);
println!("Processing commitment time {:?}", process_commit_time);
println!("Processing shares time {:?}", process_shares_time);
println!(
"Computing joint randomness time {:?}",
compute_randomness_time
);
}
check::<256>(&mut rng, 10, 5);
check::<256>(&mut rng, 20, 5);
check::<256>(&mut rng, 30, 5);
check::<256>(&mut rng, 10, 10);
check::<256>(&mut rng, 10, 20);
}
}