use core::{
future::Future,
mem,
pin::Pin,
task::{Context, Poll, ready},
};
use futures_core::{FusedFuture, FusedStream, TryStream};
use pin_project_lite::pin_project;
use crate::Report;
pin_project! {
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct TryCollectReports<S, A, C> {
#[pin]
stream: S,
output: Result<A, Report<[C]>>,
context_len: usize,
context_bound: usize
}
}
impl<S, A, C> TryCollectReports<S, A, C>
where
S: TryStream,
A: Default + Extend<S::Ok>,
{
fn new(stream: S, bound: Option<usize>) -> Self {
Self {
stream,
output: Ok(Default::default()),
context_len: 0,
context_bound: bound.unwrap_or(usize::MAX),
}
}
}
impl<S, A, C> FusedFuture for TryCollectReports<S, A, C>
where
S: TryStream<Error: Into<Report<[C]>>> + FusedStream,
A: Default + Extend<S::Ok>,
{
fn is_terminated(&self) -> bool {
self.stream.is_terminated()
}
}
impl<S, A, C> Future for TryCollectReports<S, A, C>
where
S: TryStream<Error: Into<Report<[C]>>>,
A: Default + Extend<S::Ok>,
{
type Output = Result<A, Report<[C]>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
let value = loop {
if *this.context_len >= *this.context_bound {
break mem::replace(this.output, Ok(A::default()));
}
let next = ready!(this.stream.as_mut().try_poll_next(cx));
match (next, &mut *this.output) {
(Some(Ok(value)), Ok(output)) => {
output.extend(core::iter::once(value));
}
(Some(Ok(_)), Err(_)) => {
}
(Some(Err(error)), output @ Ok(_)) => {
*output = Err(error.into());
*this.context_len += 1;
}
(Some(Err(error)), Err(report)) => {
report.append(error.into());
*this.context_len += 1;
}
(None, output) => {
break mem::replace(output, Ok(A::default()));
}
}
};
Poll::Ready(value)
}
}
pub trait TryReportStreamExt<C>: TryStream<Error: Into<Report<[C]>>> {
fn try_collect_reports<A>(self) -> TryCollectReports<Self, A, C>
where
A: Default + Extend<Self::Ok>,
Self: Sized,
{
TryCollectReports::new(self, None)
}
fn try_collect_reports_bounded<A>(self, bound: usize) -> TryCollectReports<Self, A, C>
where
A: Default + Extend<Self::Ok>,
Self: Sized,
{
TryCollectReports::new(self, Some(bound))
}
}
impl<S, C> TryReportStreamExt<C> for S where S: TryStream<Error: Into<Report<[C]>>> {}