use futures::future::{BoxFuture, Shared};
#[cfg(any(test, feature = "dev"))]
use futures::FutureExt;
#[cfg(any(test, feature = "dev"))]
use primitives::types::Positive;
use primitives::{
algebra::{field::FieldExtension, BoxedUint},
correlated_randomness::{dabits::DaBit, pow::PowPair, singlets::Singlet, triples::Triple},
};
#[cfg(any(test, feature = "dev"))]
use crate::preprocessing::dealer::{global::GlobalDealer, mock::MockDealer, TrustedGenerator};
use crate::{
errors::AbortError,
preprocessing::{
iterator::{
NextBatch,
NextDaBit,
NextElement,
NextPowPair,
NextSinglet,
NextSingletBatch,
NextTriple,
NextTripleBatch,
},
Preprocessing,
},
};
pub trait PreprocessingSource<P: Preprocessing + 'static>: 'static + Send {
fn request_n_elements_batch(
&mut self,
n_elements: usize,
associated_data: <P as Preprocessing>::AssociatedData,
) -> NextBatch<P>;
fn request_n_elements(
&mut self,
n_elements: usize,
associated_data: <P as Preprocessing>::AssociatedData,
) -> Vec<NextElement<P>> {
use std::sync::{Arc, Mutex};
use futures::FutureExt;
let batch = self.request_n_elements_batch(n_elements, associated_data);
#[allow(clippy::type_complexity)]
let slots: Shared<
BoxFuture<'static, Result<Arc<Mutex<Vec<Option<P>>>>, AbortError>>,
> = async move {
let vec = batch.await?;
Ok(Arc::new(Mutex::new(
vec.into_iter().map(Some).collect::<Vec<_>>(),
)))
}
.boxed()
.shared();
let futures = (0..n_elements)
.map(|i| {
let s = slots.clone();
Box::pin(async move {
s.await?
.lock()
.map_err(|_| {
AbortError::internal_error("preprocessing slots mutex poisoned")
})?
.get_mut(i)
.and_then(Option::take)
.ok_or_else(|| {
AbortError::internal_error(&format!(
"Preprocessing batch too small: index {i} out of bounds"
))
})
}) as NextElement<P>
})
.collect::<Vec<_>>();
futures
}
}
pub trait SingletSource<F: FieldExtension>: PreprocessingSource<Singlet<F>> {
fn request_n_singlets(&mut self, n_singlets: usize) -> Vec<NextSinglet<F>> {
self.request_n_elements(n_singlets, ())
}
fn request_n_singlets_batch(&mut self, n_singlets: usize) -> NextSingletBatch<F> {
self.request_n_elements_batch(n_singlets, ())
}
}
impl<F: FieldExtension, T: PreprocessingSource<Singlet<F>>> SingletSource<F> for T {}
pub trait TripleSource<F: FieldExtension>: PreprocessingSource<Triple<F>> {
fn request_n_triples(&mut self, n_triples: usize) -> Vec<NextTriple<F>> {
self.request_n_elements(n_triples, ())
}
fn request_n_triples_batch(&mut self, n_triples: usize) -> NextTripleBatch<F> {
self.request_n_elements_batch(n_triples, ())
}
}
impl<F: FieldExtension, T: PreprocessingSource<Triple<F>>> TripleSource<F> for T {}
pub trait DaBitSource<F: FieldExtension>: PreprocessingSource<DaBit<F>> {
fn request_n_dabits(&mut self, n_dabits: usize) -> Vec<NextDaBit<F>> {
self.request_n_elements(n_dabits, ())
}
fn request_n_dabits_batch(&mut self, n_dabits: usize) -> NextBatch<DaBit<F>> {
self.request_n_elements_batch(n_dabits, ())
}
}
impl<F: FieldExtension, T: PreprocessingSource<DaBit<F>>> DaBitSource<F> for T {}
pub trait PowPairSource<F: FieldExtension>: PreprocessingSource<PowPair<F>> {
fn request_n_pow_pairs(
&mut self,
n_pow_pairs: usize,
exponent: BoxedUint,
) -> Vec<NextPowPair<F>> {
self.request_n_elements(n_pow_pairs, exponent)
}
fn request_n_pow_pairs_batch(
&mut self,
n_pow_pairs: usize,
exponent: BoxedUint,
) -> NextBatch<PowPair<F>> {
self.request_n_elements_batch(n_pow_pairs, exponent)
}
}
impl<F: FieldExtension, T: PreprocessingSource<PowPair<F>>> PowPairSource<F> for T {}
#[cfg(any(test, feature = "dev"))]
impl<P: Preprocessing + 'static, M: Positive> PreprocessingSource<P> for MockDealer<M>
where
GlobalDealer<M>: TrustedGenerator<P>,
{
fn request_n_elements_batch(
&mut self,
n_elements: usize,
associated_data: <P as Preprocessing>::AssociatedData,
) -> NextBatch<P> {
let local_elements = self.request_n(n_elements, associated_data);
futures::future::try_join_all(local_elements)
.boxed()
.shared()
}
}