use std::{collections::HashMap, fmt::Debug, future::Future, pin::Pin, vec::IntoIter};
use futures::future::{BoxFuture, Shared};
use primitives::{
algebra::{
elliptic_curve::Curve,
field::{binary::Gf2_128, mersenne::Mersenne107},
BoxedUint,
},
correlated_randomness::{dabits::DaBit, pow::PowPair, singlets::Singlet, triples::Triple},
};
use crate::{
circuit::{CircuitPreprocessing, FieldCircuitPreprocessing},
errors::AbortError,
preprocessing::errors::PreprocessingBundlerError,
};
pub type NextElement<P> = Pin<Box<dyn Future<Output = Result<P, AbortError>> + Send + 'static>>;
pub type NextBatch<P> = Shared<BoxFuture<'static, Result<Vec<P>, AbortError>>>;
pub type NextSinglet<F> = NextElement<Singlet<F>>;
pub type NextSingletBatch<F> = NextBatch<Singlet<F>>;
pub type NextTriple<F> = NextElement<Triple<F>>;
pub type NextTripleBatch<F> = NextBatch<Triple<F>>;
pub type NextDaBit<F> = NextElement<DaBit<F>>;
pub type NextDaBitBatch<F> = NextBatch<DaBit<F>>;
pub type NextPowPair<F> = NextElement<PowPair<F>>;
pub type NextPowPairBatch<F> = NextBatch<PowPair<F>>;
pub struct CerberusPreprocessingIterator<C: Curve> {
pub base_field_dabits: IntoIter<NextDaBit<C::BaseField>>,
pub base_field_pow_preprocessing: HashMap<BoxedUint, IntoIter<NextPowPair<C::BaseField>>>,
pub base_field_singlets: IntoIter<NextSinglet<C::BaseField>>,
pub base_field_triples: IntoIter<NextTriple<C::BaseField>>,
pub binary_singlets: IntoIter<NextSinglet<Gf2_128>>,
pub binary_triples: IntoIter<NextTriple<Gf2_128>>,
pub mersenne107_dabits: IntoIter<NextDaBit<Mersenne107>>,
pub mersenne107_singlets: IntoIter<NextSinglet<Mersenne107>>,
pub mersenne107_triples: IntoIter<NextTriple<Mersenne107>>,
pub scalar_dabits: IntoIter<NextDaBit<C::Scalar>>,
pub scalar_singlets: IntoIter<NextSinglet<C::Scalar>>,
pub scalar_triples: IntoIter<NextTriple<C::Scalar>>,
}
impl<C: Curve> CerberusPreprocessingIterator<C> {
pub fn len(&self) -> CircuitPreprocessing {
CircuitPreprocessing {
base_field_pow_pairs: self
.base_field_pow_preprocessing
.iter()
.map(|(x, y)| (x.clone(), y.len()))
.collect(),
bit_singlets: self.binary_singlets.len(),
bit_triples: self.binary_triples.len(),
base_field: FieldCircuitPreprocessing {
singlets: self.base_field_singlets.len(),
triples: self.base_field_triples.len(),
dabits: self.base_field_dabits.len(),
},
scalar: FieldCircuitPreprocessing {
singlets: self.scalar_singlets.len(),
triples: self.scalar_triples.len(),
dabits: self.scalar_dabits.len(),
},
mersenne107: FieldCircuitPreprocessing {
singlets: self.mersenne107_singlets.len(),
triples: self.mersenne107_triples.len(),
dabits: self.mersenne107_dabits.len(),
},
}
}
pub fn is_empty(&self) -> bool {
self.base_field_pow_preprocessing
.iter()
.all(|(_, x)| x.len() == 0)
&& self.base_field_dabits.len() == 0
&& self.base_field_singlets.len() == 0
&& self.base_field_triples.len() == 0
&& self.binary_singlets.len() == 0
&& self.binary_triples.len() == 0
&& self.mersenne107_dabits.len() == 0
&& self.mersenne107_singlets.len() == 0
&& self.mersenne107_triples.len() == 0
&& self.scalar_dabits.len() == 0
&& self.scalar_singlets.len() == 0
&& self.scalar_triples.len() == 0
}
pub fn next_base_field_dabit(
&mut self,
) -> Result<NextDaBit<C::BaseField>, PreprocessingBundlerError> {
self.base_field_dabits.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientDaBits(
std::any::type_name::<C::BaseField>().to_string(),
)
})
}
pub fn next_base_field_powpair(
&mut self,
exp: &BoxedUint,
) -> Result<NextPowPair<C::BaseField>, PreprocessingBundlerError> {
self.base_field_pow_preprocessing
.get_mut(exp)
.ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))?
.next()
.ok_or_else(|| PreprocessingBundlerError::InsufficientPowPreprocessing(exp.clone()))
}
pub fn next_base_field_singlet(
&mut self,
) -> Result<NextSinglet<C::BaseField>, PreprocessingBundlerError> {
self.base_field_singlets.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientSinglets(
std::any::type_name::<C::BaseField>().to_string(),
)
})
}
pub fn next_base_field_triple(
&mut self,
) -> Result<NextTriple<C::BaseField>, PreprocessingBundlerError> {
self.base_field_triples.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientTriples(
std::any::type_name::<C::BaseField>().to_string(),
)
})
}
pub fn next_bit_singlet(&mut self) -> Result<NextSinglet<Gf2_128>, PreprocessingBundlerError> {
self.binary_singlets.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientSinglets(
std::any::type_name::<Gf2_128>().to_string(),
)
})
}
pub fn next_bit_triple(&mut self) -> Result<NextTriple<Gf2_128>, PreprocessingBundlerError> {
self.binary_triples.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientTriples(
std::any::type_name::<Gf2_128>().to_string(),
)
})
}
pub fn next_mersenne107_dabit(
&mut self,
) -> Result<NextDaBit<Mersenne107>, PreprocessingBundlerError> {
self.mersenne107_dabits
.next()
.ok_or_else(|| PreprocessingBundlerError::InsufficientDaBits("Mersenne107".to_string()))
}
pub fn next_mersenne107_singlet(
&mut self,
) -> Result<NextSinglet<Mersenne107>, PreprocessingBundlerError> {
self.mersenne107_singlets.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientSinglets("Mersenne107".to_string())
})
}
pub fn next_mersenne107_triple(
&mut self,
) -> Result<NextTriple<Mersenne107>, PreprocessingBundlerError> {
self.mersenne107_triples.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientTriples("Mersenne107".to_string())
})
}
pub fn next_scalar_dabit(&mut self) -> Result<NextDaBit<C::Scalar>, PreprocessingBundlerError> {
self.scalar_dabits.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientDaBits(
std::any::type_name::<C::Scalar>().to_string(),
)
})
}
pub fn next_scalar_singlet(
&mut self,
) -> Result<NextSinglet<C::Scalar>, PreprocessingBundlerError> {
self.scalar_singlets.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientSinglets(
std::any::type_name::<C::Scalar>().to_string(),
)
})
}
pub fn next_scalar_triple(
&mut self,
) -> Result<NextTriple<C::Scalar>, PreprocessingBundlerError> {
self.scalar_triples.next().ok_or_else(|| {
PreprocessingBundlerError::InsufficientTriples(
std::any::type_name::<C::Scalar>().to_string(),
)
})
}
}
impl<C: Curve> Debug for CerberusPreprocessingIterator<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CerberusPreprocessingIterator with len: ")?;
self.len().fmt(f)
}
}