use futures_core::stream::TryStream;
use pin_project::unsafe_project;
use std::{
marker::PhantomData,
pin::Pin,
task::{Context as TaskContext, Poll},
};
use {ErrorCompat, IntoError};
pub trait TryStreamExt: TryStream + Sized {
fn context<C, E>(self, context: C) -> Context<Self, C, E>
where
C: IntoError<E, Source = Self::Error> + Clone,
E: std::error::Error + ErrorCompat;
fn with_context<F, C, E>(self, context: F) -> WithContext<Self, F, E>
where
F: FnMut() -> C,
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat;
}
impl<St> TryStreamExt for St
where
St: TryStream,
{
fn context<C, E>(self, context: C) -> Context<Self, C, E>
where
C: IntoError<E, Source = Self::Error> + Clone,
E: std::error::Error + ErrorCompat,
{
Context {
inner: self,
context,
_e: PhantomData,
}
}
fn with_context<F, C, E>(self, context: F) -> WithContext<Self, F, E>
where
F: FnMut() -> C,
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat,
{
WithContext {
inner: self,
context,
_e: PhantomData,
}
}
}
#[unsafe_project(Unpin)]
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct Context<St, C, E> {
#[pin]
inner: St,
context: C,
_e: PhantomData<E>,
}
impl<St, C, E> TryStream for Context<St, C, E>
where
St: TryStream,
C: IntoError<E, Source = St::Error> + Clone,
E: std::error::Error + ErrorCompat,
{
type Ok = St::Ok;
type Error = E;
fn try_poll_next(
self: Pin<&mut Self>,
ctx: &mut TaskContext,
) -> Poll<Option<Result<Self::Ok, Self::Error>>> {
let this = self.project();
let inner = this.inner;
let context = this.context;
match inner.try_poll_next(ctx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(v))) => Poll::Ready(Some(Ok(v))),
Poll::Ready(Some(Err(error))) => {
let error = context.clone().into_error(error);
Poll::Ready(Some(Err(error)))
}
}
}
}
#[unsafe_project(Unpin)]
#[derive(Debug)]
#[must_use = "streams do nothing unless polled"]
pub struct WithContext<St, F, E> {
#[pin]
inner: St,
context: F,
_e: PhantomData<E>,
}
impl<St, F, C, E> TryStream for WithContext<St, F, E>
where
St: TryStream,
F: FnMut() -> C,
C: IntoError<E, Source = St::Error>,
E: std::error::Error + ErrorCompat,
{
type Ok = St::Ok;
type Error = E;
fn try_poll_next(
self: Pin<&mut Self>,
ctx: &mut TaskContext,
) -> Poll<Option<Result<Self::Ok, Self::Error>>> {
let this = self.project();
let inner = this.inner;
let context = this.context;
match inner.try_poll_next(ctx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(v))) => Poll::Ready(Some(Ok(v))),
Poll::Ready(Some(Err(error))) => {
let error = context().into_error(error);
Poll::Ready(Some(Err(error)))
}
}
}
}