Skip to main content

iroh_http_core/
server.rs

1//! Incoming HTTP request — `serve()` implementation.
2//!
3//! Each accepted QUIC bidirectional stream is driven by hyper's HTTP/1.1
4//! server connection.  A `tower::Service` (`RequestService`) bridges between
5//! hyper and the existing body-channel + slab infrastructure.
6
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc, Mutex,
14    },
15    task::{Context, Poll},
16    time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27    base32_encode,
28    client::{body_from_reader, pump_hyper_body_to_channel_limited},
29    io::IrohStream,
30    stream::{HandleStore, ResponseHeadEntry},
31    ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34// ── Type aliases ──────────────────────────────────────────────────────────────
35
36type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39// ── ServerLimits ──────────────────────────────────────────────────────────────
40
41/// Server-side limits shared between [`NodeOptions`](crate::NodeOptions) and
42/// the serve path.
43///
44/// Embedding this struct in both `NodeOptions` and `EndpointInner` guarantees
45/// that adding a new limit field produces a compile error if only one side is
46/// updated.
47#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49    pub max_concurrency: Option<usize>,
50    pub max_consecutive_errors: Option<usize>,
51    pub request_timeout_ms: Option<u64>,
52    pub max_connections_per_peer: Option<usize>,
53    pub max_request_body_bytes: Option<usize>,
54    pub drain_timeout_secs: Option<u64>,
55    pub max_total_connections: Option<usize>,
56    /// When `true` (the default), reject new requests immediately with `503
57    /// Service Unavailable` when `max_concurrency` is already reached rather
58    /// than queuing them.  Prevents thundering-herd on recovery.
59    pub load_shed: Option<bool>,
60}
61
62/// Backward-compatible alias — existing code that names `ServeOptions` keeps
63/// compiling without changes.
64pub type ServeOptions = ServerLimits;
65
66const DEFAULT_CONCURRENCY: usize = 64;
67const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
68const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
69const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
70
71// ── ServeHandle ───────────────────────────────────────────────────────────────
72
73pub struct ServeHandle {
74    join: tokio::task::JoinHandle<()>,
75    shutdown_notify: Arc<tokio::sync::Notify>,
76    drain_timeout: std::time::Duration,
77    /// Resolves to `true` once the serve task has fully exited.
78    done_rx: tokio::sync::watch::Receiver<bool>,
79}
80
81impl ServeHandle {
82    pub fn shutdown(&self) {
83        self.shutdown_notify.notify_one();
84    }
85    pub async fn drain(self) {
86        self.shutdown();
87        let _ = self.join.await;
88    }
89    pub fn abort(&self) {
90        self.join.abort();
91    }
92    pub fn drain_timeout(&self) -> std::time::Duration {
93        self.drain_timeout
94    }
95    /// Subscribe to the serve-loop-done signal.
96    ///
97    /// The returned receiver resolves (changes to `true`) once the serve task
98    /// has fully exited, including the drain phase.
99    pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
100        self.done_rx.clone()
101    }
102}
103
104// ── respond() ────────────────────────────────────────────────────────────────
105
106pub fn respond(
107    handles: &HandleStore,
108    req_handle: u64,
109    status: u16,
110    headers: Vec<(String, String)>,
111) -> Result<(), CoreError> {
112    StatusCode::from_u16(status)
113        .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
114    for (name, value) in &headers {
115        HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
116            CoreError::invalid_input(format!("invalid response header name {:?}", name))
117        })?;
118        HeaderValue::from_str(value).map_err(|_| {
119            CoreError::invalid_input(format!("invalid response header value for {:?}", name))
120        })?;
121    }
122
123    let sender = handles
124        .take_req_sender(req_handle)
125        .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
126    sender
127        .send(ResponseHeadEntry { status, headers })
128        .map_err(|_| CoreError::internal("serve task dropped before respond"))
129}
130
131// ── PeerConnectionGuard ───────────────────────────────────────────────────────
132
133type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
134
135struct PeerConnectionGuard {
136    counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
137    peer: iroh::PublicKey,
138    peer_id_str: String,
139    on_event: Option<ConnectionEventFn>,
140}
141
142impl PeerConnectionGuard {
143    fn acquire(
144        counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
145        peer: iroh::PublicKey,
146        peer_id_str: String,
147        max: usize,
148        on_event: Option<ConnectionEventFn>,
149    ) -> Option<Self> {
150        let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
151        let count = map.entry(peer).or_insert(0);
152        if *count >= max {
153            return None;
154        }
155        let was_zero = *count == 0;
156        *count += 1;
157        let guard = PeerConnectionGuard {
158            counts: counts.clone(),
159            peer,
160            peer_id_str: peer_id_str.clone(),
161            on_event: on_event.clone(),
162        };
163        // Fire connected event on 0 → 1 transition (first connection from this peer).
164        if was_zero {
165            if let Some(cb) = &on_event {
166                cb(ConnectionEvent {
167                    peer_id: peer_id_str,
168                    connected: true,
169                });
170            }
171        }
172        Some(guard)
173    }
174}
175
176impl Drop for PeerConnectionGuard {
177    fn drop(&mut self) {
178        let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
179        if let Some(c) = map.get_mut(&self.peer) {
180            *c = c.saturating_sub(1);
181            if *c == 0 {
182                map.remove(&self.peer);
183                // Fire disconnected event on 1 → 0 transition (last connection from this peer closed).
184                if let Some(cb) = &self.on_event {
185                    cb(ConnectionEvent {
186                        peer_id: self.peer_id_str.clone(),
187                        connected: false,
188                    });
189                }
190            }
191        }
192    }
193}
194
195// ── RequestService ────────────────────────────────────────────────────────────
196
197#[derive(Clone)]
198struct RequestService {
199    on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
200    endpoint: IrohEndpoint,
201    own_node_id: Arc<String>,
202    remote_node_id: Option<String>,
203    max_request_body_bytes: Option<usize>,
204    max_header_size: Option<usize>,
205    #[cfg(feature = "compression")]
206    compression: Option<crate::endpoint::CompressionOptions>,
207}
208
209impl Service<hyper::Request<Incoming>> for RequestService {
210    type Response = hyper::Response<BoxBody>;
211    type Error = BoxError;
212    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
213
214    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
215        Poll::Ready(Ok(()))
216    }
217
218    fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
219        let svc = self.clone();
220        Box::pin(async move { svc.handle(req).await })
221    }
222}
223
224impl RequestService {
225    async fn handle(
226        self,
227        mut req: hyper::Request<Incoming>,
228    ) -> Result<hyper::Response<BoxBody>, BoxError> {
229        let handles = self.endpoint.handles();
230        let own_node_id = &*self.own_node_id;
231        let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
232        let max_request_body_bytes = self.max_request_body_bytes;
233        let max_header_size = self.max_header_size;
234
235        let method = req.method().to_string();
236        let path_and_query = req
237            .uri()
238            .path_and_query()
239            .map(|p| p.as_str())
240            .unwrap_or("/")
241            .to_string();
242
243        tracing::debug!(
244            method = %method,
245            path = %path_and_query,
246            peer = %remote_node_id,
247            "iroh-http: incoming request",
248        );
249        // Strip any client-supplied peer-id to prevent spoofing,
250        // then inject the authenticated identity from the QUIC connection.
251        //
252        // ISS-011: Use raw byte length for header-size accounting to prevent
253        // bypass via non-UTF8 values.  Reject non-UTF8 header values with 400
254        // instead of silently converting them to empty strings.
255
256        // First pass: measure header bytes using raw values (before lossy conversion).
257        if let Some(limit) = max_header_size {
258            let header_bytes: usize = req
259                .headers()
260                .iter()
261                .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
262                .map(|(k, v)| k.as_str().len() + v.as_bytes().len() + 4) // ": " + "\r\n"
263                .sum::<usize>()
264                + "peer-id".len()
265                + remote_node_id.len()
266                + 4
267                + req.uri().to_string().len()
268                + method.len()
269                + 12; // "HTTP/1.1 \r\n\r\n" overhead
270            if header_bytes > limit {
271                let resp = hyper::Response::builder()
272                    .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
273                    .body(crate::box_body(http_body_util::Empty::new()))
274                    .unwrap();
275                return Ok(resp);
276            }
277        }
278
279        // Build header list — reject non-UTF8 values instead of silently dropping.
280        let mut req_headers: Vec<(String, String)> = Vec::new();
281        for (k, v) in req.headers().iter() {
282            if k.as_str().eq_ignore_ascii_case("peer-id") {
283                continue;
284            }
285            match v.to_str() {
286                Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
287                Err(_) => {
288                    let resp = hyper::Response::builder()
289                        .status(StatusCode::BAD_REQUEST)
290                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
291                            b"non-UTF8 header value",
292                        ))))
293                        .unwrap();
294                    return Ok(resp);
295                }
296            }
297        }
298        req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
299
300        let url = format!("httpi://{own_node_id}{path_and_query}");
301
302        // ISS-015: strict duplex upgrade validation — require CONNECT method +
303        // Upgrade: iroh-duplex + Connection: upgrade headers.
304        let has_upgrade_header = req_headers.iter().any(|(k, v)| {
305            k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
306        });
307        let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
308            k.eq_ignore_ascii_case("connection")
309                && v.split(',')
310                    .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
311        });
312        let is_connect = req.method() == http::Method::CONNECT;
313
314        let is_bidi = if has_upgrade_header {
315            if !has_connection_upgrade || !is_connect {
316                let resp = hyper::Response::builder()
317                    .status(StatusCode::BAD_REQUEST)
318                    .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
319                        b"duplex upgrade requires CONNECT method with Connection: upgrade header",
320                    ))))
321                    .unwrap();
322                return Ok(resp);
323            }
324            true
325        } else {
326            false
327        };
328
329        // For duplex: capture the upgrade future BEFORE consuming the request.
330        let upgrade_future = if is_bidi {
331            Some(hyper::upgrade::on(&mut req))
332        } else {
333            None
334        };
335
336        // ── Allocate channels ────────────────────────────────────────────────
337
338        // Request body: writer pumped from hyper; reader given to JS.
339        let mut guard = handles.insert_guard();
340        let (req_body_writer, req_body_reader) = handles.make_body_channel();
341        let req_body_handle = guard
342            .insert_reader(req_body_reader)
343            .map_err(|e| -> BoxError { e.into() })?;
344
345        // Response body: writer given to JS (sendChunk); reader feeds hyper response.
346        let (res_body_writer, res_body_reader) = handles.make_body_channel();
347        let res_body_handle = guard
348            .insert_writer(res_body_writer)
349            .map_err(|e| -> BoxError { e.into() })?;
350
351        // ── Trailer channels (non-duplex only) ───────────────────────────────
352
353        let (req_trailers_handle, res_trailers_handle, req_trailer_tx, opt_res_trailer_rx) =
354            if !is_bidi {
355                // Request trailers: pump delivers them; JS reads via nextTrailer.
356                let (rq_tx, rq_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
357                let rq_h = guard
358                    .insert_trailer_receiver(rq_rx)
359                    .map_err(|e| -> BoxError { e.into() })?;
360                // Response trailers: JS delivers via sendTrailers; pump appends to body.
361                let (rs_tx, rs_rx) = tokio::sync::oneshot::channel::<Vec<(String, String)>>();
362                let rs_h = guard
363                    .insert_trailer_sender(rs_tx)
364                    .map_err(|e| -> BoxError { e.into() })?;
365                (rq_h, rs_h, Some(rq_tx), Some(rs_rx))
366            } else {
367                (0u64, 0u64, None, None)
368            };
369
370        // ── Allocate response-head rendezvous ────────────────────────────────
371
372        let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
373        let req_handle = guard
374            .allocate_req_handle(head_tx)
375            .map_err(|e| -> BoxError { e.into() })?;
376
377        guard.commit();
378
379        // RAII guard: remove the req_handle slab entry on all exit paths
380        // (413 early-return, timeout drop, "JS handler dropped", normal completion).
381        // If respond() already consumed the entry, take_req_sender returns None — safe no-op.
382        struct ReqHeadCleanup {
383            endpoint: IrohEndpoint,
384            req_handle: u64,
385        }
386        impl Drop for ReqHeadCleanup {
387            fn drop(&mut self) {
388                self.endpoint.handles().take_req_sender(self.req_handle);
389            }
390        }
391        let _req_head_cleanup = ReqHeadCleanup {
392            endpoint: self.endpoint.clone(),
393            req_handle,
394        };
395
396        // ── Pump request body ────────────────────────────────────────────────
397
398        // For duplex: keep req_body_writer to move into the upgrade spawn below.
399        // For regular: consume it immediately into the pump task.
400        // ISS-004: create an overflow channel so the serve path can return 413.
401        let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
402            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
403            (Some(tx), Some(rx))
404        } else {
405            (None, None)
406        };
407
408        let duplex_req_body_writer = if !is_bidi {
409            let body = req.into_body();
410            let trailer_tx = req_trailer_tx.expect("non-duplex has req_trailer_tx");
411            let frame_timeout = handles.drain_timeout();
412            tokio::spawn(pump_hyper_body_to_channel_limited(
413                body,
414                req_body_writer,
415                trailer_tx,
416                max_request_body_bytes,
417                frame_timeout,
418                body_overflow_tx,
419            ));
420            None
421        } else {
422            // Duplex: discard the HTTP preamble body (empty before 101).
423            drop(req.into_body());
424            Some(req_body_writer)
425        };
426
427        // ── Fire on_request callback ─────────────────────────────────────────
428
429        on_request_fire(
430            &self.on_request,
431            req_handle,
432            req_body_handle,
433            res_body_handle,
434            req_trailers_handle,
435            res_trailers_handle,
436            method,
437            url,
438            req_headers,
439            remote_node_id,
440            is_bidi,
441        );
442
443        // ── Await response head from JS (race against body overflow) ─────────
444        //
445        // ISS-004: if the request body exceeds maxRequestBodyBytes, return 413
446        // immediately without waiting for the JS handler to respond.
447
448        let response_head = if let Some(overflow_rx) = body_overflow_rx {
449            tokio::select! {
450                biased;
451                _ = overflow_rx => {
452                    // Body too large: ReqHeadCleanup RAII guard will remove the slab
453                    // entry when this function exits (issue-7 fix).
454                    let resp = hyper::Response::builder()
455                        .status(StatusCode::PAYLOAD_TOO_LARGE)
456                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
457                            b"request body too large",
458                        ))))
459                        .expect("valid 413 response");
460                    return Ok(resp);
461                }
462                head = head_rx => {
463                    head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
464                }
465            }
466        } else {
467            head_rx
468                .await
469                .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
470        };
471
472        // ── Duplex path: honor handler status, upgrade only on 101 ──────────────
473        //
474        // ISS-002: the handler may reject the duplex request by returning any
475        // non-101 status.  Only perform the QUIC stream pump when the handler
476        // explicitly returns 101 Switching Protocols.
477
478        if let Some(upgrade_fut) = upgrade_future {
479            let req_body_writer =
480                duplex_req_body_writer.expect("duplex path always has req_body_writer");
481
482            // If the handler returned a non-101 status, send that response and
483            // do NOT perform the upgrade.  Drop the upgrade future and writer.
484            if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
485                drop(upgrade_fut);
486                drop(req_body_writer);
487                let mut resp_builder = hyper::Response::builder().status(response_head.status);
488                for (k, v) in &response_head.headers {
489                    resp_builder = resp_builder.header(k.as_str(), v.as_str());
490                }
491                let resp = resp_builder
492                    .body(crate::box_body(http_body_util::Empty::new()))
493                    .map_err(|e| -> BoxError { e.into() })?;
494                return Ok(resp);
495            }
496
497            // Spawn the upgrade pump after hyper delivers the 101.
498            //
499            // Both directions are wired to the channels already sent to JS:
500            //   recv_io → req_body_writer  (JS reads via req_body_handle)
501            //   res_body_reader → send_io  (JS writes via res_body_handle)
502            tokio::spawn(async move {
503                match upgrade_fut.await {
504                    Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
505                    Ok(upgraded) => {
506                        let io = TokioIo::new(upgraded);
507                        crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
508                    }
509                }
510            });
511
512            // ISS-015: emit both Connection and Upgrade headers in 101 response.
513            let resp = hyper::Response::builder()
514                .status(StatusCode::SWITCHING_PROTOCOLS)
515                .header(hyper::header::CONNECTION, "Upgrade")
516                .header(hyper::header::UPGRADE, "iroh-duplex")
517                .body(crate::box_body(http_body_util::Empty::new()))
518                .unwrap();
519            return Ok(resp);
520        }
521
522        // ── Regular HTTP response ─────────────────────────────────────────────
523
524        let has_trailer_hdr = response_head
525            .headers
526            .iter()
527            .any(|(k, _)| k.eq_ignore_ascii_case("trailer"));
528        let trailer_rx_for_body = if has_trailer_hdr {
529            opt_res_trailer_rx
530        } else {
531            handles.remove_trailer_sender(res_trailers_handle);
532            None
533        };
534
535        let body_stream = body_from_reader(res_body_reader, trailer_rx_for_body);
536
537        let mut resp_builder = hyper::Response::builder().status(response_head.status);
538        for (k, v) in &response_head.headers {
539            resp_builder = resp_builder.header(k.as_str(), v.as_str());
540        }
541
542        #[cfg(feature = "compression")]
543        let resp_builder = resp_builder; // CompressionLayer in ServiceBuilder handles this
544
545        let resp = resp_builder
546            .body(crate::box_body(body_stream))
547            .map_err(|e| -> BoxError { e.into() })?;
548
549        Ok(resp)
550    }
551}
552
553#[inline]
554#[allow(clippy::too_many_arguments)]
555fn on_request_fire(
556    cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
557    req_handle: u64,
558    req_body_handle: u64,
559    res_body_handle: u64,
560    req_trailers_handle: u64,
561    res_trailers_handle: u64,
562    method: String,
563    url: String,
564    headers: Vec<(String, String)>,
565    remote_node_id: String,
566    is_bidi: bool,
567) {
568    cb(RequestPayload {
569        req_handle,
570        req_body_handle,
571        res_body_handle,
572        req_trailers_handle,
573        res_trailers_handle,
574        method,
575        url,
576        headers,
577        remote_node_id,
578        is_bidi,
579    });
580}
581
582// ── serve() ───────────────────────────────────────────────────────────────────
583
584/// Start the serve accept loop.
585///
586/// This is the 3-argument form for backward compatibility.
587/// Use `serve_with_events` to also receive peer connect/disconnect callbacks.
588pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
589where
590    F: Fn(RequestPayload) + Send + Sync + 'static,
591{
592    serve_with_events(endpoint, options, on_request, None)
593}
594
595/// Start the serve accept loop with an optional peer connection event callback.
596///
597/// `on_connection_event` is called on 0→1 (first connection from a peer) and
598/// 1→0 (last connection from a peer closed) count transitions.
599pub fn serve_with_events<F>(
600    endpoint: IrohEndpoint,
601    options: ServeOptions,
602    on_request: F,
603    on_connection_event: Option<ConnectionEventFn>,
604) -> ServeHandle
605where
606    F: Fn(RequestPayload) + Send + Sync + 'static,
607{
608    let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
609    let max_errors = options.max_consecutive_errors.unwrap_or(5);
610    let request_timeout = options
611        .request_timeout_ms
612        .map(Duration::from_millis)
613        .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
614    let max_conns_per_peer = options
615        .max_connections_per_peer
616        .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
617    let max_request_body_bytes = options.max_request_body_bytes;
618    let max_total_connections = options.max_total_connections;
619    let drain_timeout = Duration::from_secs(
620        options
621            .drain_timeout_secs
622            .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
623    );
624    // Load-shed is opt-out — default `true` (reject immediately when at capacity).
625    let load_shed_enabled = options.load_shed.unwrap_or(true);
626    let max_header_size = endpoint.max_header_size();
627    #[cfg(feature = "compression")]
628    let compression = endpoint.compression().cloned();
629    let own_node_id = Arc::new(endpoint.node_id().to_string());
630    let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
631
632    let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
633        Arc::new(Mutex::new(HashMap::new()));
634    let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
635
636    // In-flight request counter: incremented on accept, decremented on drop.
637    // Used for graceful drain (wait until zero or timeout).
638    let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
639    let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
640
641    let base_svc = RequestService {
642        on_request,
643        endpoint: endpoint.clone(),
644        own_node_id,
645        remote_node_id: None,
646        max_request_body_bytes,
647        max_header_size: if max_header_size == 0 {
648            None
649        } else {
650            Some(max_header_size)
651        },
652        #[cfg(feature = "compression")]
653        compression,
654    };
655
656    let shutdown_notify = Arc::new(tokio::sync::Notify::new());
657    let shutdown_listen = shutdown_notify.clone();
658    let drain_dur = drain_timeout;
659    // Re-use the endpoint's shared counters so that endpoint_stats() reflects
660    // the live connection and request counts at all times.
661    let total_connections = endpoint.inner.active_connections.clone();
662    let total_requests = endpoint.inner.active_requests.clone();
663    let (done_tx, done_rx) = tokio::sync::watch::channel(false);
664    let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
665
666    let in_flight_drain = in_flight.clone();
667    let drain_notify_drain = drain_notify.clone();
668
669    let join = tokio::spawn(async move {
670        let ep = endpoint.raw().clone();
671        let mut consecutive_errors: usize = 0;
672
673        loop {
674            let incoming = tokio::select! {
675                biased;
676                _ = shutdown_listen.notified() => {
677                    tracing::info!("iroh-http: serve loop shutting down");
678                    break;
679                }
680                inc = ep.accept() => match inc {
681                    Some(i) => i,
682                    None => {
683                        tracing::info!("iroh-http: endpoint closed (accept returned None)");
684                        let _ = endpoint_closed_tx.send(true);
685                        break;
686                    }
687                }
688            };
689
690            let conn = match incoming.await {
691                Ok(c) => {
692                    consecutive_errors = 0;
693                    c
694                }
695                Err(e) => {
696                    consecutive_errors += 1;
697                    tracing::warn!(
698                        "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
699                    );
700                    if consecutive_errors >= max_errors {
701                        tracing::error!("iroh-http: too many accept errors — shutting down");
702                        break;
703                    }
704                    continue;
705                }
706            };
707
708            let remote_pk = conn.remote_id();
709
710            // Enforce total connection limit.
711            if let Some(max_total) = max_total_connections {
712                let current = total_connections.load(Ordering::Relaxed);
713                if current >= max_total {
714                    tracing::warn!(
715                        "iroh-http: total connection limit reached ({current}/{max_total})"
716                    );
717                    conn.close(0u32.into(), b"server at capacity");
718                    continue;
719                }
720            }
721
722            let remote_id = base32_encode(remote_pk.as_bytes());
723
724            let guard =
725                match PeerConnectionGuard::acquire(&peer_counts, remote_pk, remote_id.clone(), max_conns_per_peer, conn_event_fn.clone()) {
726                    Some(g) => g,
727                    None => {
728                        tracing::warn!(
729                            "iroh-http: peer {remote_id} exceeded connection limit"
730                        );
731                        conn.close(0u32.into(), b"too many connections");
732                        continue;
733                    }
734                };
735
736            let mut peer_svc = base_svc.clone();
737            peer_svc.remote_node_id = Some(remote_id);
738
739            let timeout_dur = if request_timeout.is_zero() {
740                Duration::MAX
741            } else {
742                request_timeout
743            };
744
745            let conn_total = total_connections.clone();
746            let conn_requests = total_requests.clone();
747            let in_flight_conn = in_flight.clone();
748            let drain_notify_conn = drain_notify.clone();
749            conn_total.fetch_add(1, Ordering::Relaxed);
750            tokio::spawn(async move {
751                let _guard = guard;
752                // Decrement total connection count when this task exits.
753                struct TotalGuard(Arc<AtomicUsize>);
754                impl Drop for TotalGuard {
755                    fn drop(&mut self) {
756                        self.0.fetch_sub(1, Ordering::Relaxed);
757                    }
758                }
759                let _total_guard = TotalGuard(conn_total);
760
761                loop {
762                    let (send, recv) = match conn.accept_bi().await {
763                        Ok(pair) => pair,
764                        Err(_) => break,
765                    };
766
767                    let io = TokioIo::new(IrohStream::new(send, recv));
768                    let svc = peer_svc.clone();
769                    let req_counter = conn_requests.clone();
770                    req_counter.fetch_add(1, Ordering::Relaxed);
771                    in_flight_conn.fetch_add(1, Ordering::Relaxed);
772
773                    let in_flight_req = in_flight_conn.clone();
774                    let drain_notify_req = drain_notify_conn.clone();
775
776                    tokio::spawn(async move {
777                        // Decrement request count when this task exits.
778                        struct ReqGuard {
779                            counter: Arc<AtomicUsize>,
780                            in_flight: Arc<AtomicUsize>,
781                            drain_notify: Arc<tokio::sync::Notify>,
782                        }
783                        impl Drop for ReqGuard {
784                            fn drop(&mut self) {
785                                self.counter.fetch_sub(1, Ordering::Relaxed);
786                                if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
787                                    // Last in-flight request completed — signal drain.
788                                    self.drain_notify.notify_waiters();
789                                }
790                            }
791                        }
792                        let _req_guard = ReqGuard {
793                            counter: req_counter,
794                            in_flight: in_flight_req,
795                            drain_notify: drain_notify_req,
796                        };
797                        // ISS-001: clamp to hyper's minimum safe buffer size of 8192.
798                        // ISS-020: a stored value of 0 means "use the default" (64 KB).
799                        let effective_header_limit = if max_header_size == 0 {
800                            64 * 1024
801                        } else {
802                            max_header_size.max(8192)
803                        };
804
805                        // Build the Tower reliability service stack and serve the connection.
806                        //
807                        // Layer ordering (outermost first):
808                        //   [CompressionLayer →] TowerErrorHandler → LoadShed → ConcurrencyLimit → Timeout → RequestService
809                        //
810                        // Tower layers are applied around `RequestService` (which returns
811                        // `Response<BoxBody>`) so that `TowerErrorHandler` can convert
812                        // `Elapsed` and `Overloaded` into 408/503 HTTP responses.
813                        // CompressionLayer (if enabled) sits outside the error handler.
814                        //
815                        // Each `if load_shed_enabled` branch produces a concrete type;
816                        // both branches `.await` to `Result<(), hyper::Error>` so the
817                        // `if` expression is well-typed without boxing.
818
819                        use tower::{ServiceBuilder, limit::ConcurrencyLimitLayer, timeout::TimeoutLayer};
820
821                        #[cfg(feature = "compression")]
822                        let result = {
823                            use http::{Extensions, HeaderMap, Version};
824                            use tower_http::compression::{predicate::{Predicate, SizeAbove}, CompressionLayer};
825
826                            let compression_config = svc.compression.clone();
827                            if let Some(comp) = &compression_config {
828                                let min_bytes = comp.min_body_bytes;
829                                let mut layer = CompressionLayer::new().zstd(true);
830                                if let Some(level) = comp.level {
831                                    use tower_http::compression::CompressionLevel;
832                                    layer = layer.quality(CompressionLevel::Precise(level as i32));
833                                }
834                                let not_pre_compressed =
835                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
836                                        !h.contains_key(http::header::CONTENT_ENCODING)
837                                    };
838                                let not_no_transform =
839                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
840                                        h.get(http::header::CACHE_CONTROL)
841                                            .and_then(|v| v.to_str().ok())
842                                            .map(|v| {
843                                                !v.split(',').any(|d| {
844                                                    d.trim().eq_ignore_ascii_case("no-transform")
845                                                })
846                                            })
847                                            .unwrap_or(true)
848                                    };
849                                let predicate =
850                                    SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
851                                        .and(not_pre_compressed)
852                                        .and(not_no_transform);
853                                if load_shed_enabled {
854                                    use tower::load_shed::LoadShedLayer;
855                                    let stk = TowerErrorHandler(ServiceBuilder::new()
856                                        .layer(LoadShedLayer::new())
857                                        .layer(ConcurrencyLimitLayer::new(max))
858                                        .layer(TimeoutLayer::new(timeout_dur))
859                                        .service(svc));
860                                    hyper::server::conn::http1::Builder::new()
861                                        .max_buf_size(effective_header_limit)
862                                        .max_headers(128)
863                                        .serve_connection(io, TowerToHyperService::new(
864                                            ServiceBuilder::new()
865                                                .layer(layer.compress_when(predicate))
866                                                .service(stk),
867                                        ))
868                                        .with_upgrades()
869                                        .await
870                                } else {
871                                    let stk = TowerErrorHandler(ServiceBuilder::new()
872                                        .layer(ConcurrencyLimitLayer::new(max))
873                                        .layer(TimeoutLayer::new(timeout_dur))
874                                        .service(svc));
875                                    hyper::server::conn::http1::Builder::new()
876                                        .max_buf_size(effective_header_limit)
877                                        .max_headers(128)
878                                        .serve_connection(io, TowerToHyperService::new(
879                                            ServiceBuilder::new()
880                                                .layer(layer.compress_when(predicate))
881                                                .service(stk),
882                                        ))
883                                        .with_upgrades()
884                                        .await
885                                }
886                            } else if load_shed_enabled {
887                                use tower::load_shed::LoadShedLayer;
888                                let stk = TowerErrorHandler(ServiceBuilder::new()
889                                    .layer(LoadShedLayer::new())
890                                    .layer(ConcurrencyLimitLayer::new(max))
891                                    .layer(TimeoutLayer::new(timeout_dur))
892                                    .service(svc));
893                                hyper::server::conn::http1::Builder::new()
894                                    .max_buf_size(effective_header_limit)
895                                    .max_headers(128)
896                                    .serve_connection(io, TowerToHyperService::new(stk))
897                                    .with_upgrades()
898                                    .await
899                            } else {
900                                let stk = TowerErrorHandler(ServiceBuilder::new()
901                                    .layer(ConcurrencyLimitLayer::new(max))
902                                    .layer(TimeoutLayer::new(timeout_dur))
903                                    .service(svc));
904                                hyper::server::conn::http1::Builder::new()
905                                    .max_buf_size(effective_header_limit)
906                                    .max_headers(128)
907                                    .serve_connection(io, TowerToHyperService::new(stk))
908                                    .with_upgrades()
909                                    .await
910                            }
911                        };
912                        #[cfg(not(feature = "compression"))]
913                        let result = if load_shed_enabled {
914                            use tower::load_shed::LoadShedLayer;
915                            let stk = TowerErrorHandler(ServiceBuilder::new()
916                                .layer(LoadShedLayer::new())
917                                .layer(ConcurrencyLimitLayer::new(max))
918                                .layer(TimeoutLayer::new(timeout_dur))
919                                .service(svc));
920                            hyper::server::conn::http1::Builder::new()
921                                .max_buf_size(effective_header_limit)
922                                .max_headers(128)
923                                .serve_connection(io, TowerToHyperService::new(stk))
924                                .with_upgrades()
925                                .await
926                        } else {
927                            let stk = TowerErrorHandler(ServiceBuilder::new()
928                                .layer(ConcurrencyLimitLayer::new(max))
929                                .layer(TimeoutLayer::new(timeout_dur))
930                                .service(svc));
931                            hyper::server::conn::http1::Builder::new()
932                                .max_buf_size(effective_header_limit)
933                                .max_headers(128)
934                                .serve_connection(io, TowerToHyperService::new(stk))
935                                .with_upgrades()
936                                .await
937                        };
938
939                        if let Err(e) = result {
940                            tracing::debug!("iroh-http: http1 connection error: {e}");
941                        }
942                    });
943                }
944            });
945        }
946
947        // Graceful drain: wait for all in-flight requests to finish,
948        // or give up after `drain_timeout`.
949        //
950        // Loop avoids the race between `in_flight == 0` check and `notified()`:
951        // if the last request finishes between the load and the await, the loop
952        // re-checks immediately after the timeout wakes it.
953        let deadline = tokio::time::Instant::now() + drain_dur;
954        loop {
955            if in_flight_drain.load(Ordering::Acquire) == 0 {
956                tracing::info!("iroh-http: all in-flight requests drained");
957                break;
958            }
959            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
960            if remaining.is_zero() {
961                tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
962                break;
963            }
964            tokio::select! {
965                _ = drain_notify_drain.notified() => {}
966                _ = tokio::time::sleep(remaining) => {}
967            }
968        }
969        let _ = done_tx.send(true);
970    });
971
972    ServeHandle {
973        join,
974        shutdown_notify,
975        drain_timeout: drain_dur,
976        done_rx,
977    }
978}
979
980// ── TowerErrorHandler — maps Tower layer errors to HTTP responses ─────────────
981//
982// `ConcurrencyLimitLayer`, `TimeoutLayer`, and `LoadShedLayer` return errors
983// rather than `Response` values when they reject a request.  `TowerErrorHandler`
984// wraps the composed service and converts those errors to proper HTTP responses:
985//
986//   tower::timeout::error::Elapsed     → 408 Request Timeout
987//   tower::load_shed::error::Overloaded → 503 Service Unavailable
988//   anything else                       → 500 Internal Server Error
989//
990// This allows the whole stack to satisfy hyper's requirement that the service
991// returns `Ok(Response)` — errors crash the connection instead of producing a
992// status code.
993
994#[derive(Clone)]
995struct TowerErrorHandler<S>(S);
996
997impl<S, Req> Service<Req> for TowerErrorHandler<S>
998where
999    S: Service<Req, Response = hyper::Response<BoxBody>>,
1000    S::Error: Into<BoxError>,
1001    S::Future: Send + 'static,
1002{
1003    type Response = hyper::Response<BoxBody>;
1004    type Error = BoxError;
1005    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1006
1007    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1008        // If ConcurrencyLimitLayer is saturated AND LoadShed is present, it
1009        // returns Pending from poll_ready — LoadShed converts that to an
1010        // immediate Err(Overloaded).  If LoadShed is absent, poll_ready blocks
1011        // until a slot opens.  In both cases we propagate the result here.
1012        self.0.poll_ready(cx).map_err(Into::into)
1013    }
1014
1015    fn call(&mut self, req: Req) -> Self::Future {
1016        let fut = self.0.call(req);
1017        Box::pin(async move {
1018            match fut.await {
1019                Ok(r) => Ok(r),
1020                Err(e) => {
1021                    let e = e.into();
1022                    let status = if e.is::<tower::timeout::error::Elapsed>() {
1023                        StatusCode::REQUEST_TIMEOUT
1024                    } else if e.is::<tower::load_shed::error::Overloaded>() {
1025                        StatusCode::SERVICE_UNAVAILABLE
1026                    } else {
1027                        tracing::warn!("iroh-http: unexpected tower error: {e}");
1028                        StatusCode::INTERNAL_SERVER_ERROR
1029                    };
1030                    let body_bytes: &'static [u8] = match status {
1031                        StatusCode::REQUEST_TIMEOUT => b"request timed out",
1032                        StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1033                        _ => b"internal server error",
1034                    };
1035                    Ok(hyper::Response::builder()
1036                        .status(status)
1037                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
1038                            body_bytes,
1039                        ))))
1040                        .expect("valid error response"))
1041                }
1042            }
1043        })
1044    }
1045}