use std::{
future::Future,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use futures::{
future::{BoxFuture, Shared},
FutureExt,
};
use tokio::sync::oneshot;
use crate::correlated_randomness::stream::CorrelatedStreamError;
type SharedBatch<E> = Shared<BoxFuture<'static, Result<(), E>>>;
type Slots<P> = Arc<Mutex<Vec<Option<P>>>>;
enum NextInner<P, E> {
Channel(oneshot::Receiver<Result<P, E>>),
Boxed(BoxFuture<'static, Result<P, E>>),
}
pub struct Next<P, E> {
inner: NextInner<P, E>,
}
#[allow(non_snake_case)]
pub fn Next<P, E>(receiver: oneshot::Receiver<Result<P, E>>) -> Next<P, E> {
Next {
inner: NextInner::Channel(receiver),
}
}
impl<P, E> Next<P, E> {
pub fn from_future(future: impl Future<Output = Result<P, E>> + Send + 'static) -> Self {
Next {
inner: NextInner::Boxed(Box::pin(future)),
}
}
}
impl<P, E: From<CorrelatedStreamError>> Future for Next<P, E> {
type Output = Result<P, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match &mut self.get_mut().inner {
NextInner::Channel(rx) => Pin::new(rx).poll(cx).map(|r| {
r.unwrap_or_else(|e| Err(CorrelatedStreamError::RecvError(e.to_string()).into()))
}),
NextInner::Boxed(future) => future.as_mut().poll(cx),
}
}
}
pub struct NextVec<P, E> {
pub future: Next<Vec<P>, E>,
pub size: usize,
}
impl<P, E> NextVec<P, E> {
pub fn len(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
impl<P, E> Default for NextVec<P, E> {
fn default() -> Self {
let (tx, rx) = oneshot::channel();
let _ = tx.send(Ok(vec![]));
Self {
future: Next(rx),
size: 0,
}
}
}
impl<P, E: From<CorrelatedStreamError>> Future for NextVec<P, E> {
type Output = Result<Vec<P>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.get_mut().future).poll(cx)
}
}
pub struct NextVecIterator<P, E> {
batch: SharedBatch<E>,
slots: Slots<P>,
index: usize,
size: usize,
}
impl<P, E> IntoIterator for NextVec<P, E>
where
P: Send + 'static,
E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
type Item = Next<P, E>;
type IntoIter = NextVecIterator<P, E>;
fn into_iter(self) -> Self::IntoIter {
let size = self.size;
let slots: Slots<P> = Arc::new(Mutex::new(Vec::new()));
let batch_slots = slots.clone();
let batch = self
.map(move |items| {
*batch_slots.lock().unwrap() = items?.into_iter().map(Some).collect();
Ok(())
})
.boxed()
.shared();
NextVecIterator {
batch,
slots,
index: 0,
size,
}
}
}
impl<P, E> Iterator for NextVecIterator<P, E>
where
P: Send + 'static,
E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
type Item = Next<P, E>;
fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.size {
return None;
}
let index = self.index;
self.index += 1;
let batch = self.batch.clone();
let slots = self.slots.clone();
let size = self.size;
let future = batch
.map(move |e| {
e?;
slots
.lock()
.unwrap()
.get_mut(index)
.and_then(Option::take)
.ok_or_else(|| {
E::from(CorrelatedStreamError::RecvError(format!(
"batch index {index} out of bounds (len {size})",
)))
})
})
.boxed();
Some(Next {
inner: NextInner::Boxed(future),
})
}
}
impl<P, E> NextVecIterator<P, E>
where
P: Send + 'static,
E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
pub fn next_n(&mut self, n: usize) -> Option<NextVec<P, E>> {
if self.index + n > self.size {
return None;
}
let index = self.index;
self.index += n;
let batch = self.batch.clone();
let slots = self.slots.clone();
let size = self.size;
let future = batch
.map(move |e| {
e?;
slots
.lock()
.unwrap()
.iter_mut()
.skip(index)
.take(n)
.map(Option::take)
.collect::<Option<Vec<_>>>()
.ok_or_else(|| {
E::from(CorrelatedStreamError::RecvError(format!(
"Request (index {index}, n {n}) out of bounds (len {size})",
)))
})
})
.boxed();
Some(NextVec {
future: Next {
inner: NextInner::Boxed(future),
},
size: n,
})
}
}
impl<P, E> ExactSizeIterator for NextVecIterator<P, E>
where
P: Send + 'static,
E: Clone + Send + Sync + 'static + From<CorrelatedStreamError>,
{
fn len(&self) -> usize {
self.size.saturating_sub(self.index)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn next_vec(
result: Result<Vec<u32>, CorrelatedStreamError>,
size: usize,
) -> NextVec<u32, CorrelatedStreamError> {
let (tx, rx) = oneshot::channel();
let _ = tx.send(result);
NextVec {
future: Next(rx),
size,
}
}
#[tokio::test]
async fn items_resolve_by_index() {
let mut items: Vec<_> = next_vec(Ok(vec![10, 20, 30]), 3).into_iter().collect();
assert_eq!(items.len(), 3);
assert_eq!(items.pop().unwrap().await.unwrap(), 30);
assert_eq!(items.pop().unwrap().await.unwrap(), 20);
assert_eq!(items.pop().unwrap().await.unwrap(), 10);
}
#[tokio::test]
async fn error_fans_out_to_all_elements() {
let mut iter = next_vec(Err(CorrelatedStreamError::StreamClosed), 2).into_iter();
for _ in 0..2 {
let result = iter.next().unwrap().await;
assert_eq!(result, Err(CorrelatedStreamError::StreamClosed));
}
assert!(iter.next().is_none());
}
#[tokio::test]
async fn dropped_sender_yields_error() {
let (tx, rx) = oneshot::channel::<Result<Vec<u32>, CorrelatedStreamError>>();
drop(tx);
let nv = NextVec {
future: Next(rx),
size: 1,
};
let result = nv.into_iter().next().unwrap().await;
assert!(matches!(result, Err(CorrelatedStreamError::RecvError(_))));
}
#[tokio::test]
async fn short_batch_yields_out_of_bounds() {
let mut iter = next_vec(Ok(vec![1]), 3).into_iter();
assert_eq!(iter.next().unwrap().await.unwrap(), 1);
for _ in 0..2 {
let result = iter.next().unwrap().await;
assert!(matches!(result, Err(CorrelatedStreamError::RecvError(_))));
}
}
#[test]
fn works_outside_tokio_runtime() {
let mut iter = next_vec(Ok(vec![7, 8]), 2).into_iter();
assert_eq!(iter.len(), 2);
let first = iter.next().unwrap();
assert_eq!(iter.len(), 1);
let second = iter.next().unwrap();
assert!(iter.next().is_none());
assert_eq!(futures::executor::block_on(second).unwrap(), 8);
assert_eq!(futures::executor::block_on(first).unwrap(), 7);
}
#[test]
fn default_next_vec_is_empty_and_ready() {
let nv = NextVec::<u32, CorrelatedStreamError>::default();
assert_eq!(nv.size, 0);
assert_eq!(futures::executor::block_on(nv), Ok(vec![]));
let mut iter = NextVec::<u32, CorrelatedStreamError>::default().into_iter();
assert!(iter.next().is_none());
}
}