arcium-primitives 0.6.0

Arcium primitives
Documentation
use tokio::sync::{
    mpsc::{self, UnboundedSender},
    oneshot,
};

use crate::correlated_randomness::{
    generator::CorrelationGenerator,
    stream::{errors::CorrelatedStreamError, futures::Next, CorrelatedStream, NextVec},
    CorrelatedBatch,
};

/// A request for `n` elements, paired with the channel resolving the caller's [`NextVec`].
type Request<PB, E> = (
    usize,
    oneshot::Sender<Result<Vec<<PB as IntoIterator>::Item>, E>>,
);

/// A minimal wrapper around a [`CorrelationGenerator`] that implements [`CorrelatedStream`].
///
/// Unlike [`super::buffered::BufferedStream`], `DirectStream` has no internal buffer or
/// prefetch logic — every [`next_n`](CorrelatedStream::next_n) call triggers a fresh generation
/// run via the underlying generator.
///
/// `next_n` enqueues requests synchronously; a single owner task serves them strictly FIFO, so
/// dispatch order always matches call order (as required by the [`CorrelatedStream`] contract).
pub struct DirectStream<PB: CorrelatedBatch, G: CorrelationGenerator<PB>> {
    request_tx: UnboundedSender<Request<PB, G::Error>>,
}

impl<PB, G> DirectStream<PB, G>
where
    PB: CorrelatedBatch,
    G: CorrelationGenerator<PB> + Send + 'static,
    G::Net: Send + 'static,
{
    /// Creates a `DirectStream` from a generator and its network.
    ///
    /// Spawns the owner task serving requests, so this must be called within a tokio runtime.
    /// The task exits once the stream is dropped and all queued requests have been served.
    pub fn new(mut generator: G, mut net: G::Net) -> Self {
        let (request_tx, mut request_rx) = mpsc::unbounded_channel::<Request<PB, G::Error>>();
        tokio::spawn(async move {
            while let Some((n_elements, completion)) = request_rx.recv().await {
                let result = generator.run_for(n_elements, &mut net).await;
                let _ = completion.send(result);
            }
        });
        Self { request_tx }
    }
}

impl<PB, G> CorrelatedStream<PB::Item> for DirectStream<PB, G>
where
    PB: CorrelatedBatch,
    G: CorrelationGenerator<PB> + Send + 'static,
    G::Net: Send + 'static,
    G::Error: From<CorrelatedStreamError>,
{
    type Error = G::Error;

    fn next_n(
        &self,
        n_elements: usize,
    ) -> Result<NextVec<PB::Item, G::Error>, CorrelatedStreamError> {
        if n_elements == 0 {
            return Ok(NextVec::default());
        }
        let (tx, rx) = oneshot::channel();
        self.request_tx
            .send((n_elements, tx))
            .map_err(|_| CorrelatedStreamError::StreamClosed)?;
        Ok(NextVec {
            future: Next(rx),
            size: n_elements,
        })
    }
}

#[cfg(test)]
mod tests {
    use std::{
        sync::{Arc, Mutex},
        time::Duration,
    };

    use rand::{rngs::StdRng, SeedableRng};
    use typenum::U2;

    use crate::{
        algebra::elliptic_curve::{Curve25519Ristretto, ScalarField},
        correlated_randomness::{
            generator::CorrelationGenerator,
            singlets::{Singlet, Singlets},
            stream::{errors::CorrelatedStreamError, CorrelatedStream, DirectStream},
        },
        random::Random,
        utils::TryFuture,
    };

    type Fq = ScalarField<Curve25519Ristretto>;
    type TestPB = Singlets<Fq, U2>;
    type TestItem = Singlet<Fq>;

    /// Records the size of every `run_for` call and sleeps to widen any reordering window.
    struct MockGen {
        rng: StdRng,
        calls: Arc<Mutex<Vec<usize>>>,
    }

    impl CorrelationGenerator<TestPB> for MockGen {
        type Net = ();
        type Error = CorrelatedStreamError;

        fn run(&mut self, _net: &mut ()) -> impl TryFuture<Ok = TestPB, Error = Self::Error> {
            async move { Err(CorrelatedStreamError::StreamClosed) }
        }

        fn run_for(
            &mut self,
            n: usize,
            _net: &mut (),
        ) -> impl TryFuture<Ok = Vec<TestItem>, Error = Self::Error> {
            async move {
                self.calls.lock().unwrap().push(n);
                tokio::time::sleep(Duration::from_millis(10)).await;
                Ok(Singlet::<Fq>::random_n::<Vec<_>>(&mut self.rng, n))
            }
        }
    }

    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn next_n_dispatches_in_call_order() {
        let calls = Arc::new(Mutex::new(Vec::new()));
        let gen = MockGen {
            rng: StdRng::from_seed([0u8; 32]),
            calls: calls.clone(),
        };
        let stream = DirectStream::new(gen, ());
        // Issue back-to-back without awaiting in between; dispatch must still be FIFO.
        let f1 = stream.next_n(1).unwrap();
        let f2 = stream.next_n(2).unwrap();
        let f3 = stream.next_n(3).unwrap();
        let (a, b, c) = tokio::join!(f1, f2, f3);
        assert_eq!(a.unwrap().len(), 1);
        assert_eq!(b.unwrap().len(), 2);
        assert_eq!(c.unwrap().len(), 3);
        assert_eq!(*calls.lock().unwrap(), vec![1, 2, 3]);
    }
}