Skip to main content

iroh_http_core/
server.rs

1//! Incoming HTTP request — `serve()` implementation.
2//!
3//! Each accepted QUIC bidirectional stream is driven by hyper's HTTP/1.1
4//! server connection.  A `tower::Service` (`RequestService`) bridges between
5//! hyper and the existing body-channel + slab infrastructure.
6
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc, Mutex,
14    },
15    task::{Context, Poll},
16    time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27    base32_encode,
28    client::{body_from_reader, pump_hyper_body_to_channel_limited},
29    io::IrohStream,
30    stream::{HandleStore, ResponseHeadEntry},
31    ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34// ── Type aliases ──────────────────────────────────────────────────────────────
35
36type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39// ── ServeOptions ──────────────────────────────────────────────────────────────
40
41/// Options for the HTTP serve loop.
42///
43/// Passed directly to [`serve()`] or [`serve_with_events()`].  These govern
44/// per-request middleware (Tower layers), inbound connection caps, and
45/// serve-loop lifecycle — they do **not** affect outgoing fetch calls.
46#[derive(Debug, Clone, Default)]
47pub struct ServeOptions {
48    /// Maximum simultaneous in-flight requests.  Default: 1024.
49    pub max_concurrency: Option<usize>,
50    /// Consecutive accept-loop errors before the serve loop terminates.  Default: 5.
51    pub max_serve_errors: Option<usize>,
52    /// Per-request timeout in milliseconds.  Default: 60 000.
53    pub request_timeout_ms: Option<u64>,
54    /// Maximum connections from a single peer.  Default: 8.
55    pub max_connections_per_peer: Option<usize>,
56    /// Reject request bodies larger than this many bytes.  Default: 16 MiB.
57    pub max_request_body_bytes: Option<usize>,
58    /// Graceful shutdown drain window in milliseconds.  Default: 30 000.
59    pub drain_timeout_ms: Option<u64>,
60    /// Maximum total QUIC connections the server will accept.  Default: unlimited.
61    pub max_total_connections: Option<usize>,
62    /// When `true` (the default), reject new requests immediately with `503
63    /// Service Unavailable` when `max_concurrency` is already reached rather
64    /// than queuing them.  Prevents thundering-herd on recovery.
65    pub load_shed: Option<bool>,
66}
67
68const DEFAULT_CONCURRENCY: usize = 1024;
69const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
70const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
71const DEFAULT_DRAIN_TIMEOUT_MS: u64 = 30_000;
72/// 16 MiB — applied when `max_request_body_bytes` is not explicitly set.
73/// Prevents memory exhaustion from unbounded request bodies.
74const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 16 * 1024 * 1024;
75/// 256 MiB — applied when `max_response_body_bytes` is not explicitly set.
76/// Prevents memory exhaustion from a malicious server sending a compressed
77/// response that expands to an unbounded size (compression bomb).
78pub(crate) const DEFAULT_MAX_RESPONSE_BODY_BYTES: usize = 256 * 1024 * 1024;
79
80// ── ServeHandle ───────────────────────────────────────────────────────────────
81
82pub struct ServeHandle {
83    join: tokio::task::JoinHandle<()>,
84    shutdown_notify: Arc<tokio::sync::Notify>,
85    drain_timeout: std::time::Duration,
86    /// Resolves to `true` once the serve task has fully exited.
87    done_rx: tokio::sync::watch::Receiver<bool>,
88}
89
90impl ServeHandle {
91    pub fn shutdown(&self) {
92        self.shutdown_notify.notify_one();
93    }
94    pub async fn drain(self) {
95        self.shutdown();
96        let _ = self.join.await;
97    }
98    pub fn abort(&self) {
99        self.join.abort();
100    }
101    pub fn drain_timeout(&self) -> std::time::Duration {
102        self.drain_timeout
103    }
104    /// Subscribe to the serve-loop-done signal.
105    ///
106    /// The returned receiver resolves (changes to `true`) once the serve task
107    /// has fully exited, including the drain phase.
108    pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
109        self.done_rx.clone()
110    }
111}
112
113// ── respond() ────────────────────────────────────────────────────────────────
114
115pub fn respond(
116    handles: &HandleStore,
117    req_handle: u64,
118    status: u16,
119    headers: Vec<(String, String)>,
120) -> Result<(), CoreError> {
121    StatusCode::from_u16(status)
122        .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
123    for (name, value) in &headers {
124        HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
125            CoreError::invalid_input(format!("invalid response header name {:?}", name))
126        })?;
127        HeaderValue::from_str(value).map_err(|_| {
128            CoreError::invalid_input(format!("invalid response header value for {:?}", name))
129        })?;
130    }
131
132    let sender = handles
133        .take_req_sender(req_handle)
134        .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
135    sender
136        .send(ResponseHeadEntry { status, headers })
137        .map_err(|_| CoreError::internal("serve task dropped before respond"))
138}
139
140// ── PeerConnectionGuard ───────────────────────────────────────────────────────
141
142type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
143
144struct PeerConnectionGuard {
145    counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
146    peer: iroh::PublicKey,
147    peer_id_str: String,
148    on_event: Option<ConnectionEventFn>,
149}
150
151impl PeerConnectionGuard {
152    fn acquire(
153        counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
154        peer: iroh::PublicKey,
155        peer_id_str: String,
156        max: usize,
157        on_event: Option<ConnectionEventFn>,
158    ) -> Option<Self> {
159        let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
160        let count = map.entry(peer).or_insert(0);
161        if *count >= max {
162            return None;
163        }
164        let was_zero = *count == 0;
165        *count = count.saturating_add(1);
166        let guard = PeerConnectionGuard {
167            counts: counts.clone(),
168            peer,
169            peer_id_str: peer_id_str.clone(),
170            on_event: on_event.clone(),
171        };
172        // Fire connected event on 0 → 1 transition (first connection from this peer).
173        if was_zero {
174            if let Some(cb) = &on_event {
175                cb(ConnectionEvent {
176                    peer_id: peer_id_str,
177                    connected: true,
178                });
179            }
180        }
181        Some(guard)
182    }
183}
184
185impl Drop for PeerConnectionGuard {
186    fn drop(&mut self) {
187        let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
188        if let Some(c) = map.get_mut(&self.peer) {
189            *c = c.saturating_sub(1);
190            if *c == 0 {
191                map.remove(&self.peer);
192                // Fire disconnected event on 1 → 0 transition (last connection from this peer closed).
193                if let Some(cb) = &self.on_event {
194                    cb(ConnectionEvent {
195                        peer_id: self.peer_id_str.clone(),
196                        connected: false,
197                    });
198                }
199            }
200        }
201    }
202}
203
204// ── RequestService ────────────────────────────────────────────────────────────
205
206#[derive(Clone)]
207struct RequestService {
208    on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
209    endpoint: IrohEndpoint,
210    own_node_id: Arc<String>,
211    remote_node_id: Option<String>,
212    max_request_body_bytes: Option<usize>,
213    max_header_size: Option<usize>,
214    #[cfg(feature = "compression")]
215    compression: Option<crate::endpoint::CompressionOptions>,
216}
217
218impl Service<hyper::Request<Incoming>> for RequestService {
219    type Response = hyper::Response<BoxBody>;
220    type Error = BoxError;
221    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
222
223    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
224        Poll::Ready(Ok(()))
225    }
226
227    fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
228        let svc = self.clone();
229        Box::pin(async move { svc.handle(req).await })
230    }
231}
232
233impl RequestService {
234    async fn handle(
235        self,
236        mut req: hyper::Request<Incoming>,
237    ) -> Result<hyper::Response<BoxBody>, BoxError> {
238        let handles = self.endpoint.handles();
239        let own_node_id = &*self.own_node_id;
240        let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
241        let max_request_body_bytes = self.max_request_body_bytes;
242        let max_header_size = self.max_header_size;
243
244        let method = req.method().to_string();
245        let path_and_query = req
246            .uri()
247            .path_and_query()
248            .map(|p| p.as_str())
249            .unwrap_or("/")
250            .to_string();
251
252        tracing::debug!(
253            method = %method,
254            path = %path_and_query,
255            peer = %remote_node_id,
256            "iroh-http: incoming request",
257        );
258        // Strip any client-supplied peer-id to prevent spoofing,
259        // then inject the authenticated identity from the QUIC connection.
260        //
261        // ISS-011: Use raw byte length for header-size accounting to prevent
262        // bypass via non-UTF8 values.  Reject non-UTF8 header values with 400
263        // instead of silently converting them to empty strings.
264
265        // First pass: measure header bytes using raw values (before lossy conversion).
266        if let Some(limit) = max_header_size {
267            let header_bytes: usize = req
268                .headers()
269                .iter()
270                .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
271                .map(|(k, v)| {
272                    k.as_str()
273                        .len()
274                        .saturating_add(v.as_bytes().len())
275                        .saturating_add(4)
276                }) // ": " + "\r\n"
277                .fold(0usize, |acc, x| acc.saturating_add(x))
278                .saturating_add("peer-id".len())
279                .saturating_add(remote_node_id.len())
280                .saturating_add(4)
281                .saturating_add(req.uri().to_string().len())
282                .saturating_add(method.len())
283                .saturating_add(12); // "HTTP/1.1 \r\n\r\n" overhead
284            if header_bytes > limit {
285                let resp = hyper::Response::builder()
286                    .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
287                    .body(crate::box_body(http_body_util::Empty::new()))
288                    .expect("static response args are valid");
289                return Ok(resp);
290            }
291        }
292
293        // Build header list — reject non-UTF8 values instead of silently dropping.
294        let mut req_headers: Vec<(String, String)> = Vec::new();
295        for (k, v) in req.headers().iter() {
296            if k.as_str().eq_ignore_ascii_case("peer-id") {
297                continue;
298            }
299            match v.to_str() {
300                Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
301                Err(_) => {
302                    let resp = hyper::Response::builder()
303                        .status(StatusCode::BAD_REQUEST)
304                        .body(crate::box_body(http_body_util::Full::new(
305                            Bytes::from_static(b"non-UTF8 header value"),
306                        )))
307                        .expect("static response args are valid");
308                    return Ok(resp);
309                }
310            }
311        }
312        req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
313
314        let url = format!("httpi://{own_node_id}{path_and_query}");
315
316        // ISS-015: strict duplex upgrade validation — require CONNECT method +
317        // Upgrade: iroh-duplex + Connection: upgrade headers.
318        let has_upgrade_header = req_headers.iter().any(|(k, v)| {
319            k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
320        });
321        let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
322            k.eq_ignore_ascii_case("connection")
323                && v.split(',')
324                    .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
325        });
326        let is_connect = req.method() == http::Method::CONNECT;
327
328        let is_bidi = if has_upgrade_header {
329            if !has_connection_upgrade || !is_connect {
330                let resp = hyper::Response::builder()
331                    .status(StatusCode::BAD_REQUEST)
332                    .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
333                        b"duplex upgrade requires CONNECT method with Connection: upgrade header",
334                    ))))
335                    .expect("static response args are valid");
336                return Ok(resp);
337            }
338            true
339        } else {
340            false
341        };
342
343        // For duplex: capture the upgrade future BEFORE consuming the request.
344        let upgrade_future = if is_bidi {
345            Some(hyper::upgrade::on(&mut req))
346        } else {
347            None
348        };
349
350        // ── Allocate channels ────────────────────────────────────────────────
351
352        // Request body: writer pumped from hyper; reader given to JS.
353        let mut guard = handles.insert_guard();
354        let (req_body_writer, req_body_reader) = handles.make_body_channel();
355        let req_body_handle = guard
356            .insert_reader(req_body_reader)
357            .map_err(|e| -> BoxError { e.into() })?;
358
359        // Response body: writer given to JS (sendChunk); reader feeds hyper response.
360        let (res_body_writer, res_body_reader) = handles.make_body_channel();
361        let res_body_handle = guard
362            .insert_writer(res_body_writer)
363            .map_err(|e| -> BoxError { e.into() })?;
364
365        // ── Allocate response-head rendezvous ────────────────────────────────
366
367        let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
368        let req_handle = guard
369            .allocate_req_handle(head_tx)
370            .map_err(|e| -> BoxError { e.into() })?;
371
372        guard.commit();
373
374        // RAII guard: remove the req_handle slab entry on all exit paths
375        // (413 early-return, timeout drop, "JS handler dropped", normal completion).
376        // If respond() already consumed the entry, take_req_sender returns None — safe no-op.
377        struct ReqHeadCleanup {
378            endpoint: IrohEndpoint,
379            req_handle: u64,
380        }
381        impl Drop for ReqHeadCleanup {
382            fn drop(&mut self) {
383                self.endpoint.handles().take_req_sender(self.req_handle);
384            }
385        }
386        let _req_head_cleanup = ReqHeadCleanup {
387            endpoint: self.endpoint.clone(),
388            req_handle,
389        };
390
391        // ── Pump request body ────────────────────────────────────────────────
392
393        // For duplex: keep req_body_writer to move into the upgrade spawn below.
394        // For regular: consume it immediately into the pump task.
395        // ISS-004: create an overflow channel so the serve path can return 413.
396        let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
397            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
398            (Some(tx), Some(rx))
399        } else {
400            (None, None)
401        };
402
403        let duplex_req_body_writer = if !is_bidi {
404            let body = req.into_body();
405            let frame_timeout = handles.drain_timeout();
406            tokio::spawn(pump_hyper_body_to_channel_limited(
407                body,
408                req_body_writer,
409                max_request_body_bytes,
410                frame_timeout,
411                body_overflow_tx,
412            ));
413            None
414        } else {
415            // Duplex: discard the HTTP preamble body (empty before 101).
416            drop(req.into_body());
417            Some(req_body_writer)
418        };
419
420        // ── Fire on_request callback ─────────────────────────────────────────
421
422        on_request_fire(
423            &self.on_request,
424            req_handle,
425            req_body_handle,
426            res_body_handle,
427            method,
428            url,
429            req_headers,
430            remote_node_id,
431            is_bidi,
432        );
433
434        // ── Await response head from JS (race against body overflow) ─────────
435        //
436        // ISS-004: if the request body exceeds maxRequestBodyBytes, return 413
437        // immediately without waiting for the JS handler to respond.
438
439        let response_head = if let Some(overflow_rx) = body_overflow_rx {
440            tokio::select! {
441                biased;
442                Ok(()) = overflow_rx => {
443                    // Body too large: ReqHeadCleanup RAII guard will remove the slab
444                    // entry when this function exits (issue-7 fix).
445                    let resp = hyper::Response::builder()
446                        .status(StatusCode::PAYLOAD_TOO_LARGE)
447                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
448                            b"request body too large",
449                        ))))
450                        .expect("valid 413 response");
451                    return Ok(resp);
452                }
453                head = head_rx => {
454                    head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
455                }
456            }
457        } else {
458            head_rx
459                .await
460                .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
461        };
462
463        // ── Duplex path: honor handler status, upgrade only on 101 ──────────────
464        //
465        // ISS-002: the handler may reject the duplex request by returning any
466        // non-101 status.  Only perform the QUIC stream pump when the handler
467        // explicitly returns 101 Switching Protocols.
468
469        if let Some(upgrade_fut) = upgrade_future {
470            let req_body_writer =
471                duplex_req_body_writer.expect("duplex path always has req_body_writer");
472
473            // If the handler returned a non-101 status, send that response and
474            // do NOT perform the upgrade.  Drop the upgrade future and writer.
475            if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
476                drop(upgrade_fut);
477                drop(req_body_writer);
478                let mut resp_builder = hyper::Response::builder().status(response_head.status);
479                for (k, v) in &response_head.headers {
480                    resp_builder = resp_builder.header(k.as_str(), v.as_str());
481                }
482                let resp = resp_builder
483                    .body(crate::box_body(http_body_util::Empty::new()))
484                    .map_err(|e| -> BoxError { e.into() })?;
485                return Ok(resp);
486            }
487
488            // Spawn the upgrade pump after hyper delivers the 101.
489            //
490            // Both directions are wired to the channels already sent to JS:
491            //   recv_io → req_body_writer  (JS reads via req_body_handle)
492            //   res_body_reader → send_io  (JS writes via res_body_handle)
493            tokio::spawn(async move {
494                match upgrade_fut.await {
495                    Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
496                    Ok(upgraded) => {
497                        let io = TokioIo::new(upgraded);
498                        crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
499                    }
500                }
501            });
502
503            // ISS-015: emit both Connection and Upgrade headers in 101 response.
504            let resp = hyper::Response::builder()
505                .status(StatusCode::SWITCHING_PROTOCOLS)
506                .header(hyper::header::CONNECTION, "Upgrade")
507                .header(hyper::header::UPGRADE, "iroh-duplex")
508                .body(crate::box_body(http_body_util::Empty::new()))
509                .expect("static response args are valid");
510            return Ok(resp);
511        }
512
513        // ── Regular HTTP response ─────────────────────────────────────────────
514
515        let body_stream = body_from_reader(res_body_reader);
516
517        let mut resp_builder = hyper::Response::builder().status(response_head.status);
518        for (k, v) in &response_head.headers {
519            resp_builder = resp_builder.header(k.as_str(), v.as_str());
520        }
521
522        #[cfg(feature = "compression")]
523        let resp_builder = resp_builder; // CompressionLayer in ServiceBuilder handles this
524
525        let resp = resp_builder
526            .body(crate::box_body(body_stream))
527            .map_err(|e| -> BoxError { e.into() })?;
528
529        Ok(resp)
530    }
531}
532
533#[inline]
534#[allow(clippy::too_many_arguments)]
535fn on_request_fire(
536    cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
537    req_handle: u64,
538    req_body_handle: u64,
539    res_body_handle: u64,
540    method: String,
541    url: String,
542    headers: Vec<(String, String)>,
543    remote_node_id: String,
544    is_bidi: bool,
545) {
546    cb(RequestPayload {
547        req_handle,
548        req_body_handle,
549        res_body_handle,
550        method,
551        url,
552        headers,
553        remote_node_id,
554        is_bidi,
555    });
556}
557
558// ── serve() ───────────────────────────────────────────────────────────────────
559
560/// Start the serve accept loop.
561///
562/// This is the 3-argument form for backward compatibility.
563/// Use `serve_with_events` to also receive peer connect/disconnect callbacks.
564///
565/// # Security
566///
567/// Calling `serve()` opens a **public endpoint** on the Iroh overlay network.
568/// Unlike regular HTTP (where you choose whether to bind on `0.0.0.0`), any
569/// peer that knows or discovers your node's public key can connect and send
570/// requests. Iroh QUIC authenticates the peer's *identity* cryptographically,
571/// but does not enforce *authorization*.
572///
573/// Always inspect `RequestPayload::peer_id` (exposed as the `Peer-Id` request
574/// header at the FFI layer) and reject requests from untrusted peers:
575///
576/// ```ignore
577/// serve(endpoint, ServeOptions::default(), |payload| {
578///     if !ALLOWED_PEERS.contains(&payload.peer_id) {
579///         respond(handles, payload.req_handle, 403, vec![]).ok();
580///         return;
581///     }
582///     // ... handle request
583/// });
584/// ```
585pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
586where
587    F: Fn(RequestPayload) + Send + Sync + 'static,
588{
589    serve_with_events(endpoint, options, on_request, None)
590}
591
592/// Start the serve accept loop with an optional peer connection event callback.
593///
594/// `on_connection_event` is called on 0→1 (first connection from a peer) and
595/// 1→0 (last connection from a peer closed) count transitions.
596pub fn serve_with_events<F>(
597    endpoint: IrohEndpoint,
598    options: ServeOptions,
599    on_request: F,
600    on_connection_event: Option<ConnectionEventFn>,
601) -> ServeHandle
602where
603    F: Fn(RequestPayload) + Send + Sync + 'static,
604{
605    let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
606    let max_errors = options.max_serve_errors.unwrap_or(5);
607    let request_timeout = options
608        .request_timeout_ms
609        .map(Duration::from_millis)
610        .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
611    let max_conns_per_peer = options
612        .max_connections_per_peer
613        .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
614    let max_request_body_bytes = options
615        .max_request_body_bytes
616        .or(Some(DEFAULT_MAX_REQUEST_BODY_BYTES));
617    let max_total_connections = options.max_total_connections;
618    let drain_timeout =
619        Duration::from_millis(options.drain_timeout_ms.unwrap_or(DEFAULT_DRAIN_TIMEOUT_MS));
620    // Load-shed is opt-out — default `true` (reject immediately when at capacity).
621    let load_shed_enabled = options.load_shed.unwrap_or(true);
622    let max_header_size = endpoint.max_header_size();
623    #[cfg(feature = "compression")]
624    let compression = endpoint.compression().cloned();
625    let own_node_id = Arc::new(endpoint.node_id().to_string());
626    let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
627
628    let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
629        Arc::new(Mutex::new(HashMap::new()));
630    let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
631
632    // In-flight request counter: incremented on accept, decremented on drop.
633    // Used for graceful drain (wait until zero or timeout).
634    let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
635    let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
636
637    let base_svc = RequestService {
638        on_request,
639        endpoint: endpoint.clone(),
640        own_node_id,
641        remote_node_id: None,
642        max_request_body_bytes,
643        max_header_size: if max_header_size == 0 {
644            None
645        } else {
646            Some(max_header_size)
647        },
648        #[cfg(feature = "compression")]
649        compression,
650    };
651
652    use tower::{limit::ConcurrencyLimitLayer, Layer};
653    // SEC-002: build the concurrency limiter once so all clones share one
654    // Arc<Semaphore>, enforcing a true global request cap across every
655    // connection and request task.
656    let shared_conc = ConcurrencyLimitLayer::new(max).layer(base_svc);
657
658    let shutdown_notify = Arc::new(tokio::sync::Notify::new());
659    let shutdown_listen = shutdown_notify.clone();
660    let drain_dur = drain_timeout;
661    // Re-use the endpoint's shared counters so that endpoint_stats() reflects
662    // the live connection and request counts at all times.
663    let total_connections = endpoint.inner.active_connections.clone();
664    let total_requests = endpoint.inner.active_requests.clone();
665    let (done_tx, done_rx) = tokio::sync::watch::channel(false);
666    let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
667
668    let in_flight_drain = in_flight.clone();
669    let drain_notify_drain = drain_notify.clone();
670
671    let join = tokio::spawn(async move {
672        let ep = endpoint.raw().clone();
673        let mut consecutive_errors: usize = 0;
674
675        loop {
676            let incoming = tokio::select! {
677                biased;
678                _ = shutdown_listen.notified() => {
679                    tracing::info!("iroh-http: serve loop shutting down");
680                    break;
681                }
682                inc = ep.accept() => match inc {
683                    Some(i) => i,
684                    None => {
685                        tracing::info!("iroh-http: endpoint closed (accept returned None)");
686                        let _ = endpoint_closed_tx.send(true);
687                        break;
688                    }
689                }
690            };
691
692            let conn = match incoming.await {
693                Ok(c) => {
694                    consecutive_errors = 0;
695                    c
696                }
697                Err(e) => {
698                    consecutive_errors = consecutive_errors.saturating_add(1);
699                    tracing::warn!(
700                        "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
701                    );
702                    if consecutive_errors >= max_errors {
703                        tracing::error!("iroh-http: too many accept errors — shutting down");
704                        break;
705                    }
706                    continue;
707                }
708            };
709
710            let remote_pk = conn.remote_id();
711
712            // Enforce total connection limit.
713            if let Some(max_total) = max_total_connections {
714                let current = total_connections.load(Ordering::Relaxed);
715                if current >= max_total {
716                    tracing::warn!(
717                        "iroh-http: total connection limit reached ({current}/{max_total})"
718                    );
719                    conn.close(0u32.into(), b"server at capacity");
720                    continue;
721                }
722            }
723
724            let remote_id = base32_encode(remote_pk.as_bytes());
725
726            let guard = match PeerConnectionGuard::acquire(
727                &peer_counts,
728                remote_pk,
729                remote_id.clone(),
730                max_conns_per_peer,
731                conn_event_fn.clone(),
732            ) {
733                Some(g) => g,
734                None => {
735                    tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
736                    conn.close(0u32.into(), b"too many connections");
737                    continue;
738                }
739            };
740
741            let mut conn_conc = shared_conc.clone();
742            conn_conc.get_mut().remote_node_id = Some(remote_id);
743
744            let timeout_dur = if request_timeout.is_zero() {
745                Duration::MAX
746            } else {
747                request_timeout
748            };
749
750            let conn_total = total_connections.clone();
751            let conn_requests = total_requests.clone();
752            let in_flight_conn = in_flight.clone();
753            let drain_notify_conn = drain_notify.clone();
754            conn_total.fetch_add(1, Ordering::Relaxed);
755            tokio::spawn(async move {
756                let _guard = guard;
757                // Decrement total connection count when this task exits.
758                struct TotalGuard(Arc<AtomicUsize>);
759                impl Drop for TotalGuard {
760                    fn drop(&mut self) {
761                        self.0.fetch_sub(1, Ordering::Relaxed);
762                    }
763                }
764                let _total_guard = TotalGuard(conn_total);
765
766                loop {
767                    let (send, recv) = match conn.accept_bi().await {
768                        Ok(pair) => pair,
769                        Err(_) => break,
770                    };
771
772                    let io = TokioIo::new(IrohStream::new(send, recv));
773                    let svc = conn_conc.clone();
774                    let req_counter = conn_requests.clone();
775                    req_counter.fetch_add(1, Ordering::Relaxed);
776                    in_flight_conn.fetch_add(1, Ordering::Relaxed);
777
778                    let in_flight_req = in_flight_conn.clone();
779                    let drain_notify_req = drain_notify_conn.clone();
780
781                    tokio::spawn(async move {
782                        // Decrement request count when this task exits.
783                        struct ReqGuard {
784                            counter: Arc<AtomicUsize>,
785                            in_flight: Arc<AtomicUsize>,
786                            drain_notify: Arc<tokio::sync::Notify>,
787                        }
788                        impl Drop for ReqGuard {
789                            fn drop(&mut self) {
790                                self.counter.fetch_sub(1, Ordering::Relaxed);
791                                if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
792                                    // Last in-flight request completed — signal drain.
793                                    self.drain_notify.notify_waiters();
794                                }
795                            }
796                        }
797                        let _req_guard = ReqGuard {
798                            counter: req_counter,
799                            in_flight: in_flight_req,
800                            drain_notify: drain_notify_req,
801                        };
802                        // ISS-001: clamp to hyper's minimum safe buffer size of 8192.
803                        // ISS-020: a stored value of 0 means "use the default" (64 KB).
804                        let effective_header_limit = if max_header_size == 0 {
805                            64 * 1024
806                        } else {
807                            max_header_size.max(8192)
808                        };
809
810                        // Build the Tower reliability service stack and serve the connection.
811                        //
812                        // Layer ordering (outermost first):
813                        //   [CompressionLayer →] TowerErrorHandler → LoadShed → ConcurrencyLimit → Timeout → RequestService
814                        //
815                        // Tower layers are applied around `RequestService` (which returns
816                        // `Response<BoxBody>`) so that `TowerErrorHandler` can convert
817                        // `Elapsed` and `Overloaded` into 408/503 HTTP responses.
818                        // CompressionLayer (if enabled) sits outside the error handler.
819                        //
820                        // Each `if load_shed_enabled` branch produces a concrete type;
821                        // both branches `.await` to `Result<(), hyper::Error>` so the
822                        // `if` expression is well-typed without boxing.
823
824                        use tower::{timeout::TimeoutLayer, ServiceBuilder};
825
826                        #[cfg(feature = "compression")]
827                        let result = {
828                            use http::{Extensions, HeaderMap, Version};
829                            use tower_http::compression::{
830                                predicate::{Predicate, SizeAbove},
831                                CompressionLayer,
832                            };
833
834                            let compression_config = svc.get_ref().compression.clone();
835                            if let Some(comp) = &compression_config {
836                                let min_bytes = comp.min_body_bytes;
837                                let mut layer = CompressionLayer::new().zstd(true);
838                                if let Some(level) = comp.level {
839                                    use tower_http::compression::CompressionLevel;
840                                    layer = layer.quality(CompressionLevel::Precise(level as i32));
841                                }
842                                let not_pre_compressed =
843                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
844                                        !h.contains_key(http::header::CONTENT_ENCODING)
845                                    };
846                                let not_no_transform =
847                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
848                                        h.get(http::header::CACHE_CONTROL)
849                                            .and_then(|v| v.to_str().ok())
850                                            .map(|v| {
851                                                !v.split(',').any(|d| {
852                                                    d.trim().eq_ignore_ascii_case("no-transform")
853                                                })
854                                            })
855                                            .unwrap_or(true)
856                                    };
857                                let predicate =
858                                    SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
859                                        .and(not_pre_compressed)
860                                        .and(not_no_transform);
861                                if load_shed_enabled {
862                                    use tower::load_shed::LoadShedLayer;
863                                    let stk = TowerErrorHandler(
864                                        ServiceBuilder::new()
865                                            .layer(LoadShedLayer::new())
866                                            .layer(TimeoutLayer::new(timeout_dur))
867                                            .service(svc),
868                                    );
869                                    hyper::server::conn::http1::Builder::new()
870                                        .max_buf_size(effective_header_limit)
871                                        .max_headers(128)
872                                        .serve_connection(
873                                            io,
874                                            TowerToHyperService::new(
875                                                ServiceBuilder::new()
876                                                    .layer(layer.compress_when(predicate))
877                                                    .service(stk),
878                                            ),
879                                        )
880                                        .with_upgrades()
881                                        .await
882                                } else {
883                                    let stk = TowerErrorHandler(
884                                        ServiceBuilder::new()
885                                            .layer(TimeoutLayer::new(timeout_dur))
886                                            .service(svc),
887                                    );
888                                    hyper::server::conn::http1::Builder::new()
889                                        .max_buf_size(effective_header_limit)
890                                        .max_headers(128)
891                                        .serve_connection(
892                                            io,
893                                            TowerToHyperService::new(
894                                                ServiceBuilder::new()
895                                                    .layer(layer.compress_when(predicate))
896                                                    .service(stk),
897                                            ),
898                                        )
899                                        .with_upgrades()
900                                        .await
901                                }
902                            } else if load_shed_enabled {
903                                use tower::load_shed::LoadShedLayer;
904                                let stk = TowerErrorHandler(
905                                    ServiceBuilder::new()
906                                        .layer(LoadShedLayer::new())
907                                        .layer(TimeoutLayer::new(timeout_dur))
908                                        .service(svc),
909                                );
910                                hyper::server::conn::http1::Builder::new()
911                                    .max_buf_size(effective_header_limit)
912                                    .max_headers(128)
913                                    .serve_connection(io, TowerToHyperService::new(stk))
914                                    .with_upgrades()
915                                    .await
916                            } else {
917                                let stk = TowerErrorHandler(
918                                    ServiceBuilder::new()
919                                        .layer(TimeoutLayer::new(timeout_dur))
920                                        .service(svc),
921                                );
922                                hyper::server::conn::http1::Builder::new()
923                                    .max_buf_size(effective_header_limit)
924                                    .max_headers(128)
925                                    .serve_connection(io, TowerToHyperService::new(stk))
926                                    .with_upgrades()
927                                    .await
928                            }
929                        };
930                        #[cfg(not(feature = "compression"))]
931                        let result = if load_shed_enabled {
932                            use tower::load_shed::LoadShedLayer;
933                            let stk = TowerErrorHandler(
934                                ServiceBuilder::new()
935                                    .layer(LoadShedLayer::new())
936                                    .layer(TimeoutLayer::new(timeout_dur))
937                                    .service(svc),
938                            );
939                            hyper::server::conn::http1::Builder::new()
940                                .max_buf_size(effective_header_limit)
941                                .max_headers(128)
942                                .serve_connection(io, TowerToHyperService::new(stk))
943                                .with_upgrades()
944                                .await
945                        } else {
946                            let stk = TowerErrorHandler(
947                                ServiceBuilder::new()
948                                    .layer(TimeoutLayer::new(timeout_dur))
949                                    .service(svc),
950                            );
951                            hyper::server::conn::http1::Builder::new()
952                                .max_buf_size(effective_header_limit)
953                                .max_headers(128)
954                                .serve_connection(io, TowerToHyperService::new(stk))
955                                .with_upgrades()
956                                .await
957                        };
958
959                        if let Err(e) = result {
960                            tracing::debug!("iroh-http: http1 connection error: {e}");
961                        }
962                    });
963                }
964            });
965        }
966
967        // Graceful drain: wait for all in-flight requests to finish,
968        // or give up after `drain_timeout`.
969        //
970        // Loop avoids the race between `in_flight == 0` check and `notified()`:
971        // if the last request finishes between the load and the await, the loop
972        // re-checks immediately after the timeout wakes it.
973        let deadline = tokio::time::Instant::now()
974            .checked_add(drain_dur)
975            .expect("drain duration overflow");
976        loop {
977            if in_flight_drain.load(Ordering::Acquire) == 0 {
978                tracing::info!("iroh-http: all in-flight requests drained");
979                break;
980            }
981            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
982            if remaining.is_zero() {
983                tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
984                break;
985            }
986            tokio::select! {
987                _ = drain_notify_drain.notified() => {}
988                _ = tokio::time::sleep(remaining) => {}
989            }
990        }
991        let _ = done_tx.send(true);
992    });
993
994    ServeHandle {
995        join,
996        shutdown_notify,
997        drain_timeout: drain_dur,
998        done_rx,
999    }
1000}
1001
1002// ── TowerErrorHandler — maps Tower layer errors to HTTP responses ─────────────
1003//
1004// `ConcurrencyLimitLayer`, `TimeoutLayer`, and `LoadShedLayer` return errors
1005// rather than `Response` values when they reject a request.  `TowerErrorHandler`
1006// wraps the composed service and converts those errors to proper HTTP responses:
1007//
1008//   tower::timeout::error::Elapsed     → 408 Request Timeout
1009//   tower::load_shed::error::Overloaded → 503 Service Unavailable
1010//   anything else                       → 500 Internal Server Error
1011//
1012// This allows the whole stack to satisfy hyper's requirement that the service
1013// returns `Ok(Response)` — errors crash the connection instead of producing a
1014// status code.
1015
1016#[derive(Clone)]
1017struct TowerErrorHandler<S>(S);
1018
1019impl<S, Req> Service<Req> for TowerErrorHandler<S>
1020where
1021    S: Service<Req, Response = hyper::Response<BoxBody>>,
1022    S::Error: Into<BoxError>,
1023    S::Future: Send + 'static,
1024{
1025    type Response = hyper::Response<BoxBody>;
1026    type Error = BoxError;
1027    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1028
1029    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1030        // If ConcurrencyLimitLayer is saturated AND LoadShed is present, it
1031        // returns Pending from poll_ready — LoadShed converts that to an
1032        // immediate Err(Overloaded).  If LoadShed is absent, poll_ready blocks
1033        // until a slot opens.  In both cases we propagate the result here.
1034        self.0.poll_ready(cx).map_err(Into::into)
1035    }
1036
1037    fn call(&mut self, req: Req) -> Self::Future {
1038        let fut = self.0.call(req);
1039        Box::pin(async move {
1040            match fut.await {
1041                Ok(r) => Ok(r),
1042                Err(e) => {
1043                    let e = e.into();
1044                    let status = if e.is::<tower::timeout::error::Elapsed>() {
1045                        StatusCode::REQUEST_TIMEOUT
1046                    } else if e.is::<tower::load_shed::error::Overloaded>() {
1047                        StatusCode::SERVICE_UNAVAILABLE
1048                    } else {
1049                        tracing::warn!("iroh-http: unexpected tower error: {e}");
1050                        StatusCode::INTERNAL_SERVER_ERROR
1051                    };
1052                    let body_bytes: &'static [u8] = match status {
1053                        StatusCode::REQUEST_TIMEOUT => b"request timed out",
1054                        StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1055                        _ => b"internal server error",
1056                    };
1057                    Ok(hyper::Response::builder()
1058                        .status(status)
1059                        .body(crate::box_body(http_body_util::Full::new(
1060                            Bytes::from_static(body_bytes),
1061                        )))
1062                        .expect("valid error response"))
1063                }
1064            }
1065        })
1066    }
1067}