Skip to main content

edgeguard/
proxy.rs

1//! Request path: header-size limit -> rate limit (per-IP / per-route) -> auth -> per-key
2//! rate limit -> method allowlist -> body-size limit -> WAF input inspection -> forward to
3//! upstream.
4//! Response path: header injection (incl. CSP / CSP-report-only) -> cookie hardening ->
5//! strip leaky headers.
6//!
7//! All policy lives in [`Runtime`], held behind an [`ArcSwap`] so a config hot-reload swaps
8//! it atomically without blocking the request path or dropping in-flight connections. The
9//! upstream client and the metric registry sit *outside* the swap so the connection pool and
10//! counters survive a reload.
11
12use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use arc_swap::ArcSwap;
18use axum::{
19    body::{Body, Bytes},
20    extract::{ConnectInfo, State},
21    http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
22};
23use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
24use http_body_util::{BodyExt, Full, Limited};
25use hyper_util::client::legacy::{connect::HttpConnector, Client};
26use tokio::net::TcpStream;
27use tracing::{debug, info, warn};
28
29use crate::auth::{AuthEngine, Challenge, Decision};
30use crate::config::{Config, HeadersCfg};
31use crate::limiter::{Admit, DistributedLimiter};
32use crate::metrics::Metrics;
33use crate::waf::{WafEngine, WafMode};
34
35pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
36/// Rate limiter keyed by the authenticated principal (per-key limiting).
37pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
38pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
39
40/// Shared, cheaply-cloned handle the router hands to every request. Only the hot-swappable
41/// [`Runtime`] changes on reload; the client and metrics are stable.
42#[derive(Clone)]
43pub struct AppState {
44    pub client: UpstreamClient,
45    pub metrics: Arc<Metrics>,
46    pub runtime: Arc<ArcSwap<Runtime>>,
47    /// Managed-mode control-plane client (`Some` only when `[control_plane]` is enabled). Used to
48    /// forward CSP reports; policy pull + usage reporting run as background tasks in `main`.
49    pub cp: Option<Arc<crate::cp::CpClient>>,
50}
51
52/// A per-route rate-limit override: requests whose path starts with `prefix` use `limiter`.
53pub struct RouteLimiter {
54    pub prefix: String,
55    pub limiter: Arc<KeyedLimiter>,
56}
57
58/// All request-handling policy derived from a [`Config`]. Rebuilt from scratch on reload and
59/// swapped in atomically.
60pub struct Runtime {
61    pub cfg: Arc<Config>,
62    pub upstream_base: Arc<String>,
63    pub auth: AuthEngine,
64    /// WAF-lite input screener. Inert (`evaluate` returns `None`) when `waf.mode = "off"`.
65    pub waf: WafEngine,
66    /// Shared-store (distributed) limiter, `Some` when `ratelimit.store` is `memory`/`redis`.
67    /// When present it replaces the three `governor` limiters below (which are then `None`).
68    pub distributed: Option<DistributedLimiter>,
69    /// Global per-client-IP limiter (`None` when rate limiting is disabled or distributed).
70    pub ip_limiter: Option<Arc<KeyedLimiter>>,
71    /// Per-route limiters (also keyed per IP), checked instead of `ip_limiter` on a match.
72    pub route_limiters: Vec<RouteLimiter>,
73    /// Per-principal limiter (`None` when per-key limiting is disabled or distributed).
74    pub key_limiter: Option<Arc<StrLimiter>>,
75    pub max_body: usize,
76    /// Cap on the buffered upstream response body; `0` means unbounded.
77    pub max_response_body: usize,
78    /// Cap on total request header bytes; `0` means disabled.
79    pub max_header_bytes: usize,
80    /// Max time for the upstream request + body read; `None` disables the timeout.
81    pub upstream_timeout: Option<Duration>,
82}
83
84/// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1).
85const HOP_BY_HOP: &[&str] = &[
86    "connection",
87    "keep-alive",
88    "proxy-authenticate",
89    "proxy-authorization",
90    "te",
91    "trailer",
92    "transfer-encoding",
93    "upgrade",
94];
95
96pub async fn handle(
97    State(state): State<AppState>,
98    ConnectInfo(peer): ConnectInfo<SocketAddr>,
99    req: Request<Body>,
100) -> Response<Body> {
101    let started = Instant::now();
102    // One atomic load pins a consistent policy snapshot for the whole request, even if a
103    // reload swaps in a new Runtime mid-flight.
104    let rt = state.runtime.load();
105    let m = &state.metrics;
106
107    let method = req.method().clone();
108    let path = req
109        .uri()
110        .path_and_query()
111        .map(|p| p.as_str().to_string())
112        .unwrap_or_else(|| req.uri().path().to_string());
113
114    let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
115
116    // Reserve the internal namespace: never forward `/__edgeguard/*` upstream. Registered
117    // internal routes are matched before this fallback, so anything reaching here under that
118    // prefix is an unknown internal path — a `404` from EdgeGuard, not a request leaked to the
119    // app. This is also what keeps the ops endpoints (health/ready/metrics) unserved on the
120    // public listener in public/private split mode, rather than proxying them to the upstream.
121    if req.uri().path().starts_with("/__edgeguard/") {
122        return finish(
123            m,
124            &method,
125            &path,
126            ip,
127            started,
128            "not_found",
129            text(StatusCode::NOT_FOUND, "Not Found"),
130        );
131    }
132
133    // 0) Total request-header-size limit.
134    if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
135        return finish(
136            m,
137            &method,
138            &path,
139            ip,
140            started,
141            "header_too_large",
142            text(
143                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
144                "Request Header Fields Too Large",
145            ),
146        );
147    }
148
149    // 1) Rate limit. A matching per-route override replaces the global per-IP limit. A shared
150    //    store (distributed) limiter, when configured, replaces the in-process limiters; on a
151    //    store error it fails closed (`503`) unless `ratelimit.fail_open` is set.
152    if rt.cfg.ratelimit.enabled {
153        if let Some(d) = &rt.distributed {
154            match d.check_ip_route(ip, &path).await {
155                Admit::Allowed => {}
156                Admit::Limited(scope) => {
157                    m.record_ratelimit_hit(scope);
158                    return finish(
159                        m,
160                        &method,
161                        &path,
162                        ip,
163                        started,
164                        "rate_limited",
165                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
166                    );
167                }
168                Admit::Error => {
169                    return finish(
170                        m,
171                        &method,
172                        &path,
173                        ip,
174                        started,
175                        "limiter_error",
176                        text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
177                    );
178                }
179            }
180        } else {
181            let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
182                Some(r) => (Some(r.limiter.as_ref()), "route"),
183                None => (rt.ip_limiter.as_deref(), "ip"),
184            };
185            if let Some(limiter) = limiter {
186                if limiter.check_key(&ip).is_err() {
187                    m.record_ratelimit_hit(scope);
188                    return finish(
189                        m,
190                        &method,
191                        &path,
192                        ip,
193                        started,
194                        "rate_limited",
195                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
196                    );
197                }
198            }
199        }
200    }
201
202    // 2) Authentication. On success we learn the principal for per-key limiting.
203    let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
204        Decision::Allow(principal) => principal,
205        Decision::Deny(challenge) => {
206            let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
207            let challenge_value = match challenge {
208                Challenge::Basic(c) => Some(c),
209                Challenge::Bearer => Some("Bearer".to_string()),
210                Challenge::None => None,
211            };
212            if let Some(c) = challenge_value {
213                if let Ok(v) = HeaderValue::from_str(&c) {
214                    resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
215                }
216            }
217            return finish(m, &method, &path, ip, started, "unauthorized", resp);
218        }
219    };
220
221    // 3) Per-key rate limit (only for authenticated principals). Routed to the distributed
222    //    limiter when configured, else the in-process per-key limiter.
223    if let Some(principal) = &principal {
224        let key_admit = if let Some(d) = &rt.distributed {
225            Some(d.check_key(principal).await)
226        } else {
227            rt.key_limiter.as_ref().map(|limiter| {
228                if limiter.check_key(principal).is_err() {
229                    Admit::Limited("key")
230                } else {
231                    Admit::Allowed
232                }
233            })
234        };
235        match key_admit {
236            Some(Admit::Limited(scope)) => {
237                m.record_ratelimit_hit(scope);
238                return finish(
239                    m,
240                    &method,
241                    &path,
242                    ip,
243                    started,
244                    "rate_limited",
245                    text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
246                );
247            }
248            Some(Admit::Error) => {
249                return finish(
250                    m,
251                    &method,
252                    &path,
253                    ip,
254                    started,
255                    "limiter_error",
256                    text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
257                );
258            }
259            Some(Admit::Allowed) | None => {}
260        }
261    }
262
263    // 4) Method allowlist.
264    let allow = &rt.cfg.validation.allow_methods;
265    if !allow.is_empty()
266        && !allow
267            .iter()
268            .any(|x| x.eq_ignore_ascii_case(method.as_str()))
269    {
270        return finish(
271            m,
272            &method,
273            &path,
274            ip,
275            started,
276            "method_not_allowed",
277            text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
278        );
279    }
280
281    // 5) Buffer the body up to the configured limit.
282    let (parts, body) = req.into_parts();
283    let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
284        Ok(b) => b,
285        Err(_) => {
286            return finish(
287                m,
288                &method,
289                &path,
290                ip,
291                started,
292                "payload_too_large",
293                text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
294            )
295        }
296    };
297    // Request (ingress) size for managed-mode usage, captured before the body is forwarded upstream.
298    let ingress_bytes = header_bytes(&parts.headers).saturating_add(body_bytes.len());
299
300    // 6) WAF-lite input inspection. A no-op unless `waf.mode` is report/block. The body is
301    //    already buffered above, so inspecting it adds no extra read. On a match: `block` mode
302    //    returns 403; `report` mode logs + counts and forwards. Both record the hit so a
303    //    report-only rollout shows up in `edgeguard_waf_hits_total`.
304    if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
305        m.record_waf_hit(hit.class);
306        match rt.waf.mode() {
307            WafMode::Block => {
308                warn!(
309                    rule = %hit.rule_id,
310                    class = hit.class,
311                    location = hit.location,
312                    client_ip = %ip,
313                    path = %path,
314                    "WAF blocked request"
315                );
316                return finish(
317                    m,
318                    &method,
319                    &path,
320                    ip,
321                    started,
322                    "forbidden",
323                    text(StatusCode::FORBIDDEN, "Forbidden"),
324                );
325            }
326            WafMode::Report => warn!(
327                rule = %hit.rule_id,
328                class = hit.class,
329                location = hit.location,
330                client_ip = %ip,
331                path = %path,
332                "WAF rule matched (report-only)"
333            ),
334            // `evaluate` returns `None` when off, so this arm is unreachable; kept for
335            // exhaustiveness.
336            WafMode::Off => {}
337        }
338    }
339
340    // 7) Build the upstream request.
341    let uri = format!("{}{}", rt.upstream_base, path);
342    let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
343    {
344        let headers = up.headers_mut().expect("builder headers");
345        // Drop hop-by-hop headers (the fixed set plus any named by `Connection`) before
346        // forwarding, so they don't leak across the proxy boundary.
347        let mut forwarded = parts.headers.clone();
348        strip_hop_by_hop(&mut forwarded);
349        for (name, value) in forwarded.iter() {
350            if name == header::HOST {
351                continue; // let the client set Host for the upstream
352            }
353            headers.insert(name.clone(), value.clone());
354        }
355        // Standard forwarding headers.
356        if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
357            headers.insert(HeaderName::from_static("x-forwarded-for"), v);
358        }
359        headers.insert(
360            HeaderName::from_static("x-forwarded-proto"),
361            HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
362        );
363    }
364
365    let upstream_req = match up.body(Full::new(body_bytes)) {
366        Ok(r) => r,
367        Err(e) => {
368            warn!(error = %e, "failed to build upstream request");
369            return finish(
370                m,
371                &method,
372                &path,
373                ip,
374                started,
375                "bad_gateway",
376                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
377            );
378        }
379    };
380
381    // 8) Forward and collect the response under a single deadline, so a stalled upstream
382    //    can't pin this task. `None` => no timeout (validation.upstream_timeout = "0").
383    let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
384    let timed_out = || {
385        warn!(upstream = %uri, "upstream timed out");
386        text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
387    };
388
389    let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
390        Ok(Ok(r)) => r,
391        Ok(Err(e)) => {
392            warn!(error = %e, upstream = %uri, "upstream unreachable");
393            return finish(
394                m,
395                &method,
396                &path,
397                ip,
398                started,
399                "upstream_error",
400                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
401            );
402        }
403        Err(_) => {
404            return finish(
405                m,
406                &method,
407                &path,
408                ip,
409                started,
410                "upstream_timeout",
411                timed_out(),
412            )
413        }
414    };
415
416    let (mut resp_parts, resp_body) = upstream_resp.into_parts();
417    // Buffer the upstream body, optionally capped so a huge response can't OOM the proxy.
418    let resp_bytes = if rt.max_response_body > 0 {
419        match within(
420            deadline,
421            Limited::new(resp_body, rt.max_response_body).collect(),
422        )
423        .await
424        {
425            Ok(Ok(c)) => c.to_bytes(),
426            Ok(Err(_)) => {
427                warn!(
428                    limit = rt.max_response_body,
429                    "upstream response exceeded max_response_body"
430                );
431                return finish(
432                    m,
433                    &method,
434                    &path,
435                    ip,
436                    started,
437                    "upstream_body_too_large",
438                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
439                );
440            }
441            Err(_) => {
442                return finish(
443                    m,
444                    &method,
445                    &path,
446                    ip,
447                    started,
448                    "upstream_timeout",
449                    timed_out(),
450                )
451            }
452        }
453    } else {
454        match within(deadline, resp_body.collect()).await {
455            Ok(Ok(c)) => c.to_bytes(),
456            Ok(Err(e)) => {
457                warn!(error = %e, "failed reading upstream body");
458                return finish(
459                    m,
460                    &method,
461                    &path,
462                    ip,
463                    started,
464                    "upstream_body_error",
465                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
466                );
467            }
468            Err(_) => {
469                return finish(
470                    m,
471                    &method,
472                    &path,
473                    ip,
474                    started,
475                    "upstream_timeout",
476                    timed_out(),
477                )
478            }
479        }
480    };
481
482    // The body was rebuffered, so let the server recompute framing; strip hop-by-hop headers
483    // (incl. any named by `Connection`) so they don't leak downstream.
484    strip_hop_by_hop(&mut resp_parts.headers);
485    resp_parts.headers.remove(header::CONTENT_LENGTH);
486
487    // Managed-mode usage: this is the proxied path, where both bodies are buffered, so the byte
488    // counts are exact. (`add_usage_request` is recorded for every request in `finish`.)
489    m.add_usage_bytes(
490        ingress_bytes,
491        header_bytes(&resp_parts.headers).saturating_add(resp_bytes.len()),
492    );
493
494    let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
495    harden_response(&rt.cfg, &mut response);
496
497    finish(m, &method, &path, ip, started, "ok", response)
498}
499
500/// Readiness probe. Returns `200` only if the upstream accepts a TCP connection, so a
501/// platform's readiness check reflects whether EdgeGuard can actually serve traffic — not
502/// merely that the process booted. `503` while the upstream is unreachable. (Liveness, i.e.
503/// "is EdgeGuard itself up", is the separate unconditional `/__edgeguard/health`.)
504pub async fn ready(State(state): State<AppState>) -> StatusCode {
505    let rt = state.runtime.load();
506    let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
507        return StatusCode::SERVICE_UNAVAILABLE;
508    };
509    match tokio::time::timeout(
510        Duration::from_secs(2),
511        TcpStream::connect((host.as_str(), port)),
512    )
513    .await
514    {
515        Ok(Ok(_)) => StatusCode::OK,
516        _ => StatusCode::SERVICE_UNAVAILABLE,
517    }
518}
519
520/// Prometheus scrape endpoint (`GET /__edgeguard/metrics`). Like health/ready, it is a
521/// dedicated route outside the proxy fallback, so it is not subject to auth or rate limits —
522/// restrict access to `/__edgeguard/*` at the network layer if that matters in your setup.
523pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
524    let body = state.metrics.render();
525    let mut resp = Response::new(Body::from(body));
526    resp.headers_mut().insert(
527        header::CONTENT_TYPE,
528        HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
529    );
530    resp
531}
532
533/// CSP violation report sink (`POST /__edgeguard/csp-report`). Browsers POST a JSON report
534/// here when `headers.csp_report_uri` points at it; we count and log it, then `204`.
535pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
536    state.metrics.record_csp_report();
537    // Managed mode: forward the raw report to the control plane (fire-and-forget, so the browser's
538    // 204 is never delayed by an outbound call). Only when a control plane is configured and
539    // `forward_csp` is on.
540    if let Some(cp) = &state.cp {
541        if state.runtime.load().cfg.control_plane.forward_csp {
542            let cp = cp.clone();
543            let raw = body.clone();
544            tokio::spawn(async move { cp.forward_csp(&raw).await });
545        }
546    }
547    // This endpoint is unauthenticated and a report can carry the full document URL,
548    // referrer, and query strings — logging the whole blob at `info` is both a privacy leak
549    // and a log-flood vector. Record only the directive that fired, at `debug`.
550    match serde_json::from_slice::<serde_json::Value>(&body) {
551        Ok(report) => {
552            let directive = report
553                .get("csp-report")
554                .and_then(|r| {
555                    r.get("violated-directive")
556                        .or_else(|| r.get("effective-directive"))
557                })
558                .and_then(|v| v.as_str())
559                .unwrap_or("unknown");
560            debug!(target: "edgeguard::csp", directive, "CSP violation report");
561        }
562        Err(_) => warn!(
563            bytes = body.len(),
564            "CSP violation report with an unparseable body"
565        ),
566    }
567    StatusCode::NO_CONTENT
568}
569
570/// Resolve the client IP. The peer socket address is authoritative; `X-Forwarded-For`
571/// (first hop) is honored only when `trust_forwarded` is set, because a directly
572/// reachable client can otherwise spoof it to forge their identity.
573fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
574    if trust_forwarded {
575        if let Some(xff) = headers.get("x-forwarded-for") {
576            if let Ok(s) = xff.to_str() {
577                if let Some(first) = s.split(',').next() {
578                    if let Ok(ip) = first.trim().parse::<IpAddr>() {
579                        return ip;
580                    }
581                }
582            }
583        }
584    }
585    peer.ip()
586}
587
588/// Total size of the request headers (sum of name + value bytes), used for the header-size
589/// policy limit. This is an application-layer approximation of the on-wire header size.
590fn header_bytes(headers: &HeaderMap) -> usize {
591    headers
592        .iter()
593        .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
594        .sum()
595}
596
597/// Remove hop-by-hop headers so they don't leak across the proxy boundary (RFC 7230 §6.1):
598/// the fixed [`HOP_BY_HOP`] set plus any header *named* in a `Connection` header. Applied in
599/// both directions (request to upstream, response to client).
600fn strip_hop_by_hop(headers: &mut HeaderMap) {
601    // Header names listed in any `Connection` header are connection-specific; collect them
602    // before mutating (the borrow of `headers` must end before we remove).
603    let connection_named: Vec<HeaderName> = headers
604        .get_all(header::CONNECTION)
605        .iter()
606        .filter_map(|v| v.to_str().ok())
607        .flat_map(|v| v.split(','))
608        .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
609        .collect();
610    for name in HOP_BY_HOP {
611        headers.remove(*name);
612    }
613    for name in connection_named {
614        headers.remove(name);
615    }
616}
617
618/// Decide the `X-Forwarded-Proto` to send upstream. If EdgeGuard terminates TLS, the client
619/// hop is HTTPS. Otherwise, behind a trusted edge (`trust_forwarded_for`) we preserve the
620/// proto the edge reported (falling back to `http`); an untrusted client's `X-Forwarded-Proto`
621/// is never honored, mirroring the client-IP trust model. Returns a `'static` token so the
622/// caller can build a `HeaderValue` without fallible parsing.
623fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
624    if cfg.tls.enabled {
625        return "https";
626    }
627    if cfg.server.trust_forwarded_for {
628        if let Some(value) = headers
629            .get("x-forwarded-proto")
630            .and_then(|v| v.to_str().ok())
631        {
632            match value.split(',').next().map(str::trim) {
633                Some(p) if p.eq_ignore_ascii_case("https") => return "https",
634                Some(p) if p.eq_ignore_ascii_case("http") => return "http",
635                _ => {}
636            }
637        }
638    }
639    "http"
640}
641
642/// Pick the most specific (longest-prefix) per-route limiter matching `path`, if any.
643fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
644    routes
645        .iter()
646        .filter(|r| path.starts_with(&r.prefix))
647        .max_by_key(|r| r.prefix.len())
648}
649
650/// The HSTS header value EdgeGuard emits when `headers.hsts` is on: a two-year `max-age`
651/// including subdomains. A named constant so the live proxy and the static-host config
652/// generator ([`crate::generate`]) can't drift on it.
653pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
654
655/// The constant security response headers EdgeGuard injects, derived from the `[headers]`
656/// policy. This is the **single source of truth** shared by the live response-hardening path
657/// ([`harden_response`]) and the static-host config generator ([`crate::generate`]), so a
658/// generated `_headers` file / edge-middleware snippet matches exactly what the proxy would add
659/// at runtime. Returns `(name, value)` pairs with canonically-cased names (for readable
660/// generated output); the proxy normalizes the case when it inserts them.
661///
662/// Cookie hardening and leaky-header *stripping* are deliberately **not** here: both rewrite the
663/// upstream's actual response (`Set-Cookie`, `Server`/`X-Powered-By`), which a static file that
664/// can only "always add this header" cannot express. The generator documents that gap; the
665/// WASM worker, which sees the real response, applies them too.
666pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
667    let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
668    out.push(("X-Content-Type-Options", "nosniff".to_string()));
669    if !cfg.frame_options.is_empty() {
670        out.push(("X-Frame-Options", cfg.frame_options.clone()));
671    }
672    if !cfg.referrer_policy.is_empty() {
673        out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
674    }
675    if !cfg.permissions_policy.is_empty() {
676        out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
677    }
678    if !cfg.csp.is_empty() {
679        // Append a report-uri directive if configured, and choose enforce vs. report-only.
680        let mut value = cfg.csp.clone();
681        if !cfg.csp_report_uri.is_empty() {
682            value.push_str("; report-uri ");
683            value.push_str(&cfg.csp_report_uri);
684        }
685        let name = if cfg.csp_report_only {
686            "Content-Security-Policy-Report-Only"
687        } else {
688            "Content-Security-Policy"
689        };
690        out.push((name, value));
691    }
692    if cfg.hsts {
693        out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
694    }
695    out
696}
697
698/// Inject security headers, harden Set-Cookie, and strip leaky headers.
699fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
700    let h = resp.headers_mut();
701
702    // Inject the constant security headers (shared with the static-host generator via
703    // `security_headers`, so the two never diverge). `from_bytes` normalizes the canonical
704    // casing to lowercase; these names/values are all valid, so the inserts don't fail.
705    for (name, value) in security_headers(&cfg.headers) {
706        if let (Ok(n), Ok(v)) = (
707            HeaderName::from_bytes(name.as_bytes()),
708            HeaderValue::from_str(&value),
709        ) {
710            h.insert(n, v);
711        }
712    }
713
714    // Strip leaky headers.
715    for name in &cfg.headers.strip {
716        if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
717            h.remove(hn);
718        }
719    }
720
721    // Harden cookies: ensure Secure, HttpOnly, and a SameSite default.
722    if cfg.headers.force_secure_cookies {
723        let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
724        if !cookies.is_empty() {
725            h.remove(header::SET_COOKIE);
726            for c in cookies {
727                if let Ok(s) = c.to_str() {
728                    let hardened = harden_cookie(s);
729                    if let Ok(v) = HeaderValue::from_str(&hardened) {
730                        h.append(header::SET_COOKIE, v);
731                    }
732                } else {
733                    h.append(header::SET_COOKIE, c);
734                }
735            }
736        }
737    }
738}
739
740fn harden_cookie(cookie: &str) -> String {
741    // Inspect attribute *names* (the tokens after the first `name=value` pair), not the
742    // whole string — otherwise a value like `session=securetoken` would look like it
743    // already carries `Secure` and we'd skip hardening it.
744    let attrs: std::collections::HashSet<String> = cookie
745        .split(';')
746        .skip(1)
747        .filter_map(|p| p.trim().split('=').next())
748        .map(|k| k.trim().to_ascii_lowercase())
749        .collect();
750
751    let mut out = cookie.trim_end_matches(';').to_string();
752    if !attrs.contains("secure") {
753        out.push_str("; Secure");
754    }
755    if !attrs.contains("httponly") {
756        out.push_str("; HttpOnly");
757    }
758    if !attrs.contains("samesite") {
759        out.push_str("; SameSite=Lax");
760    }
761    out
762}
763
764/// Run `fut` bounded by an optional deadline. `None` means no timeout. On success returns
765/// the future's own output; `Err(Elapsed)` if the deadline passed first.
766async fn within<F: Future>(
767    deadline: Option<tokio::time::Instant>,
768    fut: F,
769) -> Result<F::Output, tokio::time::error::Elapsed> {
770    match deadline {
771        Some(dl) => tokio::time::timeout_at(dl, fut).await,
772        None => Ok(fut.await),
773    }
774}
775
776fn text(status: StatusCode, msg: &str) -> Response<Body> {
777    let mut resp = Response::new(Body::from(msg.to_string()));
778    *resp.status_mut() = status;
779    resp.headers_mut().insert(
780        header::CONTENT_TYPE,
781        HeaderValue::from_static("text/plain; charset=utf-8"),
782    );
783    resp
784}
785
786/// Emit a structured access-log line, record metrics, and return the response.
787fn finish(
788    metrics: &Metrics,
789    method: &Method,
790    path: &str,
791    ip: IpAddr,
792    started: Instant,
793    outcome: &str,
794    resp: Response<Body>,
795) -> Response<Body> {
796    let elapsed = started.elapsed();
797    info!(
798        %method,
799        path = %path,
800        client_ip = %ip,
801        status = resp.status().as_u16(),
802        outcome,
803        latency_ms = elapsed.as_millis() as u64,
804        "request"
805    );
806    metrics.record_request(outcome);
807    metrics.observe_latency(elapsed);
808    // Managed mode: count every finished request (proxied or rejected) toward the usage delta.
809    // Cheap (two relaxed atomic adds) and inert unless a control plane drains it for reporting.
810    metrics.add_usage_request();
811    resp
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    fn headers_with(name: &'static str, value: &str) -> HeaderMap {
819        let mut h = HeaderMap::new();
820        h.insert(name, HeaderValue::from_str(value).unwrap());
821        h
822    }
823
824    #[test]
825    fn client_ip_ignores_xff_when_untrusted() {
826        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
827        let h = headers_with("x-forwarded-for", "1.2.3.4");
828        // Untrusted: a directly reachable client must not be able to spoof its IP.
829        assert_eq!(client_ip(&h, peer, false), peer.ip());
830    }
831
832    #[test]
833    fn client_ip_uses_first_xff_hop_when_trusted() {
834        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
835        let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
836        assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
837    }
838
839    #[test]
840    fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
841        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
842        assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
843        let garbage = headers_with("x-forwarded-for", "not-an-ip");
844        assert_eq!(client_ip(&garbage, peer, true), peer.ip());
845    }
846
847    #[test]
848    fn header_bytes_sums_names_and_values() {
849        let mut h = HeaderMap::new();
850        h.insert("a", HeaderValue::from_static("bb")); // 1 + 2
851        h.insert("ccc", HeaderValue::from_static("dddd")); // 3 + 4
852        assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
853    }
854
855    #[test]
856    fn strip_hop_by_hop_removes_fixed_and_connection_named() {
857        let mut h = HeaderMap::new();
858        h.insert(
859            "connection",
860            HeaderValue::from_static("keep-alive, X-Custom-Hop"),
861        );
862        h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
863        h.insert("x-custom-hop", HeaderValue::from_static("secret"));
864        h.insert("content-type", HeaderValue::from_static("text/plain"));
865        strip_hop_by_hop(&mut h);
866        assert!(!h.contains_key("connection"));
867        assert!(!h.contains_key("keep-alive"));
868        // A header named by Connection is connection-specific and must be dropped.
869        assert!(!h.contains_key("x-custom-hop"));
870        // An end-to-end header is preserved.
871        assert!(h.contains_key("content-type"));
872    }
873
874    #[test]
875    fn forwarded_proto_reflects_tls_and_trust() {
876        let mut cfg = Config::default();
877
878        // We terminate TLS -> always https, regardless of any incoming header.
879        cfg.tls.enabled = true;
880        assert_eq!(
881            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
882            "https"
883        );
884
885        // Plain HTTP, untrusted: http, and an incoming XFP is NOT trusted.
886        cfg.tls.enabled = false;
887        cfg.server.trust_forwarded_for = false;
888        assert_eq!(
889            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
890            "http"
891        );
892
893        // Plain HTTP behind a trusted edge: preserve the edge's reported proto.
894        cfg.server.trust_forwarded_for = true;
895        assert_eq!(
896            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
897            "https"
898        );
899        assert_eq!(
900            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
901            "http"
902        );
903        // Missing or unrecognized -> http.
904        assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
905        assert_eq!(
906            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
907            "http"
908        );
909    }
910
911    #[test]
912    fn longest_route_picks_most_specific_prefix() {
913        let mk = |p: &str| RouteLimiter {
914            prefix: p.to_string(),
915            limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
916                std::num::NonZeroU32::new(1).unwrap(),
917            ))),
918        };
919        let routes = vec![mk("/api/"), mk("/api/admin/")];
920        assert_eq!(
921            longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
922            Some("/api/admin/")
923        );
924        assert_eq!(
925            longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
926            Some("/api/")
927        );
928        assert!(longest_route(&routes, "/public").is_none());
929    }
930
931    #[test]
932    fn harden_cookie_adds_missing_flags() {
933        let out = harden_cookie("sid=abc");
934        assert!(out.contains("; Secure"), "{out}");
935        assert!(out.contains("; HttpOnly"), "{out}");
936        assert!(out.contains("; SameSite=Lax"), "{out}");
937    }
938
939    #[test]
940    fn harden_cookie_preserves_existing_attributes() {
941        let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
942        assert!(out.contains("; Secure"), "{out}");
943        assert!(out.contains("SameSite=Strict"), "{out}");
944        // existing SameSite isn't overridden, HttpOnly isn't duplicated
945        assert!(!out.contains("SameSite=Lax"), "{out}");
946        assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
947    }
948
949    #[test]
950    fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
951        // The value contains the substring "secure" but there is no Secure *attribute*;
952        // it must still be added (regression guard for the token-vs-substring fix).
953        let out = harden_cookie("session=securetoken");
954        assert!(out.contains("; Secure"), "{out}");
955    }
956
957    #[test]
958    fn security_headers_reflects_config_toggles() {
959        // Defaults: every header present, CSP enforced (not report-only).
960        let cfg = HeadersCfg::default();
961        let got = security_headers(&cfg);
962        let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
963        assert!(names.contains(&"X-Content-Type-Options"));
964        assert!(names.contains(&"X-Frame-Options"));
965        assert!(names.contains(&"Referrer-Policy"));
966        assert!(names.contains(&"Permissions-Policy"));
967        assert!(names.contains(&"Content-Security-Policy"));
968        assert!(names.contains(&"Strict-Transport-Security"));
969        assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
970
971        // Disabling HSTS and clearing frame_options drops exactly those; report-only flips the
972        // CSP header name and report_uri is appended to the value.
973        let cfg = HeadersCfg {
974            hsts: false,
975            frame_options: String::new(),
976            csp: "default-src 'self'".into(),
977            csp_report_only: true,
978            csp_report_uri: "/__edgeguard/csp-report".into(),
979            ..HeadersCfg::default()
980        };
981        let got = security_headers(&cfg);
982        let map: std::collections::HashMap<&str, String> =
983            got.iter().map(|(n, v)| (*n, v.clone())).collect();
984        assert!(!map.contains_key("Strict-Transport-Security"));
985        assert!(!map.contains_key("X-Frame-Options"));
986        assert!(!map.contains_key("Content-Security-Policy"));
987        assert_eq!(
988            map.get("Content-Security-Policy-Report-Only")
989                .map(|s| s.as_str()),
990            Some("default-src 'self'; report-uri /__edgeguard/csp-report")
991        );
992    }
993}