1use std::{
8 collections::HashMap,
9 future::Future,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc, Mutex,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27 base32_encode,
28 client::{body_from_reader, pump_hyper_body_to_channel_limited},
29 io::IrohStream,
30 stream::{HandleStore, ResponseHeadEntry},
31 ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39#[derive(Debug, Clone, Default)]
47pub struct ServeOptions {
48 pub max_concurrency: Option<usize>,
50 pub max_serve_errors: Option<usize>,
52 pub request_timeout_ms: Option<u64>,
54 pub max_connections_per_peer: Option<usize>,
56 pub max_request_body_bytes: Option<usize>,
58 pub drain_timeout_ms: Option<u64>,
60 pub max_total_connections: Option<usize>,
62 pub load_shed: Option<bool>,
66}
67
68const DEFAULT_CONCURRENCY: usize = 1024;
69const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
70const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
71const DEFAULT_DRAIN_TIMEOUT_MS: u64 = 30_000;
72const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 16 * 1024 * 1024;
75pub(crate) const DEFAULT_MAX_RESPONSE_BODY_BYTES: usize = 256 * 1024 * 1024;
79
80pub struct ServeHandle {
83 join: tokio::task::JoinHandle<()>,
84 shutdown_notify: Arc<tokio::sync::Notify>,
85 drain_timeout: std::time::Duration,
86 done_rx: tokio::sync::watch::Receiver<bool>,
88}
89
90impl ServeHandle {
91 pub fn shutdown(&self) {
92 self.shutdown_notify.notify_one();
93 }
94 pub async fn drain(self) {
95 self.shutdown();
96 let _ = self.join.await;
97 }
98 pub fn abort(&self) {
99 self.join.abort();
100 }
101 pub fn drain_timeout(&self) -> std::time::Duration {
102 self.drain_timeout
103 }
104 pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
109 self.done_rx.clone()
110 }
111}
112
113pub fn respond(
116 handles: &HandleStore,
117 req_handle: u64,
118 status: u16,
119 headers: Vec<(String, String)>,
120) -> Result<(), CoreError> {
121 StatusCode::from_u16(status)
122 .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
123 for (name, value) in &headers {
124 HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
125 CoreError::invalid_input(format!("invalid response header name {:?}", name))
126 })?;
127 HeaderValue::from_str(value).map_err(|_| {
128 CoreError::invalid_input(format!("invalid response header value for {:?}", name))
129 })?;
130 }
131
132 let sender = handles
133 .take_req_sender(req_handle)
134 .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
135 sender
136 .send(ResponseHeadEntry { status, headers })
137 .map_err(|_| CoreError::internal("serve task dropped before respond"))
138}
139
140type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
143
144struct PeerConnectionGuard {
145 counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
146 peer: iroh::PublicKey,
147 peer_id_str: String,
148 on_event: Option<ConnectionEventFn>,
149}
150
151impl PeerConnectionGuard {
152 fn acquire(
153 counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
154 peer: iroh::PublicKey,
155 peer_id_str: String,
156 max: usize,
157 on_event: Option<ConnectionEventFn>,
158 ) -> Option<Self> {
159 let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
160 let count = map.entry(peer).or_insert(0);
161 if *count >= max {
162 return None;
163 }
164 let was_zero = *count == 0;
165 *count = count.saturating_add(1);
166 let guard = PeerConnectionGuard {
167 counts: counts.clone(),
168 peer,
169 peer_id_str: peer_id_str.clone(),
170 on_event: on_event.clone(),
171 };
172 if was_zero {
174 if let Some(cb) = &on_event {
175 cb(ConnectionEvent {
176 peer_id: peer_id_str,
177 connected: true,
178 });
179 }
180 }
181 Some(guard)
182 }
183}
184
185impl Drop for PeerConnectionGuard {
186 fn drop(&mut self) {
187 let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
188 if let Some(c) = map.get_mut(&self.peer) {
189 *c = c.saturating_sub(1);
190 if *c == 0 {
191 map.remove(&self.peer);
192 if let Some(cb) = &self.on_event {
194 cb(ConnectionEvent {
195 peer_id: self.peer_id_str.clone(),
196 connected: false,
197 });
198 }
199 }
200 }
201 }
202}
203
204#[derive(Clone)]
207struct RequestService {
208 on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
209 endpoint: IrohEndpoint,
210 own_node_id: Arc<String>,
211 remote_node_id: Option<String>,
212 max_request_body_bytes: Option<usize>,
213 max_header_size: Option<usize>,
214 #[cfg(feature = "compression")]
215 compression: Option<crate::endpoint::CompressionOptions>,
216}
217
218impl Service<hyper::Request<Incoming>> for RequestService {
219 type Response = hyper::Response<BoxBody>;
220 type Error = BoxError;
221 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
222
223 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224 Poll::Ready(Ok(()))
225 }
226
227 fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
228 let svc = self.clone();
229 Box::pin(async move { svc.handle(req).await })
230 }
231}
232
233impl RequestService {
234 async fn handle(
235 self,
236 mut req: hyper::Request<Incoming>,
237 ) -> Result<hyper::Response<BoxBody>, BoxError> {
238 let handles = self.endpoint.handles();
239 let own_node_id = &*self.own_node_id;
240 let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
241 let max_request_body_bytes = self.max_request_body_bytes;
242 let max_header_size = self.max_header_size;
243
244 let method = req.method().to_string();
245 let path_and_query = req
246 .uri()
247 .path_and_query()
248 .map(|p| p.as_str())
249 .unwrap_or("/")
250 .to_string();
251
252 tracing::debug!(
253 method = %method,
254 path = %path_and_query,
255 peer = %remote_node_id,
256 "iroh-http: incoming request",
257 );
258 if let Some(limit) = max_header_size {
267 let header_bytes: usize = req
268 .headers()
269 .iter()
270 .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
271 .map(|(k, v)| {
272 k.as_str()
273 .len()
274 .saturating_add(v.as_bytes().len())
275 .saturating_add(4)
276 }) .fold(0usize, |acc, x| acc.saturating_add(x))
278 .saturating_add("peer-id".len())
279 .saturating_add(remote_node_id.len())
280 .saturating_add(4)
281 .saturating_add(req.uri().to_string().len())
282 .saturating_add(method.len())
283 .saturating_add(12); if header_bytes > limit {
285 let resp = hyper::Response::builder()
286 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
287 .body(crate::box_body(http_body_util::Empty::new()))
288 .expect("static response args are valid");
289 return Ok(resp);
290 }
291 }
292
293 let mut req_headers: Vec<(String, String)> = Vec::new();
295 for (k, v) in req.headers().iter() {
296 if k.as_str().eq_ignore_ascii_case("peer-id") {
297 continue;
298 }
299 match v.to_str() {
300 Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
301 Err(_) => {
302 let resp = hyper::Response::builder()
303 .status(StatusCode::BAD_REQUEST)
304 .body(crate::box_body(http_body_util::Full::new(
305 Bytes::from_static(b"non-UTF8 header value"),
306 )))
307 .expect("static response args are valid");
308 return Ok(resp);
309 }
310 }
311 }
312 req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
313
314 let url = format!("httpi://{own_node_id}{path_and_query}");
315
316 let has_upgrade_header = req_headers.iter().any(|(k, v)| {
319 k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
320 });
321 let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
322 k.eq_ignore_ascii_case("connection")
323 && v.split(',')
324 .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
325 });
326 let is_connect = req.method() == http::Method::CONNECT;
327
328 let is_bidi = if has_upgrade_header {
329 if !has_connection_upgrade || !is_connect {
330 let resp = hyper::Response::builder()
331 .status(StatusCode::BAD_REQUEST)
332 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
333 b"duplex upgrade requires CONNECT method with Connection: upgrade header",
334 ))))
335 .expect("static response args are valid");
336 return Ok(resp);
337 }
338 true
339 } else {
340 false
341 };
342
343 let upgrade_future = if is_bidi {
345 Some(hyper::upgrade::on(&mut req))
346 } else {
347 None
348 };
349
350 let mut guard = handles.insert_guard();
354 let (req_body_writer, req_body_reader) = handles.make_body_channel();
355 let req_body_handle = guard
356 .insert_reader(req_body_reader)
357 .map_err(|e| -> BoxError { e.into() })?;
358
359 let (res_body_writer, res_body_reader) = handles.make_body_channel();
361 let res_body_handle = guard
362 .insert_writer(res_body_writer)
363 .map_err(|e| -> BoxError { e.into() })?;
364
365 let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
368 let req_handle = guard
369 .allocate_req_handle(head_tx)
370 .map_err(|e| -> BoxError { e.into() })?;
371
372 guard.commit();
373
374 struct ReqHeadCleanup {
378 endpoint: IrohEndpoint,
379 req_handle: u64,
380 }
381 impl Drop for ReqHeadCleanup {
382 fn drop(&mut self) {
383 self.endpoint.handles().take_req_sender(self.req_handle);
384 }
385 }
386 let _req_head_cleanup = ReqHeadCleanup {
387 endpoint: self.endpoint.clone(),
388 req_handle,
389 };
390
391 let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
397 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
398 (Some(tx), Some(rx))
399 } else {
400 (None, None)
401 };
402
403 let duplex_req_body_writer = if !is_bidi {
404 let body = req.into_body();
405 let frame_timeout = handles.drain_timeout();
406 tokio::spawn(pump_hyper_body_to_channel_limited(
407 body,
408 req_body_writer,
409 max_request_body_bytes,
410 frame_timeout,
411 body_overflow_tx,
412 ));
413 None
414 } else {
415 drop(req.into_body());
417 Some(req_body_writer)
418 };
419
420 on_request_fire(
423 &self.on_request,
424 req_handle,
425 req_body_handle,
426 res_body_handle,
427 method,
428 url,
429 req_headers,
430 remote_node_id,
431 is_bidi,
432 );
433
434 let response_head = if let Some(overflow_rx) = body_overflow_rx {
440 tokio::select! {
441 biased;
442 Ok(()) = overflow_rx => {
443 let resp = hyper::Response::builder()
446 .status(StatusCode::PAYLOAD_TOO_LARGE)
447 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
448 b"request body too large",
449 ))))
450 .expect("valid 413 response");
451 return Ok(resp);
452 }
453 head = head_rx => {
454 head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
455 }
456 }
457 } else {
458 head_rx
459 .await
460 .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
461 };
462
463 if let Some(upgrade_fut) = upgrade_future {
470 let req_body_writer =
471 duplex_req_body_writer.expect("duplex path always has req_body_writer");
472
473 if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
476 drop(upgrade_fut);
477 drop(req_body_writer);
478 let mut resp_builder = hyper::Response::builder().status(response_head.status);
479 for (k, v) in &response_head.headers {
480 resp_builder = resp_builder.header(k.as_str(), v.as_str());
481 }
482 let resp = resp_builder
483 .body(crate::box_body(http_body_util::Empty::new()))
484 .map_err(|e| -> BoxError { e.into() })?;
485 return Ok(resp);
486 }
487
488 tokio::spawn(async move {
494 match upgrade_fut.await {
495 Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
496 Ok(upgraded) => {
497 let io = TokioIo::new(upgraded);
498 crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
499 }
500 }
501 });
502
503 let resp = hyper::Response::builder()
505 .status(StatusCode::SWITCHING_PROTOCOLS)
506 .header(hyper::header::CONNECTION, "Upgrade")
507 .header(hyper::header::UPGRADE, "iroh-duplex")
508 .body(crate::box_body(http_body_util::Empty::new()))
509 .expect("static response args are valid");
510 return Ok(resp);
511 }
512
513 let body_stream = body_from_reader(res_body_reader);
516
517 let mut resp_builder = hyper::Response::builder().status(response_head.status);
518 for (k, v) in &response_head.headers {
519 resp_builder = resp_builder.header(k.as_str(), v.as_str());
520 }
521
522 #[cfg(feature = "compression")]
523 let resp_builder = resp_builder; let resp = resp_builder
526 .body(crate::box_body(body_stream))
527 .map_err(|e| -> BoxError { e.into() })?;
528
529 Ok(resp)
530 }
531}
532
533#[inline]
534#[allow(clippy::too_many_arguments)]
535fn on_request_fire(
536 cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
537 req_handle: u64,
538 req_body_handle: u64,
539 res_body_handle: u64,
540 method: String,
541 url: String,
542 headers: Vec<(String, String)>,
543 remote_node_id: String,
544 is_bidi: bool,
545) {
546 cb(RequestPayload {
547 req_handle,
548 req_body_handle,
549 res_body_handle,
550 method,
551 url,
552 headers,
553 remote_node_id,
554 is_bidi,
555 });
556}
557
558pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
586where
587 F: Fn(RequestPayload) + Send + Sync + 'static,
588{
589 serve_with_events(endpoint, options, on_request, None)
590}
591
592pub fn serve_with_events<F>(
597 endpoint: IrohEndpoint,
598 options: ServeOptions,
599 on_request: F,
600 on_connection_event: Option<ConnectionEventFn>,
601) -> ServeHandle
602where
603 F: Fn(RequestPayload) + Send + Sync + 'static,
604{
605 let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
606 let max_errors = options.max_serve_errors.unwrap_or(5);
607 let request_timeout = options
608 .request_timeout_ms
609 .map(Duration::from_millis)
610 .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
611 let max_conns_per_peer = options
612 .max_connections_per_peer
613 .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
614 let max_request_body_bytes = options
615 .max_request_body_bytes
616 .or(Some(DEFAULT_MAX_REQUEST_BODY_BYTES));
617 let max_total_connections = options.max_total_connections;
618 let drain_timeout =
619 Duration::from_millis(options.drain_timeout_ms.unwrap_or(DEFAULT_DRAIN_TIMEOUT_MS));
620 let load_shed_enabled = options.load_shed.unwrap_or(true);
622 let max_header_size = endpoint.max_header_size();
623 #[cfg(feature = "compression")]
624 let compression = endpoint.compression().cloned();
625 let own_node_id = Arc::new(endpoint.node_id().to_string());
626 let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
627
628 let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
629 Arc::new(Mutex::new(HashMap::new()));
630 let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
631
632 let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
635 let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
636
637 let base_svc = RequestService {
638 on_request,
639 endpoint: endpoint.clone(),
640 own_node_id,
641 remote_node_id: None,
642 max_request_body_bytes,
643 max_header_size: if max_header_size == 0 {
644 None
645 } else {
646 Some(max_header_size)
647 },
648 #[cfg(feature = "compression")]
649 compression,
650 };
651
652 use tower::{limit::ConcurrencyLimitLayer, Layer};
653 let shared_conc = ConcurrencyLimitLayer::new(max).layer(base_svc);
657
658 let shutdown_notify = Arc::new(tokio::sync::Notify::new());
659 let shutdown_listen = shutdown_notify.clone();
660 let drain_dur = drain_timeout;
661 let total_connections = endpoint.inner.active_connections.clone();
664 let total_requests = endpoint.inner.active_requests.clone();
665 let (done_tx, done_rx) = tokio::sync::watch::channel(false);
666 let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
667
668 let in_flight_drain = in_flight.clone();
669 let drain_notify_drain = drain_notify.clone();
670
671 let join = tokio::spawn(async move {
672 let ep = endpoint.raw().clone();
673 let mut consecutive_errors: usize = 0;
674
675 loop {
676 let incoming = tokio::select! {
677 biased;
678 _ = shutdown_listen.notified() => {
679 tracing::info!("iroh-http: serve loop shutting down");
680 break;
681 }
682 inc = ep.accept() => match inc {
683 Some(i) => i,
684 None => {
685 tracing::info!("iroh-http: endpoint closed (accept returned None)");
686 let _ = endpoint_closed_tx.send(true);
687 break;
688 }
689 }
690 };
691
692 let conn = match incoming.await {
693 Ok(c) => {
694 consecutive_errors = 0;
695 c
696 }
697 Err(e) => {
698 consecutive_errors = consecutive_errors.saturating_add(1);
699 tracing::warn!(
700 "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
701 );
702 if consecutive_errors >= max_errors {
703 tracing::error!("iroh-http: too many accept errors — shutting down");
704 break;
705 }
706 continue;
707 }
708 };
709
710 let remote_pk = conn.remote_id();
711
712 if let Some(max_total) = max_total_connections {
714 let current = total_connections.load(Ordering::Relaxed);
715 if current >= max_total {
716 tracing::warn!(
717 "iroh-http: total connection limit reached ({current}/{max_total})"
718 );
719 conn.close(0u32.into(), b"server at capacity");
720 continue;
721 }
722 }
723
724 let remote_id = base32_encode(remote_pk.as_bytes());
725
726 let guard = match PeerConnectionGuard::acquire(
727 &peer_counts,
728 remote_pk,
729 remote_id.clone(),
730 max_conns_per_peer,
731 conn_event_fn.clone(),
732 ) {
733 Some(g) => g,
734 None => {
735 tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
736 conn.close(0u32.into(), b"too many connections");
737 continue;
738 }
739 };
740
741 let mut conn_conc = shared_conc.clone();
742 conn_conc.get_mut().remote_node_id = Some(remote_id);
743
744 let timeout_dur = if request_timeout.is_zero() {
745 Duration::MAX
746 } else {
747 request_timeout
748 };
749
750 let conn_total = total_connections.clone();
751 let conn_requests = total_requests.clone();
752 let in_flight_conn = in_flight.clone();
753 let drain_notify_conn = drain_notify.clone();
754 conn_total.fetch_add(1, Ordering::Relaxed);
755 tokio::spawn(async move {
756 let _guard = guard;
757 struct TotalGuard(Arc<AtomicUsize>);
759 impl Drop for TotalGuard {
760 fn drop(&mut self) {
761 self.0.fetch_sub(1, Ordering::Relaxed);
762 }
763 }
764 let _total_guard = TotalGuard(conn_total);
765
766 loop {
767 let (send, recv) = match conn.accept_bi().await {
768 Ok(pair) => pair,
769 Err(_) => break,
770 };
771
772 let io = TokioIo::new(IrohStream::new(send, recv));
773 let svc = conn_conc.clone();
774 let req_counter = conn_requests.clone();
775 req_counter.fetch_add(1, Ordering::Relaxed);
776 in_flight_conn.fetch_add(1, Ordering::Relaxed);
777
778 let in_flight_req = in_flight_conn.clone();
779 let drain_notify_req = drain_notify_conn.clone();
780
781 tokio::spawn(async move {
782 struct ReqGuard {
784 counter: Arc<AtomicUsize>,
785 in_flight: Arc<AtomicUsize>,
786 drain_notify: Arc<tokio::sync::Notify>,
787 }
788 impl Drop for ReqGuard {
789 fn drop(&mut self) {
790 self.counter.fetch_sub(1, Ordering::Relaxed);
791 if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
792 self.drain_notify.notify_waiters();
794 }
795 }
796 }
797 let _req_guard = ReqGuard {
798 counter: req_counter,
799 in_flight: in_flight_req,
800 drain_notify: drain_notify_req,
801 };
802 let effective_header_limit = if max_header_size == 0 {
805 64 * 1024
806 } else {
807 max_header_size.max(8192)
808 };
809
810 use tower::{timeout::TimeoutLayer, ServiceBuilder};
825
826 #[cfg(feature = "compression")]
827 let result = {
828 use http::{Extensions, HeaderMap, Version};
829 use tower_http::compression::{
830 predicate::{Predicate, SizeAbove},
831 CompressionLayer,
832 };
833
834 let compression_config = svc.get_ref().compression.clone();
835 if let Some(comp) = &compression_config {
836 let min_bytes = comp.min_body_bytes;
837 let mut layer = CompressionLayer::new().zstd(true);
838 if let Some(level) = comp.level {
839 use tower_http::compression::CompressionLevel;
840 layer = layer.quality(CompressionLevel::Precise(level as i32));
841 }
842 let not_pre_compressed =
843 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
844 !h.contains_key(http::header::CONTENT_ENCODING)
845 };
846 let not_no_transform =
847 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
848 h.get(http::header::CACHE_CONTROL)
849 .and_then(|v| v.to_str().ok())
850 .map(|v| {
851 !v.split(',').any(|d| {
852 d.trim().eq_ignore_ascii_case("no-transform")
853 })
854 })
855 .unwrap_or(true)
856 };
857 let predicate =
858 SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
859 .and(not_pre_compressed)
860 .and(not_no_transform);
861 if load_shed_enabled {
862 use tower::load_shed::LoadShedLayer;
863 let stk = TowerErrorHandler(
864 ServiceBuilder::new()
865 .layer(LoadShedLayer::new())
866 .layer(TimeoutLayer::new(timeout_dur))
867 .service(svc),
868 );
869 hyper::server::conn::http1::Builder::new()
870 .max_buf_size(effective_header_limit)
871 .max_headers(128)
872 .serve_connection(
873 io,
874 TowerToHyperService::new(
875 ServiceBuilder::new()
876 .layer(layer.compress_when(predicate))
877 .service(stk),
878 ),
879 )
880 .with_upgrades()
881 .await
882 } else {
883 let stk = TowerErrorHandler(
884 ServiceBuilder::new()
885 .layer(TimeoutLayer::new(timeout_dur))
886 .service(svc),
887 );
888 hyper::server::conn::http1::Builder::new()
889 .max_buf_size(effective_header_limit)
890 .max_headers(128)
891 .serve_connection(
892 io,
893 TowerToHyperService::new(
894 ServiceBuilder::new()
895 .layer(layer.compress_when(predicate))
896 .service(stk),
897 ),
898 )
899 .with_upgrades()
900 .await
901 }
902 } else if load_shed_enabled {
903 use tower::load_shed::LoadShedLayer;
904 let stk = TowerErrorHandler(
905 ServiceBuilder::new()
906 .layer(LoadShedLayer::new())
907 .layer(TimeoutLayer::new(timeout_dur))
908 .service(svc),
909 );
910 hyper::server::conn::http1::Builder::new()
911 .max_buf_size(effective_header_limit)
912 .max_headers(128)
913 .serve_connection(io, TowerToHyperService::new(stk))
914 .with_upgrades()
915 .await
916 } else {
917 let stk = TowerErrorHandler(
918 ServiceBuilder::new()
919 .layer(TimeoutLayer::new(timeout_dur))
920 .service(svc),
921 );
922 hyper::server::conn::http1::Builder::new()
923 .max_buf_size(effective_header_limit)
924 .max_headers(128)
925 .serve_connection(io, TowerToHyperService::new(stk))
926 .with_upgrades()
927 .await
928 }
929 };
930 #[cfg(not(feature = "compression"))]
931 let result = if load_shed_enabled {
932 use tower::load_shed::LoadShedLayer;
933 let stk = TowerErrorHandler(
934 ServiceBuilder::new()
935 .layer(LoadShedLayer::new())
936 .layer(TimeoutLayer::new(timeout_dur))
937 .service(svc),
938 );
939 hyper::server::conn::http1::Builder::new()
940 .max_buf_size(effective_header_limit)
941 .max_headers(128)
942 .serve_connection(io, TowerToHyperService::new(stk))
943 .with_upgrades()
944 .await
945 } else {
946 let stk = TowerErrorHandler(
947 ServiceBuilder::new()
948 .layer(TimeoutLayer::new(timeout_dur))
949 .service(svc),
950 );
951 hyper::server::conn::http1::Builder::new()
952 .max_buf_size(effective_header_limit)
953 .max_headers(128)
954 .serve_connection(io, TowerToHyperService::new(stk))
955 .with_upgrades()
956 .await
957 };
958
959 if let Err(e) = result {
960 tracing::debug!("iroh-http: http1 connection error: {e}");
961 }
962 });
963 }
964 });
965 }
966
967 let deadline = tokio::time::Instant::now()
974 .checked_add(drain_dur)
975 .expect("drain duration overflow");
976 loop {
977 if in_flight_drain.load(Ordering::Acquire) == 0 {
978 tracing::info!("iroh-http: all in-flight requests drained");
979 break;
980 }
981 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
982 if remaining.is_zero() {
983 tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
984 break;
985 }
986 tokio::select! {
987 _ = drain_notify_drain.notified() => {}
988 _ = tokio::time::sleep(remaining) => {}
989 }
990 }
991 let _ = done_tx.send(true);
992 });
993
994 ServeHandle {
995 join,
996 shutdown_notify,
997 drain_timeout: drain_dur,
998 done_rx,
999 }
1000}
1001
1002#[derive(Clone)]
1017struct TowerErrorHandler<S>(S);
1018
1019impl<S, Req> Service<Req> for TowerErrorHandler<S>
1020where
1021 S: Service<Req, Response = hyper::Response<BoxBody>>,
1022 S::Error: Into<BoxError>,
1023 S::Future: Send + 'static,
1024{
1025 type Response = hyper::Response<BoxBody>;
1026 type Error = BoxError;
1027 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1028
1029 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1030 self.0.poll_ready(cx).map_err(Into::into)
1035 }
1036
1037 fn call(&mut self, req: Req) -> Self::Future {
1038 let fut = self.0.call(req);
1039 Box::pin(async move {
1040 match fut.await {
1041 Ok(r) => Ok(r),
1042 Err(e) => {
1043 let e = e.into();
1044 let status = if e.is::<tower::timeout::error::Elapsed>() {
1045 StatusCode::REQUEST_TIMEOUT
1046 } else if e.is::<tower::load_shed::error::Overloaded>() {
1047 StatusCode::SERVICE_UNAVAILABLE
1048 } else {
1049 tracing::warn!("iroh-http: unexpected tower error: {e}");
1050 StatusCode::INTERNAL_SERVER_ERROR
1051 };
1052 let body_bytes: &'static [u8] = match status {
1053 StatusCode::REQUEST_TIMEOUT => b"request timed out",
1054 StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1055 _ => b"internal server error",
1056 };
1057 Ok(hyper::Response::builder()
1058 .status(status)
1059 .body(crate::box_body(http_body_util::Full::new(
1060 Bytes::from_static(body_bytes),
1061 )))
1062 .expect("valid error response"))
1063 }
1064 }
1065 })
1066 }
1067}