primitives/correlated_randomness/stream/
direct.rs1use tokio::sync::{
2 mpsc::{self, UnboundedSender},
3 oneshot,
4};
5
6use crate::correlated_randomness::{
7 generator::CorrelationGenerator,
8 stream::{errors::CorrelatedStreamError, futures::Next, CorrelatedStream, NextVec},
9 CorrelatedBatch,
10};
11
12type Request<PB, E> = (
14 usize,
15 oneshot::Sender<Result<Vec<<PB as IntoIterator>::Item>, E>>,
16);
17
18pub struct DirectStream<PB: CorrelatedBatch, G: CorrelationGenerator<PB>> {
27 request_tx: UnboundedSender<Request<PB, G::Error>>,
28}
29
30impl<PB, G> DirectStream<PB, G>
31where
32 PB: CorrelatedBatch,
33 G: CorrelationGenerator<PB> + Send + 'static,
34 G::Net: Send + 'static,
35{
36 pub fn new(mut generator: G, mut net: G::Net) -> Self {
41 let (request_tx, mut request_rx) = mpsc::unbounded_channel::<Request<PB, G::Error>>();
42 tokio::spawn(async move {
43 while let Some((n_elements, completion)) = request_rx.recv().await {
44 let result = generator.run_for(n_elements, &mut net).await;
45 let _ = completion.send(result);
46 }
47 });
48 Self { request_tx }
49 }
50}
51
52impl<PB, G> CorrelatedStream<PB::Item> for DirectStream<PB, G>
53where
54 PB: CorrelatedBatch,
55 G: CorrelationGenerator<PB> + Send + 'static,
56 G::Net: Send + 'static,
57 G::Error: From<CorrelatedStreamError>,
58{
59 type Error = G::Error;
60
61 fn next_n(
62 &self,
63 n_elements: usize,
64 ) -> Result<NextVec<PB::Item, G::Error>, CorrelatedStreamError> {
65 if n_elements == 0 {
66 return Ok(NextVec::default());
67 }
68 let (tx, rx) = oneshot::channel();
69 self.request_tx
70 .send((n_elements, tx))
71 .map_err(|_| CorrelatedStreamError::StreamClosed)?;
72 Ok(NextVec {
73 future: Next(rx),
74 size: n_elements,
75 })
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use std::{
82 sync::{Arc, Mutex},
83 time::Duration,
84 };
85
86 use rand::{rngs::StdRng, SeedableRng};
87 use typenum::U2;
88
89 use crate::{
90 algebra::elliptic_curve::{Curve25519Ristretto, ScalarField},
91 correlated_randomness::{
92 generator::CorrelationGenerator,
93 singlets::{Singlet, Singlets},
94 stream::{errors::CorrelatedStreamError, CorrelatedStream, DirectStream},
95 },
96 random::Random,
97 utils::TryFuture,
98 };
99
100 type Fq = ScalarField<Curve25519Ristretto>;
101 type TestPB = Singlets<Fq, U2>;
102 type TestItem = Singlet<Fq>;
103
104 struct MockGen {
106 rng: StdRng,
107 calls: Arc<Mutex<Vec<usize>>>,
108 }
109
110 impl CorrelationGenerator<TestPB> for MockGen {
111 type Net = ();
112 type Error = CorrelatedStreamError;
113
114 fn run(&mut self, _net: &mut ()) -> impl TryFuture<Ok = TestPB, Error = Self::Error> {
115 async move { Err(CorrelatedStreamError::StreamClosed) }
116 }
117
118 fn run_for(
119 &mut self,
120 n: usize,
121 _net: &mut (),
122 ) -> impl TryFuture<Ok = Vec<TestItem>, Error = Self::Error> {
123 async move {
124 self.calls.lock().unwrap().push(n);
125 tokio::time::sleep(Duration::from_millis(10)).await;
126 Ok(Singlet::<Fq>::random_n::<Vec<_>>(&mut self.rng, n))
127 }
128 }
129 }
130
131 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
132 async fn next_n_dispatches_in_call_order() {
133 let calls = Arc::new(Mutex::new(Vec::new()));
134 let gen = MockGen {
135 rng: StdRng::from_seed([0u8; 32]),
136 calls: calls.clone(),
137 };
138 let stream = DirectStream::new(gen, ());
139 let f1 = stream.next_n(1).unwrap();
141 let f2 = stream.next_n(2).unwrap();
142 let f3 = stream.next_n(3).unwrap();
143 let (a, b, c) = tokio::join!(f1, f2, f3);
144 assert_eq!(a.unwrap().len(), 1);
145 assert_eq!(b.unwrap().len(), 2);
146 assert_eq!(c.unwrap().len(), 3);
147 assert_eq!(*calls.lock().unwrap(), vec![1, 2, 3]);
148 }
149}