1use std::{
8 collections::HashMap,
9 future::Future,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc, Mutex,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27 base32_encode,
28 client::{body_from_reader, pump_hyper_body_to_channel_limited},
29 io::IrohStream,
30 stream::{HandleStore, ResponseHeadEntry},
31 ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49 pub max_concurrency: Option<usize>,
50 pub max_consecutive_errors: Option<usize>,
51 pub request_timeout_ms: Option<u64>,
52 pub max_connections_per_peer: Option<usize>,
53 pub max_request_body_bytes: Option<usize>,
54 pub drain_timeout_secs: Option<u64>,
55 pub max_total_connections: Option<usize>,
56 pub load_shed: Option<bool>,
60}
61
62pub type ServeOptions = ServerLimits;
65
66const DEFAULT_CONCURRENCY: usize = 64;
67const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
68const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
69const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
70
71pub struct ServeHandle {
74 join: tokio::task::JoinHandle<()>,
75 shutdown_notify: Arc<tokio::sync::Notify>,
76 drain_timeout: std::time::Duration,
77 done_rx: tokio::sync::watch::Receiver<bool>,
79}
80
81impl ServeHandle {
82 pub fn shutdown(&self) {
83 self.shutdown_notify.notify_one();
84 }
85 pub async fn drain(self) {
86 self.shutdown();
87 let _ = self.join.await;
88 }
89 pub fn abort(&self) {
90 self.join.abort();
91 }
92 pub fn drain_timeout(&self) -> std::time::Duration {
93 self.drain_timeout
94 }
95 pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
100 self.done_rx.clone()
101 }
102}
103
104pub fn respond(
107 handles: &HandleStore,
108 req_handle: u64,
109 status: u16,
110 headers: Vec<(String, String)>,
111) -> Result<(), CoreError> {
112 StatusCode::from_u16(status)
113 .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
114 for (name, value) in &headers {
115 HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
116 CoreError::invalid_input(format!("invalid response header name {:?}", name))
117 })?;
118 HeaderValue::from_str(value).map_err(|_| {
119 CoreError::invalid_input(format!("invalid response header value for {:?}", name))
120 })?;
121 }
122
123 let sender = handles
124 .take_req_sender(req_handle)
125 .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
126 sender
127 .send(ResponseHeadEntry { status, headers })
128 .map_err(|_| CoreError::internal("serve task dropped before respond"))
129}
130
131type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
134
135struct PeerConnectionGuard {
136 counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
137 peer: iroh::PublicKey,
138 peer_id_str: String,
139 on_event: Option<ConnectionEventFn>,
140}
141
142impl PeerConnectionGuard {
143 fn acquire(
144 counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
145 peer: iroh::PublicKey,
146 peer_id_str: String,
147 max: usize,
148 on_event: Option<ConnectionEventFn>,
149 ) -> Option<Self> {
150 let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
151 let count = map.entry(peer).or_insert(0);
152 if *count >= max {
153 return None;
154 }
155 let was_zero = *count == 0;
156 *count += 1;
157 let guard = PeerConnectionGuard {
158 counts: counts.clone(),
159 peer,
160 peer_id_str: peer_id_str.clone(),
161 on_event: on_event.clone(),
162 };
163 if was_zero {
165 if let Some(cb) = &on_event {
166 cb(ConnectionEvent {
167 peer_id: peer_id_str,
168 connected: true,
169 });
170 }
171 }
172 Some(guard)
173 }
174}
175
176impl Drop for PeerConnectionGuard {
177 fn drop(&mut self) {
178 let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
179 if let Some(c) = map.get_mut(&self.peer) {
180 *c = c.saturating_sub(1);
181 if *c == 0 {
182 map.remove(&self.peer);
183 if let Some(cb) = &self.on_event {
185 cb(ConnectionEvent {
186 peer_id: self.peer_id_str.clone(),
187 connected: false,
188 });
189 }
190 }
191 }
192 }
193}
194
195#[derive(Clone)]
198struct RequestService {
199 on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
200 endpoint: IrohEndpoint,
201 own_node_id: Arc<String>,
202 remote_node_id: Option<String>,
203 max_request_body_bytes: Option<usize>,
204 max_header_size: Option<usize>,
205 #[cfg(feature = "compression")]
206 compression: Option<crate::endpoint::CompressionOptions>,
207}
208
209impl Service<hyper::Request<Incoming>> for RequestService {
210 type Response = hyper::Response<BoxBody>;
211 type Error = BoxError;
212 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215 Poll::Ready(Ok(()))
216 }
217
218 fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
219 let svc = self.clone();
220 Box::pin(async move { svc.handle(req).await })
221 }
222}
223
224impl RequestService {
225 async fn handle(
226 self,
227 mut req: hyper::Request<Incoming>,
228 ) -> Result<hyper::Response<BoxBody>, BoxError> {
229 let handles = self.endpoint.handles();
230 let own_node_id = &*self.own_node_id;
231 let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
232 let max_request_body_bytes = self.max_request_body_bytes;
233 let max_header_size = self.max_header_size;
234
235 let method = req.method().to_string();
236 let path_and_query = req
237 .uri()
238 .path_and_query()
239 .map(|p| p.as_str())
240 .unwrap_or("/")
241 .to_string();
242
243 tracing::debug!(
244 method = %method,
245 path = %path_and_query,
246 peer = %remote_node_id,
247 "iroh-http: incoming request",
248 );
249 if let Some(limit) = max_header_size {
258 let header_bytes: usize = req
259 .headers()
260 .iter()
261 .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
262 .map(|(k, v)| k.as_str().len() + v.as_bytes().len() + 4) .sum::<usize>()
264 + "peer-id".len()
265 + remote_node_id.len()
266 + 4
267 + req.uri().to_string().len()
268 + method.len()
269 + 12; if header_bytes > limit {
271 let resp = hyper::Response::builder()
272 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
273 .body(crate::box_body(http_body_util::Empty::new()))
274 .unwrap();
275 return Ok(resp);
276 }
277 }
278
279 let mut req_headers: Vec<(String, String)> = Vec::new();
281 for (k, v) in req.headers().iter() {
282 if k.as_str().eq_ignore_ascii_case("peer-id") {
283 continue;
284 }
285 match v.to_str() {
286 Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
287 Err(_) => {
288 let resp = hyper::Response::builder()
289 .status(StatusCode::BAD_REQUEST)
290 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
291 b"non-UTF8 header value",
292 ))))
293 .unwrap();
294 return Ok(resp);
295 }
296 }
297 }
298 req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
299
300 let url = format!("httpi://{own_node_id}{path_and_query}");
301
302 let has_upgrade_header = req_headers.iter().any(|(k, v)| {
305 k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
306 });
307 let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
308 k.eq_ignore_ascii_case("connection")
309 && v.split(',')
310 .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
311 });
312 let is_connect = req.method() == http::Method::CONNECT;
313
314 let is_bidi = if has_upgrade_header {
315 if !has_connection_upgrade || !is_connect {
316 let resp = hyper::Response::builder()
317 .status(StatusCode::BAD_REQUEST)
318 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
319 b"duplex upgrade requires CONNECT method with Connection: upgrade header",
320 ))))
321 .unwrap();
322 return Ok(resp);
323 }
324 true
325 } else {
326 false
327 };
328
329 let upgrade_future = if is_bidi {
331 Some(hyper::upgrade::on(&mut req))
332 } else {
333 None
334 };
335
336 let mut guard = handles.insert_guard();
340 let (req_body_writer, req_body_reader) = handles.make_body_channel();
341 let req_body_handle = guard
342 .insert_reader(req_body_reader)
343 .map_err(|e| -> BoxError { e.into() })?;
344
345 let (res_body_writer, res_body_reader) = handles.make_body_channel();
347 let res_body_handle = guard
348 .insert_writer(res_body_writer)
349 .map_err(|e| -> BoxError { e.into() })?;
350
351 let (req_trailers_handle, res_trailers_handle, req_trailer_tx, opt_res_trailer_rx) =
354 if !is_bidi {
355 let (rq_tx, rq_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
357 let rq_h = guard
358 .insert_trailer_receiver(rq_rx)
359 .map_err(|e| -> BoxError { e.into() })?;
360 let (rs_tx, rs_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
362 let rs_h = guard
363 .insert_trailer_sender(rs_tx)
364 .map_err(|e| -> BoxError { e.into() })?;
365 (rq_h, rs_h, Some(rq_tx), Some(rs_rx))
366 } else {
367 (0u64, 0u64, None, None)
368 };
369
370 let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
373 let req_handle = guard
374 .allocate_req_handle(head_tx)
375 .map_err(|e| -> BoxError { e.into() })?;
376
377 guard.commit();
378
379 struct ReqHeadCleanup {
383 endpoint: IrohEndpoint,
384 req_handle: u64,
385 }
386 impl Drop for ReqHeadCleanup {
387 fn drop(&mut self) {
388 self.endpoint.handles().take_req_sender(self.req_handle);
389 }
390 }
391 let _req_head_cleanup = ReqHeadCleanup {
392 endpoint: self.endpoint.clone(),
393 req_handle,
394 };
395
396 let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
402 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
403 (Some(tx), Some(rx))
404 } else {
405 (None, None)
406 };
407
408 let duplex_req_body_writer = if !is_bidi {
409 let body = req.into_body();
410 let trailer_tx = req_trailer_tx.expect("non-duplex has req_trailer_tx");
411 let frame_timeout = handles.drain_timeout();
412 tokio::spawn(pump_hyper_body_to_channel_limited(
413 body,
414 req_body_writer,
415 trailer_tx,
416 max_request_body_bytes,
417 frame_timeout,
418 body_overflow_tx,
419 ));
420 None
421 } else {
422 drop(req.into_body());
424 Some(req_body_writer)
425 };
426
427 on_request_fire(
430 &self.on_request,
431 req_handle,
432 req_body_handle,
433 res_body_handle,
434 req_trailers_handle,
435 res_trailers_handle,
436 method,
437 url,
438 req_headers,
439 remote_node_id,
440 is_bidi,
441 );
442
443 let response_head = if let Some(overflow_rx) = body_overflow_rx {
449 tokio::select! {
450 biased;
451 _ = overflow_rx => {
452 let resp = hyper::Response::builder()
455 .status(StatusCode::PAYLOAD_TOO_LARGE)
456 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
457 b"request body too large",
458 ))))
459 .expect("valid 413 response");
460 return Ok(resp);
461 }
462 head = head_rx => {
463 head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
464 }
465 }
466 } else {
467 head_rx
468 .await
469 .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
470 };
471
472 if let Some(upgrade_fut) = upgrade_future {
479 let req_body_writer =
480 duplex_req_body_writer.expect("duplex path always has req_body_writer");
481
482 if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
485 drop(upgrade_fut);
486 drop(req_body_writer);
487 let mut resp_builder = hyper::Response::builder().status(response_head.status);
488 for (k, v) in &response_head.headers {
489 resp_builder = resp_builder.header(k.as_str(), v.as_str());
490 }
491 let resp = resp_builder
492 .body(crate::box_body(http_body_util::Empty::new()))
493 .map_err(|e| -> BoxError { e.into() })?;
494 return Ok(resp);
495 }
496
497 tokio::spawn(async move {
503 match upgrade_fut.await {
504 Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
505 Ok(upgraded) => {
506 let io = TokioIo::new(upgraded);
507 crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
508 }
509 }
510 });
511
512 let resp = hyper::Response::builder()
514 .status(StatusCode::SWITCHING_PROTOCOLS)
515 .header(hyper::header::CONNECTION, "Upgrade")
516 .header(hyper::header::UPGRADE, "iroh-duplex")
517 .body(crate::box_body(http_body_util::Empty::new()))
518 .unwrap();
519 return Ok(resp);
520 }
521
522 let has_trailer_hdr = response_head
525 .headers
526 .iter()
527 .any(|(k, _)| k.eq_ignore_ascii_case("trailer"));
528 let trailer_rx_for_body = if has_trailer_hdr {
529 opt_res_trailer_rx
530 } else {
531 handles.remove_trailer_sender(res_trailers_handle);
532 None
533 };
534
535 let body_stream = body_from_reader(res_body_reader, trailer_rx_for_body);
536
537 let mut resp_builder = hyper::Response::builder().status(response_head.status);
538 for (k, v) in &response_head.headers {
539 resp_builder = resp_builder.header(k.as_str(), v.as_str());
540 }
541
542 #[cfg(feature = "compression")]
543 let resp_builder = resp_builder; let resp = resp_builder
546 .body(crate::box_body(body_stream))
547 .map_err(|e| -> BoxError { e.into() })?;
548
549 Ok(resp)
550 }
551}
552
553#[inline]
554#[allow(clippy::too_many_arguments)]
555fn on_request_fire(
556 cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
557 req_handle: u64,
558 req_body_handle: u64,
559 res_body_handle: u64,
560 req_trailers_handle: u64,
561 res_trailers_handle: u64,
562 method: String,
563 url: String,
564 headers: Vec<(String, String)>,
565 remote_node_id: String,
566 is_bidi: bool,
567) {
568 cb(RequestPayload {
569 req_handle,
570 req_body_handle,
571 res_body_handle,
572 req_trailers_handle,
573 res_trailers_handle,
574 method,
575 url,
576 headers,
577 remote_node_id,
578 is_bidi,
579 });
580}
581
582pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
589where
590 F: Fn(RequestPayload) + Send + Sync + 'static,
591{
592 serve_with_events(endpoint, options, on_request, None)
593}
594
595pub fn serve_with_events<F>(
600 endpoint: IrohEndpoint,
601 options: ServeOptions,
602 on_request: F,
603 on_connection_event: Option<ConnectionEventFn>,
604) -> ServeHandle
605where
606 F: Fn(RequestPayload) + Send + Sync + 'static,
607{
608 let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
609 let max_errors = options.max_consecutive_errors.unwrap_or(5);
610 let request_timeout = options
611 .request_timeout_ms
612 .map(Duration::from_millis)
613 .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
614 let max_conns_per_peer = options
615 .max_connections_per_peer
616 .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
617 let max_request_body_bytes = options.max_request_body_bytes;
618 let max_total_connections = options.max_total_connections;
619 let drain_timeout = Duration::from_secs(
620 options
621 .drain_timeout_secs
622 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
623 );
624 let load_shed_enabled = options.load_shed.unwrap_or(true);
626 let max_header_size = endpoint.max_header_size();
627 #[cfg(feature = "compression")]
628 let compression = endpoint.compression().cloned();
629 let own_node_id = Arc::new(endpoint.node_id().to_string());
630 let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
631
632 let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
633 Arc::new(Mutex::new(HashMap::new()));
634 let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
635
636 let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
639 let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
640
641 let base_svc = RequestService {
642 on_request,
643 endpoint: endpoint.clone(),
644 own_node_id,
645 remote_node_id: None,
646 max_request_body_bytes,
647 max_header_size: if max_header_size == 0 {
648 None
649 } else {
650 Some(max_header_size)
651 },
652 #[cfg(feature = "compression")]
653 compression,
654 };
655
656 let shutdown_notify = Arc::new(tokio::sync::Notify::new());
657 let shutdown_listen = shutdown_notify.clone();
658 let drain_dur = drain_timeout;
659 let total_connections = endpoint.inner.active_connections.clone();
662 let total_requests = endpoint.inner.active_requests.clone();
663 let (done_tx, done_rx) = tokio::sync::watch::channel(false);
664 let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
665
666 let in_flight_drain = in_flight.clone();
667 let drain_notify_drain = drain_notify.clone();
668
669 let join = tokio::spawn(async move {
670 let ep = endpoint.raw().clone();
671 let mut consecutive_errors: usize = 0;
672
673 loop {
674 let incoming = tokio::select! {
675 biased;
676 _ = shutdown_listen.notified() => {
677 tracing::info!("iroh-http: serve loop shutting down");
678 break;
679 }
680 inc = ep.accept() => match inc {
681 Some(i) => i,
682 None => {
683 tracing::info!("iroh-http: endpoint closed (accept returned None)");
684 let _ = endpoint_closed_tx.send(true);
685 break;
686 }
687 }
688 };
689
690 let conn = match incoming.await {
691 Ok(c) => {
692 consecutive_errors = 0;
693 c
694 }
695 Err(e) => {
696 consecutive_errors += 1;
697 tracing::warn!(
698 "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
699 );
700 if consecutive_errors >= max_errors {
701 tracing::error!("iroh-http: too many accept errors — shutting down");
702 break;
703 }
704 continue;
705 }
706 };
707
708 let remote_pk = conn.remote_id();
709
710 if let Some(max_total) = max_total_connections {
712 let current = total_connections.load(Ordering::Relaxed);
713 if current >= max_total {
714 tracing::warn!(
715 "iroh-http: total connection limit reached ({current}/{max_total})"
716 );
717 conn.close(0u32.into(), b"server at capacity");
718 continue;
719 }
720 }
721
722 let remote_id = base32_encode(remote_pk.as_bytes());
723
724 let guard =
725 match PeerConnectionGuard::acquire(&peer_counts, remote_pk, remote_id.clone(), max_conns_per_peer, conn_event_fn.clone()) {
726 Some(g) => g,
727 None => {
728 tracing::warn!(
729 "iroh-http: peer {remote_id} exceeded connection limit"
730 );
731 conn.close(0u32.into(), b"too many connections");
732 continue;
733 }
734 };
735
736 let mut peer_svc = base_svc.clone();
737 peer_svc.remote_node_id = Some(remote_id);
738
739 let timeout_dur = if request_timeout.is_zero() {
740 Duration::MAX
741 } else {
742 request_timeout
743 };
744
745 let conn_total = total_connections.clone();
746 let conn_requests = total_requests.clone();
747 let in_flight_conn = in_flight.clone();
748 let drain_notify_conn = drain_notify.clone();
749 conn_total.fetch_add(1, Ordering::Relaxed);
750 tokio::spawn(async move {
751 let _guard = guard;
752 struct TotalGuard(Arc<AtomicUsize>);
754 impl Drop for TotalGuard {
755 fn drop(&mut self) {
756 self.0.fetch_sub(1, Ordering::Relaxed);
757 }
758 }
759 let _total_guard = TotalGuard(conn_total);
760
761 loop {
762 let (send, recv) = match conn.accept_bi().await {
763 Ok(pair) => pair,
764 Err(_) => break,
765 };
766
767 let io = TokioIo::new(IrohStream::new(send, recv));
768 let svc = peer_svc.clone();
769 let req_counter = conn_requests.clone();
770 req_counter.fetch_add(1, Ordering::Relaxed);
771 in_flight_conn.fetch_add(1, Ordering::Relaxed);
772
773 let in_flight_req = in_flight_conn.clone();
774 let drain_notify_req = drain_notify_conn.clone();
775
776 tokio::spawn(async move {
777 struct ReqGuard {
779 counter: Arc<AtomicUsize>,
780 in_flight: Arc<AtomicUsize>,
781 drain_notify: Arc<tokio::sync::Notify>,
782 }
783 impl Drop for ReqGuard {
784 fn drop(&mut self) {
785 self.counter.fetch_sub(1, Ordering::Relaxed);
786 if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
787 self.drain_notify.notify_waiters();
789 }
790 }
791 }
792 let _req_guard = ReqGuard {
793 counter: req_counter,
794 in_flight: in_flight_req,
795 drain_notify: drain_notify_req,
796 };
797 let effective_header_limit = if max_header_size == 0 {
800 64 * 1024
801 } else {
802 max_header_size.max(8192)
803 };
804
805 use tower::{ServiceBuilder, limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
820
821 #[cfg(feature = "compression")]
822 let result = {
823 use http::{Extensions, HeaderMap, Version};
824 use tower_http::compression::{predicate::{Predicate, SizeAbove}, CompressionLayer};
825
826 let compression_config = svc.compression.clone();
827 if let Some(comp) = &compression_config {
828 let min_bytes = comp.min_body_bytes;
829 let mut layer = CompressionLayer::new().zstd(true);
830 if let Some(level) = comp.level {
831 use tower_http::compression::CompressionLevel;
832 layer = layer.quality(CompressionLevel::Precise(level as i32));
833 }
834 let not_pre_compressed =
835 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
836 !h.contains_key(http::header::CONTENT_ENCODING)
837 };
838 let not_no_transform =
839 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
840 h.get(http::header::CACHE_CONTROL)
841 .and_then(|v| v.to_str().ok())
842 .map(|v| {
843 !v.split(',').any(|d| {
844 d.trim().eq_ignore_ascii_case("no-transform")
845 })
846 })
847 .unwrap_or(true)
848 };
849 let predicate =
850 SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
851 .and(not_pre_compressed)
852 .and(not_no_transform);
853 if load_shed_enabled {
854 use tower::load_shed::LoadShedLayer;
855 let stk = TowerErrorHandler(ServiceBuilder::new()
856 .layer(LoadShedLayer::new())
857 .layer(ConcurrencyLimitLayer::new(max))
858 .layer(TimeoutLayer::new(timeout_dur))
859 .service(svc));
860 hyper::server::conn::http1::Builder::new()
861 .max_buf_size(effective_header_limit)
862 .max_headers(128)
863 .serve_connection(io, TowerToHyperService::new(
864 ServiceBuilder::new()
865 .layer(layer.compress_when(predicate))
866 .service(stk),
867 ))
868 .with_upgrades()
869 .await
870 } else {
871 let stk = TowerErrorHandler(ServiceBuilder::new()
872 .layer(ConcurrencyLimitLayer::new(max))
873 .layer(TimeoutLayer::new(timeout_dur))
874 .service(svc));
875 hyper::server::conn::http1::Builder::new()
876 .max_buf_size(effective_header_limit)
877 .max_headers(128)
878 .serve_connection(io, TowerToHyperService::new(
879 ServiceBuilder::new()
880 .layer(layer.compress_when(predicate))
881 .service(stk),
882 ))
883 .with_upgrades()
884 .await
885 }
886 } else if load_shed_enabled {
887 use tower::load_shed::LoadShedLayer;
888 let stk = TowerErrorHandler(ServiceBuilder::new()
889 .layer(LoadShedLayer::new())
890 .layer(ConcurrencyLimitLayer::new(max))
891 .layer(TimeoutLayer::new(timeout_dur))
892 .service(svc));
893 hyper::server::conn::http1::Builder::new()
894 .max_buf_size(effective_header_limit)
895 .max_headers(128)
896 .serve_connection(io, TowerToHyperService::new(stk))
897 .with_upgrades()
898 .await
899 } else {
900 let stk = TowerErrorHandler(ServiceBuilder::new()
901 .layer(ConcurrencyLimitLayer::new(max))
902 .layer(TimeoutLayer::new(timeout_dur))
903 .service(svc));
904 hyper::server::conn::http1::Builder::new()
905 .max_buf_size(effective_header_limit)
906 .max_headers(128)
907 .serve_connection(io, TowerToHyperService::new(stk))
908 .with_upgrades()
909 .await
910 }
911 };
912 #[cfg(not(feature = "compression"))]
913 let result = if load_shed_enabled {
914 use tower::load_shed::LoadShedLayer;
915 let stk = TowerErrorHandler(ServiceBuilder::new()
916 .layer(LoadShedLayer::new())
917 .layer(ConcurrencyLimitLayer::new(max))
918 .layer(TimeoutLayer::new(timeout_dur))
919 .service(svc));
920 hyper::server::conn::http1::Builder::new()
921 .max_buf_size(effective_header_limit)
922 .max_headers(128)
923 .serve_connection(io, TowerToHyperService::new(stk))
924 .with_upgrades()
925 .await
926 } else {
927 let stk = TowerErrorHandler(ServiceBuilder::new()
928 .layer(ConcurrencyLimitLayer::new(max))
929 .layer(TimeoutLayer::new(timeout_dur))
930 .service(svc));
931 hyper::server::conn::http1::Builder::new()
932 .max_buf_size(effective_header_limit)
933 .max_headers(128)
934 .serve_connection(io, TowerToHyperService::new(stk))
935 .with_upgrades()
936 .await
937 };
938
939 if let Err(e) = result {
940 tracing::debug!("iroh-http: http1 connection error: {e}");
941 }
942 });
943 }
944 });
945 }
946
947 let deadline = tokio::time::Instant::now() + drain_dur;
954 loop {
955 if in_flight_drain.load(Ordering::Acquire) == 0 {
956 tracing::info!("iroh-http: all in-flight requests drained");
957 break;
958 }
959 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
960 if remaining.is_zero() {
961 tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
962 break;
963 }
964 tokio::select! {
965 _ = drain_notify_drain.notified() => {}
966 _ = tokio::time::sleep(remaining) => {}
967 }
968 }
969 let _ = done_tx.send(true);
970 });
971
972 ServeHandle {
973 join,
974 shutdown_notify,
975 drain_timeout: drain_dur,
976 done_rx,
977 }
978}
979
980#[derive(Clone)]
995struct TowerErrorHandler<S>(S);
996
997impl<S, Req> Service<Req> for TowerErrorHandler<S>
998where
999 S: Service<Req, Response = hyper::Response<BoxBody>>,
1000 S::Error: Into<BoxError>,
1001 S::Future: Send + 'static,
1002{
1003 type Response = hyper::Response<BoxBody>;
1004 type Error = BoxError;
1005 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1006
1007 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1008 self.0.poll_ready(cx).map_err(Into::into)
1013 }
1014
1015 fn call(&mut self, req: Req) -> Self::Future {
1016 let fut = self.0.call(req);
1017 Box::pin(async move {
1018 match fut.await {
1019 Ok(r) => Ok(r),
1020 Err(e) => {
1021 let e = e.into();
1022 let status = if e.is::<tower::timeout::error::Elapsed>() {
1023 StatusCode::REQUEST_TIMEOUT
1024 } else if e.is::<tower::load_shed::error::Overloaded>() {
1025 StatusCode::SERVICE_UNAVAILABLE
1026 } else {
1027 tracing::warn!("iroh-http: unexpected tower error: {e}");
1028 StatusCode::INTERNAL_SERVER_ERROR
1029 };
1030 let body_bytes: &'static [u8] = match status {
1031 StatusCode::REQUEST_TIMEOUT => b"request timed out",
1032 StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1033 _ => b"internal server error",
1034 };
1035 Ok(hyper::Response::builder()
1036 .status(status)
1037 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
1038 body_bytes,
1039 ))))
1040 .expect("valid error response"))
1041 }
1042 }
1043 })
1044 }
1045}