use ark_ec::CurveGroup;
use itertools::Itertools;
use crate::{algebra::Scalar, network::PartyId};
pub trait SharedValueSource<C: CurveGroup>: Send + Sync {
fn next_shared_bit(&mut self) -> Scalar<C>;
fn next_shared_bit_batch(&mut self, num_values: usize) -> Vec<Scalar<C>> {
(0..num_values)
.map(|_| self.next_shared_bit())
.collect_vec()
}
fn next_shared_value(&mut self) -> Scalar<C>;
fn next_shared_value_batch(&mut self, num_values: usize) -> Vec<Scalar<C>> {
(0..num_values)
.map(|_| self.next_shared_value())
.collect_vec()
}
fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>);
fn next_shared_inverse_pair_batch(
&mut self,
num_pairs: usize,
) -> (Vec<Scalar<C>>, Vec<Scalar<C>>) {
(0..num_pairs)
.map(|_| self.next_shared_inverse_pair())
.unzip()
}
fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>);
#[allow(clippy::type_complexity)]
fn next_triplet_batch(
&mut self,
num_triplets: usize,
) -> (Vec<Scalar<C>>, Vec<Scalar<C>>, Vec<Scalar<C>>) {
let mut a_vals = Vec::with_capacity(num_triplets);
let mut b_vals = Vec::with_capacity(num_triplets);
let mut c_vals = Vec::with_capacity(num_triplets);
for _ in 0..num_triplets {
let (a, b, c) = self.next_triplet();
a_vals.push(a);
b_vals.push(b);
c_vals.push(c);
}
(a_vals, b_vals, c_vals)
}
}
#[cfg(any(feature = "test_helpers", test))]
#[derive(Clone, Debug, Default)]
pub struct PartyIDBeaverSource {
party_id: u64,
}
#[cfg(any(feature = "test_helpers", test))]
impl PartyIDBeaverSource {
pub fn new(party_id: u64) -> Self {
Self { party_id }
}
}
#[cfg(any(feature = "test_helpers", test))]
impl<C: CurveGroup> SharedValueSource<C> for PartyIDBeaverSource {
fn next_shared_bit(&mut self) -> Scalar<C> {
assert!(self.party_id == 0 || self.party_id == 1);
Scalar::from(self.party_id)
}
fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>) {
if self.party_id == 0 {
(Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64))
} else {
(Scalar::from(1u64), Scalar::from(0u64), Scalar::from(4u64))
}
}
fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>) {
(Scalar::from(self.party_id), Scalar::from(self.party_id))
}
fn next_shared_value(&mut self) -> Scalar<C> {
Scalar::from(self.party_id)
}
}
#[cfg(any(feature = "test_helpers", test))]
pub struct ZeroBeaverSource {
party_id: PartyId,
}
#[cfg(any(feature = "test_helpers", test))]
impl ZeroBeaverSource {
pub fn new(party_id: PartyId) -> Self {
Self { party_id }
}
}
#[cfg(any(feature = "test_helpers", test))]
impl<C: CurveGroup> SharedValueSource<C> for ZeroBeaverSource {
fn next_shared_bit(&mut self) -> Scalar<C> {
Scalar::zero()
}
fn next_triplet(&mut self) -> (Scalar<C>, Scalar<C>, Scalar<C>) {
(Scalar::zero(), Scalar::zero(), Scalar::zero())
}
fn next_shared_inverse_pair(&mut self) -> (Scalar<C>, Scalar<C>) {
let val = Scalar::from(self.party_id);
(val, val)
}
fn next_shared_value(&mut self) -> Scalar<C> {
Scalar::zero()
}
}