use std::{
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug)]
pub struct WithCancelSignal<F: Future, C: Future> {
future: Pin<Box<F>>,
cancel: Pin<Box<C>>,
}
impl<F, C> Future for WithCancelSignal<F, C>
where
F: Future,
C: Future,
{
type Output = Result<F::Output, C::Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Ready(o) = Pin::new(&mut self.future).poll(cx) {
return Poll::Ready(Ok(o));
}
if let Poll::Ready(o) = Pin::new(&mut self.cancel).poll(cx) {
return Poll::Ready(Err(o));
}
Poll::Pending
}
}
pub trait FutureExt: Future + Sized {
fn with_cancel_signal<C: Future>(self, cancel: C) -> WithCancelSignal<Self, C> {
WithCancelSignal {
future: Box::pin(self),
cancel: Box::pin(cancel),
}
}
}
impl<T: Future + Sized> FutureExt for T {}
pub trait IntoFutureWithArgs<A, F: Future> {
fn into_future_with_args(self, args: A) -> F;
}
impl<T, A, F> IntoFutureWithArgs<A, F> for T
where
T: FnOnce(A) -> F,
F: Future,
{
fn into_future_with_args(self, args: A) -> F {
self(args)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn with_cancel_signal() {
use std::time::Duration;
use tokio::time::sleep;
let cancel = async move { sleep(Duration::from_millis(100)).await };
let future = async move { sleep(Duration::from_millis(200)).await };
assert!(future.with_cancel_signal(cancel).await.is_err());
let cancel = async move { sleep(Duration::from_millis(100)).await };
let future = async move { sleep(Duration::from_millis(50)).await };
assert!(future.with_cancel_signal(cancel).await.is_ok());
}
#[tokio::test]
async fn into_future_with_args() {
async fn into_signal(num: i32) -> i32 {
num
}
async fn add((a, b): (i32, i32)) -> i32 {
a + b
}
async fn wait_signal<A, F: Future>(
signal: A,
into: impl IntoFutureWithArgs<A, F>,
) -> F::Output {
into.into_future_with_args(signal).await
}
assert_eq!(into_signal.into_future_with_args(42).await, 42);
assert_eq!(add.into_future_with_args((40, 2)).await, 42);
assert_eq!(
(|num| async move { num }).into_future_with_args(42).await,
42
);
assert_eq!(
(|(a, b)| async move { a + b })
.into_future_with_args((40, 2))
.await,
42
);
assert_eq!((async |num| num).into_future_with_args(42).await, 42);
assert_eq!(
(async |(a, b)| a + b).into_future_with_args((40, 2)).await,
42
);
assert_eq!(wait_signal(42, into_signal).await, 42);
assert_eq!(wait_signal((40, 2), add).await, 42);
assert_eq!(wait_signal(42, |num| async move { num }).await, 42);
assert_eq!(
wait_signal((40, 2), |(a, b)| async move { a + b }).await,
42
);
assert_eq!(wait_signal(42, async |num| num).await, 42);
assert_eq!(wait_signal((40, 2), async |(a, b)| a + b).await, 42);
}
}