use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use bytesbuf::BytesView;
use http_body::{Body, Frame, SizeHint};
use pin_project::pin_project;
use tick::{Clock, Delay};
use crate::{HttpError, Result};
#[pin_project]
pub(crate) struct TimeoutBody<B> {
#[pin]
inner: Option<B>,
timeout: Duration,
clock: Clock,
current_delay: Option<Delay>,
}
impl<B> TimeoutBody<B> {
pub(crate) fn new(inner: B, timeout: Duration, clock: &Clock) -> Self {
Self {
inner: Some(inner),
timeout,
clock: clock.clone(),
current_delay: None,
}
}
}
impl<B> Body for TimeoutBody<B>
where
B: Body<Data = BytesView, Error = HttpError>,
{
type Data = BytesView;
type Error = HttpError;
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>>>> {
let mut this = self.project();
let Some(inner) = this.inner.as_mut().as_pin_mut() else {
return Poll::Ready(None);
};
if let Poll::Ready(result) = inner.poll_frame(cx) {
*this.current_delay = None;
return Poll::Ready(result);
}
let delay = this.current_delay.get_or_insert_with(|| Delay::new(this.clock, *this.timeout));
if Pin::new(delay).poll(cx).is_ready() {
*this.current_delay = None;
this.inner.set(None);
return Poll::Ready(Some(Err(HttpError::timeout_for_body(*this.timeout))));
}
Poll::Pending
}
fn size_hint(&self) -> SizeHint {
self.inner.as_ref().map(http_body::Body::size_hint).unwrap_or_default()
}
fn is_end_stream(&self) -> bool {
self.inner.as_ref().is_none_or(http_body::Body::is_end_stream)
}
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use std::pin::Pin;
use std::sync::atomic::{AtomicU32, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use bytesbuf::BytesView;
use bytesbuf::mem::GlobalPool;
use futures::executor::block_on;
use http_body::{Body, Frame};
use tick::ClockControl;
use crate::testing::create_stream_body;
use crate::{HttpBodyBuilder, HttpBodyOptions, HttpError, Result};
#[test]
fn stream_body_returns_data_before_timeout() {
let clock = ClockControl::new().to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let chunks: Vec<Result<BytesView>> = vec![Ok(BytesView::copied_from_slice(b"streamed data", &builder))];
let options = HttpBodyOptions::default().timeout(Duration::from_secs(30));
let body = builder.stream(futures::stream::iter(chunks), &options);
let bytes = block_on(body.into_bytes()).unwrap();
assert_eq!(bytes, b"streamed data");
}
#[test]
fn stream_body_times_out_when_pending() {
let clock = ClockControl::new().auto_advance_timers(true).to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_millis(100));
let body = builder.body(PendingBody, &options);
let err = block_on(body.into_bytes()).unwrap_err();
assert!(
err.to_string().contains("body data was not fully received"),
"expected body timeout error, got: {err}"
);
}
#[test]
fn body_timeout_chains_with_buffer_limit() {
let clock = ClockControl::new().auto_advance_timers(true).to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock).with_options(HttpBodyOptions::default().buffer_limit(1024));
assert_eq!(builder.options, HttpBodyOptions::default().buffer_limit(1024));
let options = HttpBodyOptions::default().timeout(Duration::from_secs(30));
let body = builder.body(PendingBody, &options);
let err = block_on(body.into_bytes()).unwrap_err();
assert!(err.to_string().contains("body data was not fully received"));
}
#[test]
fn size_hint_delegates_to_inner() {
let builder = HttpBodyBuilder::new_fake();
let body = create_stream_body(&builder, b"hello", &HttpBodyOptions::default());
let hint = body.size_hint();
assert_eq!(hint.lower(), 0);
}
#[test]
fn size_hint_delegates_through_timeout_body() {
let clock = ClockControl::new().to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_secs(30));
let body = builder.body(
http_body_util::Full::new(BytesView::copied_from_slice(b"hello", &builder)),
&options,
);
let hint = body.size_hint();
assert_eq!(hint.lower(), 5);
assert_eq!(hint.upper(), Some(5));
}
#[test]
fn is_end_stream_true_when_inner_is_empty() {
let clock = ClockControl::new().to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_secs(1));
let body = builder.body(http_body_util::Empty::new(), &options);
assert!(body.is_end_stream());
}
#[test]
fn is_end_stream_false_when_inner_has_data() {
let clock = ClockControl::new().to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_secs(1));
let body = builder.body(http_body_util::Full::new(BytesView::copied_from_slice(b"data", &builder)), &options);
assert!(!body.is_end_stream());
}
#[test]
fn poll_frame_returns_data_through_timeout_body() {
let clock = ClockControl::new().to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_secs(30));
let body = builder.body(
http_body_util::Full::new(BytesView::copied_from_slice(b"payload", &builder)),
&options,
);
let bytes = block_on(body.into_bytes()).unwrap();
assert_eq!(bytes, b"payload");
}
#[test]
fn poll_frame_times_out_when_pending_with_short_timeout() {
let clock = ClockControl::new().auto_advance_timers(true).to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_millis(1));
let body = builder.body(PendingBody, &options);
let err = block_on(body.into_bytes()).unwrap_err();
assert!(
err.to_string().contains("body data was not fully received"),
"expected body timeout error, got: {err}"
);
}
#[test]
fn poll_frame_returns_data_even_when_clock_advanced_past_timeout() {
let control = ClockControl::new();
let clock = control.to_clock();
let builder = HttpBodyBuilder::new(GlobalPool::new(), &clock);
let options = HttpBodyOptions::default().timeout(Duration::from_millis(1));
let body = builder.body(
http_body_util::Full::new(BytesView::copied_from_slice(b"ready data", &builder)),
&options,
);
control.advance(Duration::from_mins(1));
let bytes = block_on(body.into_bytes()).unwrap();
assert_eq!(bytes, b"ready data");
}
#[test]
fn poll_frame_times_out_via_delay_when_inner_body_advances_clock() {
let control = ClockControl::new();
let clock = control.to_clock();
let timeout = Duration::from_millis(100);
let body = ClockAdvancingBody {
control,
advance_by: Duration::from_mins(1),
poll_count: AtomicU32::new(0),
};
let timeout_body = super::TimeoutBody::new(body, timeout, &clock);
let http_body = HttpBodyBuilder::new_fake().body(timeout_body, &HttpBodyOptions::default());
let err = block_on(http_body.into_bytes()).unwrap_err();
assert!(
err.to_string().contains("body data was not fully received"),
"expected body timeout error, got: {err}"
);
}
#[test]
fn poll_frame_returns_error_after_timeout() {
let clock = ClockControl::new().auto_advance_timers(true).to_clock();
let timeout = Duration::from_millis(50);
let mut timeout_body = super::TimeoutBody::new(PendingBody, timeout, &clock);
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
assert!(Pin::new(&mut timeout_body).poll_frame(&mut cx).is_pending());
let result = Pin::new(&mut timeout_body).poll_frame(&mut cx);
assert!(matches!(result, Poll::Ready(Some(Err(_)))));
let result = Pin::new(&mut timeout_body).poll_frame(&mut cx);
assert!(matches!(result, Poll::Ready(None)));
}
struct PendingBody;
impl Body for PendingBody {
type Data = BytesView;
type Error = HttpError;
fn poll_frame(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>>>> {
Poll::Pending
}
}
struct ClockAdvancingBody {
control: ClockControl,
advance_by: Duration,
poll_count: AtomicU32,
}
impl Body for ClockAdvancingBody {
type Data = BytesView;
type Error = HttpError;
fn poll_frame(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Result<Frame<Self::Data>>>> {
let count = self.poll_count.fetch_add(1, Ordering::Relaxed);
if count >= 1 {
self.control.advance(self.advance_by);
}
cx.waker().wake_by_ref();
Poll::Pending
}
}
}