1use std::{
8 collections::HashMap,
9 future::Future,
10 pin::Pin,
11 sync::{
12 atomic::{AtomicUsize, Ordering},
13 Arc, Mutex,
14 },
15 task::{Context, Poll},
16 time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27 base32_encode,
28 client::{body_from_reader, pump_hyper_body_to_channel_limited},
29 io::IrohStream,
30 stream::{HandleStore, ResponseHeadEntry},
31 ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49 pub max_concurrency: Option<usize>,
50 pub max_consecutive_errors: Option<usize>,
51 pub request_timeout_ms: Option<u64>,
52 pub max_connections_per_peer: Option<usize>,
53 pub max_request_body_bytes: Option<usize>,
54 pub max_response_body_bytes: Option<usize>,
57 pub drain_timeout_secs: Option<u64>,
58 pub max_total_connections: Option<usize>,
59 pub load_shed: Option<bool>,
63}
64
65pub type ServeOptions = ServerLimits;
68
69const DEFAULT_CONCURRENCY: usize = 1024;
70const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
71const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
72const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
73const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 16 * 1024 * 1024;
76pub(crate) const DEFAULT_MAX_RESPONSE_BODY_BYTES: usize = 256 * 1024 * 1024;
80
81pub struct ServeHandle {
84 join: tokio::task::JoinHandle<()>,
85 shutdown_notify: Arc<tokio::sync::Notify>,
86 drain_timeout: std::time::Duration,
87 done_rx: tokio::sync::watch::Receiver<bool>,
89}
90
91impl ServeHandle {
92 pub fn shutdown(&self) {
93 self.shutdown_notify.notify_one();
94 }
95 pub async fn drain(self) {
96 self.shutdown();
97 let _ = self.join.await;
98 }
99 pub fn abort(&self) {
100 self.join.abort();
101 }
102 pub fn drain_timeout(&self) -> std::time::Duration {
103 self.drain_timeout
104 }
105 pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
110 self.done_rx.clone()
111 }
112}
113
114pub fn respond(
117 handles: &HandleStore,
118 req_handle: u64,
119 status: u16,
120 headers: Vec<(String, String)>,
121) -> Result<(), CoreError> {
122 StatusCode::from_u16(status)
123 .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
124 for (name, value) in &headers {
125 HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
126 CoreError::invalid_input(format!("invalid response header name {:?}", name))
127 })?;
128 HeaderValue::from_str(value).map_err(|_| {
129 CoreError::invalid_input(format!("invalid response header value for {:?}", name))
130 })?;
131 }
132
133 let sender = handles
134 .take_req_sender(req_handle)
135 .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
136 sender
137 .send(ResponseHeadEntry { status, headers })
138 .map_err(|_| CoreError::internal("serve task dropped before respond"))
139}
140
141type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
144
145struct PeerConnectionGuard {
146 counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
147 peer: iroh::PublicKey,
148 peer_id_str: String,
149 on_event: Option<ConnectionEventFn>,
150}
151
152impl PeerConnectionGuard {
153 fn acquire(
154 counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
155 peer: iroh::PublicKey,
156 peer_id_str: String,
157 max: usize,
158 on_event: Option<ConnectionEventFn>,
159 ) -> Option<Self> {
160 let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
161 let count = map.entry(peer).or_insert(0);
162 if *count >= max {
163 return None;
164 }
165 let was_zero = *count == 0;
166 *count = count.saturating_add(1);
167 let guard = PeerConnectionGuard {
168 counts: counts.clone(),
169 peer,
170 peer_id_str: peer_id_str.clone(),
171 on_event: on_event.clone(),
172 };
173 if was_zero {
175 if let Some(cb) = &on_event {
176 cb(ConnectionEvent {
177 peer_id: peer_id_str,
178 connected: true,
179 });
180 }
181 }
182 Some(guard)
183 }
184}
185
186impl Drop for PeerConnectionGuard {
187 fn drop(&mut self) {
188 let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
189 if let Some(c) = map.get_mut(&self.peer) {
190 *c = c.saturating_sub(1);
191 if *c == 0 {
192 map.remove(&self.peer);
193 if let Some(cb) = &self.on_event {
195 cb(ConnectionEvent {
196 peer_id: self.peer_id_str.clone(),
197 connected: false,
198 });
199 }
200 }
201 }
202 }
203}
204
205#[derive(Clone)]
208struct RequestService {
209 on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
210 endpoint: IrohEndpoint,
211 own_node_id: Arc<String>,
212 remote_node_id: Option<String>,
213 max_request_body_bytes: Option<usize>,
214 max_header_size: Option<usize>,
215 #[cfg(feature = "compression")]
216 compression: Option<crate::endpoint::CompressionOptions>,
217}
218
219impl Service<hyper::Request<Incoming>> for RequestService {
220 type Response = hyper::Response<BoxBody>;
221 type Error = BoxError;
222 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
223
224 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225 Poll::Ready(Ok(()))
226 }
227
228 fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
229 let svc = self.clone();
230 Box::pin(async move { svc.handle(req).await })
231 }
232}
233
234impl RequestService {
235 async fn handle(
236 self,
237 mut req: hyper::Request<Incoming>,
238 ) -> Result<hyper::Response<BoxBody>, BoxError> {
239 let handles = self.endpoint.handles();
240 let own_node_id = &*self.own_node_id;
241 let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
242 let max_request_body_bytes = self.max_request_body_bytes;
243 let max_header_size = self.max_header_size;
244
245 let method = req.method().to_string();
246 let path_and_query = req
247 .uri()
248 .path_and_query()
249 .map(|p| p.as_str())
250 .unwrap_or("/")
251 .to_string();
252
253 tracing::debug!(
254 method = %method,
255 path = %path_and_query,
256 peer = %remote_node_id,
257 "iroh-http: incoming request",
258 );
259 if let Some(limit) = max_header_size {
268 let header_bytes: usize = req
269 .headers()
270 .iter()
271 .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
272 .map(|(k, v)| {
273 k.as_str()
274 .len()
275 .saturating_add(v.as_bytes().len())
276 .saturating_add(4)
277 }) .fold(0usize, |acc, x| acc.saturating_add(x))
279 .saturating_add("peer-id".len())
280 .saturating_add(remote_node_id.len())
281 .saturating_add(4)
282 .saturating_add(req.uri().to_string().len())
283 .saturating_add(method.len())
284 .saturating_add(12); if header_bytes > limit {
286 let resp = hyper::Response::builder()
287 .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
288 .body(crate::box_body(http_body_util::Empty::new()))
289 .expect("static response args are valid");
290 return Ok(resp);
291 }
292 }
293
294 let mut req_headers: Vec<(String, String)> = Vec::new();
296 for (k, v) in req.headers().iter() {
297 if k.as_str().eq_ignore_ascii_case("peer-id") {
298 continue;
299 }
300 match v.to_str() {
301 Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
302 Err(_) => {
303 let resp = hyper::Response::builder()
304 .status(StatusCode::BAD_REQUEST)
305 .body(crate::box_body(http_body_util::Full::new(
306 Bytes::from_static(b"non-UTF8 header value"),
307 )))
308 .expect("static response args are valid");
309 return Ok(resp);
310 }
311 }
312 }
313 req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
314
315 let url = format!("httpi://{own_node_id}{path_and_query}");
316
317 let has_upgrade_header = req_headers.iter().any(|(k, v)| {
320 k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
321 });
322 let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
323 k.eq_ignore_ascii_case("connection")
324 && v.split(',')
325 .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
326 });
327 let is_connect = req.method() == http::Method::CONNECT;
328
329 let is_bidi = if has_upgrade_header {
330 if !has_connection_upgrade || !is_connect {
331 let resp = hyper::Response::builder()
332 .status(StatusCode::BAD_REQUEST)
333 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
334 b"duplex upgrade requires CONNECT method with Connection: upgrade header",
335 ))))
336 .expect("static response args are valid");
337 return Ok(resp);
338 }
339 true
340 } else {
341 false
342 };
343
344 let upgrade_future = if is_bidi {
346 Some(hyper::upgrade::on(&mut req))
347 } else {
348 None
349 };
350
351 let mut guard = handles.insert_guard();
355 let (req_body_writer, req_body_reader) = handles.make_body_channel();
356 let req_body_handle = guard
357 .insert_reader(req_body_reader)
358 .map_err(|e| -> BoxError { e.into() })?;
359
360 let (res_body_writer, res_body_reader) = handles.make_body_channel();
362 let res_body_handle = guard
363 .insert_writer(res_body_writer)
364 .map_err(|e| -> BoxError { e.into() })?;
365
366 let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
369 let req_handle = guard
370 .allocate_req_handle(head_tx)
371 .map_err(|e| -> BoxError { e.into() })?;
372
373 guard.commit();
374
375 struct ReqHeadCleanup {
379 endpoint: IrohEndpoint,
380 req_handle: u64,
381 }
382 impl Drop for ReqHeadCleanup {
383 fn drop(&mut self) {
384 self.endpoint.handles().take_req_sender(self.req_handle);
385 }
386 }
387 let _req_head_cleanup = ReqHeadCleanup {
388 endpoint: self.endpoint.clone(),
389 req_handle,
390 };
391
392 let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
398 let (tx, rx) = tokio::sync::oneshot::channel::<()>();
399 (Some(tx), Some(rx))
400 } else {
401 (None, None)
402 };
403
404 let duplex_req_body_writer = if !is_bidi {
405 let body = req.into_body();
406 let frame_timeout = handles.drain_timeout();
407 tokio::spawn(pump_hyper_body_to_channel_limited(
408 body,
409 req_body_writer,
410 max_request_body_bytes,
411 frame_timeout,
412 body_overflow_tx,
413 ));
414 None
415 } else {
416 drop(req.into_body());
418 Some(req_body_writer)
419 };
420
421 on_request_fire(
424 &self.on_request,
425 req_handle,
426 req_body_handle,
427 res_body_handle,
428 method,
429 url,
430 req_headers,
431 remote_node_id,
432 is_bidi,
433 );
434
435 let response_head = if let Some(overflow_rx) = body_overflow_rx {
441 tokio::select! {
442 biased;
443 Ok(()) = overflow_rx => {
444 let resp = hyper::Response::builder()
447 .status(StatusCode::PAYLOAD_TOO_LARGE)
448 .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
449 b"request body too large",
450 ))))
451 .expect("valid 413 response");
452 return Ok(resp);
453 }
454 head = head_rx => {
455 head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
456 }
457 }
458 } else {
459 head_rx
460 .await
461 .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
462 };
463
464 if let Some(upgrade_fut) = upgrade_future {
471 let req_body_writer =
472 duplex_req_body_writer.expect("duplex path always has req_body_writer");
473
474 if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
477 drop(upgrade_fut);
478 drop(req_body_writer);
479 let mut resp_builder = hyper::Response::builder().status(response_head.status);
480 for (k, v) in &response_head.headers {
481 resp_builder = resp_builder.header(k.as_str(), v.as_str());
482 }
483 let resp = resp_builder
484 .body(crate::box_body(http_body_util::Empty::new()))
485 .map_err(|e| -> BoxError { e.into() })?;
486 return Ok(resp);
487 }
488
489 tokio::spawn(async move {
495 match upgrade_fut.await {
496 Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
497 Ok(upgraded) => {
498 let io = TokioIo::new(upgraded);
499 crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
500 }
501 }
502 });
503
504 let resp = hyper::Response::builder()
506 .status(StatusCode::SWITCHING_PROTOCOLS)
507 .header(hyper::header::CONNECTION, "Upgrade")
508 .header(hyper::header::UPGRADE, "iroh-duplex")
509 .body(crate::box_body(http_body_util::Empty::new()))
510 .expect("static response args are valid");
511 return Ok(resp);
512 }
513
514 let body_stream = body_from_reader(res_body_reader);
517
518 let mut resp_builder = hyper::Response::builder().status(response_head.status);
519 for (k, v) in &response_head.headers {
520 resp_builder = resp_builder.header(k.as_str(), v.as_str());
521 }
522
523 #[cfg(feature = "compression")]
524 let resp_builder = resp_builder; let resp = resp_builder
527 .body(crate::box_body(body_stream))
528 .map_err(|e| -> BoxError { e.into() })?;
529
530 Ok(resp)
531 }
532}
533
534#[inline]
535#[allow(clippy::too_many_arguments)]
536fn on_request_fire(
537 cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
538 req_handle: u64,
539 req_body_handle: u64,
540 res_body_handle: u64,
541 method: String,
542 url: String,
543 headers: Vec<(String, String)>,
544 remote_node_id: String,
545 is_bidi: bool,
546) {
547 cb(RequestPayload {
548 req_handle,
549 req_body_handle,
550 res_body_handle,
551 method,
552 url,
553 headers,
554 remote_node_id,
555 is_bidi,
556 });
557}
558
559pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
587where
588 F: Fn(RequestPayload) + Send + Sync + 'static,
589{
590 serve_with_events(endpoint, options, on_request, None)
591}
592
593pub fn serve_with_events<F>(
598 endpoint: IrohEndpoint,
599 options: ServeOptions,
600 on_request: F,
601 on_connection_event: Option<ConnectionEventFn>,
602) -> ServeHandle
603where
604 F: Fn(RequestPayload) + Send + Sync + 'static,
605{
606 let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
607 let max_errors = options.max_consecutive_errors.unwrap_or(5);
608 let request_timeout = options
609 .request_timeout_ms
610 .map(Duration::from_millis)
611 .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
612 let max_conns_per_peer = options
613 .max_connections_per_peer
614 .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
615 let max_request_body_bytes = options
616 .max_request_body_bytes
617 .or(Some(DEFAULT_MAX_REQUEST_BODY_BYTES));
618 let max_total_connections = options.max_total_connections;
619 let drain_timeout = Duration::from_secs(
620 options
621 .drain_timeout_secs
622 .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
623 );
624 let load_shed_enabled = options.load_shed.unwrap_or(true);
626 let max_header_size = endpoint.max_header_size();
627 #[cfg(feature = "compression")]
628 let compression = endpoint.compression().cloned();
629 let own_node_id = Arc::new(endpoint.node_id().to_string());
630 let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
631
632 let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
633 Arc::new(Mutex::new(HashMap::new()));
634 let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
635
636 let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
639 let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
640
641 let base_svc = RequestService {
642 on_request,
643 endpoint: endpoint.clone(),
644 own_node_id,
645 remote_node_id: None,
646 max_request_body_bytes,
647 max_header_size: if max_header_size == 0 {
648 None
649 } else {
650 Some(max_header_size)
651 },
652 #[cfg(feature = "compression")]
653 compression,
654 };
655
656 use tower::{limit::ConcurrencyLimitLayer, Layer};
657 let shared_conc = ConcurrencyLimitLayer::new(max).layer(base_svc);
661
662 let shutdown_notify = Arc::new(tokio::sync::Notify::new());
663 let shutdown_listen = shutdown_notify.clone();
664 let drain_dur = drain_timeout;
665 let total_connections = endpoint.inner.active_connections.clone();
668 let total_requests = endpoint.inner.active_requests.clone();
669 let (done_tx, done_rx) = tokio::sync::watch::channel(false);
670 let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
671
672 let in_flight_drain = in_flight.clone();
673 let drain_notify_drain = drain_notify.clone();
674
675 let join = tokio::spawn(async move {
676 let ep = endpoint.raw().clone();
677 let mut consecutive_errors: usize = 0;
678
679 loop {
680 let incoming = tokio::select! {
681 biased;
682 _ = shutdown_listen.notified() => {
683 tracing::info!("iroh-http: serve loop shutting down");
684 break;
685 }
686 inc = ep.accept() => match inc {
687 Some(i) => i,
688 None => {
689 tracing::info!("iroh-http: endpoint closed (accept returned None)");
690 let _ = endpoint_closed_tx.send(true);
691 break;
692 }
693 }
694 };
695
696 let conn = match incoming.await {
697 Ok(c) => {
698 consecutive_errors = 0;
699 c
700 }
701 Err(e) => {
702 consecutive_errors = consecutive_errors.saturating_add(1);
703 tracing::warn!(
704 "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
705 );
706 if consecutive_errors >= max_errors {
707 tracing::error!("iroh-http: too many accept errors — shutting down");
708 break;
709 }
710 continue;
711 }
712 };
713
714 let remote_pk = conn.remote_id();
715
716 if let Some(max_total) = max_total_connections {
718 let current = total_connections.load(Ordering::Relaxed);
719 if current >= max_total {
720 tracing::warn!(
721 "iroh-http: total connection limit reached ({current}/{max_total})"
722 );
723 conn.close(0u32.into(), b"server at capacity");
724 continue;
725 }
726 }
727
728 let remote_id = base32_encode(remote_pk.as_bytes());
729
730 let guard = match PeerConnectionGuard::acquire(
731 &peer_counts,
732 remote_pk,
733 remote_id.clone(),
734 max_conns_per_peer,
735 conn_event_fn.clone(),
736 ) {
737 Some(g) => g,
738 None => {
739 tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
740 conn.close(0u32.into(), b"too many connections");
741 continue;
742 }
743 };
744
745 let mut conn_conc = shared_conc.clone();
746 conn_conc.get_mut().remote_node_id = Some(remote_id);
747
748 let timeout_dur = if request_timeout.is_zero() {
749 Duration::MAX
750 } else {
751 request_timeout
752 };
753
754 let conn_total = total_connections.clone();
755 let conn_requests = total_requests.clone();
756 let in_flight_conn = in_flight.clone();
757 let drain_notify_conn = drain_notify.clone();
758 conn_total.fetch_add(1, Ordering::Relaxed);
759 tokio::spawn(async move {
760 let _guard = guard;
761 struct TotalGuard(Arc<AtomicUsize>);
763 impl Drop for TotalGuard {
764 fn drop(&mut self) {
765 self.0.fetch_sub(1, Ordering::Relaxed);
766 }
767 }
768 let _total_guard = TotalGuard(conn_total);
769
770 loop {
771 let (send, recv) = match conn.accept_bi().await {
772 Ok(pair) => pair,
773 Err(_) => break,
774 };
775
776 let io = TokioIo::new(IrohStream::new(send, recv));
777 let svc = conn_conc.clone();
778 let req_counter = conn_requests.clone();
779 req_counter.fetch_add(1, Ordering::Relaxed);
780 in_flight_conn.fetch_add(1, Ordering::Relaxed);
781
782 let in_flight_req = in_flight_conn.clone();
783 let drain_notify_req = drain_notify_conn.clone();
784
785 tokio::spawn(async move {
786 struct ReqGuard {
788 counter: Arc<AtomicUsize>,
789 in_flight: Arc<AtomicUsize>,
790 drain_notify: Arc<tokio::sync::Notify>,
791 }
792 impl Drop for ReqGuard {
793 fn drop(&mut self) {
794 self.counter.fetch_sub(1, Ordering::Relaxed);
795 if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
796 self.drain_notify.notify_waiters();
798 }
799 }
800 }
801 let _req_guard = ReqGuard {
802 counter: req_counter,
803 in_flight: in_flight_req,
804 drain_notify: drain_notify_req,
805 };
806 let effective_header_limit = if max_header_size == 0 {
809 64 * 1024
810 } else {
811 max_header_size.max(8192)
812 };
813
814 use tower::{timeout::TimeoutLayer, ServiceBuilder};
829
830 #[cfg(feature = "compression")]
831 let result = {
832 use http::{Extensions, HeaderMap, Version};
833 use tower_http::compression::{
834 predicate::{Predicate, SizeAbove},
835 CompressionLayer,
836 };
837
838 let compression_config = svc.get_ref().compression.clone();
839 if let Some(comp) = &compression_config {
840 let min_bytes = comp.min_body_bytes;
841 let mut layer = CompressionLayer::new().zstd(true);
842 if let Some(level) = comp.level {
843 use tower_http::compression::CompressionLevel;
844 layer = layer.quality(CompressionLevel::Precise(level as i32));
845 }
846 let not_pre_compressed =
847 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
848 !h.contains_key(http::header::CONTENT_ENCODING)
849 };
850 let not_no_transform =
851 |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
852 h.get(http::header::CACHE_CONTROL)
853 .and_then(|v| v.to_str().ok())
854 .map(|v| {
855 !v.split(',').any(|d| {
856 d.trim().eq_ignore_ascii_case("no-transform")
857 })
858 })
859 .unwrap_or(true)
860 };
861 let predicate =
862 SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
863 .and(not_pre_compressed)
864 .and(not_no_transform);
865 if load_shed_enabled {
866 use tower::load_shed::LoadShedLayer;
867 let stk = TowerErrorHandler(
868 ServiceBuilder::new()
869 .layer(LoadShedLayer::new())
870 .layer(TimeoutLayer::new(timeout_dur))
871 .service(svc),
872 );
873 hyper::server::conn::http1::Builder::new()
874 .max_buf_size(effective_header_limit)
875 .max_headers(128)
876 .serve_connection(
877 io,
878 TowerToHyperService::new(
879 ServiceBuilder::new()
880 .layer(layer.compress_when(predicate))
881 .service(stk),
882 ),
883 )
884 .with_upgrades()
885 .await
886 } else {
887 let stk = TowerErrorHandler(
888 ServiceBuilder::new()
889 .layer(TimeoutLayer::new(timeout_dur))
890 .service(svc),
891 );
892 hyper::server::conn::http1::Builder::new()
893 .max_buf_size(effective_header_limit)
894 .max_headers(128)
895 .serve_connection(
896 io,
897 TowerToHyperService::new(
898 ServiceBuilder::new()
899 .layer(layer.compress_when(predicate))
900 .service(stk),
901 ),
902 )
903 .with_upgrades()
904 .await
905 }
906 } else if load_shed_enabled {
907 use tower::load_shed::LoadShedLayer;
908 let stk = TowerErrorHandler(
909 ServiceBuilder::new()
910 .layer(LoadShedLayer::new())
911 .layer(TimeoutLayer::new(timeout_dur))
912 .service(svc),
913 );
914 hyper::server::conn::http1::Builder::new()
915 .max_buf_size(effective_header_limit)
916 .max_headers(128)
917 .serve_connection(io, TowerToHyperService::new(stk))
918 .with_upgrades()
919 .await
920 } else {
921 let stk = TowerErrorHandler(
922 ServiceBuilder::new()
923 .layer(TimeoutLayer::new(timeout_dur))
924 .service(svc),
925 );
926 hyper::server::conn::http1::Builder::new()
927 .max_buf_size(effective_header_limit)
928 .max_headers(128)
929 .serve_connection(io, TowerToHyperService::new(stk))
930 .with_upgrades()
931 .await
932 }
933 };
934 #[cfg(not(feature = "compression"))]
935 let result = if load_shed_enabled {
936 use tower::load_shed::LoadShedLayer;
937 let stk = TowerErrorHandler(
938 ServiceBuilder::new()
939 .layer(LoadShedLayer::new())
940 .layer(TimeoutLayer::new(timeout_dur))
941 .service(svc),
942 );
943 hyper::server::conn::http1::Builder::new()
944 .max_buf_size(effective_header_limit)
945 .max_headers(128)
946 .serve_connection(io, TowerToHyperService::new(stk))
947 .with_upgrades()
948 .await
949 } else {
950 let stk = TowerErrorHandler(
951 ServiceBuilder::new()
952 .layer(TimeoutLayer::new(timeout_dur))
953 .service(svc),
954 );
955 hyper::server::conn::http1::Builder::new()
956 .max_buf_size(effective_header_limit)
957 .max_headers(128)
958 .serve_connection(io, TowerToHyperService::new(stk))
959 .with_upgrades()
960 .await
961 };
962
963 if let Err(e) = result {
964 tracing::debug!("iroh-http: http1 connection error: {e}");
965 }
966 });
967 }
968 });
969 }
970
971 let deadline = tokio::time::Instant::now()
978 .checked_add(drain_dur)
979 .expect("drain duration overflow");
980 loop {
981 if in_flight_drain.load(Ordering::Acquire) == 0 {
982 tracing::info!("iroh-http: all in-flight requests drained");
983 break;
984 }
985 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
986 if remaining.is_zero() {
987 tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
988 break;
989 }
990 tokio::select! {
991 _ = drain_notify_drain.notified() => {}
992 _ = tokio::time::sleep(remaining) => {}
993 }
994 }
995 let _ = done_tx.send(true);
996 });
997
998 ServeHandle {
999 join,
1000 shutdown_notify,
1001 drain_timeout: drain_dur,
1002 done_rx,
1003 }
1004}
1005
1006#[derive(Clone)]
1021struct TowerErrorHandler<S>(S);
1022
1023impl<S, Req> Service<Req> for TowerErrorHandler<S>
1024where
1025 S: Service<Req, Response = hyper::Response<BoxBody>>,
1026 S::Error: Into<BoxError>,
1027 S::Future: Send + 'static,
1028{
1029 type Response = hyper::Response<BoxBody>;
1030 type Error = BoxError;
1031 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1032
1033 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1034 self.0.poll_ready(cx).map_err(Into::into)
1039 }
1040
1041 fn call(&mut self, req: Req) -> Self::Future {
1042 let fut = self.0.call(req);
1043 Box::pin(async move {
1044 match fut.await {
1045 Ok(r) => Ok(r),
1046 Err(e) => {
1047 let e = e.into();
1048 let status = if e.is::<tower::timeout::error::Elapsed>() {
1049 StatusCode::REQUEST_TIMEOUT
1050 } else if e.is::<tower::load_shed::error::Overloaded>() {
1051 StatusCode::SERVICE_UNAVAILABLE
1052 } else {
1053 tracing::warn!("iroh-http: unexpected tower error: {e}");
1054 StatusCode::INTERNAL_SERVER_ERROR
1055 };
1056 let body_bytes: &'static [u8] = match status {
1057 StatusCode::REQUEST_TIMEOUT => b"request timed out",
1058 StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1059 _ => b"internal server error",
1060 };
1061 Ok(hyper::Response::builder()
1062 .status(status)
1063 .body(crate::box_body(http_body_util::Full::new(
1064 Bytes::from_static(body_bytes),
1065 )))
1066 .expect("valid error response"))
1067 }
1068 }
1069 })
1070 }
1071}