Skip to main content

edgeguard/
proxy.rs

1//! Request path: header-size limit -> rate limit (per-IP / per-route) -> auth -> per-key
2//! rate limit -> method allowlist -> body-size limit -> WAF input inspection -> forward to
3//! upstream.
4//! Response path: header injection (incl. CSP / CSP-report-only) -> cookie hardening ->
5//! strip leaky headers.
6//!
7//! All policy lives in [`Runtime`], held behind an [`ArcSwap`] so a config hot-reload swaps
8//! it atomically without blocking the request path or dropping in-flight connections. The
9//! upstream client and the metric registry sit *outside* the swap so the connection pool and
10//! counters survive a reload.
11
12use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use arc_swap::ArcSwap;
20use axum::{
21    body::{Body, Bytes},
22    extract::{ConnectInfo, State},
23    http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
24};
25use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
26use http_body_util::{BodyExt, Full, Limited};
27use hyper::body::{Body as HttpBody, Frame, SizeHint};
28use hyper_util::client::legacy::{connect::HttpConnector, Client};
29use hyper_util::rt::TokioIo;
30use tokio::net::TcpStream;
31use tracing::{debug, info, warn};
32
33use crate::auth::{AuthEngine, Challenge, Decision};
34use crate::config::{Config, HeadersCfg};
35use crate::limiter::{Admit, DistributedLimiter};
36use crate::metrics::Metrics;
37use crate::waf::{WafEngine, WafMode};
38
39pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
40/// Rate limiter keyed by the authenticated principal (per-key limiting).
41pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
42pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
43
44/// Shared, cheaply-cloned handle the router hands to every request. Only the hot-swappable
45/// [`Runtime`] changes on reload; the client and metrics are stable.
46#[derive(Clone)]
47pub struct AppState {
48    pub client: UpstreamClient,
49    pub metrics: Arc<Metrics>,
50    pub runtime: Arc<ArcSwap<Runtime>>,
51    /// Managed-mode control-plane client (`Some` only when `[control_plane]` is enabled). Used to
52    /// forward CSP reports; policy pull + usage reporting run as background tasks in `main`.
53    pub cp: Option<Arc<crate::cp::CpClient>>,
54    /// Shared quota verdict, updated by the managed-mode quota poller and read by the
55    /// hard-stop gate below. Lives here (not on the hot-swappable [`Runtime`]) so a policy reload
56    /// never resets enforcement. Inert unless `control_plane.enforce_quota` is set.
57    pub quota: Arc<crate::cp::QuotaState>,
58}
59
60/// A per-route rate-limit override: requests whose path starts with `prefix` use `limiter`.
61pub struct RouteLimiter {
62    pub prefix: String,
63    pub limiter: Arc<KeyedLimiter>,
64}
65
66/// All request-handling policy derived from a [`Config`]. Rebuilt from scratch on reload and
67/// swapped in atomically.
68pub struct Runtime {
69    pub cfg: Arc<Config>,
70    /// Default upstream base URL (the single `server.upstream`/`app_port`), used when no
71    /// `[[upstreams]]` prefix matches.
72    pub upstream_base: Arc<String>,
73    /// Per-path-prefix upstream overrides as `(prefix, base)`; the longest matching prefix wins.
74    /// Empty unless `[[upstreams]]` is configured.
75    pub upstream_routes: Vec<(String, Arc<String>)>,
76    pub auth: AuthEngine,
77    /// WAF-lite input screener. Inert (`evaluate` returns `None`) when `waf.mode = "off"`.
78    pub waf: WafEngine,
79    /// Compiled CORS policy; `None` when `cors.enabled = false` (the proxy then skips CORS).
80    pub cors: Option<crate::cors::CorsPolicy>,
81    /// Compiled IP allow/deny policy; `None` when both lists are empty (no IP gating).
82    pub access: Option<crate::access::AccessPolicy>,
83    /// Shared-store (distributed) limiter, `Some` when `ratelimit.store` is `memory`/`redis`.
84    /// When present it replaces the three `governor` limiters below (which are then `None`).
85    pub distributed: Option<DistributedLimiter>,
86    /// Global per-client-IP limiter (`None` when rate limiting is disabled or distributed).
87    pub ip_limiter: Option<Arc<KeyedLimiter>>,
88    /// Per-route limiters (also keyed per IP), checked instead of `ip_limiter` on a match.
89    pub route_limiters: Vec<RouteLimiter>,
90    /// Per-principal limiter (`None` when per-key limiting is disabled or distributed).
91    pub key_limiter: Option<Arc<StrLimiter>>,
92    pub max_body: usize,
93    /// Cap on the buffered upstream response body; `0` means unbounded.
94    pub max_response_body: usize,
95    /// Cap on total request header bytes; `0` means disabled.
96    pub max_header_bytes: usize,
97    /// Max time for the upstream request + body read; `None` disables the timeout.
98    pub upstream_timeout: Option<Duration>,
99    /// Forward `text/event-stream` responses unbuffered (SSE passthrough). See
100    /// [`crate::config::ValidationCfg::stream_passthrough`].
101    pub stream_passthrough: bool,
102    /// Tunnel WebSocket / `Upgrade` connections to the upstream. See
103    /// [`crate::config::ValidationCfg::websocket_passthrough`].
104    pub websocket_passthrough: bool,
105}
106
107impl Runtime {
108    /// The upstream base URL to forward `path` to: the longest matching `[[upstreams]]` prefix,
109    /// or the default [`Runtime::upstream_base`] when none match.
110    pub fn pick_upstream(&self, path: &str) -> &str {
111        self.upstream_routes
112            .iter()
113            .filter(|(prefix, _)| path_prefix_matches(path, prefix))
114            .max_by_key(|(prefix, _)| prefix.len())
115            .map(|(_, base)| base.as_str())
116            .unwrap_or_else(|| self.upstream_base.as_str())
117    }
118}
119
120/// Whether `prefix` matches `path` on a path-segment boundary. `prefix` is a validated upstream
121/// route prefix (always starts with `/`); `path` is the request path-and-query. A plain
122/// `str::starts_with` would route a sibling like `/apiary` to the `/api` upstream, so the match
123/// only succeeds when the prefix is followed by a real boundary: end of path, a `/`, or the query
124/// separator `?`. A trailing slash on the prefix is itself a boundary.
125fn path_prefix_matches(path: &str, prefix: &str) -> bool {
126    if prefix == "/" {
127        return true;
128    }
129    match path.strip_prefix(prefix) {
130        Some(rest) => {
131            rest.is_empty()
132                || prefix.ends_with('/')
133                || rest.starts_with('/')
134                || rest.starts_with('?')
135        }
136        None => false,
137    }
138}
139
140/// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1).
141const HOP_BY_HOP: &[&str] = &[
142    "connection",
143    "keep-alive",
144    "proxy-authenticate",
145    "proxy-authorization",
146    "te",
147    "trailer",
148    "transfer-encoding",
149    "upgrade",
150];
151
152pub async fn handle(
153    State(state): State<AppState>,
154    ConnectInfo(peer): ConnectInfo<SocketAddr>,
155    req: Request<Body>,
156) -> Response<Body> {
157    // One atomic load pins a consistent policy snapshot for the whole request, even if a reload
158    // swaps in a new Runtime mid-flight — routing, auth, *and* the final CORS decoration below all
159    // see the same one (loading again here could decorate with a policy the request never used).
160    let rt = state.runtime.load_full();
161    // Capture the request Origin before the body is consumed, so we can CORS-decorate *every*
162    // response — including EdgeGuard-generated 401/403/429 — not just proxied successes. Without
163    // this, an allowed browser origin sees a generic CORS failure instead of the real status.
164    let origin = req
165        .headers()
166        .get(header::ORIGIN)
167        .and_then(|v| v.to_str().ok())
168        .map(str::to_owned);
169    let mut resp = handle_inner(&state, &rt, peer, req).await;
170    if let Some(origin) = &origin {
171        if let Some(cors) = &rt.cors {
172            cors.decorate_origin(origin, &mut resp);
173        }
174    }
175    resp
176}
177
178async fn handle_inner(
179    state: &AppState,
180    rt: &Runtime,
181    peer: SocketAddr,
182    req: Request<Body>,
183) -> Response<Body> {
184    let started = Instant::now();
185    let m = &state.metrics;
186
187    let method = req.method().clone();
188    let path = req
189        .uri()
190        .path_and_query()
191        .map(|p| p.as_str().to_string())
192        .unwrap_or_else(|| req.uri().path().to_string());
193
194    let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
195    // Request id for correlation: reuse a well-formed inbound one, else generate. Echoed on the
196    // response and the access log by `finish`, and forwarded upstream below.
197    let rid = resolve_request_id(req.headers());
198
199    // Reserve the internal namespace: never forward `/__edgeguard/*` upstream. Registered
200    // internal routes are matched before this fallback, so anything reaching here under that
201    // prefix is an unknown internal path — a `404` from EdgeGuard, not a request leaked to the
202    // app. This is also what keeps the ops endpoints (health/ready/metrics) unserved on the
203    // public listener in public/private split mode, rather than proxying them to the upstream.
204    if req.uri().path().starts_with("/__edgeguard/") {
205        return finish(
206            m,
207            &rid,
208            &method,
209            &path,
210            ip,
211            started,
212            "not_found",
213            text(StatusCode::NOT_FOUND, "Not Found"),
214        );
215    }
216
217    // 0) IP access control. A coarse network gate (CIDR allow/deny) evaluated before auth and
218    //    rate limiting, so a denied/non-allowlisted client is dropped with `403` before consuming
219    //    any limiter token or auth work. Keys on the same resolved client IP as rate limiting.
220    if let Some(access) = &rt.access {
221        if !access.allowed(ip) {
222            return finish(
223                m,
224                &rid,
225                &method,
226                &path,
227                ip,
228                started,
229                "ip_denied",
230                text(StatusCode::FORBIDDEN, "Forbidden"),
231            );
232        }
233    }
234
235    // 0.1) Total request-header-size limit.
236    if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
237        return finish(
238            m,
239            &rid,
240            &method,
241            &path,
242            ip,
243            started,
244            "header_too_large",
245            text(
246                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
247                "Request Header Fields Too Large",
248            ),
249        );
250    }
251
252    // 0.5) Quota hard-stop (managed mode, opt-in). When the control plane reports the
253    //      edge over its quota, reject the edge's traffic with `429` and a
254    //      month-scale `Retry-After`, until the next successful poll clears it. Off unless
255    //      `control_plane.enforce_quota` is set; the `/__edgeguard/*` endpoints are excluded above,
256    //      so health/ready/metrics keep serving even while over quota.
257    if rt.cfg.control_plane.enforce_quota && state.quota.blocked() {
258        let mut resp = text(StatusCode::TOO_MANY_REQUESTS, "Quota Exceeded");
259        let reset = state.quota.reset_epoch();
260        if reset > 0 {
261            let now = std::time::SystemTime::now()
262                .duration_since(std::time::UNIX_EPOCH)
263                .map(|d| d.as_secs() as i64)
264                .unwrap_or(0);
265            let retry_after = reset.saturating_sub(now).max(0);
266            if let Ok(v) = HeaderValue::from_str(&retry_after.to_string()) {
267                resp.headers_mut().insert(header::RETRY_AFTER, v);
268            }
269        }
270        return finish(m, &rid, &method, &path, ip, started, "over_quota", resp);
271    }
272
273    // 1) Rate limit. A matching per-route override replaces the global per-IP limit. A shared
274    //    store (distributed) limiter, when configured, replaces the in-process limiters; on a
275    //    store error it fails closed (`503`) unless `ratelimit.fail_open` is set.
276    if rt.cfg.ratelimit.enabled {
277        if let Some(d) = &rt.distributed {
278            match d.check_ip_route(ip, &path).await {
279                Admit::Allowed => {}
280                Admit::Limited(scope) => {
281                    m.record_ratelimit_hit(scope);
282                    return finish(
283                        m,
284                        &rid,
285                        &method,
286                        &path,
287                        ip,
288                        started,
289                        "rate_limited",
290                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
291                    );
292                }
293                Admit::Error => {
294                    return finish(
295                        m,
296                        &rid,
297                        &method,
298                        &path,
299                        ip,
300                        started,
301                        "limiter_error",
302                        text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
303                    );
304                }
305            }
306        } else {
307            let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
308                Some(r) => (Some(r.limiter.as_ref()), "route"),
309                None => (rt.ip_limiter.as_deref(), "ip"),
310            };
311            if let Some(limiter) = limiter {
312                if limiter.check_key(&ip).is_err() {
313                    m.record_ratelimit_hit(scope);
314                    return finish(
315                        m,
316                        &rid,
317                        &method,
318                        &path,
319                        ip,
320                        started,
321                        "rate_limited",
322                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
323                    );
324                }
325            }
326        }
327    }
328
329    // 1.5) CORS preflight. Answer a browser preflight (`OPTIONS` + `Origin` +
330    //      `Access-Control-Request-Method`) here, *before* auth: a preflight carries no
331    //      credentials, so gating it behind the auth check would make every cross-origin call
332    //      fail. Only a real preflight is short-circuited; a plain `OPTIONS` falls through.
333    if method == Method::OPTIONS {
334        if let Some(cors) = &rt.cors {
335            if let Some(resp) = cors.preflight_response(req.headers()) {
336                return finish(m, &rid, &method, &path, ip, started, "cors_preflight", resp);
337            }
338        }
339    }
340
341    // 2) Authentication. On success we learn the principal for per-key limiting.
342    let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
343        Decision::Allow(principal) => principal,
344        Decision::Deny(challenge) => {
345            let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
346            let challenge_value = match challenge {
347                Challenge::Basic(c) => Some(c),
348                Challenge::Bearer => Some("Bearer".to_string()),
349                Challenge::None => None,
350            };
351            if let Some(c) = challenge_value {
352                if let Ok(v) = HeaderValue::from_str(&c) {
353                    resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
354                }
355            }
356            return finish(m, &rid, &method, &path, ip, started, "unauthorized", resp);
357        }
358    };
359
360    // 3) Per-key rate limit (only for authenticated principals). Routed to the distributed
361    //    limiter when configured, else the in-process per-key limiter.
362    if let Some(principal) = &principal {
363        let key_admit = if let Some(d) = &rt.distributed {
364            Some(d.check_key(principal).await)
365        } else {
366            rt.key_limiter.as_ref().map(|limiter| {
367                if limiter.check_key(principal).is_err() {
368                    Admit::Limited("key")
369                } else {
370                    Admit::Allowed
371                }
372            })
373        };
374        match key_admit {
375            Some(Admit::Limited(scope)) => {
376                m.record_ratelimit_hit(scope);
377                return finish(
378                    m,
379                    &rid,
380                    &method,
381                    &path,
382                    ip,
383                    started,
384                    "rate_limited",
385                    text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
386                );
387            }
388            Some(Admit::Error) => {
389                return finish(
390                    m,
391                    &rid,
392                    &method,
393                    &path,
394                    ip,
395                    started,
396                    "limiter_error",
397                    text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
398                );
399            }
400            Some(Admit::Allowed) | None => {}
401        }
402    }
403
404    // 4) Method allowlist.
405    let allow = &rt.cfg.validation.allow_methods;
406    if !allow.is_empty()
407        && !allow
408            .iter()
409            .any(|x| x.eq_ignore_ascii_case(method.as_str()))
410    {
411        return finish(
412            m,
413            &rid,
414            &method,
415            &path,
416            ip,
417            started,
418            "method_not_allowed",
419            text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
420        );
421    }
422
423    // 4.5) WebSocket / `Upgrade` passthrough (opt-in). An upgrade request can't go through the
424    //      buffer-and-forward path below — it needs a raw bidirectional tunnel. When enabled, hand
425    //      off to `proxy_upgrade`, which forwards the request *with* its upgrade headers (the
426    //      normal path strips them) and splices the connections on a `101`. The request is already
427    //      authenticated and rate-limited at this point. When disabled (default), fall through and
428    //      the upgrade headers are stripped like any other hop-by-hop header.
429    if rt.websocket_passthrough && is_upgrade_request(req.headers()) {
430        return proxy_upgrade(state, rt, req, &rid, &method, &path, ip, started).await;
431    }
432
433    // 5) Buffer the body up to the configured limit.
434    let (parts, body) = req.into_parts();
435    let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
436        Ok(b) => b,
437        Err(_) => {
438            return finish(
439                m,
440                &rid,
441                &method,
442                &path,
443                ip,
444                started,
445                "payload_too_large",
446                text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
447            )
448        }
449    };
450    // Request (ingress) size for managed-mode usage, captured before the body is forwarded upstream.
451    let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
452
453    // 6) WAF-lite input inspection. A no-op unless `waf.mode` is report/block. The body is
454    //    already buffered above, so inspecting it adds no extra read. On a match: `block` mode
455    //    returns 403; `report` mode logs + counts and forwards. Both record the hit so a
456    //    report-only rollout shows up in `edgeguard_waf_hits_total`.
457    if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
458        m.record_waf_hit(hit.class);
459        match rt.waf.mode() {
460            WafMode::Block => {
461                warn!(
462                    rule = %hit.rule_id,
463                    class = hit.class,
464                    location = hit.location,
465                    client_ip = %ip,
466                    path = %path,
467                    "WAF blocked request"
468                );
469                return finish(
470                    m,
471                    &rid,
472                    &method,
473                    &path,
474                    ip,
475                    started,
476                    "forbidden",
477                    text(StatusCode::FORBIDDEN, "Forbidden"),
478                );
479            }
480            WafMode::Report => warn!(
481                rule = %hit.rule_id,
482                class = hit.class,
483                location = hit.location,
484                client_ip = %ip,
485                path = %path,
486                "WAF rule matched (report-only)"
487            ),
488            // `evaluate` returns `None` when off, so this arm is unreachable; kept for
489            // exhaustiveness.
490            WafMode::Off => {}
491        }
492    }
493
494    // 7) Build the upstream request (the per-path upstream override, or the default).
495    let uri = format!("{}{}", rt.pick_upstream(&path), path);
496    let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
497    {
498        let headers = up.headers_mut().expect("builder headers");
499        // Drop hop-by-hop headers (the fixed set plus any named by `Connection`) before
500        // forwarding, so they don't leak across the proxy boundary.
501        let mut forwarded = parts.headers.clone();
502        strip_hop_by_hop(&mut forwarded);
503        for (name, value) in forwarded.iter() {
504            if name == header::HOST {
505                continue; // let the client set Host for the upstream
506            }
507            headers.insert(name.clone(), value.clone());
508        }
509        // Standard forwarding headers.
510        if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
511            headers.insert(HeaderName::from_static("x-forwarded-for"), v);
512        }
513        headers.insert(
514            HeaderName::from_static("x-forwarded-proto"),
515            HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
516        );
517        // Forward the (resolved/generated) request id so the upstream logs the same correlation id.
518        if let Ok(v) = HeaderValue::from_str(&rid) {
519            headers.insert(HeaderName::from_static(REQUEST_ID_HEADER), v);
520        }
521    }
522
523    let upstream_req = match up.body(Full::new(body_bytes)) {
524        Ok(r) => r,
525        Err(e) => {
526            warn!(error = %e, "failed to build upstream request");
527            return finish(
528                m,
529                &rid,
530                &method,
531                &path,
532                ip,
533                started,
534                "bad_gateway",
535                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
536            );
537        }
538    };
539
540    // 8) Forward and collect the response under a single deadline, so a stalled upstream
541    //    can't pin this task. `None` => no timeout (validation.upstream_timeout = "0").
542    let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
543    let timed_out = || {
544        warn!(upstream = %uri, "upstream timed out");
545        text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
546    };
547
548    let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
549        Ok(Ok(r)) => r,
550        Ok(Err(e)) => {
551            warn!(error = %e, upstream = %uri, "upstream unreachable");
552            return finish(
553                m,
554                &rid,
555                &method,
556                &path,
557                ip,
558                started,
559                "upstream_error",
560                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
561            );
562        }
563        Err(_) => {
564            return finish(
565                m,
566                &rid,
567                &method,
568                &path,
569                ip,
570                started,
571                "upstream_timeout",
572                timed_out(),
573            )
574        }
575    };
576
577    let (mut resp_parts, resp_body) = upstream_resp.into_parts();
578
579    // 8a) SSE passthrough: forward a `text/event-stream` response frame-by-frame instead of
580    //     buffering the whole body, so the client sees events as they arrive (time-to-first-byte
581    //     is preserved). The buffering path below would hold the entire stream until the upstream
582    //     finished, which defeats SSE. On a streamed body the `max_response_body` cap and the
583    //     body-read deadline don't apply — the connect/first-byte `upstream_timeout` already
584    //     bounded time-to-headers — and egress bytes are tallied by `CountingBody` as frames flow.
585    //     Response hardening is headers-only, so it stays correct on a streaming body.
586    if rt.stream_passthrough && is_event_stream(&resp_parts.headers) {
587        strip_hop_by_hop(&mut resp_parts.headers);
588        resp_parts.headers.remove(header::CONTENT_LENGTH);
589        let header_egress = header_bytes(&resp_parts.headers);
590        let body = Body::new(CountingBody::new(
591            resp_body,
592            Arc::clone(m),
593            ingress_bytes,
594            header_egress,
595        ));
596        let mut response = Response::from_parts(resp_parts, body);
597        harden_response(&rt.cfg, &mut response);
598        // CORS decoration happens centrally in `handle` (covers this and every error path).
599        return finish(m, &rid, &method, &path, ip, started, "ok", response);
600    }
601
602    // Buffer the upstream body, optionally capped so a huge response can't OOM the proxy.
603    let resp_bytes = if rt.max_response_body > 0 {
604        match within(
605            deadline,
606            Limited::new(resp_body, rt.max_response_body).collect(),
607        )
608        .await
609        {
610            Ok(Ok(c)) => c.to_bytes(),
611            Ok(Err(_)) => {
612                warn!(
613                    limit = rt.max_response_body,
614                    "upstream response exceeded max_response_body"
615                );
616                return finish(
617                    m,
618                    &rid,
619                    &method,
620                    &path,
621                    ip,
622                    started,
623                    "upstream_body_too_large",
624                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
625                );
626            }
627            Err(_) => {
628                return finish(
629                    m,
630                    &rid,
631                    &method,
632                    &path,
633                    ip,
634                    started,
635                    "upstream_timeout",
636                    timed_out(),
637                )
638            }
639        }
640    } else {
641        match within(deadline, resp_body.collect()).await {
642            Ok(Ok(c)) => c.to_bytes(),
643            Ok(Err(e)) => {
644                warn!(error = %e, "failed reading upstream body");
645                return finish(
646                    m,
647                    &rid,
648                    &method,
649                    &path,
650                    ip,
651                    started,
652                    "upstream_body_error",
653                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
654                );
655            }
656            Err(_) => {
657                return finish(
658                    m,
659                    &rid,
660                    &method,
661                    &path,
662                    ip,
663                    started,
664                    "upstream_timeout",
665                    timed_out(),
666                )
667            }
668        }
669    };
670
671    // The body was rebuffered, so let the server recompute framing; strip hop-by-hop headers
672    // (incl. any named by `Connection`) so they don't leak downstream.
673    strip_hop_by_hop(&mut resp_parts.headers);
674    resp_parts.headers.remove(header::CONTENT_LENGTH);
675
676    // Managed-mode usage: this is the proxied path, where both bodies are buffered, so the byte
677    // counts are exact. (`add_usage_request` is recorded for every request in `finish`.)
678    m.add_usage_bytes(
679        ingress_bytes,
680        header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
681    );
682
683    let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
684    harden_response(&rt.cfg, &mut response);
685    // CORS decoration happens centrally in `handle` (covers this and every error path).
686
687    finish(m, &rid, &method, &path, ip, started, "ok", response)
688}
689
690/// Readiness probe. Returns `200` only if the upstream accepts a TCP connection, so a
691/// platform's readiness check reflects whether EdgeGuard can actually serve traffic — not
692/// merely that the process booted. `503` while the upstream is unreachable. (Liveness, i.e.
693/// "is EdgeGuard itself up", is the separate unconditional `/__edgeguard/health`.)
694pub async fn ready(State(state): State<AppState>) -> StatusCode {
695    let rt = state.runtime.load();
696    let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
697        return StatusCode::SERVICE_UNAVAILABLE;
698    };
699    match tokio::time::timeout(
700        Duration::from_secs(2),
701        TcpStream::connect((host.as_str(), port)),
702    )
703    .await
704    {
705        Ok(Ok(_)) => StatusCode::OK,
706        _ => StatusCode::SERVICE_UNAVAILABLE,
707    }
708}
709
710/// Prometheus scrape endpoint (`GET /__edgeguard/metrics`). Like health/ready, it is a
711/// dedicated route outside the proxy fallback, so it is not subject to auth or rate limits —
712/// restrict access to `/__edgeguard/*` at the network layer if that matters in your setup.
713pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
714    let body = state.metrics.render();
715    let mut resp = Response::new(Body::from(body));
716    resp.headers_mut().insert(
717        header::CONTENT_TYPE,
718        HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
719    );
720    resp
721}
722
723/// CSP violation report sink (`POST /__edgeguard/csp-report`). Browsers POST a JSON report
724/// here when `headers.csp_report_uri` points at it; we count and log it, then `204`.
725pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
726    state.metrics.record_csp_report();
727    // Managed mode: forward the raw report to the control plane (fire-and-forget, so the browser's
728    // 204 is never delayed by an outbound call). Only when a control plane is configured and
729    // `forward_csp` is on.
730    if let Some(cp) = &state.cp {
731        if state.runtime.load().cfg.control_plane.forward_csp {
732            let cp = cp.clone();
733            let raw = body.clone();
734            tokio::spawn(async move { cp.forward_csp(&raw).await });
735        }
736    }
737    // This endpoint is unauthenticated and a report can carry the full document URL,
738    // referrer, and query strings — logging the whole blob at `info` is both a privacy leak
739    // and a log-flood vector. Record only the directive that fired, at `debug`.
740    match serde_json::from_slice::<serde_json::Value>(&body) {
741        Ok(report) => {
742            let directive = report
743                .get("csp-report")
744                .and_then(|r| {
745                    r.get("violated-directive")
746                        .or_else(|| r.get("effective-directive"))
747                })
748                .and_then(|v| v.as_str())
749                .unwrap_or("unknown");
750            debug!(target: "edgeguard::csp", directive, "CSP violation report");
751        }
752        Err(_) => warn!(
753            bytes = body.len(),
754            "CSP violation report with an unparseable body"
755        ),
756    }
757    StatusCode::NO_CONTENT
758}
759
760/// Header EdgeGuard reads an inbound request id from and echoes on every response. A
761/// `&'static str` (rather than a `HeaderName` const, which isn't a const fn) — `HeaderMap`'s
762/// `get`/`insert` accept it directly.
763const REQUEST_ID_HEADER: &str = "x-request-id";
764
765/// Resolve the request id for log correlation: reuse a well-formed inbound `X-Request-Id` (one a
766/// CDN/LB already set), else mint a UUID v4. The inbound value is trusted only when it's a short,
767/// printable-ASCII token, so a hostile client can't inject newlines/control characters into the
768/// access log or the echoed response header.
769fn resolve_request_id(headers: &HeaderMap) -> String {
770    if let Some(v) = headers.get(REQUEST_ID_HEADER).and_then(|v| v.to_str().ok()) {
771        let v = v.trim();
772        if !v.is_empty() && v.len() <= 128 && v.bytes().all(|b| b.is_ascii_graphic()) {
773            return v.to_string();
774        }
775    }
776    uuid::Uuid::new_v4().to_string()
777}
778
779/// Resolve the client IP. The peer socket address is authoritative; `X-Forwarded-For`
780/// (first hop) is honored only when `trust_forwarded` is set, because a directly
781/// reachable client can otherwise spoof it to forge their identity.
782fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
783    if trust_forwarded {
784        if let Some(xff) = headers.get("x-forwarded-for") {
785            if let Ok(s) = xff.to_str() {
786                if let Some(first) = s.split(',').next() {
787                    if let Ok(ip) = first.trim().parse::<IpAddr>() {
788                        return ip;
789                    }
790                }
791            }
792        }
793    }
794    peer.ip()
795}
796
797/// Total size of the request headers (sum of name + value bytes), used for the header-size
798/// policy limit. This is an application-layer approximation of the on-wire header size.
799fn header_bytes(headers: &HeaderMap) -> usize {
800    headers
801        .iter()
802        .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
803        .sum()
804}
805
806/// True if the response is a Server-Sent Events stream (`Content-Type: text/event-stream`,
807/// ignoring any `; charset=…` parameter and leading whitespace). The signal we use to forward a
808/// response unbuffered when `validation.stream_passthrough` is on.
809fn is_event_stream(headers: &HeaderMap) -> bool {
810    headers
811        .get(header::CONTENT_TYPE)
812        .and_then(|v| v.to_str().ok())
813        .map(|v| {
814            v.split(';')
815                .next()
816                .map(str::trim)
817                .map(|ct| ct.eq_ignore_ascii_case("text/event-stream"))
818                .unwrap_or(false)
819        })
820        .unwrap_or(false)
821}
822
823/// True when the request asks to upgrade the protocol — a `Connection: upgrade` token plus an
824/// `Upgrade` header (e.g. a WebSocket handshake). The signal for [`proxy_upgrade`].
825fn is_upgrade_request(headers: &HeaderMap) -> bool {
826    let conn_has_upgrade = headers
827        .get_all(header::CONNECTION)
828        .iter()
829        .filter_map(|v| v.to_str().ok())
830        .flat_map(|v| v.split(','))
831        .any(|t| t.trim().eq_ignore_ascii_case("upgrade"));
832    conn_has_upgrade && headers.contains_key(header::UPGRADE)
833}
834
835/// Tunnel a WebSocket / `Upgrade` request to the upstream. Unlike the normal path (which strips
836/// the hop-by-hop `Upgrade`/`Connection` headers), this forwards the handshake intact; on the
837/// upstream's `101 Switching Protocols` it splices the client and upstream connections into a raw
838/// bidirectional byte tunnel for the lifetime of the socket. Any other upstream status is passed
839/// back to the client unchanged, so a rejected handshake surfaces normally.
840// Mirrors the `handle` forward path's parameters (state/runtime/request + the access-log tuple);
841// see the note on `finish`.
842#[allow(clippy::too_many_arguments)]
843async fn proxy_upgrade(
844    state: &AppState,
845    rt: &Runtime,
846    mut req: Request<Body>,
847    request_id: &str,
848    method: &Method,
849    path: &str,
850    ip: IpAddr,
851    started: Instant,
852) -> Response<Body> {
853    let m = &state.metrics;
854
855    // The client-side upgrade future: once we return a `101`, the server completes it and yields
856    // the raw client connection. Take it (removing the extension from `req`) before forwarding.
857    let client_upgrade = hyper::upgrade::on(&mut req);
858
859    // Build the upstream request: copy end-to-end headers AND the upgrade/connection headers
860    // (the handshake needs them), add the forwarding headers, send an empty body.
861    let uri = format!("{}{}", rt.pick_upstream(path), path);
862    let mut up = Request::builder().method(req.method().clone()).uri(&uri);
863    {
864        let headers = up.headers_mut().expect("builder headers");
865        // Strip hop-by-hop headers (the fixed set + any named by `Connection`) before forwarding,
866        // so a client can't smuggle connection-scoped headers upstream — then re-add the handshake
867        // headers the upgrade itself needs (`Connection: upgrade` + the requested `Upgrade`).
868        let upgrade = req.headers().get(header::UPGRADE).cloned();
869        let mut forwarded = req.headers().clone();
870        strip_hop_by_hop(&mut forwarded);
871        for (name, value) in forwarded.iter() {
872            if name == header::HOST {
873                continue;
874            }
875            headers.insert(name.clone(), value.clone());
876        }
877        headers.insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
878        if let Some(v) = upgrade {
879            headers.insert(header::UPGRADE, v);
880        }
881        if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
882            headers.insert(HeaderName::from_static("x-forwarded-for"), v);
883        }
884        headers.insert(
885            HeaderName::from_static("x-forwarded-proto"),
886            HeaderValue::from_static(forwarded_proto(&rt.cfg, req.headers())),
887        );
888        if let Ok(v) = HeaderValue::from_str(request_id) {
889            headers.insert(HeaderName::from_static(REQUEST_ID_HEADER), v);
890        }
891    }
892    let upstream_req = match up.body(Full::new(Bytes::new())) {
893        Ok(r) => r,
894        Err(e) => {
895            warn!(error = %e, "failed to build upstream upgrade request");
896            return finish(
897                m,
898                request_id,
899                method,
900                path,
901                ip,
902                started,
903                "bad_gateway",
904                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
905            );
906        }
907    };
908
909    // Bound the handshake by the same `upstream_timeout` as the buffered path, so a stalled
910    // upstream can't pin this task (a `None` deadline means no timeout).
911    let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
912    let timed_out = || {
913        warn!(upstream = %uri, "upstream timed out (upgrade)");
914        finish(
915            m,
916            request_id,
917            method,
918            path,
919            ip,
920            started,
921            "upstream_timeout",
922            text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout"),
923        )
924    };
925
926    let mut up_resp = match within(deadline, state.client.request(upstream_req)).await {
927        Ok(Ok(r)) => r,
928        Ok(Err(e)) => {
929            warn!(error = %e, upstream = %uri, "upstream unreachable (upgrade)");
930            return finish(
931                m,
932                request_id,
933                method,
934                path,
935                ip,
936                started,
937                "upstream_error",
938                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
939            );
940        }
941        Err(_) => return timed_out(),
942    };
943
944    // Upstream declined to upgrade: forward its response as-is (the client sees the rejection),
945    // but under the same deadline and `max_response_body` cap as the normal buffered path so a
946    // rejected handshake can't hang or buffer an unbounded body.
947    if up_resp.status() != StatusCode::SWITCHING_PROTOCOLS {
948        let (mut parts, body) = up_resp.into_parts();
949        // Collect the rejection body, capped by `max_response_body` when set. Both arms normalize
950        // any read/limit error to `()` — the distinction doesn't change the `502` we return.
951        let body_fut = async {
952            if rt.max_response_body > 0 {
953                Limited::new(body, rt.max_response_body)
954                    .collect()
955                    .await
956                    .map(|c| c.to_bytes())
957                    .map_err(|_| ())
958            } else {
959                body.collect().await.map(|c| c.to_bytes()).map_err(|_| ())
960            }
961        };
962        let bytes = match within(deadline, body_fut).await {
963            Ok(Ok(b)) => b,
964            Ok(Err(())) => {
965                warn!("upstream upgrade-rejection body failed or exceeded max_response_body");
966                return finish(
967                    m,
968                    request_id,
969                    method,
970                    path,
971                    ip,
972                    started,
973                    "bad_gateway",
974                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
975                );
976            }
977            Err(_) => return timed_out(),
978        };
979        strip_hop_by_hop(&mut parts.headers);
980        parts.headers.remove(header::CONTENT_LENGTH);
981        let mut response = Response::from_parts(parts, Body::from(bytes));
982        harden_response(&rt.cfg, &mut response);
983        return finish(m, request_id, method, path, ip, started, "ok", response);
984    }
985
986    // `101`: wire up the upstream-side upgrade and splice the two connections once both complete.
987    let upstream_upgrade = hyper::upgrade::on(&mut up_resp);
988    tokio::spawn(async move {
989        match tokio::join!(client_upgrade, upstream_upgrade) {
990            (Ok(client_io), Ok(up_io)) => {
991                let mut client_io = TokioIo::new(client_io);
992                let mut up_io = TokioIo::new(up_io);
993                if let Err(e) = tokio::io::copy_bidirectional(&mut client_io, &mut up_io).await {
994                    debug!(error = %e, "websocket tunnel closed");
995                }
996            }
997            (c, u) => warn!(
998                client_ok = c.is_ok(),
999                upstream_ok = u.is_ok(),
1000                "websocket upgrade did not complete"
1001            ),
1002        }
1003    });
1004
1005    // Return the upstream's `101` — its headers carry `Sec-WebSocket-Accept` etc., and returning a
1006    // `101` is what makes the server upgrade the client side (completing `client_upgrade` above).
1007    // Strip hop-by-hop headers (the fixed set + any named by `Connection`) so the upstream can't
1008    // leak connection-scoped headers downstream, then re-add the handshake headers the upgrade
1009    // itself needs (`Connection: upgrade` + the negotiated `Upgrade`).
1010    let (mut parts, _body) = up_resp.into_parts();
1011    let upgrade = parts.headers.get(header::UPGRADE).cloned();
1012    strip_hop_by_hop(&mut parts.headers);
1013    parts.headers.remove(header::CONTENT_LENGTH);
1014    parts
1015        .headers
1016        .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
1017    if let Some(v) = upgrade {
1018        parts.headers.insert(header::UPGRADE, v);
1019    }
1020    let response = Response::from_parts(parts, Body::empty());
1021    finish(
1022        m,
1023        request_id,
1024        method,
1025        path,
1026        ip,
1027        started,
1028        "ws_upgrade",
1029        response,
1030    )
1031}
1032
1033/// Wraps a streaming upstream body to tally egress bytes (response headers + each data frame)
1034/// and report them to managed-mode usage when the body is dropped — i.e. after the final frame
1035/// is sent, or earlier if the client disconnects mid-stream (we count what actually went out).
1036/// Used for SSE passthrough: the body isn't buffered, so the exact byte count the buffered path
1037/// takes up front can only be accumulated as frames flow.
1038struct CountingBody<B> {
1039    inner: B,
1040    metrics: Arc<Metrics>,
1041    ingress: usize,
1042    /// Running egress total: response header bytes, then each data frame as it passes.
1043    egress: usize,
1044}
1045
1046impl<B> CountingBody<B> {
1047    fn new(inner: B, metrics: Arc<Metrics>, ingress: usize, header_egress: usize) -> Self {
1048        Self {
1049            inner,
1050            metrics,
1051            ingress,
1052            egress: header_egress,
1053        }
1054    }
1055}
1056
1057impl<B> HttpBody for CountingBody<B>
1058where
1059    B: HttpBody<Data = Bytes> + Unpin,
1060{
1061    type Data = Bytes;
1062    type Error = B::Error;
1063
1064    fn poll_frame(
1065        mut self: Pin<&mut Self>,
1066        cx: &mut Context<'_>,
1067    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
1068        let this = self.as_mut().get_mut();
1069        let polled = Pin::new(&mut this.inner).poll_frame(cx);
1070        if let Poll::Ready(Some(Ok(frame))) = &polled {
1071            if let Some(data) = frame.data_ref() {
1072                this.egress = this.egress.saturating_add(data.len());
1073            }
1074        }
1075        polled
1076    }
1077
1078    fn is_end_stream(&self) -> bool {
1079        self.inner.is_end_stream()
1080    }
1081
1082    fn size_hint(&self) -> SizeHint {
1083        self.inner.size_hint()
1084    }
1085}
1086
1087impl<B> Drop for CountingBody<B> {
1088    fn drop(&mut self) {
1089        self.metrics.add_usage_bytes(self.ingress, self.egress);
1090    }
1091}
1092
1093/// Remove hop-by-hop headers so they don't leak across the proxy boundary (RFC 7230 §6.1):
1094/// the fixed [`HOP_BY_HOP`] set plus any header *named* in a `Connection` header. Applied in
1095/// both directions (request to upstream, response to client).
1096fn strip_hop_by_hop(headers: &mut HeaderMap) {
1097    // Header names listed in any `Connection` header are connection-specific; collect them
1098    // before mutating (the borrow of `headers` must end before we remove).
1099    let connection_named: Vec<HeaderName> = headers
1100        .get_all(header::CONNECTION)
1101        .iter()
1102        .filter_map(|v| v.to_str().ok())
1103        .flat_map(|v| v.split(','))
1104        .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
1105        .collect();
1106    for name in HOP_BY_HOP {
1107        headers.remove(*name);
1108    }
1109    for name in connection_named {
1110        headers.remove(name);
1111    }
1112}
1113
1114/// Decide the `X-Forwarded-Proto` to send upstream. If EdgeGuard terminates TLS, the client
1115/// hop is HTTPS. Otherwise, behind a trusted edge (`trust_forwarded_for`) we preserve the
1116/// proto the edge reported (falling back to `http`); an untrusted client's `X-Forwarded-Proto`
1117/// is never honored, mirroring the client-IP trust model. Returns a `'static` token so the
1118/// caller can build a `HeaderValue` without fallible parsing.
1119fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
1120    if cfg.tls.enabled {
1121        return "https";
1122    }
1123    if cfg.server.trust_forwarded_for {
1124        if let Some(value) = headers
1125            .get("x-forwarded-proto")
1126            .and_then(|v| v.to_str().ok())
1127        {
1128            match value.split(',').next().map(str::trim) {
1129                Some(p) if p.eq_ignore_ascii_case("https") => return "https",
1130                Some(p) if p.eq_ignore_ascii_case("http") => return "http",
1131                _ => {}
1132            }
1133        }
1134    }
1135    "http"
1136}
1137
1138/// Pick the most specific (longest-prefix) per-route limiter matching `path`, if any.
1139fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
1140    routes
1141        .iter()
1142        .filter(|r| path.starts_with(&r.prefix))
1143        .max_by_key(|r| r.prefix.len())
1144}
1145
1146/// The HSTS header value EdgeGuard emits when `headers.hsts` is on: a two-year `max-age`
1147/// including subdomains. A named constant so the live proxy and the static-host config
1148/// generator ([`crate::generate`]) can't drift on it.
1149pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
1150
1151/// The constant security response headers EdgeGuard injects, derived from the `[headers]`
1152/// policy. This is the **single source of truth** shared by the live response-hardening path
1153/// ([`harden_response`]) and the static-host config generator ([`crate::generate`]), so a
1154/// generated `_headers` file / edge-middleware snippet matches exactly what the proxy would add
1155/// at runtime. Returns `(name, value)` pairs with canonically-cased names (for readable
1156/// generated output); the proxy normalizes the case when it inserts them.
1157///
1158/// Cookie hardening and leaky-header *stripping* are deliberately **not** here: both rewrite the
1159/// upstream's actual response (`Set-Cookie`, `Server`/`X-Powered-By`), which a static file that
1160/// can only "always add this header" cannot express. The generator documents that gap; the
1161/// WASM worker, which sees the real response, applies them too.
1162pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
1163    let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
1164    out.push(("X-Content-Type-Options", "nosniff".to_string()));
1165    if !cfg.frame_options.is_empty() {
1166        out.push(("X-Frame-Options", cfg.frame_options.clone()));
1167    }
1168    if !cfg.referrer_policy.is_empty() {
1169        out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
1170    }
1171    if !cfg.permissions_policy.is_empty() {
1172        out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
1173    }
1174    if !cfg.csp.is_empty() {
1175        // Append a report-uri directive if configured, and choose enforce vs. report-only.
1176        let mut value = cfg.csp.clone();
1177        if !cfg.csp_report_uri.is_empty() {
1178            value.push_str("; report-uri ");
1179            value.push_str(&cfg.csp_report_uri);
1180        }
1181        let name = if cfg.csp_report_only {
1182            "Content-Security-Policy-Report-Only"
1183        } else {
1184            "Content-Security-Policy"
1185        };
1186        out.push((name, value));
1187    }
1188    if cfg.hsts {
1189        out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
1190    }
1191    out
1192}
1193
1194/// Inject security headers, harden Set-Cookie, and strip leaky headers.
1195fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
1196    let h = resp.headers_mut();
1197
1198    // Inject the constant security headers (shared with the static-host generator via
1199    // `security_headers`, so the two never diverge). `from_bytes` normalizes the canonical
1200    // casing to lowercase; these names/values are all valid, so the inserts don't fail.
1201    for (name, value) in security_headers(&cfg.headers) {
1202        if let (Ok(n), Ok(v)) = (
1203            HeaderName::from_bytes(name.as_bytes()),
1204            HeaderValue::from_str(&value),
1205        ) {
1206            h.insert(n, v);
1207        }
1208    }
1209
1210    // Strip leaky headers.
1211    for name in &cfg.headers.strip {
1212        if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
1213            h.remove(hn);
1214        }
1215    }
1216
1217    // Harden cookies: ensure Secure, HttpOnly, and a SameSite default.
1218    if cfg.headers.force_secure_cookies {
1219        let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
1220        if !cookies.is_empty() {
1221            h.remove(header::SET_COOKIE);
1222            for c in cookies {
1223                if let Ok(s) = c.to_str() {
1224                    let hardened = harden_cookie(s);
1225                    if let Ok(v) = HeaderValue::from_str(&hardened) {
1226                        h.append(header::SET_COOKIE, v);
1227                    }
1228                } else {
1229                    h.append(header::SET_COOKIE, c);
1230                }
1231            }
1232        }
1233    }
1234}
1235
1236fn harden_cookie(cookie: &str) -> String {
1237    // Inspect attribute *names* (the tokens after the first `name=value` pair), not the
1238    // whole string — otherwise a value like `session=securetoken` would look like it
1239    // already carries `Secure` and we'd skip hardening it.
1240    let attrs: std::collections::HashSet<String> = cookie
1241        .split(';')
1242        .skip(1)
1243        .filter_map(|p| p.trim().split('=').next())
1244        .map(|k| k.trim().to_ascii_lowercase())
1245        .collect();
1246
1247    let mut out = cookie.trim_end_matches(';').to_string();
1248    if !attrs.contains("secure") {
1249        out.push_str("; Secure");
1250    }
1251    if !attrs.contains("httponly") {
1252        out.push_str("; HttpOnly");
1253    }
1254    if !attrs.contains("samesite") {
1255        out.push_str("; SameSite=Lax");
1256    }
1257    out
1258}
1259
1260/// Run `fut` bounded by an optional deadline. `None` means no timeout. On success returns
1261/// the future's own output; `Err(Elapsed)` if the deadline passed first.
1262async fn within<F: Future>(
1263    deadline: Option<tokio::time::Instant>,
1264    fut: F,
1265) -> Result<F::Output, tokio::time::error::Elapsed> {
1266    match deadline {
1267        Some(dl) => tokio::time::timeout_at(dl, fut).await,
1268        None => Ok(fut.await),
1269    }
1270}
1271
1272fn text(status: StatusCode, msg: &str) -> Response<Body> {
1273    let mut resp = Response::new(Body::from(msg.to_string()));
1274    *resp.status_mut() = status;
1275    resp.headers_mut().insert(
1276        header::CONTENT_TYPE,
1277        HeaderValue::from_static("text/plain; charset=utf-8"),
1278    );
1279    resp
1280}
1281
1282/// Emit a structured access-log line, record metrics, stamp the response with `X-Request-Id`,
1283/// and return it.
1284// All args are part of the access-log/identity tuple for one request; bundling them in a struct
1285// would just move the same fields behind another name at every (already terse) call site.
1286#[allow(clippy::too_many_arguments)]
1287fn finish(
1288    metrics: &Metrics,
1289    request_id: &str,
1290    method: &Method,
1291    path: &str,
1292    ip: IpAddr,
1293    started: Instant,
1294    outcome: &str,
1295    mut resp: Response<Body>,
1296) -> Response<Body> {
1297    // Echo the request id on every response (including error responses) so a client / upstream /
1298    // log can be correlated. `resolve_request_id` guarantees it's a valid header value.
1299    if let Ok(v) = HeaderValue::from_str(request_id) {
1300        resp.headers_mut().insert(REQUEST_ID_HEADER, v);
1301    }
1302    let elapsed = started.elapsed();
1303    info!(
1304        request_id,
1305        %method,
1306        path = %path,
1307        client_ip = %ip,
1308        status = resp.status().as_u16(),
1309        outcome,
1310        latency_ms = elapsed.as_millis() as u64,
1311        "request"
1312    );
1313    metrics.record_request(outcome);
1314    metrics.observe_latency(elapsed);
1315    // Managed mode: count every finished request (proxied or rejected) toward the usage delta.
1316    // Cheap (two relaxed atomic adds) and inert unless a control plane drains it for reporting.
1317    metrics.add_usage_request();
1318    resp
1319}
1320
1321#[cfg(test)]
1322mod tests {
1323    use super::*;
1324
1325    fn headers_with(name: &'static str, value: &str) -> HeaderMap {
1326        let mut h = HeaderMap::new();
1327        h.insert(name, HeaderValue::from_str(value).unwrap());
1328        h
1329    }
1330
1331    #[test]
1332    fn path_prefix_matches_on_segment_boundary_only() {
1333        // Exact, sub-path, and query-boundary matches.
1334        assert!(path_prefix_matches("/api", "/api"));
1335        assert!(path_prefix_matches("/api/users", "/api"));
1336        assert!(path_prefix_matches("/api?x=1", "/api"));
1337        // A trailing-slash prefix matches its sub-paths.
1338        assert!(path_prefix_matches("/api/users", "/api/"));
1339        // Sibling paths sharing a textual prefix must NOT match.
1340        assert!(!path_prefix_matches("/apiary", "/api"));
1341        assert!(!path_prefix_matches("/apiary/honey", "/api"));
1342        // `/` matches everything.
1343        assert!(path_prefix_matches("/anything", "/"));
1344    }
1345
1346    #[test]
1347    fn client_ip_ignores_xff_when_untrusted() {
1348        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1349        let h = headers_with("x-forwarded-for", "1.2.3.4");
1350        // Untrusted: a directly reachable client must not be able to spoof its IP.
1351        assert_eq!(client_ip(&h, peer, false), peer.ip());
1352    }
1353
1354    #[test]
1355    fn client_ip_uses_first_xff_hop_when_trusted() {
1356        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1357        let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
1358        assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
1359    }
1360
1361    #[test]
1362    fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
1363        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
1364        assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
1365        let garbage = headers_with("x-forwarded-for", "not-an-ip");
1366        assert_eq!(client_ip(&garbage, peer, true), peer.ip());
1367    }
1368
1369    #[test]
1370    fn header_bytes_sums_names_and_values() {
1371        let mut h = HeaderMap::new();
1372        h.insert("a", HeaderValue::from_static("bb")); // 1 + 2
1373        h.insert("ccc", HeaderValue::from_static("dddd")); // 3 + 4
1374        assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
1375    }
1376
1377    #[test]
1378    fn strip_hop_by_hop_removes_fixed_and_connection_named() {
1379        let mut h = HeaderMap::new();
1380        h.insert(
1381            "connection",
1382            HeaderValue::from_static("keep-alive, X-Custom-Hop"),
1383        );
1384        h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
1385        h.insert("x-custom-hop", HeaderValue::from_static("secret"));
1386        h.insert("content-type", HeaderValue::from_static("text/plain"));
1387        strip_hop_by_hop(&mut h);
1388        assert!(!h.contains_key("connection"));
1389        assert!(!h.contains_key("keep-alive"));
1390        // A header named by Connection is connection-specific and must be dropped.
1391        assert!(!h.contains_key("x-custom-hop"));
1392        // An end-to-end header is preserved.
1393        assert!(h.contains_key("content-type"));
1394    }
1395
1396    #[test]
1397    fn forwarded_proto_reflects_tls_and_trust() {
1398        let mut cfg = Config::default();
1399
1400        // We terminate TLS -> always https, regardless of any incoming header.
1401        cfg.tls.enabled = true;
1402        assert_eq!(
1403            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
1404            "https"
1405        );
1406
1407        // Plain HTTP, untrusted: http, and an incoming XFP is NOT trusted.
1408        cfg.tls.enabled = false;
1409        cfg.server.trust_forwarded_for = false;
1410        assert_eq!(
1411            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1412            "http"
1413        );
1414
1415        // Plain HTTP behind a trusted edge: preserve the edge's reported proto.
1416        cfg.server.trust_forwarded_for = true;
1417        assert_eq!(
1418            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1419            "https"
1420        );
1421        assert_eq!(
1422            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
1423            "http"
1424        );
1425        // Missing or unrecognized -> http.
1426        assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
1427        assert_eq!(
1428            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
1429            "http"
1430        );
1431    }
1432
1433    #[test]
1434    fn longest_route_picks_most_specific_prefix() {
1435        let mk = |p: &str| RouteLimiter {
1436            prefix: p.to_string(),
1437            limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
1438                std::num::NonZeroU32::new(1).unwrap(),
1439            ))),
1440        };
1441        let routes = vec![mk("/api/"), mk("/api/admin/")];
1442        assert_eq!(
1443            longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
1444            Some("/api/admin/")
1445        );
1446        assert_eq!(
1447            longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
1448            Some("/api/")
1449        );
1450        assert!(longest_route(&routes, "/public").is_none());
1451    }
1452
1453    #[test]
1454    fn path_prefix_matches_on_segment_boundaries() {
1455        // A prefix without a trailing slash must not match a sibling path.
1456        assert!(path_prefix_matches("/api", "/api")); // exact
1457        assert!(path_prefix_matches("/api/users", "/api")); // segment boundary
1458        assert!(path_prefix_matches("/api?q=1", "/api")); // query boundary
1459        assert!(!path_prefix_matches("/apiary", "/api")); // sibling — must NOT match
1460                                                          // A trailing-slash prefix is a clean boundary by construction.
1461        assert!(path_prefix_matches("/api/users", "/api/"));
1462        assert!(!path_prefix_matches("/apiary", "/api/"));
1463        // "/" matches everything.
1464        assert!(path_prefix_matches("/anything", "/"));
1465    }
1466
1467    #[test]
1468    fn harden_cookie_adds_missing_flags() {
1469        let out = harden_cookie("sid=abc");
1470        assert!(out.contains("; Secure"), "{out}");
1471        assert!(out.contains("; HttpOnly"), "{out}");
1472        assert!(out.contains("; SameSite=Lax"), "{out}");
1473    }
1474
1475    #[test]
1476    fn harden_cookie_preserves_existing_attributes() {
1477        let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
1478        assert!(out.contains("; Secure"), "{out}");
1479        assert!(out.contains("SameSite=Strict"), "{out}");
1480        // existing SameSite isn't overridden, HttpOnly isn't duplicated
1481        assert!(!out.contains("SameSite=Lax"), "{out}");
1482        assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
1483    }
1484
1485    #[test]
1486    fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
1487        // The value contains the substring "secure" but there is no Secure *attribute*;
1488        // it must still be added (regression guard for the token-vs-substring fix).
1489        let out = harden_cookie("session=securetoken");
1490        assert!(out.contains("; Secure"), "{out}");
1491    }
1492
1493    #[test]
1494    fn security_headers_reflects_config_toggles() {
1495        // Defaults: every header present, CSP enforced (not report-only).
1496        let cfg = HeadersCfg::default();
1497        let got = security_headers(&cfg);
1498        let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
1499        assert!(names.contains(&"X-Content-Type-Options"));
1500        assert!(names.contains(&"X-Frame-Options"));
1501        assert!(names.contains(&"Referrer-Policy"));
1502        assert!(names.contains(&"Permissions-Policy"));
1503        assert!(names.contains(&"Content-Security-Policy"));
1504        assert!(names.contains(&"Strict-Transport-Security"));
1505        assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
1506
1507        // Disabling HSTS and clearing frame_options drops exactly those; report-only flips the
1508        // CSP header name and report_uri is appended to the value.
1509        let cfg = HeadersCfg {
1510            hsts: false,
1511            frame_options: String::new(),
1512            csp: "default-src 'self'".into(),
1513            csp_report_only: true,
1514            csp_report_uri: "/__edgeguard/csp-report".into(),
1515            ..HeadersCfg::default()
1516        };
1517        let got = security_headers(&cfg);
1518        let map: std::collections::HashMap<&str, String> =
1519            got.iter().map(|(n, v)| (*n, v.clone())).collect();
1520        assert!(!map.contains_key("Strict-Transport-Security"));
1521        assert!(!map.contains_key("X-Frame-Options"));
1522        assert!(!map.contains_key("Content-Security-Policy"));
1523        assert_eq!(
1524            map.get("Content-Security-Policy-Report-Only")
1525                .map(|s| s.as_str()),
1526            Some("default-src 'self'; report-uri /__edgeguard/csp-report")
1527        );
1528    }
1529}