use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use super::sleep::Sleep;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimeoutError;
impl fmt::Display for TimeoutError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "operation timed out")
}
}
impl std::error::Error for TimeoutError {}
pub struct Timeout<F> {
future: F,
sleep: Sleep,
timed_out: bool,
}
impl<F> Timeout<F> {
#[inline]
pub fn new(future: F, duration: Duration) -> Self {
Self {
future,
sleep: Sleep::new(duration),
timed_out: false,
}
}
}
impl<F> Future for Timeout<F>
where
F: Future,
{
type Output = Result<F::Output, TimeoutError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if this.timed_out {
return Poll::Ready(Err(TimeoutError));
}
let mut future_pin = unsafe { Pin::new_unchecked(&mut this.future) };
match Future::poll(future_pin.as_mut(), cx) {
Poll::Ready(output) => return Poll::Ready(Ok(output)),
Poll::Pending => {}
}
let mut sleep_pin = unsafe { Pin::new_unchecked(&mut this.sleep) };
match sleep_pin.as_mut().poll(cx) {
Poll::Ready(()) => {
this.timed_out = true;
Poll::Ready(Err(TimeoutError))
}
Poll::Pending => Poll::Pending,
}
}
}
#[inline]
pub async fn timeout<T>(
duration: Duration,
future: impl Future<Output = T>,
) -> Result<T, TimeoutError> {
Timeout::new(future, duration).await
}