Skip to main content

edgeguard/
proxy.rs

1//! Request path: header-size limit -> rate limit (per-IP / per-route) -> auth -> per-key
2//! rate limit -> method allowlist -> body-size limit -> WAF input inspection -> forward to
3//! upstream.
4//! Response path: header injection (incl. CSP / CSP-report-only) -> cookie hardening ->
5//! strip leaky headers.
6//!
7//! All policy lives in [`Runtime`], held behind an [`ArcSwap`] so a config hot-reload swaps
8//! it atomically without blocking the request path or dropping in-flight connections. The
9//! upstream client and the metric registry sit *outside* the swap so the connection pool and
10//! counters survive a reload.
11
12use std::future::Future;
13use std::net::{IpAddr, SocketAddr};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17use arc_swap::ArcSwap;
18use axum::{
19    body::{Body, Bytes},
20    extract::{ConnectInfo, State},
21    http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode},
22};
23use governor::{clock::DefaultClock, state::keyed::DefaultKeyedStateStore, RateLimiter};
24use http_body_util::{BodyExt, Full, Limited};
25use hyper_util::client::legacy::{connect::HttpConnector, Client};
26use tokio::net::TcpStream;
27use tracing::{debug, info, warn};
28
29use crate::auth::{AuthEngine, Challenge, Decision};
30use crate::config::{Config, HeadersCfg};
31use crate::limiter::{Admit, DistributedLimiter};
32use crate::metrics::Metrics;
33use crate::waf::{WafEngine, WafMode};
34
35pub type KeyedLimiter = RateLimiter<IpAddr, DefaultKeyedStateStore<IpAddr>, DefaultClock>;
36/// Rate limiter keyed by the authenticated principal (per-key limiting).
37pub type StrLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>;
38pub type UpstreamClient = Client<HttpConnector, Full<Bytes>>;
39
40/// Shared, cheaply-cloned handle the router hands to every request. Only the hot-swappable
41/// [`Runtime`] changes on reload; the client and metrics are stable.
42#[derive(Clone)]
43pub struct AppState {
44    pub client: UpstreamClient,
45    pub metrics: Arc<Metrics>,
46    pub runtime: Arc<ArcSwap<Runtime>>,
47}
48
49/// A per-route rate-limit override: requests whose path starts with `prefix` use `limiter`.
50pub struct RouteLimiter {
51    pub prefix: String,
52    pub limiter: Arc<KeyedLimiter>,
53}
54
55/// All request-handling policy derived from a [`Config`]. Rebuilt from scratch on reload and
56/// swapped in atomically.
57pub struct Runtime {
58    pub cfg: Arc<Config>,
59    pub upstream_base: Arc<String>,
60    pub auth: AuthEngine,
61    /// WAF-lite input screener. Inert (`evaluate` returns `None`) when `waf.mode = "off"`.
62    pub waf: WafEngine,
63    /// Shared-store (distributed) limiter, `Some` when `ratelimit.store` is `memory`/`redis`.
64    /// When present it replaces the three `governor` limiters below (which are then `None`).
65    pub distributed: Option<DistributedLimiter>,
66    /// Global per-client-IP limiter (`None` when rate limiting is disabled or distributed).
67    pub ip_limiter: Option<Arc<KeyedLimiter>>,
68    /// Per-route limiters (also keyed per IP), checked instead of `ip_limiter` on a match.
69    pub route_limiters: Vec<RouteLimiter>,
70    /// Per-principal limiter (`None` when per-key limiting is disabled or distributed).
71    pub key_limiter: Option<Arc<StrLimiter>>,
72    pub max_body: usize,
73    /// Cap on the buffered upstream response body; `0` means unbounded.
74    pub max_response_body: usize,
75    /// Cap on total request header bytes; `0` means disabled.
76    pub max_header_bytes: usize,
77    /// Max time for the upstream request + body read; `None` disables the timeout.
78    pub upstream_timeout: Option<Duration>,
79}
80
81/// Hop-by-hop headers that must not be forwarded (RFC 7230 §6.1).
82const HOP_BY_HOP: &[&str] = &[
83    "connection",
84    "keep-alive",
85    "proxy-authenticate",
86    "proxy-authorization",
87    "te",
88    "trailer",
89    "transfer-encoding",
90    "upgrade",
91];
92
93pub async fn handle(
94    State(state): State<AppState>,
95    ConnectInfo(peer): ConnectInfo<SocketAddr>,
96    req: Request<Body>,
97) -> Response<Body> {
98    let started = Instant::now();
99    // One atomic load pins a consistent policy snapshot for the whole request, even if a
100    // reload swaps in a new Runtime mid-flight.
101    let rt = state.runtime.load();
102    let m = &state.metrics;
103
104    let method = req.method().clone();
105    let path = req
106        .uri()
107        .path_and_query()
108        .map(|p| p.as_str().to_string())
109        .unwrap_or_else(|| req.uri().path().to_string());
110
111    let ip = client_ip(req.headers(), peer, rt.cfg.server.trust_forwarded_for);
112
113    // Reserve the internal namespace: never forward `/__edgeguard/*` upstream. Registered
114    // internal routes are matched before this fallback, so anything reaching here under that
115    // prefix is an unknown internal path — a `404` from EdgeGuard, not a request leaked to the
116    // app. This is also what keeps the ops endpoints (health/ready/metrics) unserved on the
117    // public listener in public/private split mode, rather than proxying them to the upstream.
118    if req.uri().path().starts_with("/__edgeguard/") {
119        return finish(
120            m,
121            &method,
122            &path,
123            ip,
124            started,
125            "not_found",
126            text(StatusCode::NOT_FOUND, "Not Found"),
127        );
128    }
129
130    // 0) Total request-header-size limit.
131    if rt.max_header_bytes > 0 && header_bytes(req.headers()) > rt.max_header_bytes {
132        return finish(
133            m,
134            &method,
135            &path,
136            ip,
137            started,
138            "header_too_large",
139            text(
140                StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE,
141                "Request Header Fields Too Large",
142            ),
143        );
144    }
145
146    // 1) Rate limit. A matching per-route override replaces the global per-IP limit. A shared
147    //    store (distributed) limiter, when configured, replaces the in-process limiters; on a
148    //    store error it fails closed (`503`) unless `ratelimit.fail_open` is set.
149    if rt.cfg.ratelimit.enabled {
150        if let Some(d) = &rt.distributed {
151            match d.check_ip_route(ip, &path).await {
152                Admit::Allowed => {}
153                Admit::Limited(scope) => {
154                    m.record_ratelimit_hit(scope);
155                    return finish(
156                        m,
157                        &method,
158                        &path,
159                        ip,
160                        started,
161                        "rate_limited",
162                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
163                    );
164                }
165                Admit::Error => {
166                    return finish(
167                        m,
168                        &method,
169                        &path,
170                        ip,
171                        started,
172                        "limiter_error",
173                        text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
174                    );
175                }
176            }
177        } else {
178            let (limiter, scope) = match longest_route(&rt.route_limiters, &path) {
179                Some(r) => (Some(r.limiter.as_ref()), "route"),
180                None => (rt.ip_limiter.as_deref(), "ip"),
181            };
182            if let Some(limiter) = limiter {
183                if limiter.check_key(&ip).is_err() {
184                    m.record_ratelimit_hit(scope);
185                    return finish(
186                        m,
187                        &method,
188                        &path,
189                        ip,
190                        started,
191                        "rate_limited",
192                        text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
193                    );
194                }
195            }
196        }
197    }
198
199    // 2) Authentication. On success we learn the principal for per-key limiting.
200    let principal = match rt.auth.authorize(&rt.cfg.auth, req.headers()).await {
201        Decision::Allow(principal) => principal,
202        Decision::Deny(challenge) => {
203            let mut resp = text(StatusCode::UNAUTHORIZED, "Unauthorized");
204            let challenge_value = match challenge {
205                Challenge::Basic(c) => Some(c),
206                Challenge::Bearer => Some("Bearer".to_string()),
207                Challenge::None => None,
208            };
209            if let Some(c) = challenge_value {
210                if let Ok(v) = HeaderValue::from_str(&c) {
211                    resp.headers_mut().insert(header::WWW_AUTHENTICATE, v);
212                }
213            }
214            return finish(m, &method, &path, ip, started, "unauthorized", resp);
215        }
216    };
217
218    // 3) Per-key rate limit (only for authenticated principals). Routed to the distributed
219    //    limiter when configured, else the in-process per-key limiter.
220    if let Some(principal) = &principal {
221        let key_admit = if let Some(d) = &rt.distributed {
222            Some(d.check_key(principal).await)
223        } else {
224            rt.key_limiter.as_ref().map(|limiter| {
225                if limiter.check_key(principal).is_err() {
226                    Admit::Limited("key")
227                } else {
228                    Admit::Allowed
229                }
230            })
231        };
232        match key_admit {
233            Some(Admit::Limited(scope)) => {
234                m.record_ratelimit_hit(scope);
235                return finish(
236                    m,
237                    &method,
238                    &path,
239                    ip,
240                    started,
241                    "rate_limited",
242                    text(StatusCode::TOO_MANY_REQUESTS, "Too Many Requests"),
243                );
244            }
245            Some(Admit::Error) => {
246                return finish(
247                    m,
248                    &method,
249                    &path,
250                    ip,
251                    started,
252                    "limiter_error",
253                    text(StatusCode::SERVICE_UNAVAILABLE, "Service Unavailable"),
254                );
255            }
256            Some(Admit::Allowed) | None => {}
257        }
258    }
259
260    // 4) Method allowlist.
261    let allow = &rt.cfg.validation.allow_methods;
262    if !allow.is_empty()
263        && !allow
264            .iter()
265            .any(|x| x.eq_ignore_ascii_case(method.as_str()))
266    {
267        return finish(
268            m,
269            &method,
270            &path,
271            ip,
272            started,
273            "method_not_allowed",
274            text(StatusCode::METHOD_NOT_ALLOWED, "Method Not Allowed"),
275        );
276    }
277
278    // 5) Buffer the body up to the configured limit.
279    let (parts, body) = req.into_parts();
280    let body_bytes = match axum::body::to_bytes(body, rt.max_body).await {
281        Ok(b) => b,
282        Err(_) => {
283            return finish(
284                m,
285                &method,
286                &path,
287                ip,
288                started,
289                "payload_too_large",
290                text(StatusCode::PAYLOAD_TOO_LARGE, "Payload Too Large"),
291            )
292        }
293    };
294
295    // 6) WAF-lite input inspection. A no-op unless `waf.mode` is report/block. The body is
296    //    already buffered above, so inspecting it adds no extra read. On a match: `block` mode
297    //    returns 403; `report` mode logs + counts and forwards. Both record the hit so a
298    //    report-only rollout shows up in `edgeguard_waf_hits_total`.
299    if let Some(hit) = rt.waf.evaluate(&path, &parts.headers, &body_bytes) {
300        m.record_waf_hit(hit.class);
301        match rt.waf.mode() {
302            WafMode::Block => {
303                warn!(
304                    rule = %hit.rule_id,
305                    class = hit.class,
306                    location = hit.location,
307                    client_ip = %ip,
308                    path = %path,
309                    "WAF blocked request"
310                );
311                return finish(
312                    m,
313                    &method,
314                    &path,
315                    ip,
316                    started,
317                    "forbidden",
318                    text(StatusCode::FORBIDDEN, "Forbidden"),
319                );
320            }
321            WafMode::Report => warn!(
322                rule = %hit.rule_id,
323                class = hit.class,
324                location = hit.location,
325                client_ip = %ip,
326                path = %path,
327                "WAF rule matched (report-only)"
328            ),
329            // `evaluate` returns `None` when off, so this arm is unreachable; kept for
330            // exhaustiveness.
331            WafMode::Off => {}
332        }
333    }
334
335    // 7) Build the upstream request.
336    let uri = format!("{}{}", rt.upstream_base, path);
337    let mut up = Request::builder().method(parts.method.clone()).uri(&uri);
338    {
339        let headers = up.headers_mut().expect("builder headers");
340        // Drop hop-by-hop headers (the fixed set plus any named by `Connection`) before
341        // forwarding, so they don't leak across the proxy boundary.
342        let mut forwarded = parts.headers.clone();
343        strip_hop_by_hop(&mut forwarded);
344        for (name, value) in forwarded.iter() {
345            if name == header::HOST {
346                continue; // let the client set Host for the upstream
347            }
348            headers.insert(name.clone(), value.clone());
349        }
350        // Standard forwarding headers.
351        if let Ok(v) = HeaderValue::from_str(&ip.to_string()) {
352            headers.insert(HeaderName::from_static("x-forwarded-for"), v);
353        }
354        headers.insert(
355            HeaderName::from_static("x-forwarded-proto"),
356            HeaderValue::from_static(forwarded_proto(&rt.cfg, &parts.headers)),
357        );
358    }
359
360    let upstream_req = match up.body(Full::new(body_bytes)) {
361        Ok(r) => r,
362        Err(e) => {
363            warn!(error = %e, "failed to build upstream request");
364            return finish(
365                m,
366                &method,
367                &path,
368                ip,
369                started,
370                "bad_gateway",
371                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
372            );
373        }
374    };
375
376    // 8) Forward and collect the response under a single deadline, so a stalled upstream
377    //    can't pin this task. `None` => no timeout (validation.upstream_timeout = "0").
378    let deadline = rt.upstream_timeout.map(|d| tokio::time::Instant::now() + d);
379    let timed_out = || {
380        warn!(upstream = %uri, "upstream timed out");
381        text(StatusCode::GATEWAY_TIMEOUT, "Gateway Timeout")
382    };
383
384    let upstream_resp = match within(deadline, state.client.request(upstream_req)).await {
385        Ok(Ok(r)) => r,
386        Ok(Err(e)) => {
387            warn!(error = %e, upstream = %uri, "upstream unreachable");
388            return finish(
389                m,
390                &method,
391                &path,
392                ip,
393                started,
394                "upstream_error",
395                text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
396            );
397        }
398        Err(_) => {
399            return finish(
400                m,
401                &method,
402                &path,
403                ip,
404                started,
405                "upstream_timeout",
406                timed_out(),
407            )
408        }
409    };
410
411    let (mut resp_parts, resp_body) = upstream_resp.into_parts();
412    // Buffer the upstream body, optionally capped so a huge response can't OOM the proxy.
413    let resp_bytes = if rt.max_response_body > 0 {
414        match within(
415            deadline,
416            Limited::new(resp_body, rt.max_response_body).collect(),
417        )
418        .await
419        {
420            Ok(Ok(c)) => c.to_bytes(),
421            Ok(Err(_)) => {
422                warn!(
423                    limit = rt.max_response_body,
424                    "upstream response exceeded max_response_body"
425                );
426                return finish(
427                    m,
428                    &method,
429                    &path,
430                    ip,
431                    started,
432                    "upstream_body_too_large",
433                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
434                );
435            }
436            Err(_) => {
437                return finish(
438                    m,
439                    &method,
440                    &path,
441                    ip,
442                    started,
443                    "upstream_timeout",
444                    timed_out(),
445                )
446            }
447        }
448    } else {
449        match within(deadline, resp_body.collect()).await {
450            Ok(Ok(c)) => c.to_bytes(),
451            Ok(Err(e)) => {
452                warn!(error = %e, "failed reading upstream body");
453                return finish(
454                    m,
455                    &method,
456                    &path,
457                    ip,
458                    started,
459                    "upstream_body_error",
460                    text(StatusCode::BAD_GATEWAY, "Bad Gateway"),
461                );
462            }
463            Err(_) => {
464                return finish(
465                    m,
466                    &method,
467                    &path,
468                    ip,
469                    started,
470                    "upstream_timeout",
471                    timed_out(),
472                )
473            }
474        }
475    };
476
477    // The body was rebuffered, so let the server recompute framing; strip hop-by-hop headers
478    // (incl. any named by `Connection`) so they don't leak downstream.
479    strip_hop_by_hop(&mut resp_parts.headers);
480    resp_parts.headers.remove(header::CONTENT_LENGTH);
481
482    let mut response = Response::from_parts(resp_parts, Body::from(resp_bytes));
483    harden_response(&rt.cfg, &mut response);
484
485    finish(m, &method, &path, ip, started, "ok", response)
486}
487
488/// Readiness probe. Returns `200` only if the upstream accepts a TCP connection, so a
489/// platform's readiness check reflects whether EdgeGuard can actually serve traffic — not
490/// merely that the process booted. `503` while the upstream is unreachable. (Liveness, i.e.
491/// "is EdgeGuard itself up", is the separate unconditional `/__edgeguard/health`.)
492pub async fn ready(State(state): State<AppState>) -> StatusCode {
493    let rt = state.runtime.load();
494    let Some((host, port)) = rt.cfg.upstream_probe_addr() else {
495        return StatusCode::SERVICE_UNAVAILABLE;
496    };
497    match tokio::time::timeout(
498        Duration::from_secs(2),
499        TcpStream::connect((host.as_str(), port)),
500    )
501    .await
502    {
503        Ok(Ok(_)) => StatusCode::OK,
504        _ => StatusCode::SERVICE_UNAVAILABLE,
505    }
506}
507
508/// Prometheus scrape endpoint (`GET /__edgeguard/metrics`). Like health/ready, it is a
509/// dedicated route outside the proxy fallback, so it is not subject to auth or rate limits —
510/// restrict access to `/__edgeguard/*` at the network layer if that matters in your setup.
511pub async fn metrics_handler(State(state): State<AppState>) -> Response<Body> {
512    let body = state.metrics.render();
513    let mut resp = Response::new(Body::from(body));
514    resp.headers_mut().insert(
515        header::CONTENT_TYPE,
516        HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
517    );
518    resp
519}
520
521/// CSP violation report sink (`POST /__edgeguard/csp-report`). Browsers POST a JSON report
522/// here when `headers.csp_report_uri` points at it; we count and log it, then `204`.
523pub async fn csp_report(State(state): State<AppState>, body: Bytes) -> StatusCode {
524    state.metrics.record_csp_report();
525    // This endpoint is unauthenticated and a report can carry the full document URL,
526    // referrer, and query strings — logging the whole blob at `info` is both a privacy leak
527    // and a log-flood vector. Record only the directive that fired, at `debug`.
528    match serde_json::from_slice::<serde_json::Value>(&body) {
529        Ok(report) => {
530            let directive = report
531                .get("csp-report")
532                .and_then(|r| {
533                    r.get("violated-directive")
534                        .or_else(|| r.get("effective-directive"))
535                })
536                .and_then(|v| v.as_str())
537                .unwrap_or("unknown");
538            debug!(target: "edgeguard::csp", directive, "CSP violation report");
539        }
540        Err(_) => warn!(
541            bytes = body.len(),
542            "CSP violation report with an unparseable body"
543        ),
544    }
545    StatusCode::NO_CONTENT
546}
547
548/// Resolve the client IP. The peer socket address is authoritative; `X-Forwarded-For`
549/// (first hop) is honored only when `trust_forwarded` is set, because a directly
550/// reachable client can otherwise spoof it to forge their identity.
551fn client_ip(headers: &HeaderMap, peer: SocketAddr, trust_forwarded: bool) -> IpAddr {
552    if trust_forwarded {
553        if let Some(xff) = headers.get("x-forwarded-for") {
554            if let Ok(s) = xff.to_str() {
555                if let Some(first) = s.split(',').next() {
556                    if let Ok(ip) = first.trim().parse::<IpAddr>() {
557                        return ip;
558                    }
559                }
560            }
561        }
562    }
563    peer.ip()
564}
565
566/// Total size of the request headers (sum of name + value bytes), used for the header-size
567/// policy limit. This is an application-layer approximation of the on-wire header size.
568fn header_bytes(headers: &HeaderMap) -> usize {
569    headers
570        .iter()
571        .map(|(name, value)| name.as_str().len() + value.as_bytes().len())
572        .sum()
573}
574
575/// Remove hop-by-hop headers so they don't leak across the proxy boundary (RFC 7230 §6.1):
576/// the fixed [`HOP_BY_HOP`] set plus any header *named* in a `Connection` header. Applied in
577/// both directions (request to upstream, response to client).
578fn strip_hop_by_hop(headers: &mut HeaderMap) {
579    // Header names listed in any `Connection` header are connection-specific; collect them
580    // before mutating (the borrow of `headers` must end before we remove).
581    let connection_named: Vec<HeaderName> = headers
582        .get_all(header::CONNECTION)
583        .iter()
584        .filter_map(|v| v.to_str().ok())
585        .flat_map(|v| v.split(','))
586        .filter_map(|token| HeaderName::from_bytes(token.trim().as_bytes()).ok())
587        .collect();
588    for name in HOP_BY_HOP {
589        headers.remove(*name);
590    }
591    for name in connection_named {
592        headers.remove(name);
593    }
594}
595
596/// Decide the `X-Forwarded-Proto` to send upstream. If EdgeGuard terminates TLS, the client
597/// hop is HTTPS. Otherwise, behind a trusted edge (`trust_forwarded_for`) we preserve the
598/// proto the edge reported (falling back to `http`); an untrusted client's `X-Forwarded-Proto`
599/// is never honored, mirroring the client-IP trust model. Returns a `'static` token so the
600/// caller can build a `HeaderValue` without fallible parsing.
601fn forwarded_proto(cfg: &Config, headers: &HeaderMap) -> &'static str {
602    if cfg.tls.enabled {
603        return "https";
604    }
605    if cfg.server.trust_forwarded_for {
606        if let Some(value) = headers
607            .get("x-forwarded-proto")
608            .and_then(|v| v.to_str().ok())
609        {
610            match value.split(',').next().map(str::trim) {
611                Some(p) if p.eq_ignore_ascii_case("https") => return "https",
612                Some(p) if p.eq_ignore_ascii_case("http") => return "http",
613                _ => {}
614            }
615        }
616    }
617    "http"
618}
619
620/// Pick the most specific (longest-prefix) per-route limiter matching `path`, if any.
621fn longest_route<'a>(routes: &'a [RouteLimiter], path: &str) -> Option<&'a RouteLimiter> {
622    routes
623        .iter()
624        .filter(|r| path.starts_with(&r.prefix))
625        .max_by_key(|r| r.prefix.len())
626}
627
628/// The HSTS header value EdgeGuard emits when `headers.hsts` is on: a two-year `max-age`
629/// including subdomains. A named constant so the live proxy and the static-host config
630/// generator ([`crate::generate`]) can't drift on it.
631pub const HSTS_VALUE: &str = "max-age=63072000; includeSubDomains";
632
633/// The constant security response headers EdgeGuard injects, derived from the `[headers]`
634/// policy. This is the **single source of truth** shared by the live response-hardening path
635/// ([`harden_response`]) and the static-host config generator ([`crate::generate`]), so a
636/// generated `_headers` file / edge-middleware snippet matches exactly what the proxy would add
637/// at runtime. Returns `(name, value)` pairs with canonically-cased names (for readable
638/// generated output); the proxy normalizes the case when it inserts them.
639///
640/// Cookie hardening and leaky-header *stripping* are deliberately **not** here: both rewrite the
641/// upstream's actual response (`Set-Cookie`, `Server`/`X-Powered-By`), which a static file that
642/// can only "always add this header" cannot express. The generator documents that gap; the
643/// WASM worker, which sees the real response, applies them too.
644pub fn security_headers(cfg: &HeadersCfg) -> Vec<(&'static str, String)> {
645    let mut out: Vec<(&'static str, String)> = Vec::with_capacity(6);
646    out.push(("X-Content-Type-Options", "nosniff".to_string()));
647    if !cfg.frame_options.is_empty() {
648        out.push(("X-Frame-Options", cfg.frame_options.clone()));
649    }
650    if !cfg.referrer_policy.is_empty() {
651        out.push(("Referrer-Policy", cfg.referrer_policy.clone()));
652    }
653    if !cfg.permissions_policy.is_empty() {
654        out.push(("Permissions-Policy", cfg.permissions_policy.clone()));
655    }
656    if !cfg.csp.is_empty() {
657        // Append a report-uri directive if configured, and choose enforce vs. report-only.
658        let mut value = cfg.csp.clone();
659        if !cfg.csp_report_uri.is_empty() {
660            value.push_str("; report-uri ");
661            value.push_str(&cfg.csp_report_uri);
662        }
663        let name = if cfg.csp_report_only {
664            "Content-Security-Policy-Report-Only"
665        } else {
666            "Content-Security-Policy"
667        };
668        out.push((name, value));
669    }
670    if cfg.hsts {
671        out.push(("Strict-Transport-Security", HSTS_VALUE.to_string()));
672    }
673    out
674}
675
676/// Inject security headers, harden Set-Cookie, and strip leaky headers.
677fn harden_response(cfg: &Config, resp: &mut Response<Body>) {
678    let h = resp.headers_mut();
679
680    // Inject the constant security headers (shared with the static-host generator via
681    // `security_headers`, so the two never diverge). `from_bytes` normalizes the canonical
682    // casing to lowercase; these names/values are all valid, so the inserts don't fail.
683    for (name, value) in security_headers(&cfg.headers) {
684        if let (Ok(n), Ok(v)) = (
685            HeaderName::from_bytes(name.as_bytes()),
686            HeaderValue::from_str(&value),
687        ) {
688            h.insert(n, v);
689        }
690    }
691
692    // Strip leaky headers.
693    for name in &cfg.headers.strip {
694        if let Ok(hn) = HeaderName::from_bytes(name.as_bytes()) {
695            h.remove(hn);
696        }
697    }
698
699    // Harden cookies: ensure Secure, HttpOnly, and a SameSite default.
700    if cfg.headers.force_secure_cookies {
701        let cookies: Vec<HeaderValue> = h.get_all(header::SET_COOKIE).iter().cloned().collect();
702        if !cookies.is_empty() {
703            h.remove(header::SET_COOKIE);
704            for c in cookies {
705                if let Ok(s) = c.to_str() {
706                    let hardened = harden_cookie(s);
707                    if let Ok(v) = HeaderValue::from_str(&hardened) {
708                        h.append(header::SET_COOKIE, v);
709                    }
710                } else {
711                    h.append(header::SET_COOKIE, c);
712                }
713            }
714        }
715    }
716}
717
718fn harden_cookie(cookie: &str) -> String {
719    // Inspect attribute *names* (the tokens after the first `name=value` pair), not the
720    // whole string — otherwise a value like `session=securetoken` would look like it
721    // already carries `Secure` and we'd skip hardening it.
722    let attrs: std::collections::HashSet<String> = cookie
723        .split(';')
724        .skip(1)
725        .filter_map(|p| p.trim().split('=').next())
726        .map(|k| k.trim().to_ascii_lowercase())
727        .collect();
728
729    let mut out = cookie.trim_end_matches(';').to_string();
730    if !attrs.contains("secure") {
731        out.push_str("; Secure");
732    }
733    if !attrs.contains("httponly") {
734        out.push_str("; HttpOnly");
735    }
736    if !attrs.contains("samesite") {
737        out.push_str("; SameSite=Lax");
738    }
739    out
740}
741
742/// Run `fut` bounded by an optional deadline. `None` means no timeout. On success returns
743/// the future's own output; `Err(Elapsed)` if the deadline passed first.
744async fn within<F: Future>(
745    deadline: Option<tokio::time::Instant>,
746    fut: F,
747) -> Result<F::Output, tokio::time::error::Elapsed> {
748    match deadline {
749        Some(dl) => tokio::time::timeout_at(dl, fut).await,
750        None => Ok(fut.await),
751    }
752}
753
754fn text(status: StatusCode, msg: &str) -> Response<Body> {
755    let mut resp = Response::new(Body::from(msg.to_string()));
756    *resp.status_mut() = status;
757    resp.headers_mut().insert(
758        header::CONTENT_TYPE,
759        HeaderValue::from_static("text/plain; charset=utf-8"),
760    );
761    resp
762}
763
764/// Emit a structured access-log line, record metrics, and return the response.
765fn finish(
766    metrics: &Metrics,
767    method: &Method,
768    path: &str,
769    ip: IpAddr,
770    started: Instant,
771    outcome: &str,
772    resp: Response<Body>,
773) -> Response<Body> {
774    let elapsed = started.elapsed();
775    info!(
776        %method,
777        path = %path,
778        client_ip = %ip,
779        status = resp.status().as_u16(),
780        outcome,
781        latency_ms = elapsed.as_millis() as u64,
782        "request"
783    );
784    metrics.record_request(outcome);
785    metrics.observe_latency(elapsed);
786    resp
787}
788
789#[cfg(test)]
790mod tests {
791    use super::*;
792
793    fn headers_with(name: &'static str, value: &str) -> HeaderMap {
794        let mut h = HeaderMap::new();
795        h.insert(name, HeaderValue::from_str(value).unwrap());
796        h
797    }
798
799    #[test]
800    fn client_ip_ignores_xff_when_untrusted() {
801        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
802        let h = headers_with("x-forwarded-for", "1.2.3.4");
803        // Untrusted: a directly reachable client must not be able to spoof its IP.
804        assert_eq!(client_ip(&h, peer, false), peer.ip());
805    }
806
807    #[test]
808    fn client_ip_uses_first_xff_hop_when_trusted() {
809        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
810        let h = headers_with("x-forwarded-for", "1.2.3.4, 5.6.7.8");
811        assert_eq!(client_ip(&h, peer, true).to_string(), "1.2.3.4");
812    }
813
814    #[test]
815    fn client_ip_falls_back_to_peer_on_missing_or_garbage_xff() {
816        let peer: SocketAddr = "203.0.113.9:55000".parse().unwrap();
817        assert_eq!(client_ip(&HeaderMap::new(), peer, true), peer.ip());
818        let garbage = headers_with("x-forwarded-for", "not-an-ip");
819        assert_eq!(client_ip(&garbage, peer, true), peer.ip());
820    }
821
822    #[test]
823    fn header_bytes_sums_names_and_values() {
824        let mut h = HeaderMap::new();
825        h.insert("a", HeaderValue::from_static("bb")); // 1 + 2
826        h.insert("ccc", HeaderValue::from_static("dddd")); // 3 + 4
827        assert_eq!(header_bytes(&h), 1 + 2 + 3 + 4);
828    }
829
830    #[test]
831    fn strip_hop_by_hop_removes_fixed_and_connection_named() {
832        let mut h = HeaderMap::new();
833        h.insert(
834            "connection",
835            HeaderValue::from_static("keep-alive, X-Custom-Hop"),
836        );
837        h.insert("keep-alive", HeaderValue::from_static("timeout=5"));
838        h.insert("x-custom-hop", HeaderValue::from_static("secret"));
839        h.insert("content-type", HeaderValue::from_static("text/plain"));
840        strip_hop_by_hop(&mut h);
841        assert!(!h.contains_key("connection"));
842        assert!(!h.contains_key("keep-alive"));
843        // A header named by Connection is connection-specific and must be dropped.
844        assert!(!h.contains_key("x-custom-hop"));
845        // An end-to-end header is preserved.
846        assert!(h.contains_key("content-type"));
847    }
848
849    #[test]
850    fn forwarded_proto_reflects_tls_and_trust() {
851        let mut cfg = Config::default();
852
853        // We terminate TLS -> always https, regardless of any incoming header.
854        cfg.tls.enabled = true;
855        assert_eq!(
856            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http")),
857            "https"
858        );
859
860        // Plain HTTP, untrusted: http, and an incoming XFP is NOT trusted.
861        cfg.tls.enabled = false;
862        cfg.server.trust_forwarded_for = false;
863        assert_eq!(
864            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
865            "http"
866        );
867
868        // Plain HTTP behind a trusted edge: preserve the edge's reported proto.
869        cfg.server.trust_forwarded_for = true;
870        assert_eq!(
871            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "https")),
872            "https"
873        );
874        assert_eq!(
875            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "http, https")),
876            "http"
877        );
878        // Missing or unrecognized -> http.
879        assert_eq!(forwarded_proto(&cfg, &HeaderMap::new()), "http");
880        assert_eq!(
881            forwarded_proto(&cfg, &headers_with("x-forwarded-proto", "garbage")),
882            "http"
883        );
884    }
885
886    #[test]
887    fn longest_route_picks_most_specific_prefix() {
888        let mk = |p: &str| RouteLimiter {
889            prefix: p.to_string(),
890            limiter: Arc::new(RateLimiter::keyed(governor::Quota::per_second(
891                std::num::NonZeroU32::new(1).unwrap(),
892            ))),
893        };
894        let routes = vec![mk("/api/"), mk("/api/admin/")];
895        assert_eq!(
896            longest_route(&routes, "/api/admin/users").map(|r| r.prefix.as_str()),
897            Some("/api/admin/")
898        );
899        assert_eq!(
900            longest_route(&routes, "/api/things").map(|r| r.prefix.as_str()),
901            Some("/api/")
902        );
903        assert!(longest_route(&routes, "/public").is_none());
904    }
905
906    #[test]
907    fn harden_cookie_adds_missing_flags() {
908        let out = harden_cookie("sid=abc");
909        assert!(out.contains("; Secure"), "{out}");
910        assert!(out.contains("; HttpOnly"), "{out}");
911        assert!(out.contains("; SameSite=Lax"), "{out}");
912    }
913
914    #[test]
915    fn harden_cookie_preserves_existing_attributes() {
916        let out = harden_cookie("sid=abc; HttpOnly; SameSite=Strict");
917        assert!(out.contains("; Secure"), "{out}");
918        assert!(out.contains("SameSite=Strict"), "{out}");
919        // existing SameSite isn't overridden, HttpOnly isn't duplicated
920        assert!(!out.contains("SameSite=Lax"), "{out}");
921        assert_eq!(out.matches("HttpOnly").count(), 1, "{out}");
922    }
923
924    #[test]
925    fn harden_cookie_value_resembling_an_attr_is_not_skipped() {
926        // The value contains the substring "secure" but there is no Secure *attribute*;
927        // it must still be added (regression guard for the token-vs-substring fix).
928        let out = harden_cookie("session=securetoken");
929        assert!(out.contains("; Secure"), "{out}");
930    }
931
932    #[test]
933    fn security_headers_reflects_config_toggles() {
934        // Defaults: every header present, CSP enforced (not report-only).
935        let cfg = HeadersCfg::default();
936        let got = security_headers(&cfg);
937        let names: Vec<&str> = got.iter().map(|(n, _)| *n).collect();
938        assert!(names.contains(&"X-Content-Type-Options"));
939        assert!(names.contains(&"X-Frame-Options"));
940        assert!(names.contains(&"Referrer-Policy"));
941        assert!(names.contains(&"Permissions-Policy"));
942        assert!(names.contains(&"Content-Security-Policy"));
943        assert!(names.contains(&"Strict-Transport-Security"));
944        assert!(!names.contains(&"Content-Security-Policy-Report-Only"));
945
946        // Disabling HSTS and clearing frame_options drops exactly those; report-only flips the
947        // CSP header name and report_uri is appended to the value.
948        let cfg = HeadersCfg {
949            hsts: false,
950            frame_options: String::new(),
951            csp: "default-src 'self'".into(),
952            csp_report_only: true,
953            csp_report_uri: "/__edgeguard/csp-report".into(),
954            ..HeadersCfg::default()
955        };
956        let got = security_headers(&cfg);
957        let map: std::collections::HashMap<&str, String> =
958            got.iter().map(|(n, v)| (*n, v.clone())).collect();
959        assert!(!map.contains_key("Strict-Transport-Security"));
960        assert!(!map.contains_key("X-Frame-Options"));
961        assert!(!map.contains_key("Content-Security-Policy"));
962        assert_eq!(
963            map.get("Content-Security-Policy-Report-Only")
964                .map(|s| s.as_str()),
965            Some("default-src 'self'; report-uri /__edgeguard/csp-report")
966        );
967    }
968}