Skip to main content

edgeguard/
proxy.rs

1//! Request path: header-size limit -> rate limit (per-IP / per-route) -> auth -> per-key
2//! rate limit -> method allowlist -> body-size limit -> WAF input inspection -> forward to
3//! upstream.
4//! Response path: header injection (incl. CSP / CSP-report-only) -> cookie hardening ->
5//! strip leaky headers.
6//!
7//! All policy lives in [`Runtime`], held behind an [`ArcSwap`] so a config hot-reload swaps
8//! it atomically without blocking the request path or dropping in-flight connections. The
9//! upstream client and the metric registry sit *outside* the swap so the connection pool and
10//! counters survive a reload.
11
12use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::pin::Pin;
15use std::sync::Arc;
16use std::task::{Context, Poll};
17use std::time::{Duration, Instant};
18
19use arc_swap::ArcSwap;
20use axum::{
21    body::{Body, Bytes},
22    extract::{ConnectInfo, State},
23    http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
24};
25use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
26use http_body_util::{BodyExt, Full, Limited};
27use hyper::body::{Body as HttpBody, Frame, SizeHint};
28use hyper_util::client::legacy::{connect::HttpConnector, Client};
29use tokio::net::TcpStream;
30use tracing::{debug, info, warn};
31
32use crate::auth::{AuthEngine, Challenge, Decision};
33use crate::config::{Config, HeadersCfg};
34use crate::limiter::{Admit, DistributedLimiter};
35use crate::metrics::Metrics;
36use crate::waf::{WafEngine, WafMode};
37
38pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
39/// Rate limiter keyed by the authenticated principal (per-key limiting).
40pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
41pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
42
43/// Shared, cheaply-cloned handle the router hands to every request. Only the hot-swappable
44/// [`Runtime`] changes on reload; the client and metrics are stable.
45#[derive(Clone)]
46pub struct AppState {
47    pub client: UpstreamClient,
48    pub metrics: Arc<Metrics>,
49    pub runtime: Arc<ArcSwap<Runtime>>,
50    /// Managed-mode control-plane client (`Some` only when `[control_plane]` is enabled). Used to
51    /// forward CSP reports; policy pull + usage reporting run as background tasks in `main`.
52    pub cp: Option<Arc<crate::cp::CpClient>>,
53}
54
55/// A per-route rate-limit override: requests whose path starts with `prefix` use `limiter`.
56pub struct RouteLimiter {
57    pub prefix: String,
58    pub limiter: Arc<KeyedLimiter>,
59}
60
61/// All request-handling policy derived from a [`Config`]. Rebuilt from scratch on reload and
62/// swapped in atomically.
63pub struct Runtime {
64    pub cfg: Arc<Config>,
65    pub upstream_base: Arc<String>,
66    pub auth: AuthEngine,
67    /// WAF-lite input screener. Inert (`evaluate` returns `None`) when `waf.mode = "off"`.
68    pub waf: WafEngine,
69    /// Shared-store (distributed) limiter, `Some` when `ratelimit.store` is `memory`/`redis`.
70    /// When present it replaces the three `governor` limiters below (which are then `None`).
71    pub distributed: Option<DistributedLimiter>,
72    /// Global per-client-IP limiter (`None` when rate limiting is disabled or distributed).
73    pub ip_limiter: Option<Arc<KeyedLimiter>>,
74    /// Per-route limiters (also keyed per IP), checked instead of `ip_limiter` on a match.
75    pub route_limiters: Vec<RouteLimiter>,
76    /// Per-principal limiter (`None` when per-key limiting is disabled or distributed).
77    pub key_limiter: Option<Arc<StrLimiter>>,
78    pub max_body: usize,
79    /// Cap on the buffered upstream response body; `0` means unbounded.
80    pub max_response_body: usize,
81    /// Cap on total request header bytes; `0` means disabled.
82    pub max_header_bytes: usize,
83    /// Max time for the upstream request + body read; `None` disables the timeout.
84    pub upstream_timeout: Option<Duration>,
85    /// Forward `text/event-stream` responses unbuffered (SSE passthrough). See
86    /// [`crate::config::ValidationCfg::stream_passthrough`].
87    pub stream_passthrough: bool,
88}
89
90/// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1).
91const HOP_BY_HOP: &[&str] = &[
92    "connection",
93    "keep-alive",
94    "proxy-authenticate",
95    "proxy-authorization",
96    "te",
97    "trailer",
98    "transfer-encoding",
99    "upgrade",
100];
101
102pub async fn handle(
103    State(state): State<AppState>,
104    ConnectInfo(peer): ConnectInfo<SocketAddr>,
105    req: Request<Body>,
106) -> Response<Body> {
107    let started = Instant::now();
108    // One atomic load pins a consistent policy snapshot for the whole request, even if a
109    // reload swaps in a new Runtime mid-flight.
110    let rt = state.runtime.load();
111    let m = &state.metrics;
112
113    let method = req.method().clone();
114    let path = req
115        .uri()
116        .path_and_query()
117        .map(|p| p.as_str().to_string())
118        .unwrap_or_else(|| req.uri().path().to_string());
119
120    let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
121
122    // Reserve the internal namespace: never forward `/__edgeguard/*` upstream. Registered
123    // internal routes are matched before this fallback, so anything reaching here under that
124    // prefix is an unknown internal path — a `404` from EdgeGuard, not a request leaked to the
125    // app. This is also what keeps the ops endpoints (health/ready/metrics) unserved on the
126    // public listener in public/private split mode, rather than proxying them to the upstream.
127    if req.uri().path().starts_with("/__edgeguard/") {
128        return finish(
129            m,
130            &method,
131            &path,
132            ip,
133            started,
134            "not_found",
135            text(StatusCode::NOT_FOUND, "Not Found"),
136        );
137    }
138
139    // 0) Total request-header-size limit.
140    if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
141        return finish(
142            m,
143            &method,
144            &path,
145            ip,
146            started,
147            "header_too_large",
148            text(
149                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
150                "Request Header Fields Too Large",
151            ),
152        );
153    }
154
155    // 1) Rate limit. A matching per-route override replaces the global per-IP limit. A shared
156    //    store (distributed) limiter, when configured, replaces the in-process limiters; on a
157    //    store error it fails closed (`503`) unless `ratelimit.fail_open` is set.
158    if rt.cfg.ratelimit.enabled {
159        if let Some(d) = &rt.distributed {
160            match d.check_ip_route(ip, &path).await {
161                Admit::Allowed => {}
162                Admit::Limited(scope) => {
163                    m.record_ratelimit_hit(scope);
164                    return finish(
165                        m,
166                        &method,
167                        &path,
168                        ip,
169                        started,
170                        "rate_limited",
171                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
172                    );
173                }
174                Admit::Error => {
175                    return finish(
176                        m,
177                        &method,
178                        &path,
179                        ip,
180                        started,
181                        "limiter_error",
182                        text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
183                    );
184                }
185            }
186        } else {
187            let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
188                Some(r) => (Some(r.limiter.as_ref()), "route"),
189                None => (rt.ip_limiter.as_deref(), "ip"),
190            };
191            if let Some(limiter) = limiter {
192                if limiter.check_key(&ip).is_err() {
193                    m.record_ratelimit_hit(scope);
194                    return finish(
195                        m,
196                        &method,
197                        &path,
198                        ip,
199                        started,
200                        "rate_limited",
201                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
202                    );
203                }
204            }
205        }
206    }
207
208    // 2) Authentication. On success we learn the principal for per-key limiting.
209    let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
210        Decision::Allow(principal) => principal,
211        Decision::Deny(challenge) => {
212            let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
213            let challenge_value = match challenge {
214                Challenge::Basic(c) => Some(c),
215                Challenge::Bearer => Some("Bearer".to_string()),
216                Challenge::None => None,
217            };
218            if let Some(c) = challenge_value {
219                if let Ok(v) = HeaderValue::from_str(&c) {
220                    resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
221                }
222            }
223            return finish(m, &method, &path, ip, started, "unauthorized", resp);
224        }
225    };
226
227    // 3) Per-key rate limit (only for authenticated principals). Routed to the distributed
228    //    limiter when configured, else the in-process per-key limiter.
229    if let Some(principal) = &principal {
230        let key_admit = if let Some(d) = &rt.distributed {
231            Some(d.check_key(principal).await)
232        } else {
233            rt.key_limiter.as_ref().map(|limiter| {
234                if limiter.check_key(principal).is_err() {
235                    Admit::Limited("key")
236                } else {
237                    Admit::Allowed
238                }
239            })
240        };
241        match key_admit {
242            Some(Admit::Limited(scope)) => {
243                m.record_ratelimit_hit(scope);
244                return finish(
245                    m,
246                    &method,
247                    &path,
248                    ip,
249                    started,
250                    "rate_limited",
251                    text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
252                );
253            }
254            Some(Admit::Error) => {
255                return finish(
256                    m,
257                    &method,
258                    &path,
259                    ip,
260                    started,
261                    "limiter_error",
262                    text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
263                );
264            }
265            Some(Admit::Allowed) | None => {}
266        }
267    }
268
269    // 4) Method allowlist.
270    let allow = &rt.cfg.validation.allow_methods;
271    if !allow.is_empty()
272        && !allow
273            .iter()
274            .any(|x| x.eq_ignore_ascii_case(method.as_str()))
275    {
276        return finish(
277            m,
278            &method,
279            &path,
280            ip,
281            started,
282            "method_not_allowed",
283            text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
284        );
285    }
286
287    // 5) Buffer the body up to the configured limit.
288    let (parts, body) = req.into_parts();
289    let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
290        Ok(b) => b,
291        Err(_) => {
292            return finish(
293                m,
294                &method,
295                &path,
296                ip,
297                started,
298                "payload_too_large",
299                text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
300            )
301        }
302    };
303    // Request (ingress) size for managed-mode usage, captured before the body is forwarded upstream.
304    let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
305
306    // 6) WAF-lite input inspection. A no-op unless `waf.mode` is report/block. The body is
307    //    already buffered above, so inspecting it adds no extra read. On a match: `block` mode
308    //    returns 403; `report` mode logs + counts and forwards. Both record the hit so a
309    //    report-only rollout shows up in `edgeguard_waf_hits_total`.
310    if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
311        m.record_waf_hit(hit.class);
312        match rt.waf.mode() {
313            WafMode::Block => {
314                warn!(
315                    rule = %hit.rule_id,
316                    class = hit.class,
317                    location = hit.location,
318                    client_ip = %ip,
319                    path = %path,
320                    "WAF blocked request"
321                );
322                return finish(
323                    m,
324                    &method,
325                    &path,
326                    ip,
327                    started,
328                    "forbidden",
329                    text(StatusCode::FORBIDDEN, "Forbidden"),
330                );
331            }
332            WafMode::Report => warn!(
333                rule = %hit.rule_id,
334                class = hit.class,
335                location = hit.location,
336                client_ip = %ip,
337                path = %path,
338                "WAF rule matched (report-only)"
339            ),
340            // `evaluate` returns `None` when off, so this arm is unreachable; kept for
341            // exhaustiveness.
342            WafMode::Off => {}
343        }
344    }
345
346    // 7) Build the upstream request.
347    let uri = format!("{}{}", rt.upstream_base, path);
348    let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
349    {
350        let headers = up.headers_mut().expect("builder headers");
351        // Drop hop-by-hop headers (the fixed set plus any named by `Connection`) before
352        // forwarding, so they don't leak across the proxy boundary.
353        let mut forwarded = parts.headers.clone();
354        strip_hop_by_hop(&mut forwarded);
355        for (name, value) in forwarded.iter() {
356            if name == header::HOST {
357                continue; // let the client set Host for the upstream
358            }
359            headers.insert(name.clone(), value.clone());
360        }
361        // Standard forwarding headers.
362        if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
363            headers.insert(HeaderName::from_static("x-forwarded-for"), v);
364        }
365        headers.insert(
366            HeaderName::from_static("x-forwarded-proto"),
367            HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
368        );
369    }
370
371    let upstream_req = match up.body(Full::new(body_bytes)) {
372        Ok(r) => r,
373        Err(e) => {
374            warn!(error = %e, "failed to build upstream request");
375            return finish(
376                m,
377                &method,
378                &path,
379                ip,
380                started,
381                "bad_gateway",
382                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
383            );
384        }
385    };
386
387    // 8) Forward and collect the response under a single deadline, so a stalled upstream
388    //    can't pin this task. `None` => no timeout (validation.upstream_timeout = "0").
389    let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
390    let timed_out = || {
391        warn!(upstream = %uri, "upstream timed out");
392        text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
393    };
394
395    let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
396        Ok(Ok(r)) => r,
397        Ok(Err(e)) => {
398            warn!(error = %e, upstream = %uri, "upstream unreachable");
399            return finish(
400                m,
401                &method,
402                &path,
403                ip,
404                started,
405                "upstream_error",
406                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
407            );
408        }
409        Err(_) => {
410            return finish(
411                m,
412                &method,
413                &path,
414                ip,
415                started,
416                "upstream_timeout",
417                timed_out(),
418            )
419        }
420    };
421
422    let (mut resp_parts, resp_body) = upstream_resp.into_parts();
423
424    // 8a) SSE passthrough: forward a `text/event-stream` response frame-by-frame instead of
425    //     buffering the whole body, so the client sees events as they arrive (time-to-first-byte
426    //     is preserved). The buffering path below would hold the entire stream until the upstream
427    //     finished, which defeats SSE. On a streamed body the `max_response_body` cap and the
428    //     body-read deadline don't apply — the connect/first-byte `upstream_timeout` already
429    //     bounded time-to-headers — and egress bytes are tallied by `CountingBody` as frames flow.
430    //     Response hardening is headers-only, so it stays correct on a streaming body.
431    if rt.stream_passthrough && is_event_stream(&resp_parts.headers) {
432        strip_hop_by_hop(&mut resp_parts.headers);
433        resp_parts.headers.remove(header::CONTENT_LENGTH);
434        let header_egress = header_bytes(&resp_parts.headers);
435        let body = Body::new(CountingBody::new(
436            resp_body,
437            Arc::clone(m),
438            ingress_bytes,
439            header_egress,
440        ));
441        let mut response = Response::from_parts(resp_parts, body);
442        harden_response(&rt.cfg, &mut response);
443        return finish(m, &method, &path, ip, started, "ok", response);
444    }
445
446    // Buffer the upstream body, optionally capped so a huge response can't OOM the proxy.
447    let resp_bytes = if rt.max_response_body > 0 {
448        match within(
449            deadline,
450            Limited::new(resp_body, rt.max_response_body).collect(),
451        )
452        .await
453        {
454            Ok(Ok(c)) => c.to_bytes(),
455            Ok(Err(_)) => {
456                warn!(
457                    limit = rt.max_response_body,
458                    "upstream response exceeded max_response_body"
459                );
460                return finish(
461                    m,
462                    &method,
463                    &path,
464                    ip,
465                    started,
466                    "upstream_body_too_large",
467                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
468                );
469            }
470            Err(_) => {
471                return finish(
472                    m,
473                    &method,
474                    &path,
475                    ip,
476                    started,
477                    "upstream_timeout",
478                    timed_out(),
479                )
480            }
481        }
482    } else {
483        match within(deadline, resp_body.collect()).await {
484            Ok(Ok(c)) => c.to_bytes(),
485            Ok(Err(e)) => {
486                warn!(error = %e, "failed reading upstream body");
487                return finish(
488                    m,
489                    &method,
490                    &path,
491                    ip,
492                    started,
493                    "upstream_body_error",
494                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
495                );
496            }
497            Err(_) => {
498                return finish(
499                    m,
500                    &method,
501                    &path,
502                    ip,
503                    started,
504                    "upstream_timeout",
505                    timed_out(),
506                )
507            }
508        }
509    };
510
511    // The body was rebuffered, so let the server recompute framing; strip hop-by-hop headers
512    // (incl. any named by `Connection`) so they don't leak downstream.
513    strip_hop_by_hop(&mut resp_parts.headers);
514    resp_parts.headers.remove(header::CONTENT_LENGTH);
515
516    // Managed-mode usage: this is the proxied path, where both bodies are buffered, so the byte
517    // counts are exact. (`add_usage_request` is recorded for every request in `finish`.)
518    m.add_usage_bytes(
519        ingress_bytes,
520        header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
521    );
522
523    let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
524    harden_response(&rt.cfg, &mut response);
525
526    finish(m, &method, &path, ip, started, "ok", response)
527}
528
529/// Readiness probe. Returns `200` only if the upstream accepts a TCP connection, so a
530/// platform's readiness check reflects whether EdgeGuard can actually serve traffic — not
531/// merely that the process booted. `503` while the upstream is unreachable. (Liveness, i.e.
532/// "is EdgeGuard itself up", is the separate unconditional `/__edgeguard/health`.)
533pub async fn ready(State(state): State<AppState>) -> StatusCode {
534    let rt = state.runtime.load();
535    let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
536        return StatusCode::SERVICE_UNAVAILABLE;
537    };
538    match tokio::time::timeout(
539        Duration::from_secs(2),
540        TcpStream::connect((host.as_str(), port)),
541    )
542    .await
543    {
544        Ok(Ok(_)) => StatusCode::OK,
545        _ => StatusCode::SERVICE_UNAVAILABLE,
546    }
547}
548
549/// Prometheus scrape endpoint (`GET /__edgeguard/metrics`). Like health/ready, it is a
550/// dedicated route outside the proxy fallback, so it is not subject to auth or rate limits —
551/// restrict access to `/__edgeguard/*` at the network layer if that matters in your setup.
552pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
553    let body = state.metrics.render();
554    let mut resp = Response::new(Body::from(body));
555    resp.headers_mut().insert(
556        header::CONTENT_TYPE,
557        HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
558    );
559    resp
560}
561
562/// CSP violation report sink (`POST /__edgeguard/csp-report`). Browsers POST a JSON report
563/// here when `headers.csp_report_uri` points at it; we count and log it, then `204`.
564pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
565    state.metrics.record_csp_report();
566    // Managed mode: forward the raw report to the control plane (fire-and-forget, so the browser's
567    // 204 is never delayed by an outbound call). Only when a control plane is configured and
568    // `forward_csp` is on.
569    if let Some(cp) = &state.cp {
570        if state.runtime.load().cfg.control_plane.forward_csp {
571            let cp = cp.clone();
572            let raw = body.clone();
573            tokio::spawn(async move { cp.forward_csp(&raw).await });
574        }
575    }
576    // This endpoint is unauthenticated and a report can carry the full document URL,
577    // referrer, and query strings — logging the whole blob at `info` is both a privacy leak
578    // and a log-flood vector. Record only the directive that fired, at `debug`.
579    match serde_json::from_slice::<serde_json::Value>(&body) {
580        Ok(report) => {
581            let directive = report
582                .get("csp-report")
583                .and_then(|r| {
584                    r.get("violated-directive")
585                        .or_else(|| r.get("effective-directive"))
586                })
587                .and_then(|v| v.as_str())
588                .unwrap_or("unknown");
589            debug!(target: "edgeguard::csp", directive, "CSP violation report");
590        }
591        Err(_) => warn!(
592            bytes = body.len(),
593            "CSP violation report with an unparseable body"
594        ),
595    }
596    StatusCode::NO_CONTENT
597}
598
599/// Resolve the client IP. The peer socket address is authoritative; `X-Forwarded-For`
600/// (first hop) is honored only when `trust_forwarded` is set, because a directly
601/// reachable client can otherwise spoof it to forge their identity.
602fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
603    if trust_forwarded {
604        if let Some(xff) = headers.get("x-forwarded-for") {
605            if let Ok(s) = xff.to_str() {
606                if let Some(first) = s.split(',').next() {
607                    if let Ok(ip) = first.trim().parse::<IpAddr>() {
608                        return ip;
609                    }
610                }
611            }
612        }
613    }
614    peer.ip()
615}
616
617/// Total size of the request headers (sum of name + value bytes), used for the header-size
618/// policy limit. This is an application-layer approximation of the on-wire header size.
619fn header_bytes(headers: &HeaderMap) -> usize {
620    headers
621        .iter()
622        .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
623        .sum()
624}
625
626/// True if the response is a Server-Sent Events stream (`Content-Type: text/event-stream`,
627/// ignoring any `; charset=…` parameter and leading whitespace). The signal we use to forward a
628/// response unbuffered when `validation.stream_passthrough` is on.
629fn is_event_stream(headers: &HeaderMap) -> bool {
630    headers
631        .get(header::CONTENT_TYPE)
632        .and_then(|v| v.to_str().ok())
633        .map(|v| {
634            v.split(';')
635                .next()
636                .map(str::trim)
637                .map(|ct| ct.eq_ignore_ascii_case("text/event-stream"))
638                .unwrap_or(false)
639        })
640        .unwrap_or(false)
641}
642
643/// Wraps a streaming upstream body to tally egress bytes (response headers + each data frame)
644/// and report them to managed-mode usage when the body is dropped — i.e. after the final frame
645/// is sent, or earlier if the client disconnects mid-stream (we count what actually went out).
646/// Used for SSE passthrough: the body isn't buffered, so the exact byte count the buffered path
647/// takes up front can only be accumulated as frames flow.
648struct CountingBody<B> {
649    inner: B,
650    metrics: Arc<Metrics>,
651    ingress: usize,
652    /// Running egress total: response header bytes, then each data frame as it passes.
653    egress: usize,
654}
655
656impl<B> CountingBody<B> {
657    fn new(inner: B, metrics: Arc<Metrics>, ingress: usize, header_egress: usize) -> Self {
658        Self {
659            inner,
660            metrics,
661            ingress,
662            egress: header_egress,
663        }
664    }
665}
666
667impl<B> HttpBody for CountingBody<B>
668where
669    B: HttpBody<Data = Bytes> + Unpin,
670{
671    type Data = Bytes;
672    type Error = B::Error;
673
674    fn poll_frame(
675        mut self: Pin<&mut Self>,
676        cx: &mut Context<'_>,
677    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
678        let this = self.as_mut().get_mut();
679        let polled = Pin::new(&mut this.inner).poll_frame(cx);
680        if let Poll::Ready(Some(Ok(frame))) = &polled {
681            if let Some(data) = frame.data_ref() {
682                this.egress = this.egress.saturating_add(data.len());
683            }
684        }
685        polled
686    }
687
688    fn is_end_stream(&self) -> bool {
689        self.inner.is_end_stream()
690    }
691
692    fn size_hint(&self) -> SizeHint {
693        self.inner.size_hint()
694    }
695}
696
697impl<B> Drop for CountingBody<B> {
698    fn drop(&mut self) {
699        self.metrics.add_usage_bytes(self.ingress, self.egress);
700    }
701}
702
703/// Remove hop-by-hop headers so they don't leak across the proxy boundary (RFC 7230 §6.1):
704/// the fixed [`HOP_BY_HOP`] set plus any header *named* in a `Connection` header. Applied in
705/// both directions (request to upstream, response to client).
706fn strip_hop_by_hop(headers: &mut HeaderMap) {
707    // Header names listed in any `Connection` header are connection-specific; collect them
708    // before mutating (the borrow of `headers` must end before we remove).
709    let connection_named: Vec<HeaderName> = headers
710        .get_all(header::CONNECTION)
711        .iter()
712        .filter_map(|v| v.to_str().ok())
713        .flat_map(|v| v.split(','))
714        .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
715        .collect();
716    for name in HOP_BY_HOP {
717        headers.remove(*name);
718    }
719    for name in connection_named {
720        headers.remove(name);
721    }
722}
723
724/// Decide the `X-Forwarded-Proto` to send upstream. If EdgeGuard terminates TLS, the client
725/// hop is HTTPS. Otherwise, behind a trusted edge (`trust_forwarded_for`) we preserve the
726/// proto the edge reported (falling back to `http`); an untrusted client's `X-Forwarded-Proto`
727/// is never honored, mirroring the client-IP trust model. Returns a `'static` token so the
728/// caller can build a `HeaderValue` without fallible parsing.
729fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
730    if cfg.tls.enabled {
731        return "https";
732    }
733    if cfg.server.trust_forwarded_for {
734        if let Some(value) = headers
735            .get("x-forwarded-proto")
736            .and_then(|v| v.to_str().ok())
737        {
738            match value.split(',').next().map(str::trim) {
739                Some(p) if p.eq_ignore_ascii_case("https") => return "https",
740                Some(p) if p.eq_ignore_ascii_case("http") => return "http",
741                _ => {}
742            }
743        }
744    }
745    "http"
746}
747
748/// Pick the most specific (longest-prefix) per-route limiter matching `path`, if any.
749fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
750    routes
751        .iter()
752        .filter(|r| path.starts_with(&r.prefix))
753        .max_by_key(|r| r.prefix.len())
754}
755
756/// The HSTS header value EdgeGuard emits when `headers.hsts` is on: a two-year `max-age`
757/// including subdomains. A named constant so the live proxy and the static-host config
758/// generator ([`crate::generate`]) can't drift on it.
759pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
760
761/// The constant security response headers EdgeGuard injects, derived from the `[headers]`
762/// policy. This is the **single source of truth** shared by the live response-hardening path
763/// ([`harden_response`]) and the static-host config generator ([`crate::generate`]), so a
764/// generated `_headers` file / edge-middleware snippet matches exactly what the proxy would add
765/// at runtime. Returns `(name, value)` pairs with canonically-cased names (for readable
766/// generated output); the proxy normalizes the case when it inserts them.
767///
768/// Cookie hardening and leaky-header *stripping* are deliberately **not** here: both rewrite the
769/// upstream's actual response (`Set-Cookie`, `Server`/`X-Powered-By`), which a static file that
770/// can only "always add this header" cannot express. The generator documents that gap; the
771/// WASM worker, which sees the real response, applies them too.
772pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
773    let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
774    out.push(("X-Content-Type-Options", "nosniff".to_string()));
775    if !cfg.frame_options.is_empty() {
776        out.push(("X-Frame-Options", cfg.frame_options.clone()));
777    }
778    if !cfg.referrer_policy.is_empty() {
779        out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
780    }
781    if !cfg.permissions_policy.is_empty() {
782        out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
783    }
784    if !cfg.csp.is_empty() {
785        // Append a report-uri directive if configured, and choose enforce vs. report-only.
786        let mut value = cfg.csp.clone();
787        if !cfg.csp_report_uri.is_empty() {
788            value.push_str("; report-uri ");
789            value.push_str(&cfg.csp_report_uri);
790        }
791        let name = if cfg.csp_report_only {
792            "Content-Security-Policy-Report-Only"
793        } else {
794            "Content-Security-Policy"
795        };
796        out.push((name, value));
797    }
798    if cfg.hsts {
799        out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
800    }
801    out
802}
803
804/// Inject security headers, harden Set-Cookie, and strip leaky headers.
805fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
806    let h = resp.headers_mut();
807
808    // Inject the constant security headers (shared with the static-host generator via
809    // `security_headers`, so the two never diverge). `from_bytes` normalizes the canonical
810    // casing to lowercase; these names/values are all valid, so the inserts don't fail.
811    for (name, value) in security_headers(&cfg.headers) {
812        if let (Ok(n), Ok(v)) = (
813            HeaderName::from_bytes(name.as_bytes()),
814            HeaderValue::from_str(&value),
815        ) {
816            h.insert(n, v);
817        }
818    }
819
820    // Strip leaky headers.
821    for name in &cfg.headers.strip {
822        if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
823            h.remove(hn);
824        }
825    }
826
827    // Harden cookies: ensure Secure, HttpOnly, and a SameSite default.
828    if cfg.headers.force_secure_cookies {
829        let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
830        if !cookies.is_empty() {
831            h.remove(header::SET_COOKIE);
832            for c in cookies {
833                if let Ok(s) = c.to_str() {
834                    let hardened = harden_cookie(s);
835                    if let Ok(v) = HeaderValue::from_str(&hardened) {
836                        h.append(header::SET_COOKIE, v);
837                    }
838                } else {
839                    h.append(header::SET_COOKIE, c);
840                }
841            }
842        }
843    }
844}
845
846fn harden_cookie(cookie: &str) -> String {
847    // Inspect attribute *names* (the tokens after the first `name=value` pair), not the
848    // whole string — otherwise a value like `session=securetoken` would look like it
849    // already carries `Secure` and we'd skip hardening it.
850    let attrs: std::collections::HashSet<String> = cookie
851        .split(';')
852        .skip(1)
853        .filter_map(|p| p.trim().split('=').next())
854        .map(|k| k.trim().to_ascii_lowercase())
855        .collect();
856
857    let mut out = cookie.trim_end_matches(';').to_string();
858    if !attrs.contains("secure") {
859        out.push_str("; Secure");
860    }
861    if !attrs.contains("httponly") {
862        out.push_str("; HttpOnly");
863    }
864    if !attrs.contains("samesite") {
865        out.push_str("; SameSite=Lax");
866    }
867    out
868}
869
870/// Run `fut` bounded by an optional deadline. `None` means no timeout. On success returns
871/// the future's own output; `Err(Elapsed)` if the deadline passed first.
872async fn within<F: Future>(
873    deadline: Option<tokio::time::Instant>,
874    fut: F,
875) -> Result<F::Output, tokio::time::error::Elapsed> {
876    match deadline {
877        Some(dl) => tokio::time::timeout_at(dl, fut).await,
878        None => Ok(fut.await),
879    }
880}
881
882fn text(status: StatusCode, msg: &str) -> Response<Body> {
883    let mut resp = Response::new(Body::from(msg.to_string()));
884    *resp.status_mut() = status;
885    resp.headers_mut().insert(
886        header::CONTENT_TYPE,
887        HeaderValue::from_static("text/plain; charset=utf-8"),
888    );
889    resp
890}
891
892/// Emit a structured access-log line, record metrics, and return the response.
893fn finish(
894    metrics: &Metrics,
895    method: &Method,
896    path: &str,
897    ip: IpAddr,
898    started: Instant,
899    outcome: &str,
900    resp: Response<Body>,
901) -> Response<Body> {
902    let elapsed = started.elapsed();
903    info!(
904        %method,
905        path = %path,
906        client_ip = %ip,
907        status = resp.status().as_u16(),
908        outcome,
909        latency_ms = elapsed.as_millis() as u64,
910        "request"
911    );
912    metrics.record_request(outcome);
913    metrics.observe_latency(elapsed);
914    // Managed mode: count every finished request (proxied or rejected) toward the usage delta.
915    // Cheap (two relaxed atomic adds) and inert unless a control plane drains it for reporting.
916    metrics.add_usage_request();
917    resp
918}
919
920#[cfg(test)]
921mod tests {
922    use super::*;
923
924    fn headers_with(name: &'static str, value: &str) -> HeaderMap {
925        let mut h = HeaderMap::new();
926        h.insert(name, HeaderValue::from_str(value).unwrap());
927        h
928    }
929
930    #[test]
931    fn client_ip_ignores_xff_when_untrusted() {
932        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
933        let h = headers_with("x-forwarded-for", "1.2.3.4");
934        // Untrusted: a directly reachable client must not be able to spoof its IP.
935        assert_eq!(client_ip(&h, peer, false), peer.ip());
936    }
937
938    #[test]
939    fn client_ip_uses_first_xff_hop_when_trusted() {
940        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
941        let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
942        assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
943    }
944
945    #[test]
946    fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
947        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
948        assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
949        let garbage = headers_with("x-forwarded-for", "not-an-ip");
950        assert_eq!(client_ip(&garbage, peer, true), peer.ip());
951    }
952
953    #[test]
954    fn header_bytes_sums_names_and_values() {
955        let mut h = HeaderMap::new();
956        h.insert("a", HeaderValue::from_static("bb")); // 1 + 2
957        h.insert("ccc", HeaderValue::from_static("dddd")); // 3 + 4
958        assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
959    }
960
961    #[test]
962    fn strip_hop_by_hop_removes_fixed_and_connection_named() {
963        let mut h = HeaderMap::new();
964        h.insert(
965            "connection",
966            HeaderValue::from_static("keep-alive, X-Custom-Hop"),
967        );
968        h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
969        h.insert("x-custom-hop", HeaderValue::from_static("secret"));
970        h.insert("content-type", HeaderValue::from_static("text/plain"));
971        strip_hop_by_hop(&mut h);
972        assert!(!h.contains_key("connection"));
973        assert!(!h.contains_key("keep-alive"));
974        // A header named by Connection is connection-specific and must be dropped.
975        assert!(!h.contains_key("x-custom-hop"));
976        // An end-to-end header is preserved.
977        assert!(h.contains_key("content-type"));
978    }
979
980    #[test]
981    fn forwarded_proto_reflects_tls_and_trust() {
982        let mut cfg = Config::default();
983
984        // We terminate TLS -> always https, regardless of any incoming header.
985        cfg.tls.enabled = true;
986        assert_eq!(
987            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
988            "https"
989        );
990
991        // Plain HTTP, untrusted: http, and an incoming XFP is NOT trusted.
992        cfg.tls.enabled = false;
993        cfg.server.trust_forwarded_for = false;
994        assert_eq!(
995            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
996            "http"
997        );
998
999        // Plain HTTP behind a trusted edge: preserve the edge's reported proto.
1000        cfg.server.trust_forwarded_for = true;
1001        assert_eq!(
1002            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
1003            "https"
1004        );
1005        assert_eq!(
1006            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
1007            "http"
1008        );
1009        // Missing or unrecognized -> http.
1010        assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
1011        assert_eq!(
1012            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
1013            "http"
1014        );
1015    }
1016
1017    #[test]
1018    fn longest_route_picks_most_specific_prefix() {
1019        let mk = |p: &str| RouteLimiter {
1020            prefix: p.to_string(),
1021            limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
1022                std::num::NonZeroU32::new(1).unwrap(),
1023            ))),
1024        };
1025        let routes = vec![mk("/api/"), mk("/api/admin/")];
1026        assert_eq!(
1027            longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
1028            Some("/api/admin/")
1029        );
1030        assert_eq!(
1031            longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
1032            Some("/api/")
1033        );
1034        assert!(longest_route(&routes, "/public").is_none());
1035    }
1036
1037    #[test]
1038    fn harden_cookie_adds_missing_flags() {
1039        let out = harden_cookie("sid=abc");
1040        assert!(out.contains("; Secure"), "{out}");
1041        assert!(out.contains("; HttpOnly"), "{out}");
1042        assert!(out.contains("; SameSite=Lax"), "{out}");
1043    }
1044
1045    #[test]
1046    fn harden_cookie_preserves_existing_attributes() {
1047        let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
1048        assert!(out.contains("; Secure"), "{out}");
1049        assert!(out.contains("SameSite=Strict"), "{out}");
1050        // existing SameSite isn't overridden, HttpOnly isn't duplicated
1051        assert!(!out.contains("SameSite=Lax"), "{out}");
1052        assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
1053    }
1054
1055    #[test]
1056    fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
1057        // The value contains the substring "secure" but there is no Secure *attribute*;
1058        // it must still be added (regression guard for the token-vs-substring fix).
1059        let out = harden_cookie("session=securetoken");
1060        assert!(out.contains("; Secure"), "{out}");
1061    }
1062
1063    #[test]
1064    fn security_headers_reflects_config_toggles() {
1065        // Defaults: every header present, CSP enforced (not report-only).
1066        let cfg = HeadersCfg::default();
1067        let got = security_headers(&cfg);
1068        let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
1069        assert!(names.contains(&"X-Content-Type-Options"));
1070        assert!(names.contains(&"X-Frame-Options"));
1071        assert!(names.contains(&"Referrer-Policy"));
1072        assert!(names.contains(&"Permissions-Policy"));
1073        assert!(names.contains(&"Content-Security-Policy"));
1074        assert!(names.contains(&"Strict-Transport-Security"));
1075        assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
1076
1077        // Disabling HSTS and clearing frame_options drops exactly those; report-only flips the
1078        // CSP header name and report_uri is appended to the value.
1079        let cfg = HeadersCfg {
1080            hsts: false,
1081            frame_options: String::new(),
1082            csp: "default-src 'self'".into(),
1083            csp_report_only: true,
1084            csp_report_uri: "/__edgeguard/csp-report".into(),
1085            ..HeadersCfg::default()
1086        };
1087        let got = security_headers(&cfg);
1088        let map: std::collections::HashMap<&str, String> =
1089            got.iter().map(|(n, v)| (*n, v.clone())).collect();
1090        assert!(!map.contains_key("Strict-Transport-Security"));
1091        assert!(!map.contains_key("X-Frame-Options"));
1092        assert!(!map.contains_key("Content-Security-Policy"));
1093        assert_eq!(
1094            map.get("Content-Security-Policy-Report-Only")
1095                .map(|s| s.as_str()),
1096            Some("default-src 'self'; report-uri /__edgeguard/csp-report")
1097        );
1098    }
1099}