use std::collections::HashMap;
use primitives::{
algebra::{
elliptic_curve::Curve,
field::{binary::Gf2_128, mersenne::Mersenne107, FieldExtension},
BoxedUint,
},
random::Seed,
types::Positive,
};
use crate::{
circuit::Circuit,
preprocessing::{
dealer::{global::GlobalDealer, TrustedGenerator},
iterator::{
CerberusPreprocessingIterator,
NextDaBit,
NextElement,
NextPowPair,
NextSinglet,
NextTriple,
},
Preprocessing,
PreprocessingBundler,
},
};
pub struct MockDealer<M: Positive> {
global: GlobalDealer<M>,
local_peer_pos: usize,
}
impl<C: Curve, M: Positive> PreprocessingBundler<C> for MockDealer<M> {
fn fetch_for(&mut self, circuit: &Circuit<C>) -> CerberusPreprocessingIterator<C> {
let req = circuit.required_preprocessing();
let raw_bf_dabits = self.request_n_dabits::<C::BaseField>(req.base_field.dabits);
let raw_bf_singlets = self.request_n_singlets::<C::BaseField>(req.base_field.singlets);
let raw_bf_triples = self.request_n_triples::<C::BaseField>(req.base_field.triples);
let raw_bin_singlets = self.request_n_singlets::<Gf2_128>(req.bit_singlets);
let raw_bin_triples = self.request_n_triples::<Gf2_128>(req.bit_triples);
let raw_m107_dabits = self.request_n_dabits::<Mersenne107>(req.mersenne107.dabits);
let raw_m107_singlets = self.request_n_singlets::<Mersenne107>(req.mersenne107.singlets);
let raw_m107_triples = self.request_n_triples::<Mersenne107>(req.mersenne107.triples);
let raw_sc_dabits = self.request_n_dabits::<C::Scalar>(req.scalar.dabits);
let raw_sc_singlets = self.request_n_singlets::<C::Scalar>(req.scalar.singlets);
let raw_sc_triples = self.request_n_triples::<C::Scalar>(req.scalar.triples);
let mut base_field_pow_preprocessing = HashMap::default();
for (exp, n_pow_pairs) in req.base_field_pow_pairs {
let pow_futures = self.request_n_pow_pairs::<C::BaseField>(n_pow_pairs, exp.clone());
base_field_pow_preprocessing.insert(exp, pow_futures.into_iter());
}
CerberusPreprocessingIterator {
base_field_dabits: raw_bf_dabits.into_iter(),
base_field_pow_preprocessing,
base_field_singlets: raw_bf_singlets.into_iter(),
base_field_triples: raw_bf_triples.into_iter(),
binary_singlets: raw_bin_singlets.into_iter(),
binary_triples: raw_bin_triples.into_iter(),
mersenne107_dabits: raw_m107_dabits.into_iter(),
mersenne107_singlets: raw_m107_singlets.into_iter(),
mersenne107_triples: raw_m107_triples.into_iter(),
scalar_dabits: raw_sc_dabits.into_iter(),
scalar_singlets: raw_sc_singlets.into_iter(),
scalar_triples: raw_sc_triples.into_iter(),
}
}
}
impl<M: Positive> MockDealer<M> {
pub fn new(n_parties: usize, local_peer_pos: usize, alphas_seed: Seed) -> Self {
let global = GlobalDealer::new_with(n_parties, alphas_seed);
Self {
global,
local_peer_pos,
}
}
fn request_n<P: Preprocessing + 'static>(
&mut self,
n: usize,
associated_data: P::AssociatedData,
) -> Vec<NextElement<P>>
where
GlobalDealer<M>: TrustedGenerator<P>,
{
self.global
.generate_n_for_each(n, associated_data)
.swap_remove(self.local_peer_pos)
.into_iter()
.map(|e| Box::pin(std::future::ready(Ok(e))) as _)
.collect()
}
pub fn request_n_singlets<F: FieldExtension>(&mut self, n: usize) -> Vec<NextSinglet<F>> {
self.request_n(n, ())
}
pub fn request_n_triples<F: FieldExtension>(&mut self, n: usize) -> Vec<NextTriple<F>> {
self.request_n(n, ())
}
pub fn request_n_pow_pairs<F: FieldExtension>(
&mut self,
n: usize,
exponent: BoxedUint,
) -> Vec<NextPowPair<F>> {
self.request_n(n, exponent)
}
pub fn request_n_dabits<F: FieldExtension>(&mut self, n: usize) -> Vec<NextDaBit<F>> {
self.request_n(n, ())
}
}