1use std::{
8 collections::HashMap,
9 future::Future,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc, Mutex,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27 base32_encode,
28 client::{body_from_reader, pump_hyper_body_to_channel_limited},
29 io::IrohStream,
30 stream::{HandleStore, ResponseHeadEntry},
31 ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49 pub max_concurrency: Option<usize>,
50 pub max_consecutive_errors: Option<usize>,
51 pub request_timeout_ms: Option<u64>,
52 pub max_connections_per_peer: Option<usize>,
53 pub max_request_body_bytes: Option<usize>,
54 pub drain_timeout_secs: Option<u64>,
55 pub max_total_connections: Option<usize>,
56 pub load_shed: Option<bool>,
60}
61
62pub type ServeOptions = ServerLimits;
65
66const DEFAULT_CONCURRENCY: usize = 64;
67const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
68const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
69const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
70
71pub struct ServeHandle {
74 join: tokio::task::JoinHandle<()>,
75 shutdown_notify: Arc<tokio::sync::Notify>,
76 drain_timeout: std::time::Duration,
77 done_rx: tokio::sync::watch::Receiver<bool>,
79}
80
81impl ServeHandle {
82 pub fn shutdown(&self) {
83 self.shutdown_notify.notify_one();
84 }
85 pub async fn drain(self) {
86 self.shutdown();
87 let _ = self.join.await;
88 }
89 pub fn abort(&self) {
90 self.join.abort();
91 }
92 pub fn drain_timeout(&self) -> std::time::Duration {
93 self.drain_timeout
94 }
95 pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
100 self.done_rx.clone()
101 }
102}
103
104pub fn respond(
107 handles: &HandleStore,
108 req_handle: u64,
109 status: u16,
110 headers: Vec<(String, String)>,
111) -> Result<(), CoreError> {
112 StatusCode::from_u16(status)
113 .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
114 for (name, value) in &headers {
115 HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
116 CoreError::invalid_input(format!("invalid response header name {:?}", name))
117 })?;
118 HeaderValue::from_str(value).map_err(|_| {
119 CoreError::invalid_input(format!("invalid response header value for {:?}", name))
120 })?;
121 }
122
123 let sender = handles
124 .take_req_sender(req_handle)
125 .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
126 sender
127 .send(ResponseHeadEntry { status, headers })
128 .map_err(|_| CoreError::internal("serve task dropped before respond"))
129}
130
131type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
134
135struct PeerConnectionGuard {
136 counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
137 peer: iroh::PublicKey,
138 peer_id_str: String,
139 on_event: Option<ConnectionEventFn>,
140}
141
142impl PeerConnectionGuard {
143 fn acquire(
144 counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
145 peer: iroh::PublicKey,
146 peer_id_str: String,
147 max: usize,
148 on_event: Option<ConnectionEventFn>,
149 ) -> Option<Self> {
150 let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
151 let count = map.entry(peer).or_insert(0);
152 if *count >= max {
153 return None;
154 }
155 let was_zero = *count == 0;
156 *count = count.saturating_add(1);
157 let guard = PeerConnectionGuard {
158 counts: counts.clone(),
159 peer,
160 peer_id_str: peer_id_str.clone(),
161 on_event: on_event.clone(),
162 };
163 if was_zero {
165 if let Some(cb) = &on_event {
166 cb(ConnectionEvent {
167 peer_id: peer_id_str,
168 connected: true,
169 });
170 }
171 }
172 Some(guard)
173 }
174}
175
176impl Drop for PeerConnectionGuard {
177 fn drop(&mut self) {
178 let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
179 if let Some(c) = map.get_mut(&self.peer) {
180 *c = c.saturating_sub(1);
181 if *c == 0 {
182 map.remove(&self.peer);
183 if let Some(cb) = &self.on_event {
185 cb(ConnectionEvent {
186 peer_id: self.peer_id_str.clone(),
187 connected: false,
188 });
189 }
190 }
191 }
192 }
193}
194
195#[derive(Clone)]
198struct RequestService {
199 on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
200 endpoint: IrohEndpoint,
201 own_node_id: Arc<String>,
202 remote_node_id: Option<String>,
203 max_request_body_bytes: Option<usize>,
204 max_header_size: Option<usize>,
205 #[cfg(feature = "compression")]
206 compression: Option<crate::endpoint::CompressionOptions>,
207}
208
209impl Service<hyper::Request<Incoming>> for RequestService {
210 type Response = hyper::Response<BoxBody>;
211 type Error = BoxError;
212 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 Poll::Ready(Ok(()))
216 }
217
218 fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
219 let svc = self.clone();
220 Box::pin(async move { svc.handle(req).await })
221 }
222}
223
224impl RequestService {
225 async fn handle(
226 self,
227 mut req: hyper::Request<Incoming>,
228 ) -> Result<hyper::Response<BoxBody>, BoxError> {
229 let handles = self.endpoint.handles();
230 let own_node_id = &*self.own_node_id;
231 let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
232 let max_request_body_bytes = self.max_request_body_bytes;
233 let max_header_size = self.max_header_size;
234
235 let method = req.method().to_string();
236 let path_and_query = req
237 .uri()
238 .path_and_query()
239 .map(|p| p.as_str())
240 .unwrap_or("/")
241 .to_string();
242
243 tracing::debug!(
244 method = %method,
245 path = %path_and_query,
246 peer = %remote_node_id,
247 "iroh-http: incoming request",
248 );
249 if let Some(limit) = max_header_size {
258 let header_bytes: usize = req
259 .headers()
260 .iter()
261 .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
262 .map(|(k, v)| {
263 k.as_str()
264 .len()
265 .saturating_add(v.as_bytes().len())
266 .saturating_add(4)
267 }) .fold(0usize, |acc, x| acc.saturating_add(x))
269 .saturating_add("peer-id".len())
270 .saturating_add(remote_node_id.len())
271 .saturating_add(4)
272 .saturating_add(req.uri().to_string().len())
273 .saturating_add(method.len())
274 .saturating_add(12); if header_bytes > limit {
276 let resp = hyper::Response::builder()
277 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
278 .body(crate::box_body(http_body_util::Empty::new()))
279 .expect("static response args are valid");
280 return Ok(resp);
281 }
282 }
283
284 let mut req_headers: Vec<(String, String)> = Vec::new();
286 for (k, v) in req.headers().iter() {
287 if k.as_str().eq_ignore_ascii_case("peer-id") {
288 continue;
289 }
290 match v.to_str() {
291 Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
292 Err(_) => {
293 let resp = hyper::Response::builder()
294 .status(StatusCode::BAD_REQUEST)
295 .body(crate::box_body(http_body_util::Full::new(
296 Bytes::from_static(b"non-UTF8 header value"),
297 )))
298 .expect("static response args are valid");
299 return Ok(resp);
300 }
301 }
302 }
303 req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
304
305 let url = format!("httpi://{own_node_id}{path_and_query}");
306
307 let has_upgrade_header = req_headers.iter().any(|(k, v)| {
310 k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
311 });
312 let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
313 k.eq_ignore_ascii_case("connection")
314 && v.split(',')
315 .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
316 });
317 let is_connect = req.method() == http::Method::CONNECT;
318
319 let is_bidi = if has_upgrade_header {
320 if !has_connection_upgrade || !is_connect {
321 let resp = hyper::Response::builder()
322 .status(StatusCode::BAD_REQUEST)
323 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
324 b"duplex upgrade requires CONNECT method with Connection: upgrade header",
325 ))))
326 .expect("static response args are valid");
327 return Ok(resp);
328 }
329 true
330 } else {
331 false
332 };
333
334 let upgrade_future = if is_bidi {
336 Some(hyper::upgrade::on(&mut req))
337 } else {
338 None
339 };
340
341 let mut guard = handles.insert_guard();
345 let (req_body_writer, req_body_reader) = handles.make_body_channel();
346 let req_body_handle = guard
347 .insert_reader(req_body_reader)
348 .map_err(|e| -> BoxError { e.into() })?;
349
350 let (res_body_writer, res_body_reader) = handles.make_body_channel();
352 let res_body_handle = guard
353 .insert_writer(res_body_writer)
354 .map_err(|e| -> BoxError { e.into() })?;
355
356 let (req_trailers_handle, res_trailers_handle, req_trailer_tx, opt_res_trailer_rx) =
359 if !is_bidi {
360 let (rq_tx, rq_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
362 let rq_h = guard
363 .insert_trailer_receiver(rq_rx)
364 .map_err(|e| -> BoxError { e.into() })?;
365 let (rs_tx, rs_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
367 let rs_h = guard
368 .insert_trailer_sender(rs_tx)
369 .map_err(|e| -> BoxError { e.into() })?;
370 (rq_h, rs_h, Some(rq_tx), Some(rs_rx))
371 } else {
372 (0u64, 0u64, None, None)
373 };
374
375 let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
378 let req_handle = guard
379 .allocate_req_handle(head_tx)
380 .map_err(|e| -> BoxError { e.into() })?;
381
382 guard.commit();
383
384 struct ReqHeadCleanup {
388 endpoint: IrohEndpoint,
389 req_handle: u64,
390 }
391 impl Drop for ReqHeadCleanup {
392 fn drop(&mut self) {
393 self.endpoint.handles().take_req_sender(self.req_handle);
394 }
395 }
396 let _req_head_cleanup = ReqHeadCleanup {
397 endpoint: self.endpoint.clone(),
398 req_handle,
399 };
400
401 let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
407 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
408 (Some(tx), Some(rx))
409 } else {
410 (None, None)
411 };
412
413 let duplex_req_body_writer = if !is_bidi {
414 let body = req.into_body();
415 let trailer_tx = req_trailer_tx.expect("non-duplex has req_trailer_tx");
416 let frame_timeout = handles.drain_timeout();
417 tokio::spawn(pump_hyper_body_to_channel_limited(
418 body,
419 req_body_writer,
420 trailer_tx,
421 max_request_body_bytes,
422 frame_timeout,
423 body_overflow_tx,
424 ));
425 None
426 } else {
427 drop(req.into_body());
429 Some(req_body_writer)
430 };
431
432 on_request_fire(
435 &self.on_request,
436 req_handle,
437 req_body_handle,
438 res_body_handle,
439 req_trailers_handle,
440 res_trailers_handle,
441 method,
442 url,
443 req_headers,
444 remote_node_id,
445 is_bidi,
446 );
447
448 let response_head = if let Some(overflow_rx) = body_overflow_rx {
454 tokio::select! {
455 biased;
456 _ = overflow_rx => {
457 let resp = hyper::Response::builder()
460 .status(StatusCode::PAYLOAD_TOO_LARGE)
461 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
462 b"request body too large",
463 ))))
464 .expect("valid 413 response");
465 return Ok(resp);
466 }
467 head = head_rx => {
468 head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
469 }
470 }
471 } else {
472 head_rx
473 .await
474 .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
475 };
476
477 if let Some(upgrade_fut) = upgrade_future {
484 let req_body_writer =
485 duplex_req_body_writer.expect("duplex path always has req_body_writer");
486
487 if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
490 drop(upgrade_fut);
491 drop(req_body_writer);
492 let mut resp_builder = hyper::Response::builder().status(response_head.status);
493 for (k, v) in &response_head.headers {
494 resp_builder = resp_builder.header(k.as_str(), v.as_str());
495 }
496 let resp = resp_builder
497 .body(crate::box_body(http_body_util::Empty::new()))
498 .map_err(|e| -> BoxError { e.into() })?;
499 return Ok(resp);
500 }
501
502 tokio::spawn(async move {
508 match upgrade_fut.await {
509 Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
510 Ok(upgraded) => {
511 let io = TokioIo::new(upgraded);
512 crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
513 }
514 }
515 });
516
517 let resp = hyper::Response::builder()
519 .status(StatusCode::SWITCHING_PROTOCOLS)
520 .header(hyper::header::CONNECTION, "Upgrade")
521 .header(hyper::header::UPGRADE, "iroh-duplex")
522 .body(crate::box_body(http_body_util::Empty::new()))
523 .expect("static response args are valid");
524 return Ok(resp);
525 }
526
527 let has_trailer_hdr = response_head
530 .headers
531 .iter()
532 .any(|(k, _)| k.eq_ignore_ascii_case("trailer"));
533 let trailer_rx_for_body = if has_trailer_hdr {
534 opt_res_trailer_rx
535 } else {
536 handles.remove_trailer_sender(res_trailers_handle);
537 None
538 };
539
540 let body_stream = body_from_reader(res_body_reader, trailer_rx_for_body);
541
542 let mut resp_builder = hyper::Response::builder().status(response_head.status);
543 for (k, v) in &response_head.headers {
544 resp_builder = resp_builder.header(k.as_str(), v.as_str());
545 }
546
547 #[cfg(feature = "compression")]
548 let resp_builder = resp_builder; let resp = resp_builder
551 .body(crate::box_body(body_stream))
552 .map_err(|e| -> BoxError { e.into() })?;
553
554 Ok(resp)
555 }
556}
557
558#[inline]
559#[allow(clippy::too_many_arguments)]
560fn on_request_fire(
561 cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
562 req_handle: u64,
563 req_body_handle: u64,
564 res_body_handle: u64,
565 req_trailers_handle: u64,
566 res_trailers_handle: u64,
567 method: String,
568 url: String,
569 headers: Vec<(String, String)>,
570 remote_node_id: String,
571 is_bidi: bool,
572) {
573 cb(RequestPayload {
574 req_handle,
575 req_body_handle,
576 res_body_handle,
577 req_trailers_handle,
578 res_trailers_handle,
579 method,
580 url,
581 headers,
582 remote_node_id,
583 is_bidi,
584 });
585}
586
587pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
594where
595 F: Fn(RequestPayload) + Send + Sync + 'static,
596{
597 serve_with_events(endpoint, options, on_request, None)
598}
599
600pub fn serve_with_events<F>(
605 endpoint: IrohEndpoint,
606 options: ServeOptions,
607 on_request: F,
608 on_connection_event: Option<ConnectionEventFn>,
609) -> ServeHandle
610where
611 F: Fn(RequestPayload) + Send + Sync + 'static,
612{
613 let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
614 let max_errors = options.max_consecutive_errors.unwrap_or(5);
615 let request_timeout = options
616 .request_timeout_ms
617 .map(Duration::from_millis)
618 .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
619 let max_conns_per_peer = options
620 .max_connections_per_peer
621 .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
622 let max_request_body_bytes = options.max_request_body_bytes;
623 let max_total_connections = options.max_total_connections;
624 let drain_timeout = Duration::from_secs(
625 options
626 .drain_timeout_secs
627 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
628 );
629 let load_shed_enabled = options.load_shed.unwrap_or(true);
631 let max_header_size = endpoint.max_header_size();
632 #[cfg(feature = "compression")]
633 let compression = endpoint.compression().cloned();
634 let own_node_id = Arc::new(endpoint.node_id().to_string());
635 let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
636
637 let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
638 Arc::new(Mutex::new(HashMap::new()));
639 let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
640
641 let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
644 let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
645
646 let base_svc = RequestService {
647 on_request,
648 endpoint: endpoint.clone(),
649 own_node_id,
650 remote_node_id: None,
651 max_request_body_bytes,
652 max_header_size: if max_header_size == 0 {
653 None
654 } else {
655 Some(max_header_size)
656 },
657 #[cfg(feature = "compression")]
658 compression,
659 };
660
661 let shutdown_notify = Arc::new(tokio::sync::Notify::new());
662 let shutdown_listen = shutdown_notify.clone();
663 let drain_dur = drain_timeout;
664 let total_connections = endpoint.inner.active_connections.clone();
667 let total_requests = endpoint.inner.active_requests.clone();
668 let (done_tx, done_rx) = tokio::sync::watch::channel(false);
669 let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
670
671 let in_flight_drain = in_flight.clone();
672 let drain_notify_drain = drain_notify.clone();
673
674 let join = tokio::spawn(async move {
675 let ep = endpoint.raw().clone();
676 let mut consecutive_errors: usize = 0;
677
678 loop {
679 let incoming = tokio::select! {
680 biased;
681 _ = shutdown_listen.notified() => {
682 tracing::info!("iroh-http: serve loop shutting down");
683 break;
684 }
685 inc = ep.accept() => match inc {
686 Some(i) => i,
687 None => {
688 tracing::info!("iroh-http: endpoint closed (accept returned None)");
689 let _ = endpoint_closed_tx.send(true);
690 break;
691 }
692 }
693 };
694
695 let conn = match incoming.await {
696 Ok(c) => {
697 consecutive_errors = 0;
698 c
699 }
700 Err(e) => {
701 consecutive_errors = consecutive_errors.saturating_add(1);
702 tracing::warn!(
703 "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
704 );
705 if consecutive_errors >= max_errors {
706 tracing::error!("iroh-http: too many accept errors — shutting down");
707 break;
708 }
709 continue;
710 }
711 };
712
713 let remote_pk = conn.remote_id();
714
715 if let Some(max_total) = max_total_connections {
717 let current = total_connections.load(Ordering::Relaxed);
718 if current >= max_total {
719 tracing::warn!(
720 "iroh-http: total connection limit reached ({current}/{max_total})"
721 );
722 conn.close(0u32.into(), b"server at capacity");
723 continue;
724 }
725 }
726
727 let remote_id = base32_encode(remote_pk.as_bytes());
728
729 let guard = match PeerConnectionGuard::acquire(
730 &peer_counts,
731 remote_pk,
732 remote_id.clone(),
733 max_conns_per_peer,
734 conn_event_fn.clone(),
735 ) {
736 Some(g) => g,
737 None => {
738 tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
739 conn.close(0u32.into(), b"too many connections");
740 continue;
741 }
742 };
743
744 let mut peer_svc = base_svc.clone();
745 peer_svc.remote_node_id = Some(remote_id);
746
747 let timeout_dur = if request_timeout.is_zero() {
748 Duration::MAX
749 } else {
750 request_timeout
751 };
752
753 let conn_total = total_connections.clone();
754 let conn_requests = total_requests.clone();
755 let in_flight_conn = in_flight.clone();
756 let drain_notify_conn = drain_notify.clone();
757 conn_total.fetch_add(1, Ordering::Relaxed);
758 tokio::spawn(async move {
759 let _guard = guard;
760 struct TotalGuard(Arc<AtomicUsize>);
762 impl Drop for TotalGuard {
763 fn drop(&mut self) {
764 self.0.fetch_sub(1, Ordering::Relaxed);
765 }
766 }
767 let _total_guard = TotalGuard(conn_total);
768
769 loop {
770 let (send, recv) = match conn.accept_bi().await {
771 Ok(pair) => pair,
772 Err(_) => break,
773 };
774
775 let io = TokioIo::new(IrohStream::new(send, recv));
776 let svc = peer_svc.clone();
777 let req_counter = conn_requests.clone();
778 req_counter.fetch_add(1, Ordering::Relaxed);
779 in_flight_conn.fetch_add(1, Ordering::Relaxed);
780
781 let in_flight_req = in_flight_conn.clone();
782 let drain_notify_req = drain_notify_conn.clone();
783
784 tokio::spawn(async move {
785 struct ReqGuard {
787 counter: Arc<AtomicUsize>,
788 in_flight: Arc<AtomicUsize>,
789 drain_notify: Arc<tokio::sync::Notify>,
790 }
791 impl Drop for ReqGuard {
792 fn drop(&mut self) {
793 self.counter.fetch_sub(1, Ordering::Relaxed);
794 if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
795 self.drain_notify.notify_waiters();
797 }
798 }
799 }
800 let _req_guard = ReqGuard {
801 counter: req_counter,
802 in_flight: in_flight_req,
803 drain_notify: drain_notify_req,
804 };
805 let effective_header_limit = if max_header_size == 0 {
808 64 * 1024
809 } else {
810 max_header_size.max(8192)
811 };
812
813 use tower::{
828 limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, ServiceBuilder,
829 };
830
831 #[cfg(feature = "compression")]
832 let result = {
833 use http::{Extensions, HeaderMap, Version};
834 use tower_http::compression::{
835 predicate::{Predicate, SizeAbove},
836 CompressionLayer,
837 };
838
839 let compression_config = svc.compression.clone();
840 if let Some(comp) = &compression_config {
841 let min_bytes = comp.min_body_bytes;
842 let mut layer = CompressionLayer::new().zstd(true);
843 if let Some(level) = comp.level {
844 use tower_http::compression::CompressionLevel;
845 layer = layer.quality(CompressionLevel::Precise(level as i32));
846 }
847 let not_pre_compressed =
848 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
849 !h.contains_key(http::header::CONTENT_ENCODING)
850 };
851 let not_no_transform =
852 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
853 h.get(http::header::CACHE_CONTROL)
854 .and_then(|v| v.to_str().ok())
855 .map(|v| {
856 !v.split(',').any(|d| {
857 d.trim().eq_ignore_ascii_case("no-transform")
858 })
859 })
860 .unwrap_or(true)
861 };
862 let predicate =
863 SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
864 .and(not_pre_compressed)
865 .and(not_no_transform);
866 if load_shed_enabled {
867 use tower::load_shed::LoadShedLayer;
868 let stk = TowerErrorHandler(
869 ServiceBuilder::new()
870 .layer(LoadShedLayer::new())
871 .layer(ConcurrencyLimitLayer::new(max))
872 .layer(TimeoutLayer::new(timeout_dur))
873 .service(svc),
874 );
875 hyper::server::conn::http1::Builder::new()
876 .max_buf_size(effective_header_limit)
877 .max_headers(128)
878 .serve_connection(
879 io,
880 TowerToHyperService::new(
881 ServiceBuilder::new()
882 .layer(layer.compress_when(predicate))
883 .service(stk),
884 ),
885 )
886 .with_upgrades()
887 .await
888 } else {
889 let stk = TowerErrorHandler(
890 ServiceBuilder::new()
891 .layer(ConcurrencyLimitLayer::new(max))
892 .layer(TimeoutLayer::new(timeout_dur))
893 .service(svc),
894 );
895 hyper::server::conn::http1::Builder::new()
896 .max_buf_size(effective_header_limit)
897 .max_headers(128)
898 .serve_connection(
899 io,
900 TowerToHyperService::new(
901 ServiceBuilder::new()
902 .layer(layer.compress_when(predicate))
903 .service(stk),
904 ),
905 )
906 .with_upgrades()
907 .await
908 }
909 } else if load_shed_enabled {
910 use tower::load_shed::LoadShedLayer;
911 let stk = TowerErrorHandler(
912 ServiceBuilder::new()
913 .layer(LoadShedLayer::new())
914 .layer(ConcurrencyLimitLayer::new(max))
915 .layer(TimeoutLayer::new(timeout_dur))
916 .service(svc),
917 );
918 hyper::server::conn::http1::Builder::new()
919 .max_buf_size(effective_header_limit)
920 .max_headers(128)
921 .serve_connection(io, TowerToHyperService::new(stk))
922 .with_upgrades()
923 .await
924 } else {
925 let stk = TowerErrorHandler(
926 ServiceBuilder::new()
927 .layer(ConcurrencyLimitLayer::new(max))
928 .layer(TimeoutLayer::new(timeout_dur))
929 .service(svc),
930 );
931 hyper::server::conn::http1::Builder::new()
932 .max_buf_size(effective_header_limit)
933 .max_headers(128)
934 .serve_connection(io, TowerToHyperService::new(stk))
935 .with_upgrades()
936 .await
937 }
938 };
939 #[cfg(not(feature = "compression"))]
940 let result = if load_shed_enabled {
941 use tower::load_shed::LoadShedLayer;
942 let stk = TowerErrorHandler(
943 ServiceBuilder::new()
944 .layer(LoadShedLayer::new())
945 .layer(ConcurrencyLimitLayer::new(max))
946 .layer(TimeoutLayer::new(timeout_dur))
947 .service(svc),
948 );
949 hyper::server::conn::http1::Builder::new()
950 .max_buf_size(effective_header_limit)
951 .max_headers(128)
952 .serve_connection(io, TowerToHyperService::new(stk))
953 .with_upgrades()
954 .await
955 } else {
956 let stk = TowerErrorHandler(
957 ServiceBuilder::new()
958 .layer(ConcurrencyLimitLayer::new(max))
959 .layer(TimeoutLayer::new(timeout_dur))
960 .service(svc),
961 );
962 hyper::server::conn::http1::Builder::new()
963 .max_buf_size(effective_header_limit)
964 .max_headers(128)
965 .serve_connection(io, TowerToHyperService::new(stk))
966 .with_upgrades()
967 .await
968 };
969
970 if let Err(e) = result {
971 tracing::debug!("iroh-http: http1 connection error: {e}");
972 }
973 });
974 }
975 });
976 }
977
978 let deadline = tokio::time::Instant::now()
985 .checked_add(drain_dur)
986 .expect("drain duration overflow");
987 loop {
988 if in_flight_drain.load(Ordering::Acquire) == 0 {
989 tracing::info!("iroh-http: all in-flight requests drained");
990 break;
991 }
992 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
993 if remaining.is_zero() {
994 tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
995 break;
996 }
997 tokio::select! {
998 _ = drain_notify_drain.notified() => {}
999 _ = tokio::time::sleep(remaining) => {}
1000 }
1001 }
1002 let _ = done_tx.send(true);
1003 });
1004
1005 ServeHandle {
1006 join,
1007 shutdown_notify,
1008 drain_timeout: drain_dur,
1009 done_rx,
1010 }
1011}
1012
1013#[derive(Clone)]
1028struct TowerErrorHandler<S>(S);
1029
1030impl<S, Req> Service<Req> for TowerErrorHandler<S>
1031where
1032 S: Service<Req, Response = hyper::Response<BoxBody>>,
1033 S::Error: Into<BoxError>,
1034 S::Future: Send + 'static,
1035{
1036 type Response = hyper::Response<BoxBody>;
1037 type Error = BoxError;
1038 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1039
1040 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1041 self.0.poll_ready(cx).map_err(Into::into)
1046 }
1047
1048 fn call(&mut self, req: Req) -> Self::Future {
1049 let fut = self.0.call(req);
1050 Box::pin(async move {
1051 match fut.await {
1052 Ok(r) => Ok(r),
1053 Err(e) => {
1054 let e = e.into();
1055 let status = if e.is::<tower::timeout::error::Elapsed>() {
1056 StatusCode::REQUEST_TIMEOUT
1057 } else if e.is::<tower::load_shed::error::Overloaded>() {
1058 StatusCode::SERVICE_UNAVAILABLE
1059 } else {
1060 tracing::warn!("iroh-http: unexpected tower error: {e}");
1061 StatusCode::INTERNAL_SERVER_ERROR
1062 };
1063 let body_bytes: &'static [u8] = match status {
1064 StatusCode::REQUEST_TIMEOUT => b"request timed out",
1065 StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1066 _ => b"internal server error",
1067 };
1068 Ok(hyper::Response::builder()
1069 .status(status)
1070 .body(crate::box_body(http_body_util::Full::new(
1071 Bytes::from_static(body_bytes),
1072 )))
1073 .expect("valid error response"))
1074 }
1075 }
1076 })
1077 }
1078}