Skip to main content

iroh_http_core/
server.rs

1//! Incoming HTTP request — `serve()` implementation.
2//!
3//! Each accepted QUIC bidirectional stream is driven by hyper's HTTP/1.1
4//! server connection.  A `tower::Service` (`RequestService`) bridges between
5//! hyper and the existing body-channel + slab infrastructure.
6
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc, Mutex,
14    },
15    task::{Context, Poll},
16    time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27    base32_encode,
28    client::{body_from_reader, pump_hyper_body_to_channel_limited},
29    io::IrohStream,
30    stream::{HandleStore, ResponseHeadEntry},
31    ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34// ── Type aliases ──────────────────────────────────────────────────────────────
35
36type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39// ── ServerLimits ──────────────────────────────────────────────────────────────
40
41/// Server-side limits shared between [`NodeOptions`](crate::NodeOptions) and
42/// the serve path.
43///
44/// Embedding this struct in both `NodeOptions` and `EndpointInner` guarantees
45/// that adding a new limit field produces a compile error if only one side is
46/// updated.
47#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49    pub max_concurrency: Option<usize>,
50    pub max_consecutive_errors: Option<usize>,
51    pub request_timeout_ms: Option<u64>,
52    pub max_connections_per_peer: Option<usize>,
53    pub max_request_body_bytes: Option<usize>,
54    pub drain_timeout_secs: Option<u64>,
55    pub max_total_connections: Option<usize>,
56    /// When `true` (the default), reject new requests immediately with `503
57    /// Service Unavailable` when `max_concurrency` is already reached rather
58    /// than queuing them.  Prevents thundering-herd on recovery.
59    pub load_shed: Option<bool>,
60}
61
62/// Backward-compatible alias — existing code that names `ServeOptions` keeps
63/// compiling without changes.
64pub type ServeOptions = ServerLimits;
65
66const DEFAULT_CONCURRENCY: usize = 64;
67const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
68const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
69const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
70
71// ── ServeHandle ───────────────────────────────────────────────────────────────
72
73pub struct ServeHandle {
74    join: tokio::task::JoinHandle<()>,
75    shutdown_notify: Arc<tokio::sync::Notify>,
76    drain_timeout: std::time::Duration,
77    /// Resolves to `true` once the serve task has fully exited.
78    done_rx: tokio::sync::watch::Receiver<bool>,
79}
80
81impl ServeHandle {
82    pub fn shutdown(&self) {
83        self.shutdown_notify.notify_one();
84    }
85    pub async fn drain(self) {
86        self.shutdown();
87        let _ = self.join.await;
88    }
89    pub fn abort(&self) {
90        self.join.abort();
91    }
92    pub fn drain_timeout(&self) -> std::time::Duration {
93        self.drain_timeout
94    }
95    /// Subscribe to the serve-loop-done signal.
96    ///
97    /// The returned receiver resolves (changes to `true`) once the serve task
98    /// has fully exited, including the drain phase.
99    pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
100        self.done_rx.clone()
101    }
102}
103
104// ── respond() ────────────────────────────────────────────────────────────────
105
106pub fn respond(
107    handles: &HandleStore,
108    req_handle: u64,
109    status: u16,
110    headers: Vec<(String, String)>,
111) -> Result<(), CoreError> {
112    StatusCode::from_u16(status)
113        .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
114    for (name, value) in &headers {
115        HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
116            CoreError::invalid_input(format!("invalid response header name {:?}", name))
117        })?;
118        HeaderValue::from_str(value).map_err(|_| {
119            CoreError::invalid_input(format!("invalid response header value for {:?}", name))
120        })?;
121    }
122
123    let sender = handles
124        .take_req_sender(req_handle)
125        .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
126    sender
127        .send(ResponseHeadEntry { status, headers })
128        .map_err(|_| CoreError::internal("serve task dropped before respond"))
129}
130
131// ── PeerConnectionGuard ───────────────────────────────────────────────────────
132
133type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
134
135struct PeerConnectionGuard {
136    counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
137    peer: iroh::PublicKey,
138    peer_id_str: String,
139    on_event: Option<ConnectionEventFn>,
140}
141
142impl PeerConnectionGuard {
143    fn acquire(
144        counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
145        peer: iroh::PublicKey,
146        peer_id_str: String,
147        max: usize,
148        on_event: Option<ConnectionEventFn>,
149    ) -> Option<Self> {
150        let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
151        let count = map.entry(peer).or_insert(0);
152        if *count >= max {
153            return None;
154        }
155        let was_zero = *count == 0;
156        *count = count.saturating_add(1);
157        let guard = PeerConnectionGuard {
158            counts: counts.clone(),
159            peer,
160            peer_id_str: peer_id_str.clone(),
161            on_event: on_event.clone(),
162        };
163        // Fire connected event on 0 → 1 transition (first connection from this peer).
164        if was_zero {
165            if let Some(cb) = &on_event {
166                cb(ConnectionEvent {
167                    peer_id: peer_id_str,
168                    connected: true,
169                });
170            }
171        }
172        Some(guard)
173    }
174}
175
176impl Drop for PeerConnectionGuard {
177    fn drop(&mut self) {
178        let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
179        if let Some(c) = map.get_mut(&self.peer) {
180            *c = c.saturating_sub(1);
181            if *c == 0 {
182                map.remove(&self.peer);
183                // Fire disconnected event on 1 → 0 transition (last connection from this peer closed).
184                if let Some(cb) = &self.on_event {
185                    cb(ConnectionEvent {
186                        peer_id: self.peer_id_str.clone(),
187                        connected: false,
188                    });
189                }
190            }
191        }
192    }
193}
194
195// ── RequestService ────────────────────────────────────────────────────────────
196
197#[derive(Clone)]
198struct RequestService {
199    on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
200    endpoint: IrohEndpoint,
201    own_node_id: Arc<String>,
202    remote_node_id: Option<String>,
203    max_request_body_bytes: Option<usize>,
204    max_header_size: Option<usize>,
205    #[cfg(feature = "compression")]
206    compression: Option<crate::endpoint::CompressionOptions>,
207}
208
209impl Service<hyper::Request<Incoming>> for RequestService {
210    type Response = hyper::Response<BoxBody>;
211    type Error = BoxError;
212    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        Poll::Ready(Ok(()))
216    }
217
218    fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
219        let svc = self.clone();
220        Box::pin(async move { svc.handle(req).await })
221    }
222}
223
224impl RequestService {
225    async fn handle(
226        self,
227        mut req: hyper::Request<Incoming>,
228    ) -> Result<hyper::Response<BoxBody>, BoxError> {
229        let handles = self.endpoint.handles();
230        let own_node_id = &*self.own_node_id;
231        let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
232        let max_request_body_bytes = self.max_request_body_bytes;
233        let max_header_size = self.max_header_size;
234
235        let method = req.method().to_string();
236        let path_and_query = req
237            .uri()
238            .path_and_query()
239            .map(|p| p.as_str())
240            .unwrap_or("/")
241            .to_string();
242
243        tracing::debug!(
244            method = %method,
245            path = %path_and_query,
246            peer = %remote_node_id,
247            "iroh-http: incoming request",
248        );
249        // Strip any client-supplied peer-id to prevent spoofing,
250        // then inject the authenticated identity from the QUIC connection.
251        //
252        // ISS-011: Use raw byte length for header-size accounting to prevent
253        // bypass via non-UTF8 values.  Reject non-UTF8 header values with 400
254        // instead of silently converting them to empty strings.
255
256        // First pass: measure header bytes using raw values (before lossy conversion).
257        if let Some(limit) = max_header_size {
258            let header_bytes: usize = req
259                .headers()
260                .iter()
261                .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
262                .map(|(k, v)| {
263                    k.as_str()
264                        .len()
265                        .saturating_add(v.as_bytes().len())
266                        .saturating_add(4)
267                }) // ": " + "\r\n"
268                .fold(0usize, |acc, x| acc.saturating_add(x))
269                .saturating_add("peer-id".len())
270                .saturating_add(remote_node_id.len())
271                .saturating_add(4)
272                .saturating_add(req.uri().to_string().len())
273                .saturating_add(method.len())
274                .saturating_add(12); // "HTTP/1.1 \r\n\r\n" overhead
275            if header_bytes > limit {
276                let resp = hyper::Response::builder()
277                    .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
278                    .body(crate::box_body(http_body_util::Empty::new()))
279                    .expect("static response args are valid");
280                return Ok(resp);
281            }
282        }
283
284        // Build header list — reject non-UTF8 values instead of silently dropping.
285        let mut req_headers: Vec<(String, String)> = Vec::new();
286        for (k, v) in req.headers().iter() {
287            if k.as_str().eq_ignore_ascii_case("peer-id") {
288                continue;
289            }
290            match v.to_str() {
291                Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
292                Err(_) => {
293                    let resp = hyper::Response::builder()
294                        .status(StatusCode::BAD_REQUEST)
295                        .body(crate::box_body(http_body_util::Full::new(
296                            Bytes::from_static(b"non-UTF8 header value"),
297                        )))
298                        .expect("static response args are valid");
299                    return Ok(resp);
300                }
301            }
302        }
303        req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
304
305        let url = format!("httpi://{own_node_id}{path_and_query}");
306
307        // ISS-015: strict duplex upgrade validation — require CONNECT method +
308        // Upgrade: iroh-duplex + Connection: upgrade headers.
309        let has_upgrade_header = req_headers.iter().any(|(k, v)| {
310            k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
311        });
312        let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
313            k.eq_ignore_ascii_case("connection")
314                && v.split(',')
315                    .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
316        });
317        let is_connect = req.method() == http::Method::CONNECT;
318
319        let is_bidi = if has_upgrade_header {
320            if !has_connection_upgrade || !is_connect {
321                let resp = hyper::Response::builder()
322                    .status(StatusCode::BAD_REQUEST)
323                    .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
324                        b"duplex upgrade requires CONNECT method with Connection: upgrade header",
325                    ))))
326                    .expect("static response args are valid");
327                return Ok(resp);
328            }
329            true
330        } else {
331            false
332        };
333
334        // For duplex: capture the upgrade future BEFORE consuming the request.
335        let upgrade_future = if is_bidi {
336            Some(hyper::upgrade::on(&mut req))
337        } else {
338            None
339        };
340
341        // ── Allocate channels ────────────────────────────────────────────────
342
343        // Request body: writer pumped from hyper; reader given to JS.
344        let mut guard = handles.insert_guard();
345        let (req_body_writer, req_body_reader) = handles.make_body_channel();
346        let req_body_handle = guard
347            .insert_reader(req_body_reader)
348            .map_err(|e| -> BoxError { e.into() })?;
349
350        // Response body: writer given to JS (sendChunk); reader feeds hyper response.
351        let (res_body_writer, res_body_reader) = handles.make_body_channel();
352        let res_body_handle = guard
353            .insert_writer(res_body_writer)
354            .map_err(|e| -> BoxError { e.into() })?;
355
356        // ── Trailer channels (non-duplex only) ───────────────────────────────
357
358        let (req_trailers_handle, res_trailers_handle, req_trailer_tx, opt_res_trailer_rx) =
359            if !is_bidi {
360                // Request trailers: pump delivers them; JS reads via nextTrailer.
361                let (rq_tx, rq_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
362                let rq_h = guard
363                    .insert_trailer_receiver(rq_rx)
364                    .map_err(|e| -> BoxError { e.into() })?;
365                // Response trailers: JS delivers via sendTrailers; pump appends to body.
366                let (rs_tx, rs_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
367                let rs_h = guard
368                    .insert_trailer_sender(rs_tx)
369                    .map_err(|e| -> BoxError { e.into() })?;
370                (rq_h, rs_h, Some(rq_tx), Some(rs_rx))
371            } else {
372                (0u64, 0u64, None, None)
373            };
374
375        // ── Allocate response-head rendezvous ────────────────────────────────
376
377        let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
378        let req_handle = guard
379            .allocate_req_handle(head_tx)
380            .map_err(|e| -> BoxError { e.into() })?;
381
382        guard.commit();
383
384        // RAII guard: remove the req_handle slab entry on all exit paths
385        // (413 early-return, timeout drop, "JS handler dropped", normal completion).
386        // If respond() already consumed the entry, take_req_sender returns None — safe no-op.
387        struct ReqHeadCleanup {
388            endpoint: IrohEndpoint,
389            req_handle: u64,
390        }
391        impl Drop for ReqHeadCleanup {
392            fn drop(&mut self) {
393                self.endpoint.handles().take_req_sender(self.req_handle);
394            }
395        }
396        let _req_head_cleanup = ReqHeadCleanup {
397            endpoint: self.endpoint.clone(),
398            req_handle,
399        };
400
401        // ── Pump request body ────────────────────────────────────────────────
402
403        // For duplex: keep req_body_writer to move into the upgrade spawn below.
404        // For regular: consume it immediately into the pump task.
405        // ISS-004: create an overflow channel so the serve path can return 413.
406        let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
407            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
408            (Some(tx), Some(rx))
409        } else {
410            (None, None)
411        };
412
413        let duplex_req_body_writer = if !is_bidi {
414            let body = req.into_body();
415            let trailer_tx = req_trailer_tx.expect("non-duplex has req_trailer_tx");
416            let frame_timeout = handles.drain_timeout();
417            tokio::spawn(pump_hyper_body_to_channel_limited(
418                body,
419                req_body_writer,
420                trailer_tx,
421                max_request_body_bytes,
422                frame_timeout,
423                body_overflow_tx,
424            ));
425            None
426        } else {
427            // Duplex: discard the HTTP preamble body (empty before 101).
428            drop(req.into_body());
429            Some(req_body_writer)
430        };
431
432        // ── Fire on_request callback ─────────────────────────────────────────
433
434        on_request_fire(
435            &self.on_request,
436            req_handle,
437            req_body_handle,
438            res_body_handle,
439            req_trailers_handle,
440            res_trailers_handle,
441            method,
442            url,
443            req_headers,
444            remote_node_id,
445            is_bidi,
446        );
447
448        // ── Await response head from JS (race against body overflow) ─────────
449        //
450        // ISS-004: if the request body exceeds maxRequestBodyBytes, return 413
451        // immediately without waiting for the JS handler to respond.
452
453        let response_head = if let Some(overflow_rx) = body_overflow_rx {
454            tokio::select! {
455                biased;
456                _ = overflow_rx => {
457                    // Body too large: ReqHeadCleanup RAII guard will remove the slab
458                    // entry when this function exits (issue-7 fix).
459                    let resp = hyper::Response::builder()
460                        .status(StatusCode::PAYLOAD_TOO_LARGE)
461                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
462                            b"request body too large",
463                        ))))
464                        .expect("valid 413 response");
465                    return Ok(resp);
466                }
467                head = head_rx => {
468                    head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
469                }
470            }
471        } else {
472            head_rx
473                .await
474                .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
475        };
476
477        // ── Duplex path: honor handler status, upgrade only on 101 ──────────────
478        //
479        // ISS-002: the handler may reject the duplex request by returning any
480        // non-101 status.  Only perform the QUIC stream pump when the handler
481        // explicitly returns 101 Switching Protocols.
482
483        if let Some(upgrade_fut) = upgrade_future {
484            let req_body_writer =
485                duplex_req_body_writer.expect("duplex path always has req_body_writer");
486
487            // If the handler returned a non-101 status, send that response and
488            // do NOT perform the upgrade.  Drop the upgrade future and writer.
489            if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
490                drop(upgrade_fut);
491                drop(req_body_writer);
492                let mut resp_builder = hyper::Response::builder().status(response_head.status);
493                for (k, v) in &response_head.headers {
494                    resp_builder = resp_builder.header(k.as_str(), v.as_str());
495                }
496                let resp = resp_builder
497                    .body(crate::box_body(http_body_util::Empty::new()))
498                    .map_err(|e| -> BoxError { e.into() })?;
499                return Ok(resp);
500            }
501
502            // Spawn the upgrade pump after hyper delivers the 101.
503            //
504            // Both directions are wired to the channels already sent to JS:
505            //   recv_io → req_body_writer  (JS reads via req_body_handle)
506            //   res_body_reader → send_io  (JS writes via res_body_handle)
507            tokio::spawn(async move {
508                match upgrade_fut.await {
509                    Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
510                    Ok(upgraded) => {
511                        let io = TokioIo::new(upgraded);
512                        crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
513                    }
514                }
515            });
516
517            // ISS-015: emit both Connection and Upgrade headers in 101 response.
518            let resp = hyper::Response::builder()
519                .status(StatusCode::SWITCHING_PROTOCOLS)
520                .header(hyper::header::CONNECTION, "Upgrade")
521                .header(hyper::header::UPGRADE, "iroh-duplex")
522                .body(crate::box_body(http_body_util::Empty::new()))
523                .expect("static response args are valid");
524            return Ok(resp);
525        }
526
527        // ── Regular HTTP response ─────────────────────────────────────────────
528
529        let has_trailer_hdr = response_head
530            .headers
531            .iter()
532            .any(|(k, _)| k.eq_ignore_ascii_case("trailer"));
533        let trailer_rx_for_body = if has_trailer_hdr {
534            opt_res_trailer_rx
535        } else {
536            handles.remove_trailer_sender(res_trailers_handle);
537            None
538        };
539
540        let body_stream = body_from_reader(res_body_reader, trailer_rx_for_body);
541
542        let mut resp_builder = hyper::Response::builder().status(response_head.status);
543        for (k, v) in &response_head.headers {
544            resp_builder = resp_builder.header(k.as_str(), v.as_str());
545        }
546
547        #[cfg(feature = "compression")]
548        let resp_builder = resp_builder; // CompressionLayer in ServiceBuilder handles this
549
550        let resp = resp_builder
551            .body(crate::box_body(body_stream))
552            .map_err(|e| -> BoxError { e.into() })?;
553
554        Ok(resp)
555    }
556}
557
558#[inline]
559#[allow(clippy::too_many_arguments)]
560fn on_request_fire(
561    cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
562    req_handle: u64,
563    req_body_handle: u64,
564    res_body_handle: u64,
565    req_trailers_handle: u64,
566    res_trailers_handle: u64,
567    method: String,
568    url: String,
569    headers: Vec<(String, String)>,
570    remote_node_id: String,
571    is_bidi: bool,
572) {
573    cb(RequestPayload {
574        req_handle,
575        req_body_handle,
576        res_body_handle,
577        req_trailers_handle,
578        res_trailers_handle,
579        method,
580        url,
581        headers,
582        remote_node_id,
583        is_bidi,
584    });
585}
586
587// ── serve() ───────────────────────────────────────────────────────────────────
588
589/// Start the serve accept loop.
590///
591/// This is the 3-argument form for backward compatibility.
592/// Use `serve_with_events` to also receive peer connect/disconnect callbacks.
593pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
594where
595    F: Fn(RequestPayload) + Send + Sync + 'static,
596{
597    serve_with_events(endpoint, options, on_request, None)
598}
599
600/// Start the serve accept loop with an optional peer connection event callback.
601///
602/// `on_connection_event` is called on 0→1 (first connection from a peer) and
603/// 1→0 (last connection from a peer closed) count transitions.
604pub fn serve_with_events<F>(
605    endpoint: IrohEndpoint,
606    options: ServeOptions,
607    on_request: F,
608    on_connection_event: Option<ConnectionEventFn>,
609) -> ServeHandle
610where
611    F: Fn(RequestPayload) + Send + Sync + 'static,
612{
613    let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
614    let max_errors = options.max_consecutive_errors.unwrap_or(5);
615    let request_timeout = options
616        .request_timeout_ms
617        .map(Duration::from_millis)
618        .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
619    let max_conns_per_peer = options
620        .max_connections_per_peer
621        .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
622    let max_request_body_bytes = options.max_request_body_bytes;
623    let max_total_connections = options.max_total_connections;
624    let drain_timeout = Duration::from_secs(
625        options
626            .drain_timeout_secs
627            .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
628    );
629    // Load-shed is opt-out — default `true` (reject immediately when at capacity).
630    let load_shed_enabled = options.load_shed.unwrap_or(true);
631    let max_header_size = endpoint.max_header_size();
632    #[cfg(feature = "compression")]
633    let compression = endpoint.compression().cloned();
634    let own_node_id = Arc::new(endpoint.node_id().to_string());
635    let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
636
637    let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
638        Arc::new(Mutex::new(HashMap::new()));
639    let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
640
641    // In-flight request counter: incremented on accept, decremented on drop.
642    // Used for graceful drain (wait until zero or timeout).
643    let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
644    let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
645
646    let base_svc = RequestService {
647        on_request,
648        endpoint: endpoint.clone(),
649        own_node_id,
650        remote_node_id: None,
651        max_request_body_bytes,
652        max_header_size: if max_header_size == 0 {
653            None
654        } else {
655            Some(max_header_size)
656        },
657        #[cfg(feature = "compression")]
658        compression,
659    };
660
661    let shutdown_notify = Arc::new(tokio::sync::Notify::new());
662    let shutdown_listen = shutdown_notify.clone();
663    let drain_dur = drain_timeout;
664    // Re-use the endpoint's shared counters so that endpoint_stats() reflects
665    // the live connection and request counts at all times.
666    let total_connections = endpoint.inner.active_connections.clone();
667    let total_requests = endpoint.inner.active_requests.clone();
668    let (done_tx, done_rx) = tokio::sync::watch::channel(false);
669    let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
670
671    let in_flight_drain = in_flight.clone();
672    let drain_notify_drain = drain_notify.clone();
673
674    let join = tokio::spawn(async move {
675        let ep = endpoint.raw().clone();
676        let mut consecutive_errors: usize = 0;
677
678        loop {
679            let incoming = tokio::select! {
680                biased;
681                _ = shutdown_listen.notified() => {
682                    tracing::info!("iroh-http: serve loop shutting down");
683                    break;
684                }
685                inc = ep.accept() => match inc {
686                    Some(i) => i,
687                    None => {
688                        tracing::info!("iroh-http: endpoint closed (accept returned None)");
689                        let _ = endpoint_closed_tx.send(true);
690                        break;
691                    }
692                }
693            };
694
695            let conn = match incoming.await {
696                Ok(c) => {
697                    consecutive_errors = 0;
698                    c
699                }
700                Err(e) => {
701                    consecutive_errors = consecutive_errors.saturating_add(1);
702                    tracing::warn!(
703                        "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
704                    );
705                    if consecutive_errors >= max_errors {
706                        tracing::error!("iroh-http: too many accept errors — shutting down");
707                        break;
708                    }
709                    continue;
710                }
711            };
712
713            let remote_pk = conn.remote_id();
714
715            // Enforce total connection limit.
716            if let Some(max_total) = max_total_connections {
717                let current = total_connections.load(Ordering::Relaxed);
718                if current >= max_total {
719                    tracing::warn!(
720                        "iroh-http: total connection limit reached ({current}/{max_total})"
721                    );
722                    conn.close(0u32.into(), b"server at capacity");
723                    continue;
724                }
725            }
726
727            let remote_id = base32_encode(remote_pk.as_bytes());
728
729            let guard = match PeerConnectionGuard::acquire(
730                &peer_counts,
731                remote_pk,
732                remote_id.clone(),
733                max_conns_per_peer,
734                conn_event_fn.clone(),
735            ) {
736                Some(g) => g,
737                None => {
738                    tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
739                    conn.close(0u32.into(), b"too many connections");
740                    continue;
741                }
742            };
743
744            let mut peer_svc = base_svc.clone();
745            peer_svc.remote_node_id = Some(remote_id);
746
747            let timeout_dur = if request_timeout.is_zero() {
748                Duration::MAX
749            } else {
750                request_timeout
751            };
752
753            let conn_total = total_connections.clone();
754            let conn_requests = total_requests.clone();
755            let in_flight_conn = in_flight.clone();
756            let drain_notify_conn = drain_notify.clone();
757            conn_total.fetch_add(1, Ordering::Relaxed);
758            tokio::spawn(async move {
759                let _guard = guard;
760                // Decrement total connection count when this task exits.
761                struct TotalGuard(Arc<AtomicUsize>);
762                impl Drop for TotalGuard {
763                    fn drop(&mut self) {
764                        self.0.fetch_sub(1, Ordering::Relaxed);
765                    }
766                }
767                let _total_guard = TotalGuard(conn_total);
768
769                loop {
770                    let (send, recv) = match conn.accept_bi().await {
771                        Ok(pair) => pair,
772                        Err(_) => break,
773                    };
774
775                    let io = TokioIo::new(IrohStream::new(send, recv));
776                    let svc = peer_svc.clone();
777                    let req_counter = conn_requests.clone();
778                    req_counter.fetch_add(1, Ordering::Relaxed);
779                    in_flight_conn.fetch_add(1, Ordering::Relaxed);
780
781                    let in_flight_req = in_flight_conn.clone();
782                    let drain_notify_req = drain_notify_conn.clone();
783
784                    tokio::spawn(async move {
785                        // Decrement request count when this task exits.
786                        struct ReqGuard {
787                            counter: Arc<AtomicUsize>,
788                            in_flight: Arc<AtomicUsize>,
789                            drain_notify: Arc<tokio::sync::Notify>,
790                        }
791                        impl Drop for ReqGuard {
792                            fn drop(&mut self) {
793                                self.counter.fetch_sub(1, Ordering::Relaxed);
794                                if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
795                                    // Last in-flight request completed — signal drain.
796                                    self.drain_notify.notify_waiters();
797                                }
798                            }
799                        }
800                        let _req_guard = ReqGuard {
801                            counter: req_counter,
802                            in_flight: in_flight_req,
803                            drain_notify: drain_notify_req,
804                        };
805                        // ISS-001: clamp to hyper's minimum safe buffer size of 8192.
806                        // ISS-020: a stored value of 0 means "use the default" (64 KB).
807                        let effective_header_limit = if max_header_size == 0 {
808                            64 * 1024
809                        } else {
810                            max_header_size.max(8192)
811                        };
812
813                        // Build the Tower reliability service stack and serve the connection.
814                        //
815                        // Layer ordering (outermost first):
816                        //   [CompressionLayer →] TowerErrorHandler → LoadShed → ConcurrencyLimit → Timeout → RequestService
817                        //
818                        // Tower layers are applied around `RequestService` (which returns
819                        // `Response<BoxBody>`) so that `TowerErrorHandler` can convert
820                        // `Elapsed` and `Overloaded` into 408/503 HTTP responses.
821                        // CompressionLayer (if enabled) sits outside the error handler.
822                        //
823                        // Each `if load_shed_enabled` branch produces a concrete type;
824                        // both branches `.await` to `Result<(), hyper::Error>` so the
825                        // `if` expression is well-typed without boxing.
826
827                        use tower::{
828                            limit::ConcurrencyLimitLayer, timeout::TimeoutLayer, ServiceBuilder,
829                        };
830
831                        #[cfg(feature = "compression")]
832                        let result = {
833                            use http::{Extensions, HeaderMap, Version};
834                            use tower_http::compression::{
835                                predicate::{Predicate, SizeAbove},
836                                CompressionLayer,
837                            };
838
839                            let compression_config = svc.compression.clone();
840                            if let Some(comp) = &compression_config {
841                                let min_bytes = comp.min_body_bytes;
842                                let mut layer = CompressionLayer::new().zstd(true);
843                                if let Some(level) = comp.level {
844                                    use tower_http::compression::CompressionLevel;
845                                    layer = layer.quality(CompressionLevel::Precise(level as i32));
846                                }
847                                let not_pre_compressed =
848                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
849                                        !h.contains_key(http::header::CONTENT_ENCODING)
850                                    };
851                                let not_no_transform =
852                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
853                                        h.get(http::header::CACHE_CONTROL)
854                                            .and_then(|v| v.to_str().ok())
855                                            .map(|v| {
856                                                !v.split(',').any(|d| {
857                                                    d.trim().eq_ignore_ascii_case("no-transform")
858                                                })
859                                            })
860                                            .unwrap_or(true)
861                                    };
862                                let predicate =
863                                    SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
864                                        .and(not_pre_compressed)
865                                        .and(not_no_transform);
866                                if load_shed_enabled {
867                                    use tower::load_shed::LoadShedLayer;
868                                    let stk = TowerErrorHandler(
869                                        ServiceBuilder::new()
870                                            .layer(LoadShedLayer::new())
871                                            .layer(ConcurrencyLimitLayer::new(max))
872                                            .layer(TimeoutLayer::new(timeout_dur))
873                                            .service(svc),
874                                    );
875                                    hyper::server::conn::http1::Builder::new()
876                                        .max_buf_size(effective_header_limit)
877                                        .max_headers(128)
878                                        .serve_connection(
879                                            io,
880                                            TowerToHyperService::new(
881                                                ServiceBuilder::new()
882                                                    .layer(layer.compress_when(predicate))
883                                                    .service(stk),
884                                            ),
885                                        )
886                                        .with_upgrades()
887                                        .await
888                                } else {
889                                    let stk = TowerErrorHandler(
890                                        ServiceBuilder::new()
891                                            .layer(ConcurrencyLimitLayer::new(max))
892                                            .layer(TimeoutLayer::new(timeout_dur))
893                                            .service(svc),
894                                    );
895                                    hyper::server::conn::http1::Builder::new()
896                                        .max_buf_size(effective_header_limit)
897                                        .max_headers(128)
898                                        .serve_connection(
899                                            io,
900                                            TowerToHyperService::new(
901                                                ServiceBuilder::new()
902                                                    .layer(layer.compress_when(predicate))
903                                                    .service(stk),
904                                            ),
905                                        )
906                                        .with_upgrades()
907                                        .await
908                                }
909                            } else if load_shed_enabled {
910                                use tower::load_shed::LoadShedLayer;
911                                let stk = TowerErrorHandler(
912                                    ServiceBuilder::new()
913                                        .layer(LoadShedLayer::new())
914                                        .layer(ConcurrencyLimitLayer::new(max))
915                                        .layer(TimeoutLayer::new(timeout_dur))
916                                        .service(svc),
917                                );
918                                hyper::server::conn::http1::Builder::new()
919                                    .max_buf_size(effective_header_limit)
920                                    .max_headers(128)
921                                    .serve_connection(io, TowerToHyperService::new(stk))
922                                    .with_upgrades()
923                                    .await
924                            } else {
925                                let stk = TowerErrorHandler(
926                                    ServiceBuilder::new()
927                                        .layer(ConcurrencyLimitLayer::new(max))
928                                        .layer(TimeoutLayer::new(timeout_dur))
929                                        .service(svc),
930                                );
931                                hyper::server::conn::http1::Builder::new()
932                                    .max_buf_size(effective_header_limit)
933                                    .max_headers(128)
934                                    .serve_connection(io, TowerToHyperService::new(stk))
935                                    .with_upgrades()
936                                    .await
937                            }
938                        };
939                        #[cfg(not(feature = "compression"))]
940                        let result = if load_shed_enabled {
941                            use tower::load_shed::LoadShedLayer;
942                            let stk = TowerErrorHandler(
943                                ServiceBuilder::new()
944                                    .layer(LoadShedLayer::new())
945                                    .layer(ConcurrencyLimitLayer::new(max))
946                                    .layer(TimeoutLayer::new(timeout_dur))
947                                    .service(svc),
948                            );
949                            hyper::server::conn::http1::Builder::new()
950                                .max_buf_size(effective_header_limit)
951                                .max_headers(128)
952                                .serve_connection(io, TowerToHyperService::new(stk))
953                                .with_upgrades()
954                                .await
955                        } else {
956                            let stk = TowerErrorHandler(
957                                ServiceBuilder::new()
958                                    .layer(ConcurrencyLimitLayer::new(max))
959                                    .layer(TimeoutLayer::new(timeout_dur))
960                                    .service(svc),
961                            );
962                            hyper::server::conn::http1::Builder::new()
963                                .max_buf_size(effective_header_limit)
964                                .max_headers(128)
965                                .serve_connection(io, TowerToHyperService::new(stk))
966                                .with_upgrades()
967                                .await
968                        };
969
970                        if let Err(e) = result {
971                            tracing::debug!("iroh-http: http1 connection error: {e}");
972                        }
973                    });
974                }
975            });
976        }
977
978        // Graceful drain: wait for all in-flight requests to finish,
979        // or give up after `drain_timeout`.
980        //
981        // Loop avoids the race between `in_flight == 0` check and `notified()`:
982        // if the last request finishes between the load and the await, the loop
983        // re-checks immediately after the timeout wakes it.
984        let deadline = tokio::time::Instant::now()
985            .checked_add(drain_dur)
986            .expect("drain duration overflow");
987        loop {
988            if in_flight_drain.load(Ordering::Acquire) == 0 {
989                tracing::info!("iroh-http: all in-flight requests drained");
990                break;
991            }
992            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
993            if remaining.is_zero() {
994                tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
995                break;
996            }
997            tokio::select! {
998                _ = drain_notify_drain.notified() => {}
999                _ = tokio::time::sleep(remaining) => {}
1000            }
1001        }
1002        let _ = done_tx.send(true);
1003    });
1004
1005    ServeHandle {
1006        join,
1007        shutdown_notify,
1008        drain_timeout: drain_dur,
1009        done_rx,
1010    }
1011}
1012
1013// ── TowerErrorHandler — maps Tower layer errors to HTTP responses ─────────────
1014//
1015// `ConcurrencyLimitLayer`, `TimeoutLayer`, and `LoadShedLayer` return errors
1016// rather than `Response` values when they reject a request.  `TowerErrorHandler`
1017// wraps the composed service and converts those errors to proper HTTP responses:
1018//
1019//   tower::timeout::error::Elapsed     → 408 Request Timeout
1020//   tower::load_shed::error::Overloaded → 503 Service Unavailable
1021//   anything else                       → 500 Internal Server Error
1022//
1023// This allows the whole stack to satisfy hyper's requirement that the service
1024// returns `Ok(Response)` — errors crash the connection instead of producing a
1025// status code.
1026
1027#[derive(Clone)]
1028struct TowerErrorHandler<S>(S);
1029
1030impl<S, Req> Service<Req> for TowerErrorHandler<S>
1031where
1032    S: Service<Req, Response = hyper::Response<BoxBody>>,
1033    S::Error: Into<BoxError>,
1034    S::Future: Send + 'static,
1035{
1036    type Response = hyper::Response<BoxBody>;
1037    type Error = BoxError;
1038    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1039
1040    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1041        // If ConcurrencyLimitLayer is saturated AND LoadShed is present, it
1042        // returns Pending from poll_ready — LoadShed converts that to an
1043        // immediate Err(Overloaded).  If LoadShed is absent, poll_ready blocks
1044        // until a slot opens.  In both cases we propagate the result here.
1045        self.0.poll_ready(cx).map_err(Into::into)
1046    }
1047
1048    fn call(&mut self, req: Req) -> Self::Future {
1049        let fut = self.0.call(req);
1050        Box::pin(async move {
1051            match fut.await {
1052                Ok(r) => Ok(r),
1053                Err(e) => {
1054                    let e = e.into();
1055                    let status = if e.is::<tower::timeout::error::Elapsed>() {
1056                        StatusCode::REQUEST_TIMEOUT
1057                    } else if e.is::<tower::load_shed::error::Overloaded>() {
1058                        StatusCode::SERVICE_UNAVAILABLE
1059                    } else {
1060                        tracing::warn!("iroh-http: unexpected tower error: {e}");
1061                        StatusCode::INTERNAL_SERVER_ERROR
1062                    };
1063                    let body_bytes: &'static [u8] = match status {
1064                        StatusCode::REQUEST_TIMEOUT => b"request timed out",
1065                        StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1066                        _ => b"internal server error",
1067                    };
1068                    Ok(hyper::Response::builder()
1069                        .status(status)
1070                        .body(crate::box_body(http_body_util::Full::new(
1071                            Bytes::from_static(body_bytes),
1072                        )))
1073                        .expect("valid error response"))
1074                }
1075            }
1076        })
1077    }
1078}