Skip to main content

primitives/correlated_randomness/stream/
direct.rs

1use tokio::sync::{
2    mpsc::{self, UnboundedSender},
3    oneshot,
4};
5
6use crate::correlated_randomness::{
7    generator::CorrelationGenerator,
8    stream::{errors::CorrelatedStreamError, futures::Next, CorrelatedStream, NextVec},
9    CorrelatedBatch,
10};
11
12/// A request for `n` elements, paired with the channel resolving the caller's [`NextVec`].
13type Request<PB, E> = (
14    usize,
15    oneshot::Sender<Result<Vec<<PB as IntoIterator>::Item>, E>>,
16);
17
18/// A minimal wrapper around a [`CorrelationGenerator`] that implements [`CorrelatedStream`].
19///
20/// Unlike [`super::buffered::BufferedStream`], `DirectStream` has no internal buffer or
21/// prefetch logic — every [`next_n`](CorrelatedStream::next_n) call triggers a fresh generation
22/// run via the underlying generator.
23///
24/// `next_n` enqueues requests synchronously; a single owner task serves them strictly FIFO, so
25/// dispatch order always matches call order (as required by the [`CorrelatedStream`] contract).
26pub struct DirectStream<PB: CorrelatedBatch, G: CorrelationGenerator<PB>> {
27    request_tx: UnboundedSender<Request<PB, G::Error>>,
28}
29
30impl<PB, G> DirectStream<PB, G>
31where
32    PB: CorrelatedBatch,
33    G: CorrelationGenerator<PB> + Send + 'static,
34    G::Net: Send + 'static,
35{
36    /// Creates a `DirectStream` from a generator and its network.
37    ///
38    /// Spawns the owner task serving requests, so this must be called within a tokio runtime.
39    /// The task exits once the stream is dropped and all queued requests have been served.
40    pub fn new(mut generator: G, mut net: G::Net) -> Self {
41        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<Request<PB, G::Error>>();
42        tokio::spawn(async move {
43            while let Some((n_elements, completion)) = request_rx.recv().await {
44                let result = generator.run_for(n_elements, &mut net).await;
45                let _ = completion.send(result);
46            }
47        });
48        Self { request_tx }
49    }
50}
51
52impl<PB, G> CorrelatedStream<PB::Item> for DirectStream<PB, G>
53where
54    PB: CorrelatedBatch,
55    G: CorrelationGenerator<PB> + Send + 'static,
56    G::Net: Send + 'static,
57    G::Error: From<CorrelatedStreamError>,
58{
59    type Error = G::Error;
60
61    fn next_n(
62        &self,
63        n_elements: usize,
64    ) -> Result<NextVec<PB::Item, G::Error>, CorrelatedStreamError> {
65        if n_elements == 0 {
66            return Ok(NextVec::default());
67        }
68        let (tx, rx) = oneshot::channel();
69        self.request_tx
70            .send((n_elements, tx))
71            .map_err(|_| CorrelatedStreamError::StreamClosed)?;
72        Ok(NextVec {
73            future: Next(rx),
74            size: n_elements,
75        })
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use std::{
82        sync::{Arc, Mutex},
83        time::Duration,
84    };
85
86    use rand::{rngs::StdRng, SeedableRng};
87    use typenum::U2;
88
89    use crate::{
90        algebra::elliptic_curve::{Curve25519Ristretto, ScalarField},
91        correlated_randomness::{
92            generator::CorrelationGenerator,
93            singlets::{Singlet, Singlets},
94            stream::{errors::CorrelatedStreamError, CorrelatedStream, DirectStream},
95        },
96        random::Random,
97        utils::TryFuture,
98    };
99
100    type Fq = ScalarField<Curve25519Ristretto>;
101    type TestPB = Singlets<Fq, U2>;
102    type TestItem = Singlet<Fq>;
103
104    /// Records the size of every `run_for` call and sleeps to widen any reordering window.
105    struct MockGen {
106        rng: StdRng,
107        calls: Arc<Mutex<Vec<usize>>>,
108    }
109
110    impl CorrelationGenerator<TestPB> for MockGen {
111        type Net = ();
112        type Error = CorrelatedStreamError;
113
114        fn run(&mut self, _net: &mut ()) -> impl TryFuture<Ok = TestPB, Error = Self::Error> {
115            async move { Err(CorrelatedStreamError::StreamClosed) }
116        }
117
118        fn run_for(
119            &mut self,
120            n: usize,
121            _net: &mut (),
122        ) -> impl TryFuture<Ok = Vec<TestItem>, Error = Self::Error> {
123            async move {
124                self.calls.lock().unwrap().push(n);
125                tokio::time::sleep(Duration::from_millis(10)).await;
126                Ok(Singlet::<Fq>::random_n::<Vec<_>>(&mut self.rng, n))
127            }
128        }
129    }
130
131    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
132    async fn next_n_dispatches_in_call_order() {
133        let calls = Arc::new(Mutex::new(Vec::new()));
134        let gen = MockGen {
135            rng: StdRng::from_seed([0u8; 32]),
136            calls: calls.clone(),
137        };
138        let stream = DirectStream::new(gen, ());
139        // Issue back-to-back without awaiting in between; dispatch must still be FIFO.
140        let f1 = stream.next_n(1).unwrap();
141        let f2 = stream.next_n(2).unwrap();
142        let f3 = stream.next_n(3).unwrap();
143        let (a, b, c) = tokio::join!(f1, f2, f3);
144        assert_eq!(a.unwrap().len(), 1);
145        assert_eq!(b.unwrap().len(), 2);
146        assert_eq!(c.unwrap().len(), 3);
147        assert_eq!(*calls.lock().unwrap(), vec![1, 2, 3]);
148    }
149}