use crate::app::{App, Policy, apply_security_headers};
use crate::error::{Error, Result};
use crate::extract::BodyLane;
use crate::response::IntoResponse;
use bytes::Bytes;
use std::sync::Arc;
pub(crate) async fn run_with_shutdown(
mut app: App,
listener: tokio::net::TcpListener,
shutdown: impl std::future::Future<Output = ()> + Send,
) -> Result<()> {
const DRAIN_CAP: std::time::Duration = std::time::Duration::from_secs(10);
const HEADER_READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
let background = app.take_background();
let built = Arc::new(app.build()?);
let mut connections = tokio::task::JoinSet::new();
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
for (_name, factory) in background {
let fut = factory(built.task_context(), shutdown_rx.clone());
connections.spawn(fut);
}
tokio::pin!(shutdown);
loop {
tokio::select! {
() = &mut shutdown => break,
accepted = listener.accept() => {
let (stream, peer_addr) = match accepted {
Ok(pair) => pair,
Err(e) if is_transient_accept_error(&e) => {
eprintln!("jerrycan: transient accept error ({e}); backing off 50ms");
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
continue;
}
Err(e) => return Err(Error::internal(format!("accept failed fatally: {e}"))),
};
let app = built.clone();
let write_stall_timeout = built.write_stall_timeout;
let mut shutdown_rx = shutdown_rx.clone();
connections.spawn(async move {
let io = hyper_util::rt::TokioIo::new(TimedIo::new(stream, write_stall_timeout));
let service = hyper::service::service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let app = app.clone();
async move {
let (mut parts, body) = req.into_parts();
parts.extensions.insert(crate::extract::ClientAddr(peer_addr));
let cors_origin = parts.headers.get(http::header::ORIGIN).cloned();
let (limit, stream) = match app.route_policy(&parts) {
Policy::Reject(response) => {
return Ok::<_, std::convert::Infallible>(response);
}
Policy::Route { limit, stream } => (limit, stream),
};
let response = if stream {
use http_body_util::combinators::UnsyncBoxBody;
let lane = BodyLane::Stream(Some(UnsyncBoxBody::new(TimedRecvBody::new(
http_body_util::Limited::new(body, limit),
app.body_read_timeout,
))));
dispatch_isolated(&app, parts, lane, cors_origin.as_ref()).await
} else {
use http_body_util::BodyExt;
let limited = http_body_util::Limited::new(body, limit);
let collected =
tokio::time::timeout(app.body_read_timeout, limited.collect()).await;
match collected {
Ok(Ok(collected)) => {
let lane = BodyLane::Buffered(collected.to_bytes());
dispatch_isolated(&app, parts, lane, cors_origin.as_ref()).await
}
Ok(Err(_)) => finish_error(
&app,
Error::payload_too_large(),
cors_origin.as_ref(),
),
Err(_) => finish_error(
&app,
Error::new(
http::StatusCode::REQUEST_TIMEOUT,
"JC0408",
"timed out reading the request body",
),
cors_origin.as_ref(),
),
}
};
Ok::<_, std::convert::Infallible>(response)
}
});
let conn = hyper::server::conn::http1::Builder::new()
.timer(hyper_util::rt::TokioTimer::new())
.header_read_timeout(HEADER_READ_TIMEOUT)
.serve_connection(io, service);
tokio::pin!(conn);
loop {
tokio::select! {
result = conn.as_mut() => {
let _ = result;
break;
}
_ = shutdown_rx.changed() => {
conn.as_mut().graceful_shutdown();
}
}
}
});
}
}
}
let _ = shutdown_tx.send(true);
drop(listener); let drain = async { while connections.join_next().await.is_some() {} };
if tokio::time::timeout(DRAIN_CAP, drain).await.is_err() {
eprintln!("jerrycan: drain cap reached — aborting remaining connections");
connections.abort_all();
}
Ok(())
}
fn finish_error(
app: &Arc<crate::app::BuiltApp>,
error: Error,
cors_origin: Option<&http::HeaderValue>,
) -> crate::response::Response {
let mut response = error.into_response();
if app.security_headers {
apply_security_headers(&mut response);
}
if let Some(config) = &app.cors {
crate::cors::apply_cors(&mut response, cors_origin, config);
}
response
}
async fn dispatch_isolated(
app: &Arc<crate::app::BuiltApp>,
parts: http::request::Parts,
lane: BodyLane,
cors_origin: Option<&http::HeaderValue>,
) -> crate::response::Response {
let app2 = app.clone();
match tokio::spawn(async move { app2.dispatch(parts, lane).await }).await {
Ok(response) => response,
Err(_join_error) => finish_error(app, Error::internal("handler panicked"), cors_origin),
}
}
pub(crate) async fn shutdown_signal() {
#[cfg(unix)]
{
let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("SIGTERM handler installation never fails on unix");
tokio::select! {
_ = tokio::signal::ctrl_c() => {}
_ = sigterm.recv() => {}
}
}
#[cfg(not(unix))]
{
let _ = tokio::signal::ctrl_c().await;
}
eprintln!("jerrycan: shutdown signal received — draining");
}
pub(crate) fn is_transient_accept_error(e: &std::io::Error) -> bool {
matches!(
e.kind(),
std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::Interrupted
| std::io::ErrorKind::WouldBlock
) || matches!(e.raw_os_error(), Some(23) | Some(24))
}
pub(crate) struct TimedIo<T> {
inner: T,
cap: std::time::Duration,
stall: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
}
impl<T> TimedIo<T> {
pub(crate) fn new(inner: T, cap: std::time::Duration) -> Self {
Self {
inner,
cap,
stall: None,
}
}
fn poll_stall(
stall: &mut Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
cap: std::time::Duration,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
use std::future::Future;
use std::task::Poll;
let sleep = stall.get_or_insert_with(|| Box::pin(tokio::time::sleep(cap)));
match sleep.as_mut().poll(cx) {
Poll::Ready(()) => {
*stall = None;
Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"connection write stalled past the cap",
)))
}
Poll::Pending => Poll::Pending,
}
}
}
impl<T: tokio::io::AsyncRead + Unpin> tokio::io::AsyncRead for TimedIo<T> {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl<T: tokio::io::AsyncWrite + Unpin> tokio::io::AsyncWrite for TimedIo<T> {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
use std::task::Poll;
match std::pin::Pin::new(&mut self.inner).poll_write(cx, buf) {
Poll::Ready(r) => {
self.stall = None;
Poll::Ready(r)
}
Poll::Pending => {
let cap = self.cap;
match Self::poll_stall(&mut self.stall, cap, cx) {
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Ready(Ok(())) => unreachable!("poll_stall never returns Ready(Ok)"),
Poll::Pending => Poll::Pending,
}
}
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
use std::task::Poll;
match std::pin::Pin::new(&mut self.inner).poll_flush(cx) {
Poll::Ready(r) => {
self.stall = None;
Poll::Ready(r)
}
Poll::Pending => {
let cap = self.cap;
Self::poll_stall(&mut self.stall, cap, cx)
}
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
#[derive(Debug)]
pub(crate) struct RecvTimeout;
impl std::fmt::Display for RecvTimeout {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("timed out waiting for the next request-body frame")
}
}
impl std::error::Error for RecvTimeout {}
pub(crate) struct TimedRecvBody<B> {
inner: B,
timeout: std::time::Duration,
sleep: Option<std::pin::Pin<Box<tokio::time::Sleep>>>,
}
impl<B> TimedRecvBody<B> {
pub(crate) fn new(inner: B, timeout: std::time::Duration) -> Self {
Self {
inner,
timeout,
sleep: None,
}
}
}
impl<B> http_body::Body for TimedRecvBody<B>
where
B: http_body::Body<Data = Bytes> + Unpin,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Data = Bytes;
type Error = Box<dyn std::error::Error + Send + Sync>;
fn poll_frame(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<std::result::Result<http_body::Frame<Bytes>, Self::Error>>> {
use std::future::Future;
use std::task::Poll;
match std::pin::Pin::new(&mut self.inner).poll_frame(cx) {
Poll::Ready(Some(Ok(frame))) => {
self.sleep = None;
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => {
self.sleep = None;
Poll::Ready(Some(Err(e.into())))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => {
let timeout = self.timeout;
let sleep = self
.sleep
.get_or_insert_with(|| Box::pin(tokio::time::sleep(timeout)));
match sleep.as_mut().poll(cx) {
Poll::Ready(()) => {
self.sleep = None;
Poll::Ready(Some(Err(Box::new(RecvTimeout))))
}
Poll::Pending => Poll::Pending,
}
}
}
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}