Skip to main content

iroh_http_core/
server.rs

1//! Incoming HTTP request — `serve()` implementation.
2//!
3//! Each accepted QUIC bidirectional stream is driven by hyper's HTTP/1.1
4//! server connection.  A `tower::Service` (`RequestService`) bridges between
5//! hyper and the existing body-channel + slab infrastructure.
6
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::{
12        atomic::{AtomicUsize, Ordering},
13        Arc, Mutex,
14    },
15    task::{Context, Poll},
16    time::Duration,
17};
18
19use bytes::Bytes;
20use http::{HeaderName, HeaderValue, StatusCode};
21use hyper::body::Incoming;
22use hyper_util::rt::TokioIo;
23use hyper_util::service::TowerToHyperService;
24use tower::Service;
25
26use crate::{
27    base32_encode,
28    client::{body_from_reader, pump_hyper_body_to_channel_limited},
29    io::IrohStream,
30    stream::{HandleStore, ResponseHeadEntry},
31    ConnectionEvent, CoreError, IrohEndpoint, RequestPayload,
32};
33
34// ── Type aliases ──────────────────────────────────────────────────────────────
35
36type BoxBody = crate::BoxBody;
37type BoxError = Box<dyn std::error::Error + Send + Sync>;
38
39// ── ServerLimits ──────────────────────────────────────────────────────────────
40
41/// Server-side limits shared between [`NodeOptions`](crate::NodeOptions) and
42/// the serve path.
43///
44/// Embedding this struct in both `NodeOptions` and `EndpointInner` guarantees
45/// that adding a new limit field produces a compile error if only one side is
46/// updated.
47#[derive(Debug, Clone, Default)]
48pub struct ServerLimits {
49    pub max_concurrency: Option<usize>,
50    pub max_consecutive_errors: Option<usize>,
51    pub request_timeout_ms: Option<u64>,
52    pub max_connections_per_peer: Option<usize>,
53    pub max_request_body_bytes: Option<usize>,
54    /// Maximum decompressed response body bytes the client will accept per fetch.
55    /// Applies to the *client* side (outgoing fetch calls).  Default: 256 MiB.
56    pub max_response_body_bytes: Option<usize>,
57    pub drain_timeout_secs: Option<u64>,
58    pub max_total_connections: Option<usize>,
59    /// When `true` (the default), reject new requests immediately with `503
60    /// Service Unavailable` when `max_concurrency` is already reached rather
61    /// than queuing them.  Prevents thundering-herd on recovery.
62    pub load_shed: Option<bool>,
63}
64
65/// Backward-compatible alias — existing code that names `ServeOptions` keeps
66/// compiling without changes.
67pub type ServeOptions = ServerLimits;
68
69const DEFAULT_CONCURRENCY: usize = 1024;
70const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 60_000;
71const DEFAULT_MAX_CONNECTIONS_PER_PEER: usize = 8;
72const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
73/// 16 MiB — applied when `max_request_body_bytes` is not explicitly set.
74/// Prevents memory exhaustion from unbounded request bodies.
75const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 16 * 1024 * 1024;
76/// 256 MiB — applied when `max_response_body_bytes` is not explicitly set.
77/// Prevents memory exhaustion from a malicious server sending a compressed
78/// response that expands to an unbounded size (compression bomb).
79pub(crate) const DEFAULT_MAX_RESPONSE_BODY_BYTES: usize = 256 * 1024 * 1024;
80
81// ── ServeHandle ───────────────────────────────────────────────────────────────
82
83pub struct ServeHandle {
84    join: tokio::task::JoinHandle<()>,
85    shutdown_notify: Arc<tokio::sync::Notify>,
86    drain_timeout: std::time::Duration,
87    /// Resolves to `true` once the serve task has fully exited.
88    done_rx: tokio::sync::watch::Receiver<bool>,
89}
90
91impl ServeHandle {
92    pub fn shutdown(&self) {
93        self.shutdown_notify.notify_one();
94    }
95    pub async fn drain(self) {
96        self.shutdown();
97        let _ = self.join.await;
98    }
99    pub fn abort(&self) {
100        self.join.abort();
101    }
102    pub fn drain_timeout(&self) -> std::time::Duration {
103        self.drain_timeout
104    }
105    /// Subscribe to the serve-loop-done signal.
106    ///
107    /// The returned receiver resolves (changes to `true`) once the serve task
108    /// has fully exited, including the drain phase.
109    pub fn subscribe_done(&self) -> tokio::sync::watch::Receiver<bool> {
110        self.done_rx.clone()
111    }
112}
113
114// ── respond() ────────────────────────────────────────────────────────────────
115
116pub fn respond(
117    handles: &HandleStore,
118    req_handle: u64,
119    status: u16,
120    headers: Vec<(String, String)>,
121) -> Result<(), CoreError> {
122    StatusCode::from_u16(status)
123        .map_err(|_| CoreError::invalid_input(format!("invalid HTTP status code: {status}")))?;
124    for (name, value) in &headers {
125        HeaderName::from_bytes(name.as_bytes()).map_err(|_| {
126            CoreError::invalid_input(format!("invalid response header name {:?}", name))
127        })?;
128        HeaderValue::from_str(value).map_err(|_| {
129            CoreError::invalid_input(format!("invalid response header value for {:?}", name))
130        })?;
131    }
132
133    let sender = handles
134        .take_req_sender(req_handle)
135        .ok_or_else(|| CoreError::invalid_handle(req_handle))?;
136    sender
137        .send(ResponseHeadEntry { status, headers })
138        .map_err(|_| CoreError::internal("serve task dropped before respond"))
139}
140
141// ── PeerConnectionGuard ───────────────────────────────────────────────────────
142
143type ConnectionEventFn = Arc<dyn Fn(ConnectionEvent) + Send + Sync>;
144
145struct PeerConnectionGuard {
146    counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
147    peer: iroh::PublicKey,
148    peer_id_str: String,
149    on_event: Option<ConnectionEventFn>,
150}
151
152impl PeerConnectionGuard {
153    fn acquire(
154        counts: &Arc<Mutex<HashMap<iroh::PublicKey, usize>>>,
155        peer: iroh::PublicKey,
156        peer_id_str: String,
157        max: usize,
158        on_event: Option<ConnectionEventFn>,
159    ) -> Option<Self> {
160        let mut map = counts.lock().unwrap_or_else(|e| e.into_inner());
161        let count = map.entry(peer).or_insert(0);
162        if *count >= max {
163            return None;
164        }
165        let was_zero = *count == 0;
166        *count = count.saturating_add(1);
167        let guard = PeerConnectionGuard {
168            counts: counts.clone(),
169            peer,
170            peer_id_str: peer_id_str.clone(),
171            on_event: on_event.clone(),
172        };
173        // Fire connected event on 0 → 1 transition (first connection from this peer).
174        if was_zero {
175            if let Some(cb) = &on_event {
176                cb(ConnectionEvent {
177                    peer_id: peer_id_str,
178                    connected: true,
179                });
180            }
181        }
182        Some(guard)
183    }
184}
185
186impl Drop for PeerConnectionGuard {
187    fn drop(&mut self) {
188        let mut map = self.counts.lock().unwrap_or_else(|e| e.into_inner());
189        if let Some(c) = map.get_mut(&self.peer) {
190            *c = c.saturating_sub(1);
191            if *c == 0 {
192                map.remove(&self.peer);
193                // Fire disconnected event on 1 → 0 transition (last connection from this peer closed).
194                if let Some(cb) = &self.on_event {
195                    cb(ConnectionEvent {
196                        peer_id: self.peer_id_str.clone(),
197                        connected: false,
198                    });
199                }
200            }
201        }
202    }
203}
204
205// ── RequestService ────────────────────────────────────────────────────────────
206
207#[derive(Clone)]
208struct RequestService {
209    on_request: Arc<dyn Fn(RequestPayload) + Send + Sync>,
210    endpoint: IrohEndpoint,
211    own_node_id: Arc<String>,
212    remote_node_id: Option<String>,
213    max_request_body_bytes: Option<usize>,
214    max_header_size: Option<usize>,
215    #[cfg(feature = "compression")]
216    compression: Option<crate::endpoint::CompressionOptions>,
217}
218
219impl Service<hyper::Request<Incoming>> for RequestService {
220    type Response = hyper::Response<BoxBody>;
221    type Error = BoxError;
222    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
223
224    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
225        Poll::Ready(Ok(()))
226    }
227
228    fn call(&mut self, req: hyper::Request<Incoming>) -> Self::Future {
229        let svc = self.clone();
230        Box::pin(async move { svc.handle(req).await })
231    }
232}
233
234impl RequestService {
235    async fn handle(
236        self,
237        mut req: hyper::Request<Incoming>,
238    ) -> Result<hyper::Response<BoxBody>, BoxError> {
239        let handles = self.endpoint.handles();
240        let own_node_id = &*self.own_node_id;
241        let remote_node_id = self.remote_node_id.clone().unwrap_or_default();
242        let max_request_body_bytes = self.max_request_body_bytes;
243        let max_header_size = self.max_header_size;
244
245        let method = req.method().to_string();
246        let path_and_query = req
247            .uri()
248            .path_and_query()
249            .map(|p| p.as_str())
250            .unwrap_or("/")
251            .to_string();
252
253        tracing::debug!(
254            method = %method,
255            path = %path_and_query,
256            peer = %remote_node_id,
257            "iroh-http: incoming request",
258        );
259        // Strip any client-supplied peer-id to prevent spoofing,
260        // then inject the authenticated identity from the QUIC connection.
261        //
262        // ISS-011: Use raw byte length for header-size accounting to prevent
263        // bypass via non-UTF8 values.  Reject non-UTF8 header values with 400
264        // instead of silently converting them to empty strings.
265
266        // First pass: measure header bytes using raw values (before lossy conversion).
267        if let Some(limit) = max_header_size {
268            let header_bytes: usize = req
269                .headers()
270                .iter()
271                .filter(|(k, _)| !k.as_str().eq_ignore_ascii_case("peer-id"))
272                .map(|(k, v)| {
273                    k.as_str()
274                        .len()
275                        .saturating_add(v.as_bytes().len())
276                        .saturating_add(4)
277                }) // ": " + "\r\n"
278                .fold(0usize, |acc, x| acc.saturating_add(x))
279                .saturating_add("peer-id".len())
280                .saturating_add(remote_node_id.len())
281                .saturating_add(4)
282                .saturating_add(req.uri().to_string().len())
283                .saturating_add(method.len())
284                .saturating_add(12); // "HTTP/1.1 \r\n\r\n" overhead
285            if header_bytes > limit {
286                let resp = hyper::Response::builder()
287                    .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE)
288                    .body(crate::box_body(http_body_util::Empty::new()))
289                    .expect("static response args are valid");
290                return Ok(resp);
291            }
292        }
293
294        // Build header list — reject non-UTF8 values instead of silently dropping.
295        let mut req_headers: Vec<(String, String)> = Vec::new();
296        for (k, v) in req.headers().iter() {
297            if k.as_str().eq_ignore_ascii_case("peer-id") {
298                continue;
299            }
300            match v.to_str() {
301                Ok(s) => req_headers.push((k.as_str().to_string(), s.to_string())),
302                Err(_) => {
303                    let resp = hyper::Response::builder()
304                        .status(StatusCode::BAD_REQUEST)
305                        .body(crate::box_body(http_body_util::Full::new(
306                            Bytes::from_static(b"non-UTF8 header value"),
307                        )))
308                        .expect("static response args are valid");
309                    return Ok(resp);
310                }
311            }
312        }
313        req_headers.push(("peer-id".to_string(), remote_node_id.clone()));
314
315        let url = format!("httpi://{own_node_id}{path_and_query}");
316
317        // ISS-015: strict duplex upgrade validation — require CONNECT method +
318        // Upgrade: iroh-duplex + Connection: upgrade headers.
319        let has_upgrade_header = req_headers.iter().any(|(k, v)| {
320            k.eq_ignore_ascii_case("upgrade") && v.eq_ignore_ascii_case("iroh-duplex")
321        });
322        let has_connection_upgrade = req_headers.iter().any(|(k, v)| {
323            k.eq_ignore_ascii_case("connection")
324                && v.split(',')
325                    .any(|tok| tok.trim().eq_ignore_ascii_case("upgrade"))
326        });
327        let is_connect = req.method() == http::Method::CONNECT;
328
329        let is_bidi = if has_upgrade_header {
330            if !has_connection_upgrade || !is_connect {
331                let resp = hyper::Response::builder()
332                    .status(StatusCode::BAD_REQUEST)
333                    .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
334                        b"duplex upgrade requires CONNECT method with Connection: upgrade header",
335                    ))))
336                    .expect("static response args are valid");
337                return Ok(resp);
338            }
339            true
340        } else {
341            false
342        };
343
344        // For duplex: capture the upgrade future BEFORE consuming the request.
345        let upgrade_future = if is_bidi {
346            Some(hyper::upgrade::on(&mut req))
347        } else {
348            None
349        };
350
351        // ── Allocate channels ────────────────────────────────────────────────
352
353        // Request body: writer pumped from hyper; reader given to JS.
354        let mut guard = handles.insert_guard();
355        let (req_body_writer, req_body_reader) = handles.make_body_channel();
356        let req_body_handle = guard
357            .insert_reader(req_body_reader)
358            .map_err(|e| -> BoxError { e.into() })?;
359
360        // Response body: writer given to JS (sendChunk); reader feeds hyper response.
361        let (res_body_writer, res_body_reader) = handles.make_body_channel();
362        let res_body_handle = guard
363            .insert_writer(res_body_writer)
364            .map_err(|e| -> BoxError { e.into() })?;
365
366        // ── Allocate response-head rendezvous ────────────────────────────────
367
368        let (head_tx, head_rx) = tokio::sync::oneshot::channel::<ResponseHeadEntry>();
369        let req_handle = guard
370            .allocate_req_handle(head_tx)
371            .map_err(|e| -> BoxError { e.into() })?;
372
373        guard.commit();
374
375        // RAII guard: remove the req_handle slab entry on all exit paths
376        // (413 early-return, timeout drop, "JS handler dropped", normal completion).
377        // If respond() already consumed the entry, take_req_sender returns None — safe no-op.
378        struct ReqHeadCleanup {
379            endpoint: IrohEndpoint,
380            req_handle: u64,
381        }
382        impl Drop for ReqHeadCleanup {
383            fn drop(&mut self) {
384                self.endpoint.handles().take_req_sender(self.req_handle);
385            }
386        }
387        let _req_head_cleanup = ReqHeadCleanup {
388            endpoint: self.endpoint.clone(),
389            req_handle,
390        };
391
392        // ── Pump request body ────────────────────────────────────────────────
393
394        // For duplex: keep req_body_writer to move into the upgrade spawn below.
395        // For regular: consume it immediately into the pump task.
396        // ISS-004: create an overflow channel so the serve path can return 413.
397        let (body_overflow_tx, body_overflow_rx) = if !is_bidi && max_request_body_bytes.is_some() {
398            let (tx, rx) = tokio::sync::oneshot::channel::<()>();
399            (Some(tx), Some(rx))
400        } else {
401            (None, None)
402        };
403
404        let duplex_req_body_writer = if !is_bidi {
405            let body = req.into_body();
406            let frame_timeout = handles.drain_timeout();
407            tokio::spawn(pump_hyper_body_to_channel_limited(
408                body,
409                req_body_writer,
410                max_request_body_bytes,
411                frame_timeout,
412                body_overflow_tx,
413            ));
414            None
415        } else {
416            // Duplex: discard the HTTP preamble body (empty before 101).
417            drop(req.into_body());
418            Some(req_body_writer)
419        };
420
421        // ── Fire on_request callback ─────────────────────────────────────────
422
423        on_request_fire(
424            &self.on_request,
425            req_handle,
426            req_body_handle,
427            res_body_handle,
428            method,
429            url,
430            req_headers,
431            remote_node_id,
432            is_bidi,
433        );
434
435        // ── Await response head from JS (race against body overflow) ─────────
436        //
437        // ISS-004: if the request body exceeds maxRequestBodyBytes, return 413
438        // immediately without waiting for the JS handler to respond.
439
440        let response_head = if let Some(overflow_rx) = body_overflow_rx {
441            tokio::select! {
442                biased;
443                Ok(()) = overflow_rx => {
444                    // Body too large: ReqHeadCleanup RAII guard will remove the slab
445                    // entry when this function exits (issue-7 fix).
446                    let resp = hyper::Response::builder()
447                        .status(StatusCode::PAYLOAD_TOO_LARGE)
448                        .body(crate::box_body(http_body_util::Full::new(Bytes::from_static(
449                            b"request body too large",
450                        ))))
451                        .expect("valid 413 response");
452                    return Ok(resp);
453                }
454                head = head_rx => {
455                    head.map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
456                }
457            }
458        } else {
459            head_rx
460                .await
461                .map_err(|_| -> BoxError { "JS handler dropped without responding".into() })?
462        };
463
464        // ── Duplex path: honor handler status, upgrade only on 101 ──────────────
465        //
466        // ISS-002: the handler may reject the duplex request by returning any
467        // non-101 status.  Only perform the QUIC stream pump when the handler
468        // explicitly returns 101 Switching Protocols.
469
470        if let Some(upgrade_fut) = upgrade_future {
471            let req_body_writer =
472                duplex_req_body_writer.expect("duplex path always has req_body_writer");
473
474            // If the handler returned a non-101 status, send that response and
475            // do NOT perform the upgrade.  Drop the upgrade future and writer.
476            if response_head.status != StatusCode::SWITCHING_PROTOCOLS.as_u16() {
477                drop(upgrade_fut);
478                drop(req_body_writer);
479                let mut resp_builder = hyper::Response::builder().status(response_head.status);
480                for (k, v) in &response_head.headers {
481                    resp_builder = resp_builder.header(k.as_str(), v.as_str());
482                }
483                let resp = resp_builder
484                    .body(crate::box_body(http_body_util::Empty::new()))
485                    .map_err(|e| -> BoxError { e.into() })?;
486                return Ok(resp);
487            }
488
489            // Spawn the upgrade pump after hyper delivers the 101.
490            //
491            // Both directions are wired to the channels already sent to JS:
492            //   recv_io → req_body_writer  (JS reads via req_body_handle)
493            //   res_body_reader → send_io  (JS writes via res_body_handle)
494            tokio::spawn(async move {
495                match upgrade_fut.await {
496                    Err(e) => tracing::warn!("iroh-http: duplex upgrade error: {e}"),
497                    Ok(upgraded) => {
498                        let io = TokioIo::new(upgraded);
499                        crate::stream::pump_duplex(io, req_body_writer, res_body_reader).await;
500                    }
501                }
502            });
503
504            // ISS-015: emit both Connection and Upgrade headers in 101 response.
505            let resp = hyper::Response::builder()
506                .status(StatusCode::SWITCHING_PROTOCOLS)
507                .header(hyper::header::CONNECTION, "Upgrade")
508                .header(hyper::header::UPGRADE, "iroh-duplex")
509                .body(crate::box_body(http_body_util::Empty::new()))
510                .expect("static response args are valid");
511            return Ok(resp);
512        }
513
514        // ── Regular HTTP response ─────────────────────────────────────────────
515
516        let body_stream = body_from_reader(res_body_reader);
517
518        let mut resp_builder = hyper::Response::builder().status(response_head.status);
519        for (k, v) in &response_head.headers {
520            resp_builder = resp_builder.header(k.as_str(), v.as_str());
521        }
522
523        #[cfg(feature = "compression")]
524        let resp_builder = resp_builder; // CompressionLayer in ServiceBuilder handles this
525
526        let resp = resp_builder
527            .body(crate::box_body(body_stream))
528            .map_err(|e| -> BoxError { e.into() })?;
529
530        Ok(resp)
531    }
532}
533
534#[inline]
535#[allow(clippy::too_many_arguments)]
536fn on_request_fire(
537    cb: &Arc<dyn Fn(RequestPayload) + Send + Sync>,
538    req_handle: u64,
539    req_body_handle: u64,
540    res_body_handle: u64,
541    method: String,
542    url: String,
543    headers: Vec<(String, String)>,
544    remote_node_id: String,
545    is_bidi: bool,
546) {
547    cb(RequestPayload {
548        req_handle,
549        req_body_handle,
550        res_body_handle,
551        method,
552        url,
553        headers,
554        remote_node_id,
555        is_bidi,
556    });
557}
558
559// ── serve() ───────────────────────────────────────────────────────────────────
560
561/// Start the serve accept loop.
562///
563/// This is the 3-argument form for backward compatibility.
564/// Use `serve_with_events` to also receive peer connect/disconnect callbacks.
565///
566/// # Security
567///
568/// Calling `serve()` opens a **public endpoint** on the Iroh overlay network.
569/// Unlike regular HTTP (where you choose whether to bind on `0.0.0.0`), any
570/// peer that knows or discovers your node's public key can connect and send
571/// requests. Iroh QUIC authenticates the peer's *identity* cryptographically,
572/// but does not enforce *authorization*.
573///
574/// Always inspect `RequestPayload::peer_id` (exposed as the `Peer-Id` request
575/// header at the FFI layer) and reject requests from untrusted peers:
576///
577/// ```ignore
578/// serve(endpoint, ServeOptions::default(), |payload| {
579///     if !ALLOWED_PEERS.contains(&payload.peer_id) {
580///         respond(handles, payload.req_handle, 403, vec![]).ok();
581///         return;
582///     }
583///     // ... handle request
584/// });
585/// ```
586pub fn serve<F>(endpoint: IrohEndpoint, options: ServeOptions, on_request: F) -> ServeHandle
587where
588    F: Fn(RequestPayload) + Send + Sync + 'static,
589{
590    serve_with_events(endpoint, options, on_request, None)
591}
592
593/// Start the serve accept loop with an optional peer connection event callback.
594///
595/// `on_connection_event` is called on 0→1 (first connection from a peer) and
596/// 1→0 (last connection from a peer closed) count transitions.
597pub fn serve_with_events<F>(
598    endpoint: IrohEndpoint,
599    options: ServeOptions,
600    on_request: F,
601    on_connection_event: Option<ConnectionEventFn>,
602) -> ServeHandle
603where
604    F: Fn(RequestPayload) + Send + Sync + 'static,
605{
606    let max = options.max_concurrency.unwrap_or(DEFAULT_CONCURRENCY);
607    let max_errors = options.max_consecutive_errors.unwrap_or(5);
608    let request_timeout = options
609        .request_timeout_ms
610        .map(Duration::from_millis)
611        .unwrap_or(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS));
612    let max_conns_per_peer = options
613        .max_connections_per_peer
614        .unwrap_or(DEFAULT_MAX_CONNECTIONS_PER_PEER);
615    let max_request_body_bytes = options
616        .max_request_body_bytes
617        .or(Some(DEFAULT_MAX_REQUEST_BODY_BYTES));
618    let max_total_connections = options.max_total_connections;
619    let drain_timeout = Duration::from_secs(
620        options
621            .drain_timeout_secs
622            .unwrap_or(DEFAULT_DRAIN_TIMEOUT_SECS),
623    );
624    // Load-shed is opt-out — default `true` (reject immediately when at capacity).
625    let load_shed_enabled = options.load_shed.unwrap_or(true);
626    let max_header_size = endpoint.max_header_size();
627    #[cfg(feature = "compression")]
628    let compression = endpoint.compression().cloned();
629    let own_node_id = Arc::new(endpoint.node_id().to_string());
630    let on_request = Arc::new(on_request) as Arc<dyn Fn(RequestPayload) + Send + Sync>;
631
632    let peer_counts: Arc<Mutex<HashMap<iroh::PublicKey, usize>>> =
633        Arc::new(Mutex::new(HashMap::new()));
634    let conn_event_fn: Option<ConnectionEventFn> = on_connection_event;
635
636    // In-flight request counter: incremented on accept, decremented on drop.
637    // Used for graceful drain (wait until zero or timeout).
638    let in_flight: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));
639    let drain_notify: Arc<tokio::sync::Notify> = Arc::new(tokio::sync::Notify::new());
640
641    let base_svc = RequestService {
642        on_request,
643        endpoint: endpoint.clone(),
644        own_node_id,
645        remote_node_id: None,
646        max_request_body_bytes,
647        max_header_size: if max_header_size == 0 {
648            None
649        } else {
650            Some(max_header_size)
651        },
652        #[cfg(feature = "compression")]
653        compression,
654    };
655
656    use tower::{limit::ConcurrencyLimitLayer, Layer};
657    // SEC-002: build the concurrency limiter once so all clones share one
658    // Arc<Semaphore>, enforcing a true global request cap across every
659    // connection and request task.
660    let shared_conc = ConcurrencyLimitLayer::new(max).layer(base_svc);
661
662    let shutdown_notify = Arc::new(tokio::sync::Notify::new());
663    let shutdown_listen = shutdown_notify.clone();
664    let drain_dur = drain_timeout;
665    // Re-use the endpoint's shared counters so that endpoint_stats() reflects
666    // the live connection and request counts at all times.
667    let total_connections = endpoint.inner.active_connections.clone();
668    let total_requests = endpoint.inner.active_requests.clone();
669    let (done_tx, done_rx) = tokio::sync::watch::channel(false);
670    let endpoint_closed_tx = endpoint.inner.closed_tx.clone();
671
672    let in_flight_drain = in_flight.clone();
673    let drain_notify_drain = drain_notify.clone();
674
675    let join = tokio::spawn(async move {
676        let ep = endpoint.raw().clone();
677        let mut consecutive_errors: usize = 0;
678
679        loop {
680            let incoming = tokio::select! {
681                biased;
682                _ = shutdown_listen.notified() => {
683                    tracing::info!("iroh-http: serve loop shutting down");
684                    break;
685                }
686                inc = ep.accept() => match inc {
687                    Some(i) => i,
688                    None => {
689                        tracing::info!("iroh-http: endpoint closed (accept returned None)");
690                        let _ = endpoint_closed_tx.send(true);
691                        break;
692                    }
693                }
694            };
695
696            let conn = match incoming.await {
697                Ok(c) => {
698                    consecutive_errors = 0;
699                    c
700                }
701                Err(e) => {
702                    consecutive_errors = consecutive_errors.saturating_add(1);
703                    tracing::warn!(
704                        "iroh-http: accept error ({consecutive_errors}/{max_errors}): {e}"
705                    );
706                    if consecutive_errors >= max_errors {
707                        tracing::error!("iroh-http: too many accept errors — shutting down");
708                        break;
709                    }
710                    continue;
711                }
712            };
713
714            let remote_pk = conn.remote_id();
715
716            // Enforce total connection limit.
717            if let Some(max_total) = max_total_connections {
718                let current = total_connections.load(Ordering::Relaxed);
719                if current >= max_total {
720                    tracing::warn!(
721                        "iroh-http: total connection limit reached ({current}/{max_total})"
722                    );
723                    conn.close(0u32.into(), b"server at capacity");
724                    continue;
725                }
726            }
727
728            let remote_id = base32_encode(remote_pk.as_bytes());
729
730            let guard = match PeerConnectionGuard::acquire(
731                &peer_counts,
732                remote_pk,
733                remote_id.clone(),
734                max_conns_per_peer,
735                conn_event_fn.clone(),
736            ) {
737                Some(g) => g,
738                None => {
739                    tracing::warn!("iroh-http: peer {remote_id} exceeded connection limit");
740                    conn.close(0u32.into(), b"too many connections");
741                    continue;
742                }
743            };
744
745            let mut conn_conc = shared_conc.clone();
746            conn_conc.get_mut().remote_node_id = Some(remote_id);
747
748            let timeout_dur = if request_timeout.is_zero() {
749                Duration::MAX
750            } else {
751                request_timeout
752            };
753
754            let conn_total = total_connections.clone();
755            let conn_requests = total_requests.clone();
756            let in_flight_conn = in_flight.clone();
757            let drain_notify_conn = drain_notify.clone();
758            conn_total.fetch_add(1, Ordering::Relaxed);
759            tokio::spawn(async move {
760                let _guard = guard;
761                // Decrement total connection count when this task exits.
762                struct TotalGuard(Arc<AtomicUsize>);
763                impl Drop for TotalGuard {
764                    fn drop(&mut self) {
765                        self.0.fetch_sub(1, Ordering::Relaxed);
766                    }
767                }
768                let _total_guard = TotalGuard(conn_total);
769
770                loop {
771                    let (send, recv) = match conn.accept_bi().await {
772                        Ok(pair) => pair,
773                        Err(_) => break,
774                    };
775
776                    let io = TokioIo::new(IrohStream::new(send, recv));
777                    let svc = conn_conc.clone();
778                    let req_counter = conn_requests.clone();
779                    req_counter.fetch_add(1, Ordering::Relaxed);
780                    in_flight_conn.fetch_add(1, Ordering::Relaxed);
781
782                    let in_flight_req = in_flight_conn.clone();
783                    let drain_notify_req = drain_notify_conn.clone();
784
785                    tokio::spawn(async move {
786                        // Decrement request count when this task exits.
787                        struct ReqGuard {
788                            counter: Arc<AtomicUsize>,
789                            in_flight: Arc<AtomicUsize>,
790                            drain_notify: Arc<tokio::sync::Notify>,
791                        }
792                        impl Drop for ReqGuard {
793                            fn drop(&mut self) {
794                                self.counter.fetch_sub(1, Ordering::Relaxed);
795                                if self.in_flight.fetch_sub(1, Ordering::AcqRel) == 1 {
796                                    // Last in-flight request completed — signal drain.
797                                    self.drain_notify.notify_waiters();
798                                }
799                            }
800                        }
801                        let _req_guard = ReqGuard {
802                            counter: req_counter,
803                            in_flight: in_flight_req,
804                            drain_notify: drain_notify_req,
805                        };
806                        // ISS-001: clamp to hyper's minimum safe buffer size of 8192.
807                        // ISS-020: a stored value of 0 means "use the default" (64 KB).
808                        let effective_header_limit = if max_header_size == 0 {
809                            64 * 1024
810                        } else {
811                            max_header_size.max(8192)
812                        };
813
814                        // Build the Tower reliability service stack and serve the connection.
815                        //
816                        // Layer ordering (outermost first):
817                        //   [CompressionLayer →] TowerErrorHandler → LoadShed → ConcurrencyLimit → Timeout → RequestService
818                        //
819                        // Tower layers are applied around `RequestService` (which returns
820                        // `Response<BoxBody>`) so that `TowerErrorHandler` can convert
821                        // `Elapsed` and `Overloaded` into 408/503 HTTP responses.
822                        // CompressionLayer (if enabled) sits outside the error handler.
823                        //
824                        // Each `if load_shed_enabled` branch produces a concrete type;
825                        // both branches `.await` to `Result<(), hyper::Error>` so the
826                        // `if` expression is well-typed without boxing.
827
828                        use tower::{timeout::TimeoutLayer, ServiceBuilder};
829
830                        #[cfg(feature = "compression")]
831                        let result = {
832                            use http::{Extensions, HeaderMap, Version};
833                            use tower_http::compression::{
834                                predicate::{Predicate, SizeAbove},
835                                CompressionLayer,
836                            };
837
838                            let compression_config = svc.get_ref().compression.clone();
839                            if let Some(comp) = &compression_config {
840                                let min_bytes = comp.min_body_bytes;
841                                let mut layer = CompressionLayer::new().zstd(true);
842                                if let Some(level) = comp.level {
843                                    use tower_http::compression::CompressionLevel;
844                                    layer = layer.quality(CompressionLevel::Precise(level as i32));
845                                }
846                                let not_pre_compressed =
847                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
848                                        !h.contains_key(http::header::CONTENT_ENCODING)
849                                    };
850                                let not_no_transform =
851                                    |_: StatusCode, _: Version, h: &HeaderMap, _: &Extensions| {
852                                        h.get(http::header::CACHE_CONTROL)
853                                            .and_then(|v| v.to_str().ok())
854                                            .map(|v| {
855                                                !v.split(',').any(|d| {
856                                                    d.trim().eq_ignore_ascii_case("no-transform")
857                                                })
858                                            })
859                                            .unwrap_or(true)
860                                    };
861                                let predicate =
862                                    SizeAbove::new(min_bytes.min(u16::MAX as usize) as u16)
863                                        .and(not_pre_compressed)
864                                        .and(not_no_transform);
865                                if load_shed_enabled {
866                                    use tower::load_shed::LoadShedLayer;
867                                    let stk = TowerErrorHandler(
868                                        ServiceBuilder::new()
869                                            .layer(LoadShedLayer::new())
870                                            .layer(TimeoutLayer::new(timeout_dur))
871                                            .service(svc),
872                                    );
873                                    hyper::server::conn::http1::Builder::new()
874                                        .max_buf_size(effective_header_limit)
875                                        .max_headers(128)
876                                        .serve_connection(
877                                            io,
878                                            TowerToHyperService::new(
879                                                ServiceBuilder::new()
880                                                    .layer(layer.compress_when(predicate))
881                                                    .service(stk),
882                                            ),
883                                        )
884                                        .with_upgrades()
885                                        .await
886                                } else {
887                                    let stk = TowerErrorHandler(
888                                        ServiceBuilder::new()
889                                            .layer(TimeoutLayer::new(timeout_dur))
890                                            .service(svc),
891                                    );
892                                    hyper::server::conn::http1::Builder::new()
893                                        .max_buf_size(effective_header_limit)
894                                        .max_headers(128)
895                                        .serve_connection(
896                                            io,
897                                            TowerToHyperService::new(
898                                                ServiceBuilder::new()
899                                                    .layer(layer.compress_when(predicate))
900                                                    .service(stk),
901                                            ),
902                                        )
903                                        .with_upgrades()
904                                        .await
905                                }
906                            } else if load_shed_enabled {
907                                use tower::load_shed::LoadShedLayer;
908                                let stk = TowerErrorHandler(
909                                    ServiceBuilder::new()
910                                        .layer(LoadShedLayer::new())
911                                        .layer(TimeoutLayer::new(timeout_dur))
912                                        .service(svc),
913                                );
914                                hyper::server::conn::http1::Builder::new()
915                                    .max_buf_size(effective_header_limit)
916                                    .max_headers(128)
917                                    .serve_connection(io, TowerToHyperService::new(stk))
918                                    .with_upgrades()
919                                    .await
920                            } else {
921                                let stk = TowerErrorHandler(
922                                    ServiceBuilder::new()
923                                        .layer(TimeoutLayer::new(timeout_dur))
924                                        .service(svc),
925                                );
926                                hyper::server::conn::http1::Builder::new()
927                                    .max_buf_size(effective_header_limit)
928                                    .max_headers(128)
929                                    .serve_connection(io, TowerToHyperService::new(stk))
930                                    .with_upgrades()
931                                    .await
932                            }
933                        };
934                        #[cfg(not(feature = "compression"))]
935                        let result = if load_shed_enabled {
936                            use tower::load_shed::LoadShedLayer;
937                            let stk = TowerErrorHandler(
938                                ServiceBuilder::new()
939                                    .layer(LoadShedLayer::new())
940                                    .layer(TimeoutLayer::new(timeout_dur))
941                                    .service(svc),
942                            );
943                            hyper::server::conn::http1::Builder::new()
944                                .max_buf_size(effective_header_limit)
945                                .max_headers(128)
946                                .serve_connection(io, TowerToHyperService::new(stk))
947                                .with_upgrades()
948                                .await
949                        } else {
950                            let stk = TowerErrorHandler(
951                                ServiceBuilder::new()
952                                    .layer(TimeoutLayer::new(timeout_dur))
953                                    .service(svc),
954                            );
955                            hyper::server::conn::http1::Builder::new()
956                                .max_buf_size(effective_header_limit)
957                                .max_headers(128)
958                                .serve_connection(io, TowerToHyperService::new(stk))
959                                .with_upgrades()
960                                .await
961                        };
962
963                        if let Err(e) = result {
964                            tracing::debug!("iroh-http: http1 connection error: {e}");
965                        }
966                    });
967                }
968            });
969        }
970
971        // Graceful drain: wait for all in-flight requests to finish,
972        // or give up after `drain_timeout`.
973        //
974        // Loop avoids the race between `in_flight == 0` check and `notified()`:
975        // if the last request finishes between the load and the await, the loop
976        // re-checks immediately after the timeout wakes it.
977        let deadline = tokio::time::Instant::now()
978            .checked_add(drain_dur)
979            .expect("drain duration overflow");
980        loop {
981            if in_flight_drain.load(Ordering::Acquire) == 0 {
982                tracing::info!("iroh-http: all in-flight requests drained");
983                break;
984            }
985            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
986            if remaining.is_zero() {
987                tracing::warn!("iroh-http: drain timed out after {}s", drain_dur.as_secs());
988                break;
989            }
990            tokio::select! {
991                _ = drain_notify_drain.notified() => {}
992                _ = tokio::time::sleep(remaining) => {}
993            }
994        }
995        let _ = done_tx.send(true);
996    });
997
998    ServeHandle {
999        join,
1000        shutdown_notify,
1001        drain_timeout: drain_dur,
1002        done_rx,
1003    }
1004}
1005
1006// ── TowerErrorHandler — maps Tower layer errors to HTTP responses ─────────────
1007//
1008// `ConcurrencyLimitLayer`, `TimeoutLayer`, and `LoadShedLayer` return errors
1009// rather than `Response` values when they reject a request.  `TowerErrorHandler`
1010// wraps the composed service and converts those errors to proper HTTP responses:
1011//
1012//   tower::timeout::error::Elapsed     → 408 Request Timeout
1013//   tower::load_shed::error::Overloaded → 503 Service Unavailable
1014//   anything else                       → 500 Internal Server Error
1015//
1016// This allows the whole stack to satisfy hyper's requirement that the service
1017// returns `Ok(Response)` — errors crash the connection instead of producing a
1018// status code.
1019
1020#[derive(Clone)]
1021struct TowerErrorHandler<S>(S);
1022
1023impl<S, Req> Service<Req> for TowerErrorHandler<S>
1024where
1025    S: Service<Req, Response = hyper::Response<BoxBody>>,
1026    S::Error: Into<BoxError>,
1027    S::Future: Send + 'static,
1028{
1029    type Response = hyper::Response<BoxBody>;
1030    type Error = BoxError;
1031    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1032
1033    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1034        // If ConcurrencyLimitLayer is saturated AND LoadShed is present, it
1035        // returns Pending from poll_ready — LoadShed converts that to an
1036        // immediate Err(Overloaded).  If LoadShed is absent, poll_ready blocks
1037        // until a slot opens.  In both cases we propagate the result here.
1038        self.0.poll_ready(cx).map_err(Into::into)
1039    }
1040
1041    fn call(&mut self, req: Req) -> Self::Future {
1042        let fut = self.0.call(req);
1043        Box::pin(async move {
1044            match fut.await {
1045                Ok(r) => Ok(r),
1046                Err(e) => {
1047                    let e = e.into();
1048                    let status = if e.is::<tower::timeout::error::Elapsed>() {
1049                        StatusCode::REQUEST_TIMEOUT
1050                    } else if e.is::<tower::load_shed::error::Overloaded>() {
1051                        StatusCode::SERVICE_UNAVAILABLE
1052                    } else {
1053                        tracing::warn!("iroh-http: unexpected tower error: {e}");
1054                        StatusCode::INTERNAL_SERVER_ERROR
1055                    };
1056                    let body_bytes: &'static [u8] = match status {
1057                        StatusCode::REQUEST_TIMEOUT => b"request timed out",
1058                        StatusCode::SERVICE_UNAVAILABLE => b"server at capacity",
1059                        _ => b"internal server error",
1060                    };
1061                    Ok(hyper::Response::builder()
1062                        .status(status)
1063                        .body(crate::box_body(http_body_util::Full::new(
1064                            Bytes::from_static(body_bytes),
1065                        )))
1066                        .expect("valid error response"))
1067                }
1068            }
1069        })
1070    }
1071}