use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
task::{Context, Poll},
time::Duration,
};
use bytes::Bytes;
use http::{HeaderName, HeaderValue, StatusCode};
use hyper::body::Incoming;
use hyper_util::rt::TokioIo;
use hyper_util::service::TowerToHyperService;
use tower::Service;
use crate::{
base32_encode,
client::{body_from_reader, pump_hyper_body_to_channel_limited},
io::IrohStream,
stream::{HandleStore, ResponseHeadEntry},
ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
};
type BoxBody = crate::BoxBody;
type BoxError = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug, Clone, Default)]
pub struct ServerLimits {
pub max_concurrency: Option<usize>,
pub max_consecutive_errors: Option<usize>,
pub request_timeout_ms: Option<u64>,
pub max_connections_per_peer: Option<usize>,
pub max_request_body_bytes: Option<usize>,
pub drain_timeout_secs: Option<u64>,
pub max_total_connections: Option<usize>,
pub load_shed: Option<bool>,
}
pub type ServeOptions = ServerLimits;
const DEFAULT_CONCURRENCY: usize = 64;
const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
pub struct ServeHandle {
join: tokio::task::JoinHandle<()>,
shutdown_notify: Arc<tokio::sync::Notify>,
drain_timeout: std::time::Duration,
done_rx: tokio::sync::watch::Receiver<bool>,
}
impl ServeHandle {
pub fn shutdown(&self) {
self.shutdown_notify.notify_one();
}
pub async fn drain(self) {
self.shutdown();
let _ = self.join.await;
}
pub fn abort(&self) {
self.join.abort();
}
pub fn drain_timeout(&self) -> std::time::Duration {
self.drain_timeout
}
pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
self.done_rx.clone()
}
}
pub fn respond(
handles: &HandleStore,
req_handle: u64,
status: u16,
headers: Vec<(String, String)>,
) -> Result<(), CoreError> {
StatusCode::from_u16(status)
.map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
for (name, value) in &headers {
HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
CoreError::invalid_input(format!("invalid response header name {:?}", name))
})?;
HeaderValue::from_str(value).map_err(|_| {
CoreError::invalid_input(format!("invalid response header value for {:?}", name))
})?;
}
let sender = handles
.take_req_sender(req_handle)
.ok_or_else(|| CoreError::invalid_handle(req_handle))?;
sender
.send(ResponseHeadEntry { status, headers })
.map_err(|_| CoreError::internal("serve task dropped before respond"))
}
type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
struct PeerConnectionGuard {
counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
peer: iroh::PublicKey,
peer_id_str: String,
on_event: Option<ConnectionEventFn>,
}
impl PeerConnectionGuard {
fn acquire(
counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
peer: iroh::PublicKey,
peer_id_str: String,
max: usize,
on_event: Option<ConnectionEventFn>,
) -> Option<Self> {
let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
let count = map.entry(peer).or_insert(0);
if *count >= max {
return None;
}
let was_zero = *count == 0;
*count += 1;
let guard = PeerConnectionGuard {
counts: counts.clone(),
peer,
peer_id_str: peer_id_str.clone(),
on_event: on_event.clone(),
};
if was_zero {
if let Some(cb) = &on_event {
cb(ConnectionEvent {
peer_id: peer_id_str,
connected: true,
});
}
}
Some(guard)
}
}
impl Drop for PeerConnectionGuard {
fn drop(&mut self) {
let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
if let Some(c) = map.get_mut(&self.peer) {
*c = c.saturating_sub(1);
if *c == 0 {
map.remove(&self.peer);
if let Some(cb) = &self.on_event {
cb(ConnectionEvent {
peer_id: self.peer_id_str.clone(),
connected: false,
});
}
}
}
}
}
#[derive(Clone)]
struct RequestService {
on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
endpoint: IrohEndpoint,
own_node_id: Arc<String>,
remote_node_id: Option<String>,
max_request_body_bytes: Option<usize>,
max_header_size: Option<usize>,
#[cfg(feature = "compression")]
compression: Option<crate::endpoint::CompressionOptions>,
}
impl Service<hyper::Request<Incoming>> for RequestService {
type Response = hyper::Response<BoxBody>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
let svc = self.clone();
Box::pin(async move { svc.handle(req).await })
}
}
impl RequestService {
async fn handle(
self,
mut req: hyper::Request<Incoming>,
) -> Result<hyper::Response<BoxBody>, BoxError> {
let handles = self.endpoint.handles();
let own_node_id = &*self.own_node_id;
let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
let max_request_body_bytes = self.max_request_body_bytes;
let max_header_size = self.max_header_size;
let method = req.method().to_string();
let path_and_query = req
.uri()
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/")
.to_string();
tracing::debug!(
method = %method,
path = %path_and_query,
peer = %remote_node_id,
"iroh-http: incoming request",
);
if let Some(limit) = max_header_size {
let header_bytes: usize = req
.headers()
.iter()
.filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
.map(|(k, v)| k.as_str().len() + v.as_bytes().len() + 4) .sum::<usize>()
+ "peer-id".len()
+ remote_node_id.len()
+ 4
+ req.uri().to_string().len()
+ method.len()
+ 12; if header_bytes > limit {
let resp = hyper::Response::builder()
.status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
.body(crate::box_body(http_body_util::Empty::new()))
.unwrap();
return Ok(resp);
}
}
let mut req_headers: Vec<(String, String)> = Vec::new();
for (k, v) in req.headers().iter() {
if k.as_str().eq_ignore_ascii_case("peer-id") {
continue;
}
match v.to_str() {
Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
Err(_) => {
let resp = hyper::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
b"non-UTF8 header value",
))))
.unwrap();
return Ok(resp);
}
}
}
req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
let url = format!("httpi://{own_node_id}{path_and_query}");
let has_upgrade_header = req_headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
});
let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
k.eq_ignore_ascii_case("connection")
&& v.split(',')
.any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
});
let is_connect = req.method() == http::Method::CONNECT;
let is_bidi = if has_upgrade_header {
if !has_connection_upgrade || !is_connect {
let resp = hyper::Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
b"duplex upgrade requires CONNECT method with Connection: upgrade header",
))))
.unwrap();
return Ok(resp);
}
true
} else {
false
};
let upgrade_future = if is_bidi {
Some(hyper::upgrade::on(&mut req))
} else {
None
};
let mut guard = handles.insert_guard();
let (req_body_writer, req_body_reader) = handles.make_body_channel();
let req_body_handle = guard
.insert_reader(req_body_reader)
.map_err(|e| -> BoxError { e.into() })?;
let (res_body_writer, res_body_reader) = handles.make_body_channel();
let res_body_handle = guard
.insert_writer(res_body_writer)
.map_err(|e| -> BoxError { e.into() })?;
let (req_trailers_handle, res_trailers_handle, req_trailer_tx, opt_res_trailer_rx) =
if !is_bidi {
let (rq_tx, rq_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
let rq_h = guard
.insert_trailer_receiver(rq_rx)
.map_err(|e| -> BoxError { e.into() })?;
let (rs_tx, rs_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
let rs_h = guard
.insert_trailer_sender(rs_tx)
.map_err(|e| -> BoxError { e.into() })?;
(rq_h, rs_h, Some(rq_tx), Some(rs_rx))
} else {
(0u64, 0u64, None, None)
};
let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
let req_handle = guard
.allocate_req_handle(head_tx)
.map_err(|e| -> BoxError { e.into() })?;
guard.commit();
struct ReqHeadCleanup {
endpoint: IrohEndpoint,
req_handle: u64,
}
impl Drop for ReqHeadCleanup {
fn drop(&mut self) {
self.endpoint.handles().take_req_sender(self.req_handle);
}
}
let _req_head_cleanup = ReqHeadCleanup {
endpoint: self.endpoint.clone(),
req_handle,
};
let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
let (tx, rx) = tokio::sync::oneshot::channel::<()>();
(Some(tx), Some(rx))
} else {
(None, None)
};
let duplex_req_body_writer = if !is_bidi {
let body = req.into_body();
let trailer_tx = req_trailer_tx.expect("non-duplex has req_trailer_tx");
let frame_timeout = handles.drain_timeout();
tokio::spawn(pump_hyper_body_to_channel_limited(
body,
req_body_writer,
trailer_tx,
max_request_body_bytes,
frame_timeout,
body_overflow_tx,
));
None
} else {
drop(req.into_body());
Some(req_body_writer)
};
on_request_fire(
&self.on_request,
req_handle,
req_body_handle,
res_body_handle,
req_trailers_handle,
res_trailers_handle,
method,
url,
req_headers,
remote_node_id,
is_bidi,
);
let response_head = if let Some(overflow_rx) = body_overflow_rx {
tokio::select! {
biased;
_ = overflow_rx => {
let resp = hyper::Response::builder()
.status(StatusCode::PAYLOAD_TOO_LARGE)
.body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
b"request body too large",
))))
.expect("valid 413 response");
return Ok(resp);
}
head = head_rx => {
head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
}
}
} else {
head_rx
.await
.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
};
if let Some(upgrade_fut) = upgrade_future {
let req_body_writer =
duplex_req_body_writer.expect("duplex path always has req_body_writer");
if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
drop(upgrade_fut);
drop(req_body_writer);
let mut resp_builder = hyper::Response::builder().status(response_head.status);
for (k, v) in &response_head.headers {
resp_builder = resp_builder.header(k.as_str(), v.as_str());
}
let resp = resp_builder
.body(crate::box_body(http_body_util::Empty::new()))
.map_err(|e| -> BoxError { e.into() })?;
return Ok(resp);
}
tokio::spawn(async move {
match upgrade_fut.await {
Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
Ok(upgraded) => {
let io = TokioIo::new(upgraded);
crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
}
}
});
let resp = hyper::Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(hyper::header::CONNECTION, "Upgrade")
.header(hyper::header::UPGRADE, "iroh-duplex")
.body(crate::box_body(http_body_util::Empty::new()))
.unwrap();
return Ok(resp);
}
let has_trailer_hdr = response_head
.headers
.iter()
.any(|(k, _)| k.eq_ignore_ascii_case("trailer"));
let trailer_rx_for_body = if has_trailer_hdr {
opt_res_trailer_rx
} else {
handles.remove_trailer_sender(res_trailers_handle);
None
};
let body_stream = body_from_reader(res_body_reader, trailer_rx_for_body);
let mut resp_builder = hyper::Response::builder().status(response_head.status);
for (k, v) in &response_head.headers {
resp_builder = resp_builder.header(k.as_str(), v.as_str());
}
#[cfg(feature = "compression")]
let resp_builder = resp_builder;
let resp = resp_builder
.body(crate::box_body(body_stream))
.map_err(|e| -> BoxError { e.into() })?;
Ok(resp)
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
fn on_request_fire(
cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
req_handle: u64,
req_body_handle: u64,
res_body_handle: u64,
req_trailers_handle: u64,
res_trailers_handle: u64,
method: String,
url: String,
headers: Vec<(String, String)>,
remote_node_id: String,
is_bidi: bool,
) {
cb(RequestPayload {
req_handle,
req_body_handle,
res_body_handle,
req_trailers_handle,
res_trailers_handle,
method,
url,
headers,
remote_node_id,
is_bidi,
});
}
pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
where
F: Fn(RequestPayload) + Send + Sync + 'static,
{
serve_with_events(endpoint, options, on_request, None)
}
pub fn serve_with_events<F>(
endpoint: IrohEndpoint,
options: ServeOptions,
on_request: F,
on_connection_event: Option<ConnectionEventFn>,
) -> ServeHandle
where
F: Fn(RequestPayload) + Send + Sync + 'static,
{
let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
let max_errors = options.max_consecutive_errors.unwrap_or(5);
let request_timeout = options
.request_timeout_ms
.map(Duration::from_millis)
.unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
let max_conns_per_peer = options
.max_connections_per_peer
.unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
let max_request_body_bytes = options.max_request_body_bytes;
let max_total_connections = options.max_total_connections;
let drain_timeout = Duration::from_secs(
options
.drain_timeout_secs
.unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
);
let load_shed_enabled = options.load_shed.unwrap_or(true);
let max_header_size = endpoint.max_header_size();
#[cfg(feature = "compression")]
let compression = endpoint.compression().cloned();
let own_node_id = Arc::new(endpoint.node_id().to_string());
let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
Arc::new(Mutex::new(HashMap::new()));
let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
let base_svc = RequestService {
on_request,
endpoint: endpoint.clone(),
own_node_id,
remote_node_id: None,
max_request_body_bytes,
max_header_size: if max_header_size == 0 {
None
} else {
Some(max_header_size)
},
#[cfg(feature = "compression")]
compression,
};
let shutdown_notify = Arc::new(tokio::sync::Notify::new());
let shutdown_listen = shutdown_notify.clone();
let drain_dur = drain_timeout;
let total_connections = endpoint.inner.active_connections.clone();
let total_requests = endpoint.inner.active_requests.clone();
let (done_tx, done_rx) = tokio::sync::watch::channel(false);
let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
let in_flight_drain = in_flight.clone();
let drain_notify_drain = drain_notify.clone();
let join = tokio::spawn(async move {
let ep = endpoint.raw().clone();
let mut consecutive_errors: usize = 0;
loop {
let incoming = tokio::select! {
biased;
_ = shutdown_listen.notified() => {
tracing::info!("iroh-http: serve loop shutting down");
break;
}
inc = ep.accept() => match inc {
Some(i) => i,
None => {
tracing::info!("iroh-http: endpoint closed (accept returned None)");
let _ = endpoint_closed_tx.send(true);
break;
}
}
};
let conn = match incoming.await {
Ok(c) => {
consecutive_errors = 0;
c
}
Err(e) => {
consecutive_errors += 1;
tracing::warn!(
"iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
);
if consecutive_errors >= max_errors {
tracing::error!("iroh-http: too many accept errors — shutting down");
break;
}
continue;
}
};
let remote_pk = conn.remote_id();
if let Some(max_total) = max_total_connections {
let current = total_connections.load(Ordering::Relaxed);
if current >= max_total {
tracing::warn!(
"iroh-http: total connection limit reached ({current}/{max_total})"
);
conn.close(0u32.into(), b"server at capacity");
continue;
}
}
let remote_id = base32_encode(remote_pk.as_bytes());
let guard =
match PeerConnectionGuard::acquire(&peer_counts, remote_pk, remote_id.clone(), max_conns_per_peer, conn_event_fn.clone()) {
Some(g) => g,
None => {
tracing::warn!(
"iroh-http: peer {remote_id} exceeded connection limit"
);
conn.close(0u32.into(), b"too many connections");
continue;
}
};
let mut peer_svc = base_svc.clone();
peer_svc.remote_node_id = Some(remote_id);
let timeout_dur = if request_timeout.is_zero() {
Duration::MAX
} else {
request_timeout
};
let conn_total = total_connections.clone();
let conn_requests = total_requests.clone();
let in_flight_conn = in_flight.clone();
let drain_notify_conn = drain_notify.clone();
conn_total.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
let _guard = guard;
struct TotalGuard(Arc<AtomicUsize>);
impl Drop for TotalGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::Relaxed);
}
}
let _total_guard = TotalGuard(conn_total);
loop {
let (send, recv) = match conn.accept_bi().await {
Ok(pair) => pair,
Err(_) => break,
};
let io = TokioIo::new(IrohStream::new(send, recv));
let svc = peer_svc.clone();
let req_counter = conn_requests.clone();
req_counter.fetch_add(1, Ordering::Relaxed);
in_flight_conn.fetch_add(1, Ordering::Relaxed);
let in_flight_req = in_flight_conn.clone();
let drain_notify_req = drain_notify_conn.clone();
tokio::spawn(async move {
struct ReqGuard {
counter: Arc<AtomicUsize>,
in_flight: Arc<AtomicUsize>,
drain_notify: Arc<tokio::sync::Notify>,
}
impl Drop for ReqGuard {
fn drop(&mut self) {
self.counter.fetch_sub(1, Ordering::Relaxed);
if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
self.drain_notify.notify_waiters();
}
}
}
let _req_guard = ReqGuard {
counter: req_counter,
in_flight: in_flight_req,
drain_notify: drain_notify_req,
};
let effective_header_limit = if max_header_size == 0 {
64 * 1024
} else {
max_header_size.max(8192)
};
use tower::{ServiceBuilder, limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
#[cfg(feature = "compression")]
let result = {
use http::{Extensions, HeaderMap, Version};
use tower_http::compression::{predicate::{Predicate, SizeAbove}, CompressionLayer};
let compression_config = svc.compression.clone();
if let Some(comp) = &compression_config {
let min_bytes = comp.min_body_bytes;
let mut layer = CompressionLayer::new().zstd(true);
if let Some(level) = comp.level {
use tower_http::compression::CompressionLevel;
layer = layer.quality(CompressionLevel::Precise(level as i32));
}
let not_pre_compressed =
|_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
!h.contains_key(http::header::CONTENT_ENCODING)
};
let not_no_transform =
|_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
h.get(http::header::CACHE_CONTROL)
.and_then(|v| v.to_str().ok())
.map(|v| {
!v.split(',').any(|d| {
d.trim().eq_ignore_ascii_case("no-transform")
})
})
.unwrap_or(true)
};
let predicate =
SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
.and(not_pre_compressed)
.and(not_no_transform);
if load_shed_enabled {
use tower::load_shed::LoadShedLayer;
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(LoadShedLayer::new())
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(
ServiceBuilder::new()
.layer(layer.compress_when(predicate))
.service(stk),
))
.with_upgrades()
.await
} else {
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(
ServiceBuilder::new()
.layer(layer.compress_when(predicate))
.service(stk),
))
.with_upgrades()
.await
}
} else if load_shed_enabled {
use tower::load_shed::LoadShedLayer;
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(LoadShedLayer::new())
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(stk))
.with_upgrades()
.await
} else {
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(stk))
.with_upgrades()
.await
}
};
#[cfg(not(feature = "compression"))]
let result = if load_shed_enabled {
use tower::load_shed::LoadShedLayer;
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(LoadShedLayer::new())
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(stk))
.with_upgrades()
.await
} else {
let stk = TowerErrorHandler(ServiceBuilder::new()
.layer(ConcurrencyLimitLayer::new(max))
.layer(TimeoutLayer::new(timeout_dur))
.service(svc));
hyper::server::conn::http1::Builder::new()
.max_buf_size(effective_header_limit)
.max_headers(128)
.serve_connection(io, TowerToHyperService::new(stk))
.with_upgrades()
.await
};
if let Err(e) = result {
tracing::debug!("iroh-http: http1 connection error: {e}");
}
});
}
});
}
let deadline = tokio::time::Instant::now() + drain_dur;
loop {
if in_flight_drain.load(Ordering::Acquire) == 0 {
tracing::info!("iroh-http: all in-flight requests drained");
break;
}
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
break;
}
tokio::select! {
_ = drain_notify_drain.notified() => {}
_ = tokio::time::sleep(remaining) => {}
}
}
let _ = done_tx.send(true);
});
ServeHandle {
join,
shutdown_notify,
drain_timeout: drain_dur,
done_rx,
}
}
#[derive(Clone)]
struct TowerErrorHandler<S>(S);
impl<S, Req> Service<Req> for TowerErrorHandler<S>
where
S: Service<Req, Response = hyper::Response<BoxBody>>,
S::Error: Into<BoxError>,
S::Future: Send + 'static,
{
type Response = hyper::Response<BoxBody>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx).map_err(Into::into)
}
fn call(&mut self, req: Req) -> Self::Future {
let fut = self.0.call(req);
Box::pin(async move {
match fut.await {
Ok(r) => Ok(r),
Err(e) => {
let e = e.into();
let status = if e.is::<tower::timeout::error::Elapsed>() {
StatusCode::REQUEST_TIMEOUT
} else if e.is::<tower::load_shed::error::Overloaded>() {
StatusCode::SERVICE_UNAVAILABLE
} else {
tracing::warn!("iroh-http: unexpected tower error: {e}");
StatusCode::INTERNAL_SERVER_ERROR
};
let body_bytes: &'static [u8] = match status {
StatusCode::REQUEST_TIMEOUT => b"request timed out",
StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
_ => b"internal server error",
};
Ok(hyper::Response::builder()
.status(status)
.body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
body_bytes,
))))
.expect("valid error response"))
}
}
})
}
}