arcium-primitives 0.6.0

Arcium primitives
Documentation
use std::{
    future::Future,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};

use futures::{
    future::{BoxFuture, Shared},
    FutureExt,
};
use tokio::sync::oneshot;

use crate::correlated_randomness::stream::CorrelatedStreamError;

/// Batch driver shared by all per-element futures; resolves once the slots are populated.
type SharedBatch<E> = Shared<BoxFuture<'static, Result<(), E>>>;
/// Element slots, populated when the batch resolves; each slot is taken exactly once.
type Slots<P> = Arc<Mutex<Vec<Option<P>>>>;

enum NextInner<P, E> {
    Channel(oneshot::Receiver<Result<P, E>>),
    Boxed(BoxFuture<'static, Result<P, E>>),
}

/// A future resolving to a single preprocessing item.
pub struct Next<P, E> {
    inner: NextInner<P, E>,
}

/// Tuple-struct-style constructor, kept so existing `Next(rx)` call sites compile.
#[allow(non_snake_case)]
pub fn Next<P, E>(receiver: oneshot::Receiver<Result<P, E>>) -> Next<P, E> {
    Next {
        inner: NextInner::Channel(receiver),
    }
}

impl<P, E> Next<P, E> {
    /// Wraps an arbitrary future, e.g. to map another `Next`'s error type
    /// without spawning a task.
    pub fn from_future(future: impl Future<Output = Result<P, E>> + Send + 'static) -> Self {
        Next {
            inner: NextInner::Boxed(Box::pin(future)),
        }
    }
}

impl<P, E: From<CorrelatedStreamError>> Future for Next<P, E> {
    type Output = Result<P, E>;
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match &mut self.get_mut().inner {
            NextInner::Channel(rx) => Pin::new(rx).poll(cx).map(|r| {
                r.unwrap_or_else(|e| Err(CorrelatedStreamError::RecvError(e.to_string()).into()))
            }),
            NextInner::Boxed(future) => future.as_mut().poll(cx),
        }
    }
}

/// A future resolving to a vector of preprocessing items, alongside the request size.
pub struct NextVec<P, E> {
    pub future: Next<Vec<P>, E>,
    pub size: usize,
}

impl<P, E> NextVec<P, E> {
    /// Returns the number of items in this future vec, which is known at construction time.
    pub fn len(&self) -> usize {
        self.size
    }

    /// Returns true if this future vec is empty, which is known at construction time.
    pub fn is_empty(&self) -> bool {
        self.size == 0
    }
}

impl<P, E> Default for NextVec<P, E> {
    fn default() -> Self {
        let (tx, rx) = oneshot::channel();
        // Send a default value immediately to avoid hanging if this future is accidentally awaited.
        let _ = tx.send(Ok(vec![]));
        Self {
            future: Next(rx),
            size: 0,
        }
    }
}

impl<P, E: From<CorrelatedStreamError>> Future for NextVec<P, E> {
    type Output = Result<Vec<P>, E>;
    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        Pin::new(&mut self.get_mut().future).poll(cx)
    }
}

/// An iterator that converts a `NextVec` future into `Next` futures. They all become ready at the
/// same time, but it abstracts away the batch and allows the caller to `await` items one by one as
/// if they were produced independently.
///
/// The batch is wrapped in a [`Shared`] future driven inline by whichever element future is
/// polled first: no task is spawned and no async runtime is required. Each element slot is an
/// `Option<P>` consumed via [`Option::take`] exactly once, avoiding any clone of `P`. Only `E`
/// requires `Clone` to fan out a single error to all per-element futures.
pub struct NextVecIterator<P, E> {
    batch: SharedBatch<E>,
    slots: Slots<P>,
    index: usize,
    size: usize,
}

impl<P, E> IntoIterator for NextVec<P, E>
where
    P: Send + 'static,
    E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
    type Item = Next<P, E>;
    type IntoIter = NextVecIterator<P, E>;

    fn into_iter(self) -> Self::IntoIter {
        let size = self.size;
        let slots: Slots<P> = Arc::new(Mutex::new(Vec::new()));
        let batch_slots = slots.clone();
        let batch = self
            .map(move |items| {
                *batch_slots.lock().unwrap() = items?.into_iter().map(Some).collect();
                Ok(())
            })
            .boxed()
            .shared();
        NextVecIterator {
            batch,
            slots,
            index: 0,
            size,
        }
    }
}

impl<P, E> Iterator for NextVecIterator<P, E>
where
    P: Send + 'static,
    E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
    type Item = Next<P, E>;

    /// Returns the next `Next<P, E>` immediately. When awaited, it polls the shared batch future
    /// inline and then consumes (not clones) the element at this index.
    fn next(&mut self) -> Option<Self::Item> {
        if self.index >= self.size {
            return None;
        }
        let index = self.index;
        self.index += 1;
        let batch = self.batch.clone();
        let slots = self.slots.clone();
        let size = self.size;
        let future = batch
            .map(move |e| {
                e?;
                slots
                    .lock()
                    .unwrap()
                    .get_mut(index)
                    .and_then(Option::take)
                    .ok_or_else(|| {
                        E::from(CorrelatedStreamError::RecvError(format!(
                            "batch index {index} out of bounds (len {size})",
                        )))
                    })
            })
            .boxed();
        Some(Next {
            inner: NextInner::Boxed(future),
        })
    }
}

