use futures_core::future::TryFuture;
use pin_project::unsafe_project;
use std::{
marker::PhantomData,
pin::Pin,
task::{Context as TaskContext, Poll},
};
use {ErrorCompat, IntoError};
pub trait TryFutureExt: TryFuture + Sized {
fn context<C, E>(self, context: C) -> Context<Self, C, E>
where
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat;
fn with_context<F, C, E>(self, context: F) -> WithContext<Self, F, E>
where
F: FnOnce() -> C,
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat;
}
impl<Fut> TryFutureExt for Fut
where
Fut: TryFuture,
{
fn context<C, E>(self, context: C) -> Context<Self, C, E>
where
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat,
{
Context {
inner: self,
context: Some(context),
_e: PhantomData,
}
}
fn with_context<F, C, E>(self, context: F) -> WithContext<Self, F, E>
where
F: FnOnce() -> C,
C: IntoError<E, Source = Self::Error>,
E: std::error::Error + ErrorCompat,
{
WithContext {
inner: self,
context: Some(context),
_e: PhantomData,
}
}
}
#[unsafe_project(Unpin)]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct Context<Fut, C, E> {
#[pin]
inner: Fut,
context: Option<C>,
_e: PhantomData<E>,
}
impl<Fut, C, E> TryFuture for Context<Fut, C, E>
where
Fut: TryFuture,
C: IntoError<E, Source = Fut::Error>,
E: std::error::Error + ErrorCompat,
{
type Ok = Fut::Ok;
type Error = E;
fn try_poll(
self: Pin<&mut Self>,
ctx: &mut TaskContext,
) -> Poll<Result<Self::Ok, Self::Error>> {
let this = self.project();
let inner = this.inner;
let context = this.context;
inner.try_poll(ctx).map_err(|error| {
context
.take()
.expect("Cannot poll Context after it resolves")
.into_error(error)
})
}
}
#[unsafe_project(Unpin)]
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct WithContext<Fut, F, E> {
#[pin]
inner: Fut,
context: Option<F>,
_e: PhantomData<E>,
}
impl<Fut, F, C, E> TryFuture for WithContext<Fut, F, E>
where
Fut: TryFuture,
F: FnOnce() -> C,
C: IntoError<E, Source = Fut::Error>,
E: std::error::Error + ErrorCompat,
{
type Ok = Fut::Ok;
type Error = E;
fn try_poll(
self: Pin<&mut Self>,
ctx: &mut TaskContext,
) -> Poll<Result<Self::Ok, Self::Error>> {
let this = self.project();
let inner = this.inner;
let context = this.context;
inner.try_poll(ctx).map_err(|error| {
let context = context
.take()
.expect("Cannot poll WithContext after it resolves");
context().into_error(error)
})
}
}