use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use bytes::Bytes;
use http::StatusCode;
use tower::Service;
use crate::{Body, BoxError};
#[derive(Clone, Default)]
pub(crate) struct HandleLayerErrorLayer;
impl<S> tower::Layer<S> for HandleLayerErrorLayer {
type Service = HandleLayerError<S>;
fn layer(&self, inner: S) -> Self::Service {
HandleLayerError(inner)
}
}
#[derive(Clone)]
pub(crate) struct HandleLayerError<S>(S);
impl<S, Req> Service<Req> for HandleLayerError<S>
where
S: Service<Req, Response = hyper::Response<Body>>,
S::Error: Into<BoxError>,
S::Future: Send + 'static,
{
type Response = hyper::Response<Body>;
type Error = std::convert::Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match self.0.poll_ready(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) => {
let e: BoxError = e.into();
tracing::error!(
"iroh-http: inner service poll_ready failed ({e}); \
treating as not-ready so the request timeout can close the connection"
);
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
fn call(&mut self, req: Req) -> Self::Future {
let fut = self.0.call(req);
Box::pin(async move {
match fut.await {
Ok(r) => Ok(r),
Err(e) => {
let e = e.into();
let status = if e.is::<tower::timeout::error::Elapsed>() {
StatusCode::REQUEST_TIMEOUT
} else if e.is::<tower::load_shed::error::Overloaded>() {
StatusCode::SERVICE_UNAVAILABLE
} else {
tracing::warn!("iroh-http: unexpected tower error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
};
let body_bytes: &'static [u8] = match status {
StatusCode::REQUEST_TIMEOUT => b"request timed out",
StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
_ => b"internal server error",
};
Ok(hyper::Response::builder()
.status(status)
.body(Body::full(Bytes::from_static(body_bytes)))
.expect("valid error response"))
}
}
})
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use std::future::{ready, Ready};
use std::task::{Context, Poll};
use super::*;
struct AlwaysErrorReady;
impl Service<hyper::Request<Body>> for AlwaysErrorReady {
type Response = hyper::Response<Body>;
type Error = BoxError;
type Future = Ready<Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Err("simulated inner poll_ready failure".into()))
}
fn call(&mut self, _req: hyper::Request<Body>) -> Self::Future {
ready(Ok(hyper::Response::new(Body::empty())))
}
}
#[test]
fn poll_ready_error_is_not_silently_swallowed() {
let inner = AlwaysErrorReady;
let mut svc: HandleLayerError<AlwaysErrorReady> = HandleLayerError(inner);
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
let result: Poll<Result<(), Infallible>> = svc.poll_ready(&mut cx);
assert!(
matches!(result, Poll::Pending),
"poll_ready must return Poll::Pending (not Poll::Ready(Ok(()))) \
when the inner service errors — regression for #179"
);
}
}