impl<P, E> NextVecIterator<P, E>
where
    P: Send + 'static,
    E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
    /// Returns a `NextVec<P, E>` immediately, a fallible future over a vector of size `n`.
    /// When awaited, it polls the shared batch future inline and then consumes
    ///  (not clones) n elements at the right index.
    pub fn next_n(&mut self, n: usize) -> Option<NextVec<P, E>> {
        if self.index + n > self.size {
            return None;
        }
        let index = self.index;
        self.index += n;
        let batch = self.batch.clone();
        let slots = self.slots.clone();
        let size = self.size;
        let future = batch
            .map(move |e| {
                e?;
                slots
                    .lock()
                    .unwrap()
                    .iter_mut()
                    .skip(index)
                    .take(n)
                    .map(Option::take)
                    .collect::<Option<Vec<_>>>()
                    .ok_or_else(|| {
                        E::from(CorrelatedStreamError::RecvError(format!(
                            "Request (index {index}, n {n}) out of bounds (len {size})",
                        )))
                    })
            })
            .boxed();
        Some(NextVec {
            future: Next {
                inner: NextInner::Boxed(future),
            },
            size: n,
        })
    }
}

impl<P, E> ExactSizeIterator for NextVecIterator<P, E>
where
    P: Send + 'static,
    E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
    fn len(&self) -> usize {
        self.size.saturating_sub(self.index)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn next_vec(
        result: Result<Vec<u32>, CorrelatedStreamError>,
        size: usize,
    ) -> NextVec<u32, CorrelatedStreamError> {
        let (tx, rx) = oneshot::channel();
        let _ = tx.send(result);
        NextVec {
            future: Next(rx),
            size,
        }
    }

    #[tokio::test]
    async fn items_resolve_by_index() {
        let mut items: Vec<_> = next_vec(Ok(vec![10, 20, 30]), 3).into_iter().collect();
        assert_eq!(items.len(), 3);
        // Await in reverse order: any element future drives the shared batch.
        assert_eq!(items.pop().unwrap().await.unwrap(), 30);
        assert_eq!(items.pop().unwrap().await.unwrap(), 20);
        assert_eq!(items.pop().unwrap().await.unwrap(), 10);
    }

    #[tokio::test]
    async fn error_fans_out_to_all_elements() {
        let mut iter = next_vec(Err(CorrelatedStreamError::StreamClosed), 2).into_iter();
        for _ in 0..2 {
            let result = iter.next().unwrap().await;
            assert_eq!(result, Err(CorrelatedStreamError::StreamClosed));
        }
        assert!(iter.next().is_none());
    }

    #[tokio::test]
    async fn dropped_sender_yields_error() {
        let (tx, rx) = oneshot::channel::<Result<Vec<u32>, CorrelatedStreamError>>();
        drop(tx);
        let nv = NextVec {
            future: Next(rx),
            size: 1,
        };
        let result = nv.into_iter().next().unwrap().await;
        assert!(matches!(result, Err(CorrelatedStreamError::RecvError(_))));
    }

    #[tokio::test]
    async fn short_batch_yields_out_of_bounds() {
        let mut iter = next_vec(Ok(vec![1]), 3).into_iter();
        assert_eq!(iter.next().unwrap().await.unwrap(), 1);
        for _ in 0..2 {
            let result = iter.next().unwrap().await;
            assert!(matches!(result, Err(CorrelatedStreamError::RecvError(_))));
        }
    }

    /// No tokio runtime anywhere: construction, iteration, and awaiting must not spawn tasks or
    /// touch a reactor.
    #[test]
    fn works_outside_tokio_runtime() {
        let mut iter = next_vec(Ok(vec![7, 8]), 2).into_iter();
        assert_eq!(iter.len(), 2);
        let first = iter.next().unwrap();
        assert_eq!(iter.len(), 1);
        let second = iter.next().unwrap();
        assert!(iter.next().is_none());
        assert_eq!(futures::executor::block_on(second).unwrap(), 8);
        assert_eq!(futures::executor::block_on(first).unwrap(), 7);
    }

    #[test]
    fn default_next_vec_is_empty_and_ready() {
        let nv = NextVec::<u32, CorrelatedStreamError>::default();
        assert_eq!(nv.size, 0);
        assert_eq!(futures::executor::block_on(nv), Ok(vec![]));
        let mut iter = NextVec::<u32, CorrelatedStreamError>::default().into_iter();
        assert!(iter.next().is_none());
    }
}