use tokio::sync::{
mpsc::{self, UnboundedSender},
oneshot,
};
use crate::correlated_randomness::{
generator::CorrelationGenerator,
stream::{errors::CorrelatedStreamError, futures::Next, CorrelatedStream, NextVec},
CorrelatedBatch,
};
type Request<PB, E> = (
usize,
oneshot::Sender<Result<Vec<<PB as IntoIterator>::Item>, E>>,
);
pub struct DirectStream<PB: CorrelatedBatch, G: CorrelationGenerator<PB>> {
request_tx: UnboundedSender<Request<PB, G::Error>>,
}
impl<PB, G> DirectStream<PB, G>
where
PB: CorrelatedBatch,
G: CorrelationGenerator<PB> + Send + 'static,
G::Net: Send + 'static,
{
pub fn new(mut generator: G, mut net: G::Net) -> Self {
let (request_tx, mut request_rx) = mpsc::unbounded_channel::<Request<PB, G::Error>>();
tokio::spawn(async move {
while let Some((n_elements, completion)) = request_rx.recv().await {
let result = generator.run_for(n_elements, &mut net).await;
let _ = completion.send(result);
}
});
Self { request_tx }
}
}
impl<PB, G> CorrelatedStream<PB::Item> for DirectStream<PB, G>
where
PB: CorrelatedBatch,
G: CorrelationGenerator<PB> + Send + 'static,
G::Net: Send + 'static,
G::Error: From<CorrelatedStreamError>,
{
type Error = G::Error;
fn next_n(
&self,
n_elements: usize,
) -> Result<NextVec<PB::Item, G::Error>, CorrelatedStreamError> {
if n_elements == 0 {
return Ok(NextVec::default());
}
let (tx, rx) = oneshot::channel();
self.request_tx
.send((n_elements, tx))
.map_err(|_| CorrelatedStreamError::StreamClosed)?;
Ok(NextVec {
future: Next(rx),
size: n_elements,
})
}
}
#[cfg(test)]
mod tests {
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use rand::{rngs::StdRng, SeedableRng};
use typenum::U2;
use crate::{
algebra::elliptic_curve::{Curve25519Ristretto, ScalarField},
correlated_randomness::{
generator::CorrelationGenerator,
singlets::{Singlet, Singlets},
stream::{errors::CorrelatedStreamError, CorrelatedStream, DirectStream},
},
random::Random,
utils::TryFuture,
};
type Fq = ScalarField<Curve25519Ristretto>;
type TestPB = Singlets<Fq, U2>;
type TestItem = Singlet<Fq>;
struct MockGen {
rng: StdRng,
calls: Arc<Mutex<Vec<usize>>>,
}
impl CorrelationGenerator<TestPB> for MockGen {
type Net = ();
type Error = CorrelatedStreamError;
fn run(&mut self, _net: &mut ()) -> impl TryFuture<Ok = TestPB, Error = Self::Error> {
async move { Err(CorrelatedStreamError::StreamClosed) }
}
fn run_for(
&mut self,
n: usize,
_net: &mut (),
) -> impl TryFuture<Ok = Vec<TestItem>, Error = Self::Error> {
async move {
self.calls.lock().unwrap().push(n);
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(Singlet::<Fq>::random_n::<Vec<_>>(&mut self.rng, n))
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn next_n_dispatches_in_call_order() {
let calls = Arc::new(Mutex::new(Vec::new()));
let gen = MockGen {
rng: StdRng::from_seed([0u8; 32]),
calls: calls.clone(),
};
let stream = DirectStream::new(gen, ());
let f1 = stream.next_n(1).unwrap();
let f2 = stream.next_n(2).unwrap();
let f3 = stream.next_n(3).unwrap();
let (a, b, c) = tokio::join!(f1, f2, f3);
assert_eq!(a.unwrap().len(), 1);
assert_eq!(b.unwrap().len(), 2);
assert_eq!(c.unwrap().len(), 3);
assert_eq!(*calls.lock().unwrap(), vec![1, 2, 3]);
}
}