Skip to main content

anvil_core/
server.rs

1//! Production HTTP serving: applies the `ServerConfig` to an `axum::Router` and
2//! starts it on the configured bind addr, with optional TLS via `axum-server`.
3
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22    AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23    RateLimitConfig, RewriteRule, RouteTimeoutRule, ServerConfig, StaticMount, TlsConfig,
24    TrailingSlashAction, TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27use axum::extract::ConnectInfo;
28
29/// Apply every layer the server config calls for to the user's web router,
30/// then merge any static-file mounts. Returns a ready-to-serve `axum::Router`.
31pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
32    let mut router = web;
33
34    // Static file mounts run BEFORE wrapping with body/timeout/compression — they
35    // serve from disk and don't need request body parsing.
36    for (prefix, mount) in &cfg.static_files {
37        router = mount_static(router, prefix, mount);
38    }
39
40    // Compose request-side layers.
41    let body_max = cfg.limits.body_max as usize;
42    router = router
43        .layer(RequestBodyLimitLayer::new(body_max))
44        .layer(TraceLayer::new_for_http());
45
46    if let Some(timeout) = cfg.limits.request_timeout {
47        router = router.layer(TimeoutLayer::new(timeout));
48    }
49
50    // L4-ish concurrency cap. Tower's `GlobalConcurrencyLimitLayer` holds a
51    // semaphore shared across every request; clones of the inner service
52    // share the permit pool. Above the cap, new requests are answered with
53    // 503 Service Unavailable (mapped from the `Overloaded` error) instead
54    // of queueing indefinitely — protects against thundering-herd overload
55    // and lets an upstream LB steer traffic to healthy peers.
56    if let Some(max) = cfg.limits.max_concurrency {
57        let max = max.max(1) as usize;
58        let limiter = Arc::new(tokio::sync::Semaphore::new(max));
59        let limiter_clone = limiter.clone();
60        router = router.layer(axum::middleware::from_fn(
61            move |req: Request<Body>, next: Next| {
62                let limiter = limiter_clone.clone();
63                async move {
64                    match limiter.try_acquire() {
65                        Ok(_permit) => next.run(req).await,
66                        Err(_) => {
67                            let mut resp = Response::new(Body::from(
68                                "service overloaded — too many concurrent requests",
69                            ));
70                            *resp.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
71                            resp
72                        }
73                    }
74                }
75            },
76        ));
77    }
78
79    // Per-route timeout overrides. Walk the request path against each
80    // configured prefix; first match wins. Applied BEFORE the global
81    // timeout layer above so that a slow upload endpoint can extend the
82    // window rather than fighting it.
83    if !cfg.route_timeouts.is_empty() {
84        let rules: Arc<Vec<RouteTimeoutRule>> = Arc::new(cfg.route_timeouts.clone());
85        let rules_clone = rules.clone();
86        router = router.layer(axum::middleware::from_fn(
87            move |req: Request<Body>, next: Next| {
88                let rules = rules_clone.clone();
89                async move { route_timeout_mw(rules, req, next).await }
90            },
91        ));
92    }
93
94    // Virtual-host gating: only accept requests whose Host header matches a
95    // configured `server_name`. Empty `server_name` = match-all.
96    if !cfg.server_name.is_empty() {
97        let allowed = cfg.server_name.clone();
98        router = router.layer(axum::middleware::from_fn(
99            move |req: Request<Body>, next: Next| {
100                let allowed = allowed.clone();
101                async move { host_match_mw(allowed, req, next).await }
102            },
103        ));
104    }
105
106    // Trusted-proxy CIDR list — captured once per `apply_layers` call so each
107    // middleware that reads the client IP can decide whether to honor
108    // X-Forwarded-For. Empty = ignore XFF entirely (safe default for direct-
109    // listen deployments).
110    let trusted: Arc<Vec<ipnet::IpNet>> = Arc::new(cfg.trusted_proxies.ranges.clone());
111
112    // IP allow/deny + basic auth — apply first so unauthorized requests don't
113    // touch any other layer.
114    if !cfg.ip_rules.is_empty() {
115        let rules = Arc::new(cfg.ip_rules.clone());
116        let rules_clone = rules.clone();
117        let trusted_clone = trusted.clone();
118        router = router.layer(axum::middleware::from_fn(
119            move |req: Request<Body>, next: Next| {
120                let rules = rules_clone.clone();
121                let trusted = trusted_clone.clone();
122                async move { ip_rules_mw(rules, trusted, req, next).await }
123            },
124        ));
125    }
126    if !cfg.basic_auth.is_empty() {
127        let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
128        let rules_clone = rules.clone();
129        router = router.layer(axum::middleware::from_fn(
130            move |req: Request<Body>, next: Next| {
131                let rules = rules_clone.clone();
132                async move { basic_auth_mw(rules, req, next).await }
133            },
134        ));
135    }
136
137    // CORS — apply early. tower-http's CorsLayer would be cleaner, but we want
138    // full TOML control without depending on tower-http's CORS feature spec.
139    if cfg.cors.enabled {
140        let cors = Arc::new(cfg.cors.clone());
141        let cors_clone = cors.clone();
142        router = router.layer(axum::middleware::from_fn(
143            move |req: Request<Body>, next: Next| {
144                let cors = cors_clone.clone();
145                async move { cors_mw(cors, req, next).await }
146            },
147        ));
148    }
149
150    // Reverse-proxy rules — apply BEFORE rewrites so the user can rewrite
151    // upstream-bound requests too.
152    if !cfg.proxies.is_empty() {
153        let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
154        let proxies_clone = proxies.clone();
155        router = router.layer(axum::middleware::from_fn(
156            move |req: Request<Body>, next: Next| {
157                let proxies = proxies_clone.clone();
158                async move { proxy_mw(proxies, req, next).await }
159            },
160        ));
161    }
162
163    // Rewrites — apply early so they see the request before other layers.
164    if !cfg.rewrites.is_empty() {
165        let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
166        let compiled_clone = compiled.clone();
167        router = router.layer(axum::middleware::from_fn(
168            move |req: Request<Body>, next: Next| {
169                let rules = compiled_clone.clone();
170                async move { rewrite_mw(rules, req, next).await }
171            },
172        ));
173    }
174
175    // Trailing-slash policy.
176    if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
177        let ts = cfg.trailing_slash.clone();
178        router = router.layer(axum::middleware::from_fn(
179            move |req: Request<Body>, next: Next| {
180                let ts = ts.clone();
181                async move { trailing_slash_mw(ts, req, next).await }
182            },
183        ));
184    }
185
186    // Custom error pages: intercept responses with matching status codes and
187    // substitute the configured file contents.
188    if !cfg.error_pages.is_empty() {
189        let pages = Arc::new(load_error_pages(&cfg.error_pages));
190        let pages_clone = pages.clone();
191        router = router.layer(axum::middleware::from_fn(
192            move |req: Request<Body>, next: Next| {
193                let pages = pages_clone.clone();
194                async move { error_pages_mw(pages, req, next).await }
195            },
196        ));
197    }
198
199    // HSTS header for HTTPS responses.
200    if cfg.tls.is_some() && cfg.hsts.enabled {
201        if let Some(value) = build_hsts_header(&cfg.hsts) {
202            router = router.layer(SetResponseHeaderLayer::if_not_present(
203                HeaderName::from_static("strict-transport-security"),
204                value,
205            ));
206        }
207    }
208
209    if cfg.compression.enabled {
210        // tower-http's `CompressionLayer` selects the encoding based on the
211        // client's `Accept-Encoding` header; we just toggle the layer on and
212        // gate via the min-size predicate. Per-algorithm disable lives on the
213        // un-parameterized layer, so we apply it before `compress_when`.
214        let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
215        let mut layer = CompressionLayer::new();
216        if !cfg
217            .compression
218            .algorithms
219            .iter()
220            .any(|a| a.eq_ignore_ascii_case("gzip"))
221        {
222            layer = layer.no_gzip();
223        }
224        if !cfg
225            .compression
226            .algorithms
227            .iter()
228            .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
229        {
230            layer = layer.no_br();
231        }
232        if !cfg
233            .compression
234            .algorithms
235            .iter()
236            .any(|a| a.eq_ignore_ascii_case("deflate"))
237        {
238            layer = layer.no_deflate();
239        }
240        let layer = layer.compress_when(SizeAbove::new(min_size));
241        router = router.layer(layer);
242    }
243
244    if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
245        let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
246        let limiter_clone = limiter.clone();
247        let trusted_clone = trusted.clone();
248        router = router.layer(axum::middleware::from_fn(
249            move |req: Request<Body>, next: Next| {
250                let limiter = limiter_clone.clone();
251                let trusted = trusted_clone.clone();
252                async move { rate_limit_mw(limiter, trusted, req, next).await }
253            },
254        ));
255    }
256
257    if matches!(
258        cfg.access_log.format,
259        AccessLogFormat::Combined | AccessLogFormat::Json
260    ) {
261        let format = cfg.access_log.format;
262        let trusted_clone = trusted.clone();
263        router = router.layer(axum::middleware::from_fn(
264            move |req: Request<Body>, next: Next| {
265                let trusted = trusted_clone.clone();
266                async move { access_log_mw(format, trusted, req, next).await }
267            },
268        ));
269    }
270
271    // Request ID — generated if the inbound request didn't carry one, echoed
272    // back on the response. Threads through `tracing` so each log line for
273    // a given request shares a `request_id` field, which is the basic
274    // building block of trace correlation across services.
275    router = router.layer(axum::middleware::from_fn(request_id_mw));
276
277    // Trusted-proxy header strip — added LAST so it wraps everything else as
278    // the OUTERMOST layer (axum: last .layer() runs first on a request).
279    // Removes proxy-supplied security-sensitive headers from any request
280    // whose direct TCP peer ISN'T in `cfg.trusted_proxies.ranges`. Empty
281    // trusted-list = direct-listen mode, headers stripped from every request.
282    //
283    // Application code can then read `X-Forwarded-Proto`, `X-Real-IP`,
284    // `X-TLS-SPKI-SHA256`, etc. without worrying that a hostile client at
285    // the edge spoofed them.
286    let trusted_for_strip = trusted.clone();
287    router = router.layer(axum::middleware::from_fn(
288        move |req: Request<Body>, next: Next| {
289            let trusted = trusted_for_strip.clone();
290            async move { strip_untrusted_headers_mw(trusted, req, next).await }
291        },
292    ));
293
294    router
295}
296
297/// Headers conventionally set by reverse proxies on behalf of the client.
298/// Stripped from any request whose direct TCP peer isn't in
299/// `cfg.trusted_proxies.ranges` (empty list = strip from everyone). This
300/// covers the standard set plus Anvil-specific paired-device TLS fingerprint.
301///
302/// Apps that legitimately accept these from end clients (rare) should add
303/// `127.0.0.1/32` or the client's IP range to `[trusted_proxies] ranges`.
304const SENSITIVE_PROXY_HEADERS: &[&str] = &[
305    "x-forwarded-for",
306    "x-forwarded-proto",
307    "x-forwarded-host",
308    "x-forwarded-port",
309    "x-forwarded-prefix",
310    "x-real-ip",
311    "forwarded",
312    "x-tls-spki-sha256",
313    "x-client-cert",
314    "x-ssl-client-cert",
315    "x-ssl-client-verify",
316    "x-ssl-client-s-dn",
317];
318
319async fn strip_untrusted_headers_mw(
320    trusted: Arc<Vec<ipnet::IpNet>>,
321    mut req: Request<Body>,
322    next: Next,
323) -> Response<Body> {
324    let peer: Option<std::net::IpAddr> = req
325        .extensions()
326        .get::<ConnectInfo<SocketAddr>>()
327        .map(|ci| ci.0.ip());
328
329    let peer_trusted = match peer {
330        Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
331        None => false,
332    };
333
334    if !peer_trusted {
335        let headers = req.headers_mut();
336        for &name in SENSITIVE_PROXY_HEADERS {
337            headers.remove(name);
338        }
339    }
340
341    next.run(req).await
342}
343
344fn mount_static(
345    router: AxumRouter<Container>,
346    prefix: &str,
347    mount: &StaticMount,
348) -> AxumRouter<Container> {
349    // Note: `ranges` is reserved for a future version of tower-http that exposes
350    // per-instance range toggling. For now ranges are always enabled.
351    let _ = mount.ranges;
352
353    // If the app registered an embedded-asset fetcher for this prefix (the
354    // single-binary distribution path), serve from memory instead of disk.
355    if let Some(fetcher) = crate::embedded::lookup(prefix) {
356        let cache = mount.cache;
357        let route_pat = format!("{}/*path", prefix.trim_end_matches('/'));
358        let nested = AxumRouter::<Container>::new().route(
359            &route_pat,
360            axum::routing::get(
361                move |axum::extract::Path(path): axum::extract::Path<String>,
362                      headers: HeaderMap| async move {
363                    serve_embedded(fetcher, cache, &path, &headers)
364                },
365            ),
366        );
367        return router.merge(nested);
368    }
369
370    let svc = ServeDir::new(&mount.dir);
371
372    let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
373    let nested = if let Some(cache) = mount.cache {
374        let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
375            .unwrap_or_else(|_| HeaderValue::from_static("public"));
376        nested.layer(SetResponseHeaderLayer::if_not_present(
377            HeaderName::from_static("cache-control"),
378            value,
379        ))
380    } else {
381        nested
382    };
383    router.merge(nested)
384}
385
386/// Embedded-asset request handler: looks up the wildcard `path` in the
387/// registered fetcher, honors `If-None-Match` against the file's ETag, and
388/// returns 200/304/404.
389fn serve_embedded(
390    fetcher: crate::embedded::EmbeddedAssetFetcher,
391    cache: Option<Duration>,
392    path: &str,
393    headers: &HeaderMap,
394) -> Response<Body> {
395    let asset = match fetcher(path) {
396        Some(a) => a,
397        None => return not_found(),
398    };
399
400    if let (Some(client_tag), Some(asset_tag)) = (
401        headers
402            .get(axum::http::header::IF_NONE_MATCH)
403            .and_then(|v| v.to_str().ok()),
404        asset.etag.as_deref(),
405    ) {
406        if etag_matches(client_tag, asset_tag) {
407            let mut resp = Response::builder()
408                .status(StatusCode::NOT_MODIFIED)
409                .body(Body::empty())
410                .expect("304 body");
411            if let Some(d) = cache {
412                if let Ok(v) = HeaderValue::from_str(&format!("public, max-age={}", d.as_secs())) {
413                    resp.headers_mut().insert("cache-control", v);
414                }
415            }
416            return resp;
417        }
418    }
419
420    let mut builder = Response::builder()
421        .status(StatusCode::OK)
422        .header("content-type", asset.content_type.as_str())
423        .header("content-length", asset.data.len());
424    if let Some(tag) = asset.etag.as_deref() {
425        builder = builder.header("etag", quote_etag(tag));
426    }
427    if let Some(d) = cache {
428        builder = builder.header("cache-control", format!("public, max-age={}", d.as_secs()));
429    }
430    builder
431        .body(Body::from(asset.data.into_owned()))
432        .unwrap_or_else(|_| not_found())
433}
434
435fn not_found() -> Response<Body> {
436    Response::builder()
437        .status(StatusCode::NOT_FOUND)
438        .body(Body::from("not found"))
439        .expect("404 body")
440}
441
442fn quote_etag(raw: &str) -> String {
443    if raw.starts_with('"') {
444        raw.to_string()
445    } else {
446        format!("\"{raw}\"")
447    }
448}
449
450fn etag_matches(client: &str, server: &str) -> bool {
451    let normalize = |s: &str| -> String {
452        s.split(',')
453            .map(|t| {
454                t.trim()
455                    .trim_matches('"')
456                    .trim_start_matches("W/")
457                    .to_string()
458            })
459            .collect::<Vec<_>>()
460            .join(",")
461    };
462    let server_norm = normalize(server);
463    normalize(client)
464        .split(',')
465        .any(|tag| tag == server_norm || tag == "*")
466}
467
468// ─── Serve entry points ─────────────────────────────────────────────────────
469
470pub async fn serve(
471    router: AxumRouter,
472    cfg: &ServerConfig,
473    shutdown: tokio::sync::oneshot::Receiver<()>,
474) -> Result<(), Error> {
475    let addr: SocketAddr = cfg
476        .bind
477        .parse()
478        .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
479
480    tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
481
482    // If a redirect-HTTP listener is configured, spawn it alongside the main listener.
483    let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
484    let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
485    tokio::spawn(async move {
486        let _ = shutdown.await;
487        let _ = shutdown_main_tx.send(());
488        let _ = shutdown_redir_tx.send(());
489    });
490
491    let redirect_task = cfg.redirect_http.clone().map(|redir| {
492        let target_host = redir
493            .target_host
494            .clone()
495            .or_else(|| cfg.server_name.first().cloned());
496        let permanent = redir.permanent;
497        let bind = redir.bind.clone();
498        tokio::spawn(async move {
499            if let Err(e) =
500                serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
501            {
502                tracing::warn!(?e, "redirect_http listener exited with error");
503            }
504        })
505    });
506
507    let main_result = if let Some(tls) = &cfg.tls {
508        if tls.acme.is_some() {
509            serve_acme(
510                router,
511                addr,
512                tls,
513                cfg.limits.drain_timeout,
514                shutdown_main_rx,
515            )
516            .await
517        } else {
518            serve_tls(
519                router,
520                addr,
521                tls,
522                cfg.limits.drain_timeout,
523                shutdown_main_rx,
524            )
525            .await
526        }
527    } else {
528        serve_plain(router, addr, shutdown_main_rx).await
529    };
530
531    if let Some(task) = redirect_task {
532        task.abort();
533    }
534
535    main_result
536}
537
538/// Plain-HTTP listener that 30x-redirects every request to its `https://`
539/// equivalent. Used when TLS is on and `redirect_http` is configured.
540async fn serve_redirect_http(
541    bind: &str,
542    target_host: Option<String>,
543    permanent: bool,
544    shutdown: tokio::sync::oneshot::Receiver<()>,
545) -> Result<(), Error> {
546    let addr: SocketAddr = bind
547        .parse()
548        .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
549    tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
550
551    let target_host = Arc::new(target_host);
552    let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
553        let target_host = target_host.clone();
554        move |req: Request<Body>| {
555            let target_host = target_host.clone();
556            async move { http_redirect_handler(req, target_host, permanent).await }
557        }
558    }));
559
560    let listener = tokio::net::TcpListener::bind(addr).await?;
561    axum::serve(listener, router)
562        .with_graceful_shutdown(async move {
563            let _ = shutdown.await;
564        })
565        .await?;
566    Ok(())
567}
568
569async fn http_redirect_handler(
570    req: Request<Body>,
571    target_host: Arc<Option<String>>,
572    permanent: bool,
573) -> Response<Body> {
574    let host = target_host.as_ref().clone().unwrap_or_else(|| {
575        req.headers()
576            .get("host")
577            .and_then(|v| v.to_str().ok())
578            .map(String::from)
579            .unwrap_or_default()
580    });
581    let path_and_query = req
582        .uri()
583        .path_and_query()
584        .map(|p| p.as_str().to_string())
585        .unwrap_or_else(|| "/".to_string());
586    let location = format!("https://{host}{path_and_query}");
587
588    let status = if permanent {
589        StatusCode::MOVED_PERMANENTLY
590    } else {
591        StatusCode::FOUND
592    };
593    let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
594    *resp.status_mut() = status;
595    if let Ok(loc) = HeaderValue::from_str(&location) {
596        resp.headers_mut().insert("location", loc);
597    }
598    resp
599}
600
601fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
602    let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
603    let mut value = format!("max-age={}", max_age.as_secs());
604    if cfg.include_subdomains {
605        value.push_str("; includeSubDomains");
606    }
607    if cfg.preload {
608        value.push_str("; preload");
609    }
610    HeaderValue::from_str(&value).ok()
611}
612
613/// Reject requests whose Host header doesn't match any configured server_name.
614async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
615    let host = req
616        .headers()
617        .get("host")
618        .and_then(|v| v.to_str().ok())
619        .unwrap_or("")
620        .to_string();
621
622    // Strip port for matching: "example.com:8080" → "example.com".
623    let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
624
625    if matches_any(&host_no_port, &allowed) {
626        return next.run(req).await;
627    }
628
629    tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
630    let mut resp = Response::new(Body::from(format!(
631        "404 not found (unknown host: {host})\n"
632    )));
633    *resp.status_mut() = StatusCode::NOT_FOUND;
634    resp
635}
636
637fn matches_any(host: &str, patterns: &[String]) -> bool {
638    patterns.iter().any(|pat| matches_pattern(host, pat))
639}
640
641/// Match a host against a pattern. Supports exact match and `*.example.com`
642/// wildcards. The pattern is normalized to lowercase. A bare `*` matches any.
643fn matches_pattern(host: &str, pattern: &str) -> bool {
644    let pattern = pattern.to_ascii_lowercase();
645    if pattern == "*" {
646        return true;
647    }
648    if let Some(suffix) = pattern.strip_prefix("*.") {
649        // `*.foo.com` matches `bar.foo.com` but not `foo.com`.
650        return host.ends_with(&format!(".{suffix}"));
651    }
652    host == pattern
653}
654
655async fn serve_plain(
656    router: AxumRouter,
657    addr: SocketAddr,
658    shutdown: tokio::sync::oneshot::Receiver<()>,
659) -> Result<(), Error> {
660    let listener = tokio::net::TcpListener::bind(addr).await?;
661    axum::serve(
662        listener,
663        router.into_make_service_with_connect_info::<SocketAddr>(),
664    )
665    .with_graceful_shutdown(async move {
666        let _ = shutdown.await;
667    })
668    .await?;
669    Ok(())
670}
671
672async fn serve_tls(
673    router: AxumRouter,
674    addr: SocketAddr,
675    tls: &TlsConfig,
676    drain: Duration,
677    shutdown: tokio::sync::oneshot::Receiver<()>,
678) -> Result<(), Error> {
679    // Path A: single cert pair → fast happy path, no resolver indirection.
680    // Path B: `[[tls.certs]]` entries present → build a custom
681    // `ResolvesServerCert` that picks the cert by ClientHello SNI hostname,
682    // with the top-level cert as the default for unmatched names.
683    let config = if tls.additional_certs.is_empty() {
684        axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
685            .await
686            .map_err(|e| Error::Config(format!("tls load: {e}")))?
687    } else {
688        let resolver = build_sni_resolver(tls)
689            .map_err(|e| Error::Config(format!("tls multi-cert load: {e}")))?;
690        let server_config = rustls::ServerConfig::builder()
691            .with_no_client_auth()
692            .with_cert_resolver(Arc::new(resolver));
693        axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config))
694    };
695
696    // Cert hot-reload: spawn a notify watcher on the cert + key paths. On
697    // any change, re-read the PEM files and `reload_from_pem_file` on the
698    // shared RustlsConfig — new TLS handshakes pick up the new cert without
699    // a process restart, dropping the "swap cert → restart server" runbook
700    // ops normally need.
701    let watch_paths = [tls.cert.clone(), tls.key.clone()];
702    let config_for_watch = config.clone();
703    let cert_path = tls.cert.clone();
704    let key_path = tls.key.clone();
705    tokio::task::spawn_blocking(move || {
706        if let Err(e) = watch_tls_certs(config_for_watch, cert_path, key_path, watch_paths) {
707            tracing::warn!(error = %e, "cert hot-reload watcher exited");
708        }
709    });
710
711    let handle = axum_server::Handle::new();
712    let handle_for_shutdown = handle.clone();
713    tokio::spawn(async move {
714        let _ = shutdown.await;
715        handle_for_shutdown.graceful_shutdown(Some(drain));
716    });
717
718    axum_server::bind_rustls(addr, config)
719        .handle(handle)
720        .serve(router.into_make_service_with_connect_info::<SocketAddr>())
721        .await
722        .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
723    Ok(())
724}
725
726/// ACME-managed TLS serve path. Auto-obtains and rotates certs via Let's
727/// Encrypt (or any other ACME directory) using TLS-ALPN-01 in-process — no
728/// external certbot run required.
729///
730/// **Status:** the `[tls.acme]` schema parses today and apps written
731/// against it are forward-compatible. The runtime implementation is held
732/// back pending a focused PR that pins compatible `rustls-acme` /
733/// `rustls` / `axum-server` versions; the upstream `rustls-acme` 0.13
734/// release has build errors against the rustls version this workspace
735/// pins for the rest of TLS. Until that PR lands, ACME configs surface
736/// as a clear startup error rather than silently no-op'ing.
737async fn serve_acme(
738    _router: AxumRouter,
739    _addr: SocketAddr,
740    tls: &TlsConfig,
741    _drain: Duration,
742    _shutdown: tokio::sync::oneshot::Receiver<()>,
743) -> Result<(), Error> {
744    let acme = tls
745        .acme
746        .as_ref()
747        .expect("serve_acme called without [tls.acme]");
748    Err(Error::Config(format!(
749        "[tls.acme] is configured for {n} domain(s) but ACME runtime support \
750         is pending a follow-up PR (rustls-acme version pin). For now, use \
751         certbot in TLS-ALPN-01 mode and `[tls] cert`/`key` pointing at the \
752         certbot output; cert hot-reload handles renewals without restart.",
753        n = acme.domains.len(),
754    )))
755}
756
757/// SNI cert resolver — picks a `CertifiedKey` based on the ClientHello's
758/// SNI hostname. Falls back to the default cert when no entry matches.
759/// Matches `server_name` patterns the same way as the host-gating middleware:
760/// exact match, `*.example.com` wildcard for one-level subdomains, or `*` for
761/// any host.
762#[derive(Debug)]
763struct SniResolver {
764    /// Pre-compiled `(server_name, CertifiedKey)` pairs in declaration order.
765    /// First match wins, so put the most specific patterns first.
766    entries: Vec<(String, Arc<rustls::sign::CertifiedKey>)>,
767    default_key: Arc<rustls::sign::CertifiedKey>,
768}
769
770impl rustls::server::ResolvesServerCert for SniResolver {
771    fn resolve(
772        &self,
773        client_hello: rustls::server::ClientHello<'_>,
774    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
775        let sni = client_hello
776            .server_name()
777            .unwrap_or("")
778            .to_ascii_lowercase();
779        for (pattern, key) in &self.entries {
780            if matches_pattern(&sni, pattern) {
781                return Some(key.clone());
782            }
783        }
784        Some(self.default_key.clone())
785    }
786}
787
788fn build_sni_resolver(tls: &TlsConfig) -> std::io::Result<SniResolver> {
789    let default_key = load_certified_key(&tls.cert, &tls.key)?;
790    let mut entries = Vec::with_capacity(tls.additional_certs.len());
791    for entry in &tls.additional_certs {
792        let key = load_certified_key(&entry.cert, &entry.key)?;
793        entries.push((entry.server_name.to_ascii_lowercase(), key));
794    }
795    tracing::info!(
796        default_cert = %tls.cert.display(),
797        additional = tls.additional_certs.len(),
798        "tls: SNI resolver active"
799    );
800    Ok(SniResolver {
801        entries,
802        default_key,
803    })
804}
805
806fn load_certified_key(
807    cert_path: &std::path::Path,
808    key_path: &std::path::Path,
809) -> std::io::Result<Arc<rustls::sign::CertifiedKey>> {
810    use std::io::BufReader;
811
812    let cert_file = std::fs::File::open(cert_path).map_err(|e| {
813        std::io::Error::new(
814            e.kind(),
815            format!("opening cert {}: {e}", cert_path.display()),
816        )
817    })?;
818    let mut cert_reader = BufReader::new(cert_file);
819    let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
820        rustls_pemfile::certs(&mut cert_reader).collect::<std::io::Result<_>>()?;
821    if certs.is_empty() {
822        return Err(std::io::Error::new(
823            std::io::ErrorKind::InvalidData,
824            format!("no certificates in {}", cert_path.display()),
825        ));
826    }
827
828    let key_file = std::fs::File::open(key_path).map_err(|e| {
829        std::io::Error::new(e.kind(), format!("opening key {}: {e}", key_path.display()))
830    })?;
831    let mut key_reader = BufReader::new(key_file);
832    let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
833        std::io::Error::new(
834            std::io::ErrorKind::InvalidData,
835            format!("no private key in {}", key_path.display()),
836        )
837    })?;
838
839    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
840        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("sign: {e}")))?;
841
842    Ok(Arc::new(rustls::sign::CertifiedKey::new(
843        certs,
844        signing_key,
845    )))
846}
847
848/// Blocking-thread watcher that re-loads the rustls config whenever the cert
849/// or key file on disk changes. Uses `notify` (already a workspace dep for
850/// the dev file watcher) so we don't add anything new.
851///
852/// Coalescing: many editors write atomically by rename, which produces
853/// `Modify`+`Create` events in quick succession; we ignore that and just
854/// reload on any non-error event. The reload is itself cheap and idempotent.
855fn watch_tls_certs(
856    config: axum_server::tls_rustls::RustlsConfig,
857    cert: std::path::PathBuf,
858    key: std::path::PathBuf,
859    watch_paths: [std::path::PathBuf; 2],
860) -> std::io::Result<()> {
861    use notify::{RecursiveMode, Watcher};
862    use std::sync::mpsc::channel;
863
864    let (tx, rx) = channel::<notify::Result<notify::Event>>();
865    let mut watcher = notify::recommended_watcher(move |res| {
866        let _ = tx.send(res);
867    })
868    .map_err(|e| std::io::Error::other(format!("notify init: {e}")))?;
869
870    // Watch the parent directories — file-level watches don't survive
871    // editors that rename-on-write (vim, cargo, etc.), but directory watches
872    // catch the rename plus the new file's creation.
873    for p in &watch_paths {
874        if let Some(parent) = p.parent() {
875            watcher
876                .watch(parent, RecursiveMode::NonRecursive)
877                .map_err(|e| std::io::Error::other(format!("notify watch: {e}")))?;
878        }
879    }
880
881    let runtime = tokio::runtime::Handle::try_current().ok();
882    while let Ok(event) = rx.recv() {
883        let Ok(event) = event else { continue };
884        // Only react to events touching our cert/key files specifically — the
885        // directory watcher fires for any sibling file too.
886        let touches_us = event.paths.iter().any(|p| p == &cert || p == &key);
887        if !touches_us {
888            continue;
889        }
890        tracing::info!(
891            cert = %cert.display(),
892            key = %key.display(),
893            "tls cert change detected — reloading"
894        );
895        let cert = cert.clone();
896        let key = key.clone();
897        let config = config.clone();
898        let reload = async move {
899            if let Err(e) = config.reload_from_pem_file(&cert, &key).await {
900                tracing::warn!(error = %e, "tls reload failed");
901            } else {
902                tracing::info!("tls cert reloaded successfully");
903            }
904        };
905        if let Some(rt) = &runtime {
906            rt.spawn(reload);
907        } else {
908            // No current tokio runtime (e.g. unit test contexts) — best-effort
909            // fire-and-forget via a fresh single-threaded runtime.
910            std::thread::spawn(|| {
911                let rt = tokio::runtime::Builder::new_current_thread()
912                    .enable_all()
913                    .build();
914                if let Ok(rt) = rt {
915                    rt.block_on(reload);
916                }
917            });
918        }
919    }
920    Ok(())
921}
922
923// ─── Rate limiter (Moka-backed token bucket) ────────────────────────────────
924
925pub struct RateLimiter {
926    /// `bucket key` → `(window_end_instant, count_remaining)`.
927    state: moka::sync::Cache<String, (Instant, u32)>,
928    default_rule: Option<RateRule>,
929    route_rules: Vec<(MatchKey, RateRule)>,
930}
931
932#[derive(Clone, Copy)]
933struct RateRule {
934    count: u32,
935    window: Duration,
936}
937
938#[derive(Clone)]
939struct MatchKey {
940    method: Option<Method>,
941    path: String,
942}
943
944impl RateLimiter {
945    pub fn from_config(cfg: &RateLimitConfig) -> Self {
946        let default_rule = cfg.per_ip.as_deref().and_then(|s| {
947            crate::server_config::parse_rate(s)
948                .map(|(count, window)| RateRule { count, window })
949                .ok()
950        });
951        let route_rules = cfg
952            .routes
953            .iter()
954            .filter_map(|(spec, rate)| {
955                let (count, window) = crate::server_config::parse_rate(rate).ok()?;
956                let (method, path) = parse_route_spec(spec);
957                Some((MatchKey { method, path }, RateRule { count, window }))
958            })
959            .collect();
960
961        Self {
962            state: moka::sync::Cache::builder()
963                .max_capacity(10_000)
964                .time_to_idle(Duration::from_secs(600))
965                .build(),
966            default_rule,
967            route_rules,
968        }
969    }
970
971    fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
972        for (key, rule) in &self.route_rules {
973            if key.path == path && key.method.as_ref().is_none_or(|m| m == method) {
974                return Some(*rule);
975            }
976        }
977        self.default_rule
978    }
979
980    fn check(&self, bucket: &str, rule: RateRule) -> bool {
981        let now = Instant::now();
982        let mut allowed = true;
983        self.state
984            .entry(bucket.to_string())
985            .and_compute_with(|existing| match existing {
986                Some(entry) => {
987                    let (window_end, count) = entry.into_value();
988                    if now >= window_end {
989                        moka::ops::compute::Op::Put((
990                            now + rule.window,
991                            rule.count.saturating_sub(1),
992                        ))
993                    } else if count > 0 {
994                        moka::ops::compute::Op::Put((window_end, count - 1))
995                    } else {
996                        allowed = false;
997                        moka::ops::compute::Op::Put((window_end, 0))
998                    }
999                }
1000                None => {
1001                    moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
1002                }
1003            });
1004        allowed
1005    }
1006}
1007
1008fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
1009    let trimmed = spec.trim();
1010    if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
1011        let method = m.parse::<Method>().ok();
1012        (method, p.trim().to_string())
1013    } else {
1014        (None, trimmed.to_string())
1015    }
1016}
1017
1018async fn rate_limit_mw(
1019    limiter: Arc<RateLimiter>,
1020    trusted: Arc<Vec<ipnet::IpNet>>,
1021    req: Request<Body>,
1022    next: Next,
1023) -> Response<Body> {
1024    let method = req.method().clone();
1025    let path = req.uri().path().to_string();
1026    let bucket = format!("{}|{}|{}", client_ip(&req, &trusted), method, path);
1027
1028    if let Some(rule) = limiter.rule_for(&method, &path) {
1029        if !limiter.check(&bucket, rule) {
1030            tracing::debug!(%method, %path, %bucket, "rate limited");
1031            let mut resp = Response::new(Body::from("rate limit exceeded"));
1032            *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
1033            return resp;
1034        }
1035    }
1036    next.run(req).await
1037}
1038
1039/// Resolve the client's IP, honoring `X-Forwarded-For` **only** when the
1040/// direct TCP peer is in `trusted` (set via `[trusted_proxies] ranges` in the
1041/// server config).
1042///
1043/// Order of precedence:
1044///   1. TCP peer (`ConnectInfo<SocketAddr>`) is in a trusted CIDR ⇒ take the
1045///      first hop in XFF (the original client per RFC 7239).
1046///   2. Otherwise ⇒ the TCP peer is the client. XFF from an untrusted peer
1047///      is ignored, so rate limits and IP rules can't be spoofed by hostile
1048///      clients setting their own XFF header.
1049///   3. No `ConnectInfo` extension (rare — only if the router was started
1050///      without `into_make_service_with_connect_info`) ⇒ `"unknown"`.
1051fn client_ip(req: &Request<Body>, trusted: &[ipnet::IpNet]) -> String {
1052    let peer: Option<std::net::IpAddr> = req
1053        .extensions()
1054        .get::<ConnectInfo<SocketAddr>>()
1055        .map(|ci| ci.0.ip());
1056
1057    let peer_trusted = match peer {
1058        Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
1059        None => false,
1060    };
1061
1062    if peer_trusted {
1063        if let Some(v) = req.headers().get("x-forwarded-for") {
1064            if let Ok(s) = v.to_str() {
1065                if let Some(first) = s.split(',').next() {
1066                    let candidate = first.trim();
1067                    if !candidate.is_empty() {
1068                        return candidate.to_string();
1069                    }
1070                }
1071            }
1072        }
1073    }
1074
1075    peer.map(|a| a.to_string())
1076        .unwrap_or_else(|| "unknown".into())
1077}
1078
1079// ─── Access log ─────────────────────────────────────────────────────────────
1080
1081/// Apply a per-route timeout to a request whose path matches one of the
1082/// configured prefixes. First match wins. Requests that don't match any
1083/// rule fall through to the global `limits.request_timeout` (if any).
1084async fn route_timeout_mw(
1085    rules: Arc<Vec<RouteTimeoutRule>>,
1086    req: Request<Body>,
1087    next: Next,
1088) -> Response<Body> {
1089    let path = req.uri().path().to_string();
1090    let matching = rules.iter().find(|r| path.starts_with(&r.prefix));
1091    match matching {
1092        Some(rule) => {
1093            let timeout = rule.timeout;
1094            match tokio::time::timeout(timeout, next.run(req)).await {
1095                Ok(resp) => resp,
1096                Err(_) => {
1097                    tracing::debug!(
1098                        path,
1099                        timeout_ms = timeout.as_millis(),
1100                        "route timeout exceeded"
1101                    );
1102                    let mut resp = Response::new(Body::from("request timed out"));
1103                    *resp.status_mut() = StatusCode::REQUEST_TIMEOUT;
1104                    resp
1105                }
1106            }
1107        }
1108        None => next.run(req).await,
1109    }
1110}
1111
1112/// Generate or pass through `x-request-id`. If the inbound request carries
1113/// one, reuse it (lets upstream LBs / clients drive the trace ID); otherwise
1114/// mint a UUID v7 (time-ordered, sortable). Always echoed on the response so
1115/// log entries can be correlated by the same id from either side.
1116async fn request_id_mw(mut req: Request<Body>, next: Next) -> Response<Body> {
1117    const HEADER: &str = "x-request-id";
1118
1119    let inbound = req
1120        .headers()
1121        .get(HEADER)
1122        .and_then(|v| v.to_str().ok())
1123        .map(|s| s.to_string());
1124
1125    let request_id = inbound.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
1126
1127    if let Ok(v) = HeaderValue::from_str(&request_id) {
1128        req.headers_mut()
1129            .insert(HeaderName::from_static(HEADER), v.clone());
1130        // Run the handler with the id available in the request extensions so
1131        // downstream code (and the trace layer) can read it without parsing
1132        // headers again.
1133        req.extensions_mut().insert(RequestId(request_id.clone()));
1134
1135        let mut resp = next.run(req).await;
1136        resp.headers_mut()
1137            .insert(HeaderName::from_static(HEADER), v);
1138        resp
1139    } else {
1140        next.run(req).await
1141    }
1142}
1143
1144/// Typed wrapper for the per-request id, stored in `Request::extensions`
1145/// by the request-id middleware. Handlers can pull it via
1146/// `req.extensions().get::<RequestId>()` (or via the axum
1147/// `Extension<RequestId>` extractor).
1148#[derive(Debug, Clone)]
1149pub struct RequestId(pub String);
1150
1151async fn access_log_mw(
1152    format: AccessLogFormat,
1153    trusted: Arc<Vec<ipnet::IpNet>>,
1154    req: Request<Body>,
1155    next: Next,
1156) -> Response<Body> {
1157    let started = Instant::now();
1158    let method = req.method().clone();
1159    let path = req.uri().path().to_string();
1160    let host = req
1161        .headers()
1162        .get("host")
1163        .and_then(|v| v.to_str().ok())
1164        .unwrap_or("-")
1165        .to_string();
1166    let referer = req
1167        .headers()
1168        .get("referer")
1169        .and_then(|v| v.to_str().ok())
1170        .map(String::from);
1171    let ua = req
1172        .headers()
1173        .get("user-agent")
1174        .and_then(|v| v.to_str().ok())
1175        .map(String::from);
1176    let ip = client_ip(&req, &trusted);
1177    let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
1178
1179    let resp = next.run(req).await;
1180    let elapsed = started.elapsed();
1181    let status = resp.status().as_u16();
1182    let bytes = response_size(resp.headers()).unwrap_or(0);
1183
1184    match format {
1185        AccessLogFormat::Combined => {
1186            tracing::info!(
1187                target: "access_log",
1188                "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
1189                ip,
1190                method,
1191                path,
1192                status,
1193                bytes,
1194                referer.as_deref().unwrap_or("-"),
1195                ua.as_deref().unwrap_or("-"),
1196                elapsed.as_millis(),
1197            );
1198        }
1199        AccessLogFormat::Json => {
1200            tracing::info!(
1201                target: "access_log",
1202                json = %serde_json::json!({
1203                    "ip": ip,
1204                    "method": method.as_str(),
1205                    "path": path,
1206                    "host": host,
1207                    "status": status,
1208                    "bytes": bytes,
1209                    "referer": referer,
1210                    "user_agent": ua,
1211                    "duration_ms": elapsed.as_millis(),
1212                    "request_id": request_id,
1213                }),
1214                "request"
1215            );
1216        }
1217        AccessLogFormat::Off => {}
1218    }
1219    resp
1220}
1221
1222fn response_size(headers: &HeaderMap) -> Option<u64> {
1223    headers
1224        .get("content-length")
1225        .and_then(|v| v.to_str().ok())
1226        .and_then(|s| s.parse().ok())
1227}
1228
1229// ─── Rewrites ───────────────────────────────────────────────────────────────
1230
1231#[derive(Clone)]
1232struct CompiledRewrite {
1233    pattern: regex::Regex,
1234    to: String,
1235    status: Option<u16>,
1236    match_query: bool,
1237}
1238
1239struct CompiledRewrites {
1240    rules: Vec<CompiledRewrite>,
1241}
1242
1243impl CompiledRewrites {
1244    fn compile(rules: &[RewriteRule]) -> Self {
1245        let compiled = rules
1246            .iter()
1247            .filter_map(|r| match regex::Regex::new(&r.from) {
1248                Ok(pattern) => Some(CompiledRewrite {
1249                    pattern,
1250                    to: r.to.clone(),
1251                    status: r.status,
1252                    match_query: r.match_query,
1253                }),
1254                Err(e) => {
1255                    tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
1256                    None
1257                }
1258            })
1259            .collect();
1260        Self { rules: compiled }
1261    }
1262}
1263
1264async fn rewrite_mw(
1265    rules: Arc<CompiledRewrites>,
1266    mut req: Request<Body>,
1267    next: Next,
1268) -> Response<Body> {
1269    let path = req.uri().path().to_string();
1270    let path_and_query = req
1271        .uri()
1272        .path_and_query()
1273        .map(|p| p.as_str().to_string())
1274        .unwrap_or_else(|| path.clone());
1275
1276    let target_str_path = path.clone();
1277    let target_str_full = path_and_query.clone();
1278
1279    // First pass: apply path-only rules.
1280    let mut applied: Option<(String, Option<u16>)> = None;
1281    for rule in &rules.rules {
1282        let subject = if rule.match_query {
1283            &target_str_full
1284        } else {
1285            &target_str_path
1286        };
1287        if rule.pattern.is_match(subject) {
1288            let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
1289            applied = Some((replaced, rule.status));
1290            break;
1291        }
1292    }
1293
1294    let Some((new_target, status)) = applied else {
1295        return next.run(req).await;
1296    };
1297
1298    match status {
1299        Some(code @ (301 | 302 | 303 | 307 | 308)) => {
1300            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1301            *resp.status_mut() =
1302                StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
1303            if let Ok(loc) = HeaderValue::from_str(&new_target) {
1304                resp.headers_mut().insert("location", loc);
1305            }
1306            resp
1307        }
1308        _ => {
1309            // Internal rewrite: replace the URI's path-and-query.
1310            let mut parts = req.uri().clone().into_parts();
1311            if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1312                parts.path_and_query = Some(new_pq);
1313            }
1314            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1315                *req.uri_mut() = new_uri;
1316            }
1317            next.run(req).await
1318        }
1319    }
1320}
1321
1322// ─── Trailing slash ─────────────────────────────────────────────────────────
1323
1324async fn trailing_slash_mw(
1325    cfg: TrailingSlashConfig,
1326    mut req: Request<Body>,
1327    next: Next,
1328) -> Response<Body> {
1329    let path = req.uri().path().to_string();
1330    if path == "/" {
1331        return next.run(req).await;
1332    }
1333
1334    let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
1335    let has_slash = path.ends_with('/');
1336
1337    if want_slash == has_slash {
1338        return next.run(req).await;
1339    }
1340
1341    let new_path = if want_slash {
1342        format!("{path}/")
1343    } else {
1344        path.trim_end_matches('/').to_string()
1345    };
1346
1347    let query = req
1348        .uri()
1349        .query()
1350        .map(|q| format!("?{q}"))
1351        .unwrap_or_default();
1352    let new_target = format!("{new_path}{query}");
1353
1354    match cfg.action {
1355        TrailingSlashAction::Redirect => {
1356            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1357            *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
1358            if let Ok(loc) = HeaderValue::from_str(&new_target) {
1359                resp.headers_mut().insert("location", loc);
1360            }
1361            resp
1362        }
1363        TrailingSlashAction::Rewrite => {
1364            let mut parts = req.uri().clone().into_parts();
1365            if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1366                parts.path_and_query = Some(pq);
1367            }
1368            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1369                *req.uri_mut() = new_uri;
1370            }
1371            next.run(req).await
1372        }
1373    }
1374}
1375
1376// ─── Custom error pages ─────────────────────────────────────────────────────
1377
1378struct LoadedErrorPages {
1379    by_status: std::collections::HashMap<u16, (String, &'static str)>,
1380}
1381
1382fn load_error_pages(
1383    raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
1384) -> LoadedErrorPages {
1385    let mut by_status = std::collections::HashMap::new();
1386    for (key, path) in raw {
1387        let Ok(code) = key.parse::<u16>() else {
1388            tracing::warn!(key, "error_pages: invalid status code, skipping");
1389            continue;
1390        };
1391        let body = match std::fs::read_to_string(path) {
1392            Ok(s) => s,
1393            Err(e) => {
1394                tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
1395                continue;
1396            }
1397        };
1398        let content_type = guess_content_type(path);
1399        by_status.insert(code, (body, content_type));
1400    }
1401    LoadedErrorPages { by_status }
1402}
1403
1404fn guess_content_type(path: &std::path::Path) -> &'static str {
1405    match path.extension().and_then(|e| e.to_str()) {
1406        Some("html") | Some("htm") => "text/html; charset=utf-8",
1407        Some("json") => "application/json",
1408        Some("txt") => "text/plain; charset=utf-8",
1409        _ => "text/plain; charset=utf-8",
1410    }
1411}
1412
1413async fn error_pages_mw(
1414    pages: Arc<LoadedErrorPages>,
1415    req: Request<Body>,
1416    next: Next,
1417) -> Response<Body> {
1418    let resp = next.run(req).await;
1419    let status = resp.status().as_u16();
1420
1421    let Some((body, ctype)) = pages.by_status.get(&status) else {
1422        return resp;
1423    };
1424
1425    let mut out = Response::new(Body::from(body.clone()));
1426    *out.status_mut() = resp.status();
1427    if let Ok(ct) = HeaderValue::from_str(ctype) {
1428        out.headers_mut().insert("content-type", ct);
1429    }
1430    // Preserve a couple of useful headers from the original response.
1431    for h in ["cache-control", "x-request-id"] {
1432        if let Some(v) = resp.headers().get(h) {
1433            out.headers_mut().insert(h, v.clone());
1434        }
1435    }
1436    out
1437}
1438
1439// ─── Reverse proxy ──────────────────────────────────────────────────────────
1440
1441#[derive(Clone)]
1442struct CompiledProxy {
1443    prefix: String,
1444    upstream: String,
1445    strip_prefix: bool,
1446    preserve_host: bool,
1447    timeout: Duration,
1448    retries: u8,
1449}
1450
1451struct CompiledProxies {
1452    rules: Vec<CompiledProxy>,
1453    client: reqwest::Client,
1454}
1455
1456impl CompiledProxies {
1457    fn compile(rules: &[ProxyRule]) -> Self {
1458        let mut compiled: Vec<CompiledProxy> = rules
1459            .iter()
1460            .map(|r| CompiledProxy {
1461                prefix: r.prefix.clone(),
1462                upstream: r.upstream.trim_end_matches('/').to_string(),
1463                strip_prefix: r.strip_prefix,
1464                preserve_host: r.preserve_host,
1465                timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
1466                retries: r.retries,
1467            })
1468            .collect();
1469        // Longest prefix first so `/api/v2/users` beats `/api`.
1470        compiled.sort_by_key(|r| std::cmp::Reverse(r.prefix.len()));
1471
1472        let client = reqwest::Client::builder()
1473            .redirect(reqwest::redirect::Policy::none())
1474            .build()
1475            .unwrap_or_else(|_| reqwest::Client::new());
1476
1477        Self {
1478            rules: compiled,
1479            client,
1480        }
1481    }
1482
1483    fn matching(&self, path: &str) -> Option<&CompiledProxy> {
1484        self.rules.iter().find(|r| path.starts_with(&r.prefix))
1485    }
1486}
1487
1488async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
1489    let path = req.uri().path().to_string();
1490    let Some(rule) = proxies.matching(&path) else {
1491        return next.run(req).await;
1492    };
1493    let rule = rule.clone();
1494
1495    match proxy_forward(&proxies.client, &rule, req).await {
1496        Ok(resp) => resp,
1497        Err(e) => {
1498            tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
1499            let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
1500            *resp.status_mut() = StatusCode::BAD_GATEWAY;
1501            resp
1502        }
1503    }
1504}
1505
1506async fn proxy_forward(
1507    client: &reqwest::Client,
1508    rule: &CompiledProxy,
1509    req: Request<Body>,
1510) -> Result<Response<Body>, String> {
1511    let (parts, body) = req.into_parts();
1512    let body_bytes = axum::body::to_bytes(body, usize::MAX)
1513        .await
1514        .map_err(|e| format!("body read: {e}"))?;
1515
1516    let original_path = parts.uri.path();
1517    let upstream_path = if rule.strip_prefix {
1518        original_path
1519            .strip_prefix(&rule.prefix)
1520            .unwrap_or(original_path)
1521    } else {
1522        original_path
1523    };
1524    let upstream_path = if upstream_path.is_empty() {
1525        "/"
1526    } else {
1527        upstream_path
1528    };
1529    let query = parts
1530        .uri
1531        .query()
1532        .map(|q| format!("?{q}"))
1533        .unwrap_or_default();
1534    let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
1535
1536    let method = parts.method.clone();
1537    let mut last_err = String::new();
1538    for attempt in 0..=rule.retries {
1539        let mut request = client
1540            .request(
1541                reqwest::Method::from_bytes(method.as_str().as_bytes())
1542                    .unwrap_or(reqwest::Method::GET),
1543                &upstream_url,
1544            )
1545            .timeout(rule.timeout)
1546            .body(body_bytes.clone());
1547
1548        for (name, value) in parts.headers.iter() {
1549            // Hop-by-hop headers per RFC 7230 §6.1 — skip.
1550            let n = name.as_str().to_ascii_lowercase();
1551            if matches!(
1552                n.as_str(),
1553                "connection"
1554                    | "keep-alive"
1555                    | "proxy-authenticate"
1556                    | "proxy-authorization"
1557                    | "te"
1558                    | "trailers"
1559                    | "transfer-encoding"
1560                    | "upgrade"
1561                    | "content-length"
1562            ) {
1563                continue;
1564            }
1565            if !rule.preserve_host && n == "host" {
1566                continue;
1567            }
1568            if let Ok(v) = value.to_str() {
1569                request = request.header(name.as_str(), v);
1570            }
1571        }
1572
1573        // X-Forwarded-* headers — useful for upstreams.
1574        if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1575            request = request.header("x-forwarded-host", host);
1576        }
1577        request = request.header("x-forwarded-proto", "http");
1578
1579        match request.send().await {
1580            Ok(resp) => return upstream_to_axum(resp).await,
1581            Err(e) => {
1582                last_err = format!("attempt {} → {e}", attempt + 1);
1583                tracing::debug!(error = %e, attempt, "proxy retry");
1584                continue;
1585            }
1586        }
1587    }
1588    Err(last_err)
1589}
1590
1591// ─── CORS ───────────────────────────────────────────────────────────────────
1592
1593async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1594    let origin = req
1595        .headers()
1596        .get("origin")
1597        .and_then(|v| v.to_str().ok())
1598        .map(String::from);
1599
1600    let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1601        cfg.allow_origins
1602            .iter()
1603            .any(|allowed| allowed == "*" || allowed == o)
1604    });
1605
1606    // Preflight
1607    if req.method() == Method::OPTIONS && origin.is_some() {
1608        let mut resp = Response::new(Body::empty());
1609        *resp.status_mut() = StatusCode::NO_CONTENT;
1610        apply_cors_headers(
1611            resp.headers_mut(),
1612            &cfg,
1613            origin.as_deref(),
1614            is_allowed_origin,
1615        );
1616        return resp;
1617    }
1618
1619    let mut resp = next.run(req).await;
1620    apply_cors_headers(
1621        resp.headers_mut(),
1622        &cfg,
1623        origin.as_deref(),
1624        is_allowed_origin,
1625    );
1626    resp
1627}
1628
1629fn apply_cors_headers(
1630    headers: &mut HeaderMap,
1631    cfg: &CorsConfig,
1632    origin: Option<&str>,
1633    is_allowed_origin: bool,
1634) {
1635    if !is_allowed_origin {
1636        return;
1637    }
1638    if let Some(origin) = origin {
1639        if let Ok(v) = HeaderValue::from_str(origin) {
1640            headers.insert("access-control-allow-origin", v);
1641        }
1642        headers.insert("vary", HeaderValue::from_static("Origin"));
1643    } else if cfg.allow_origins.iter().any(|o| o == "*") {
1644        headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1645    }
1646
1647    let methods = if cfg.allow_methods.is_empty() {
1648        "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1649    } else {
1650        cfg.allow_methods.join(", ")
1651    };
1652    if let Ok(v) = HeaderValue::from_str(&methods) {
1653        headers.insert("access-control-allow-methods", v);
1654    }
1655
1656    let allow_headers = if cfg.allow_headers.is_empty() {
1657        "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1658    } else {
1659        cfg.allow_headers.join(", ")
1660    };
1661    if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1662        headers.insert("access-control-allow-headers", v);
1663    }
1664
1665    if !cfg.expose_headers.is_empty() {
1666        if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1667            headers.insert("access-control-expose-headers", v);
1668        }
1669    }
1670
1671    if cfg.allow_credentials {
1672        headers.insert(
1673            "access-control-allow-credentials",
1674            HeaderValue::from_static("true"),
1675        );
1676    }
1677
1678    if let Some(max_age) = cfg.max_age {
1679        if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1680            headers.insert("access-control-max-age", v);
1681        }
1682    }
1683}
1684
1685// ─── IP allow/deny ──────────────────────────────────────────────────────────
1686
1687async fn ip_rules_mw(
1688    rules: Arc<Vec<IpRule>>,
1689    trusted: Arc<Vec<ipnet::IpNet>>,
1690    req: Request<Body>,
1691    next: Next,
1692) -> Response<Body> {
1693    let path = req.uri().path().to_string();
1694    let ip_str = client_ip(&req, &trusted);
1695    let ip = ip_str.parse::<std::net::IpAddr>().ok();
1696
1697    for rule in rules.iter() {
1698        if !path.starts_with(&rule.prefix) {
1699            continue;
1700        }
1701        let matches_range = ip
1702            .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1703            .unwrap_or(false);
1704        let allowed = match rule.action {
1705            IpAction::Allow => matches_range,
1706            IpAction::Deny => !matches_range,
1707        };
1708        if !allowed {
1709            tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1710            let mut resp = Response::new(Body::from("forbidden"));
1711            *resp.status_mut() = StatusCode::FORBIDDEN;
1712            return resp;
1713        }
1714        // First matching prefix wins.
1715        break;
1716    }
1717
1718    next.run(req).await
1719}
1720
1721// ─── HTTP Basic Auth ────────────────────────────────────────────────────────
1722
1723use base64::engine::general_purpose::STANDARD as B64;
1724use base64::Engine as _;
1725
1726struct CompiledBasicAuth {
1727    rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1728}
1729
1730fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1731    let compiled = rules
1732        .iter()
1733        .map(|r| {
1734            let creds = r
1735                .credentials
1736                .iter()
1737                .filter_map(|c| {
1738                    c.split_once(':')
1739                        .map(|(u, p)| (u.to_string(), p.to_string()))
1740                })
1741                .collect();
1742            (r.clone(), creds)
1743        })
1744        .collect();
1745    CompiledBasicAuth { rules: compiled }
1746}
1747
1748async fn basic_auth_mw(
1749    rules: Arc<CompiledBasicAuth>,
1750    req: Request<Body>,
1751    next: Next,
1752) -> Response<Body> {
1753    let path = req.uri().path().to_string();
1754    for (rule, creds) in &rules.rules {
1755        if !path.starts_with(&rule.prefix) {
1756            continue;
1757        }
1758        let supplied = req
1759            .headers()
1760            .get("authorization")
1761            .and_then(|v| v.to_str().ok())
1762            .and_then(|s| s.strip_prefix("Basic "))
1763            .and_then(|b64| B64.decode(b64).ok())
1764            .and_then(|bytes| String::from_utf8(bytes).ok())
1765            .and_then(|pair| {
1766                pair.split_once(':')
1767                    .map(|(u, p)| (u.to_string(), p.to_string()))
1768            });
1769
1770        let ok = supplied
1771            .as_ref()
1772            .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1773            .unwrap_or(false);
1774
1775        if ok {
1776            return next.run(req).await;
1777        }
1778
1779        let challenge = format!("Basic realm=\"{}\"", rule.realm);
1780        let mut resp = Response::new(Body::from("authentication required"));
1781        *resp.status_mut() = StatusCode::UNAUTHORIZED;
1782        if let Ok(v) = HeaderValue::from_str(&challenge) {
1783            resp.headers_mut().insert("www-authenticate", v);
1784        }
1785        return resp;
1786    }
1787    next.run(req).await
1788}
1789
1790async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1791    let status = resp.status();
1792    let headers = resp.headers().clone();
1793    let bytes = resp
1794        .bytes()
1795        .await
1796        .map_err(|e| format!("upstream body: {e}"))?;
1797    let mut out = Response::new(Body::from(bytes));
1798    *out.status_mut() =
1799        axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1800    for (name, value) in headers.iter() {
1801        let n = name.as_str().to_ascii_lowercase();
1802        if matches!(
1803            n.as_str(),
1804            "connection"
1805                | "keep-alive"
1806                | "proxy-authenticate"
1807                | "proxy-authorization"
1808                | "te"
1809                | "trailers"
1810                | "transfer-encoding"
1811                | "upgrade"
1812        ) {
1813            continue;
1814        }
1815        if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1816            if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1817                out.headers_mut().append(name, v);
1818            }
1819        }
1820    }
1821    Ok(out)
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826    use super::*;
1827    use std::net::{IpAddr, Ipv4Addr};
1828
1829    fn make_req(peer: Option<SocketAddr>, xff: Option<&str>) -> Request<Body> {
1830        let mut req = Request::builder();
1831        if let Some(v) = xff {
1832            req = req.header("x-forwarded-for", v);
1833        }
1834        let mut req = req.body(Body::empty()).unwrap();
1835        if let Some(addr) = peer {
1836            req.extensions_mut().insert(ConnectInfo(addr));
1837        }
1838        req
1839    }
1840
1841    fn cidr(s: &str) -> ipnet::IpNet {
1842        s.parse().unwrap()
1843    }
1844
1845    #[test]
1846    fn xff_ignored_when_peer_is_not_trusted() {
1847        // Hostile client direct-connects and sets their own XFF — must be ignored.
1848        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 12345);
1849        let req = make_req(Some(peer), Some("198.51.100.1"));
1850        let trusted = vec![cidr("10.0.0.0/8")];
1851
1852        assert_eq!(client_ip(&req, &trusted), "203.0.113.5");
1853    }
1854
1855    #[test]
1856    fn xff_honored_when_peer_is_a_trusted_proxy() {
1857        // The LB at 10.0.0.5 forwards the original client's IP.
1858        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1859        let req = make_req(Some(peer), Some("198.51.100.1, 10.0.0.5"));
1860        let trusted = vec![cidr("10.0.0.0/8")];
1861
1862        assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1863    }
1864
1865    #[test]
1866    fn empty_trusted_list_means_xff_is_never_honored() {
1867        // Even XFF from a private peer is ignored when no trusted ranges configured.
1868        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1869        let req = make_req(Some(peer), Some("198.51.100.1"));
1870        let trusted: Vec<ipnet::IpNet> = vec![];
1871
1872        assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1873    }
1874
1875    #[test]
1876    fn no_xff_falls_back_to_peer() {
1877        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1878        let req = make_req(Some(peer), None);
1879        let trusted = vec![cidr("10.0.0.0/8")];
1880
1881        assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1882    }
1883
1884    #[test]
1885    fn missing_connect_info_returns_unknown() {
1886        // Defensive: a router built without `into_make_service_with_connect_info`
1887        // won't have the extension; we shouldn't panic, and we shouldn't trust XFF.
1888        let req = make_req(None, Some("198.51.100.1"));
1889        let trusted = vec![cidr("10.0.0.0/8")];
1890
1891        assert_eq!(client_ip(&req, &trusted), "unknown");
1892    }
1893
1894    #[test]
1895    fn xff_with_whitespace_and_multiple_hops_picks_first() {
1896        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1897        let req = make_req(Some(peer), Some("  198.51.100.1 ,10.0.0.5, 10.0.0.7"));
1898        let trusted = vec![cidr("10.0.0.0/8")];
1899
1900        assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1901    }
1902}