Skip to main content

anvil_core/
server.rs

1//! Production HTTP serving: applies the `ServerConfig` to an `axum::Router` and
2//! starts it on the configured bind addr, with optional TLS via `axum-server`.
3
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22    AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23    RateLimitConfig, RewriteRule, RouteTimeoutRule, ServerConfig, StaticMount, TlsConfig,
24    TrailingSlashAction, TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27use axum::extract::ConnectInfo;
28
29/// Apply every layer the server config calls for to the user's web router,
30/// then merge any static-file mounts. Returns a ready-to-serve `axum::Router`.
31pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
32    let mut router = web;
33
34    // Static file mounts run BEFORE wrapping with body/timeout/compression — they
35    // serve from disk and don't need request body parsing.
36    for (prefix, mount) in &cfg.static_files {
37        router = mount_static(router, prefix, mount);
38    }
39
40    // Compose request-side layers.
41    let body_max = cfg.limits.body_max as usize;
42    router = router
43        .layer(RequestBodyLimitLayer::new(body_max))
44        .layer(TraceLayer::new_for_http());
45
46    if let Some(timeout) = cfg.limits.request_timeout {
47        router = router.layer(TimeoutLayer::new(timeout));
48    }
49
50    // L4-ish concurrency cap. Tower's `GlobalConcurrencyLimitLayer` holds a
51    // semaphore shared across every request; clones of the inner service
52    // share the permit pool. Above the cap, new requests are answered with
53    // 503 Service Unavailable (mapped from the `Overloaded` error) instead
54    // of queueing indefinitely — protects against thundering-herd overload
55    // and lets an upstream LB steer traffic to healthy peers.
56    if let Some(max) = cfg.limits.max_concurrency {
57        let max = max.max(1) as usize;
58        let limiter = Arc::new(tokio::sync::Semaphore::new(max));
59        let limiter_clone = limiter.clone();
60        router = router.layer(axum::middleware::from_fn(
61            move |req: Request<Body>, next: Next| {
62                let limiter = limiter_clone.clone();
63                async move {
64                    match limiter.try_acquire() {
65                        Ok(_permit) => next.run(req).await,
66                        Err(_) => {
67                            let mut resp = Response::new(Body::from(
68                                "service overloaded — too many concurrent requests",
69                            ));
70                            *resp.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
71                            resp
72                        }
73                    }
74                }
75            },
76        ));
77    }
78
79    // Per-route timeout overrides. Walk the request path against each
80    // configured prefix; first match wins. Applied BEFORE the global
81    // timeout layer above so that a slow upload endpoint can extend the
82    // window rather than fighting it.
83    if !cfg.route_timeouts.is_empty() {
84        let rules: Arc<Vec<RouteTimeoutRule>> = Arc::new(cfg.route_timeouts.clone());
85        let rules_clone = rules.clone();
86        router = router.layer(axum::middleware::from_fn(
87            move |req: Request<Body>, next: Next| {
88                let rules = rules_clone.clone();
89                async move { route_timeout_mw(rules, req, next).await }
90            },
91        ));
92    }
93
94    // Virtual-host gating: only accept requests whose Host header matches a
95    // configured `server_name`. Empty `server_name` = match-all.
96    if !cfg.server_name.is_empty() {
97        let allowed = cfg.server_name.clone();
98        router = router.layer(axum::middleware::from_fn(
99            move |req: Request<Body>, next: Next| {
100                let allowed = allowed.clone();
101                async move { host_match_mw(allowed, req, next).await }
102            },
103        ));
104    }
105
106    // Trusted-proxy CIDR list — captured once per `apply_layers` call so each
107    // middleware that reads the client IP can decide whether to honor
108    // X-Forwarded-For. Empty = ignore XFF entirely (safe default for direct-
109    // listen deployments).
110    let trusted: Arc<Vec<ipnet::IpNet>> = Arc::new(cfg.trusted_proxies.ranges.clone());
111
112    // IP allow/deny + basic auth — apply first so unauthorized requests don't
113    // touch any other layer.
114    if !cfg.ip_rules.is_empty() {
115        let rules = Arc::new(cfg.ip_rules.clone());
116        let rules_clone = rules.clone();
117        let trusted_clone = trusted.clone();
118        router = router.layer(axum::middleware::from_fn(
119            move |req: Request<Body>, next: Next| {
120                let rules = rules_clone.clone();
121                let trusted = trusted_clone.clone();
122                async move { ip_rules_mw(rules, trusted, req, next).await }
123            },
124        ));
125    }
126    if !cfg.basic_auth.is_empty() {
127        let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
128        let rules_clone = rules.clone();
129        router = router.layer(axum::middleware::from_fn(
130            move |req: Request<Body>, next: Next| {
131                let rules = rules_clone.clone();
132                async move { basic_auth_mw(rules, req, next).await }
133            },
134        ));
135    }
136
137    // CORS — apply early. tower-http's CorsLayer would be cleaner, but we want
138    // full TOML control without depending on tower-http's CORS feature spec.
139    if cfg.cors.enabled {
140        let cors = Arc::new(cfg.cors.clone());
141        let cors_clone = cors.clone();
142        router = router.layer(axum::middleware::from_fn(
143            move |req: Request<Body>, next: Next| {
144                let cors = cors_clone.clone();
145                async move { cors_mw(cors, req, next).await }
146            },
147        ));
148    }
149
150    // Reverse-proxy rules — apply BEFORE rewrites so the user can rewrite
151    // upstream-bound requests too.
152    if !cfg.proxies.is_empty() {
153        let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
154        let proxies_clone = proxies.clone();
155        router = router.layer(axum::middleware::from_fn(
156            move |req: Request<Body>, next: Next| {
157                let proxies = proxies_clone.clone();
158                async move { proxy_mw(proxies, req, next).await }
159            },
160        ));
161    }
162
163    // Rewrites — apply early so they see the request before other layers.
164    if !cfg.rewrites.is_empty() {
165        let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
166        let compiled_clone = compiled.clone();
167        router = router.layer(axum::middleware::from_fn(
168            move |req: Request<Body>, next: Next| {
169                let rules = compiled_clone.clone();
170                async move { rewrite_mw(rules, req, next).await }
171            },
172        ));
173    }
174
175    // Trailing-slash policy.
176    if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
177        let ts = cfg.trailing_slash.clone();
178        router = router.layer(axum::middleware::from_fn(
179            move |req: Request<Body>, next: Next| {
180                let ts = ts.clone();
181                async move { trailing_slash_mw(ts, req, next).await }
182            },
183        ));
184    }
185
186    // Custom error pages: intercept responses with matching status codes and
187    // substitute the configured file contents.
188    if !cfg.error_pages.is_empty() {
189        let pages = Arc::new(load_error_pages(&cfg.error_pages));
190        let pages_clone = pages.clone();
191        router = router.layer(axum::middleware::from_fn(
192            move |req: Request<Body>, next: Next| {
193                let pages = pages_clone.clone();
194                async move { error_pages_mw(pages, req, next).await }
195            },
196        ));
197    }
198
199    // HSTS header for HTTPS responses.
200    if cfg.tls.is_some() && cfg.hsts.enabled {
201        if let Some(value) = build_hsts_header(&cfg.hsts) {
202            router = router.layer(SetResponseHeaderLayer::if_not_present(
203                HeaderName::from_static("strict-transport-security"),
204                value,
205            ));
206        }
207    }
208
209    if cfg.compression.enabled {
210        // tower-http's `CompressionLayer` selects the encoding based on the
211        // client's `Accept-Encoding` header; we just toggle the layer on and
212        // gate via the min-size predicate. Per-algorithm disable lives on the
213        // un-parameterized layer, so we apply it before `compress_when`.
214        let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
215        let mut layer = CompressionLayer::new();
216        if !cfg
217            .compression
218            .algorithms
219            .iter()
220            .any(|a| a.eq_ignore_ascii_case("gzip"))
221        {
222            layer = layer.no_gzip();
223        }
224        if !cfg
225            .compression
226            .algorithms
227            .iter()
228            .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
229        {
230            layer = layer.no_br();
231        }
232        if !cfg
233            .compression
234            .algorithms
235            .iter()
236            .any(|a| a.eq_ignore_ascii_case("deflate"))
237        {
238            layer = layer.no_deflate();
239        }
240        let layer = layer.compress_when(SizeAbove::new(min_size));
241        router = router.layer(layer);
242    }
243
244    if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
245        let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
246        let limiter_clone = limiter.clone();
247        let trusted_clone = trusted.clone();
248        router = router.layer(axum::middleware::from_fn(
249            move |req: Request<Body>, next: Next| {
250                let limiter = limiter_clone.clone();
251                let trusted = trusted_clone.clone();
252                async move { rate_limit_mw(limiter, trusted, req, next).await }
253            },
254        ));
255    }
256
257    if matches!(
258        cfg.access_log.format,
259        AccessLogFormat::Combined | AccessLogFormat::Json
260    ) {
261        let format = cfg.access_log.format;
262        let trusted_clone = trusted.clone();
263        router = router.layer(axum::middleware::from_fn(
264            move |req: Request<Body>, next: Next| {
265                let trusted = trusted_clone.clone();
266                async move { access_log_mw(format, trusted, req, next).await }
267            },
268        ));
269    }
270
271    // Request ID — generated if the inbound request didn't carry one, echoed
272    // back on the response. Threads through `tracing` so each log line for
273    // a given request shares a `request_id` field, which is the basic
274    // building block of trace correlation across services.
275    router = router.layer(axum::middleware::from_fn(request_id_mw));
276
277    router
278}
279
280fn mount_static(
281    router: AxumRouter<Container>,
282    prefix: &str,
283    mount: &StaticMount,
284) -> AxumRouter<Container> {
285    // Note: `ranges` is reserved for a future version of tower-http that exposes
286    // per-instance range toggling. For now ranges are always enabled.
287    let _ = mount.ranges;
288
289    // If the app registered an embedded-asset fetcher for this prefix (the
290    // single-binary distribution path), serve from memory instead of disk.
291    if let Some(fetcher) = crate::embedded::lookup(prefix) {
292        let cache = mount.cache;
293        let route_pat = format!("{}/*path", prefix.trim_end_matches('/'));
294        let nested = AxumRouter::<Container>::new().route(
295            &route_pat,
296            axum::routing::get(
297                move |axum::extract::Path(path): axum::extract::Path<String>,
298                      headers: HeaderMap| async move {
299                    serve_embedded(fetcher, cache, &path, &headers)
300                },
301            ),
302        );
303        return router.merge(nested);
304    }
305
306    let svc = ServeDir::new(&mount.dir);
307
308    let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
309    let nested = if let Some(cache) = mount.cache {
310        let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
311            .unwrap_or_else(|_| HeaderValue::from_static("public"));
312        nested.layer(SetResponseHeaderLayer::if_not_present(
313            HeaderName::from_static("cache-control"),
314            value,
315        ))
316    } else {
317        nested
318    };
319    router.merge(nested)
320}
321
322/// Embedded-asset request handler: looks up the wildcard `path` in the
323/// registered fetcher, honors `If-None-Match` against the file's ETag, and
324/// returns 200/304/404.
325fn serve_embedded(
326    fetcher: crate::embedded::EmbeddedAssetFetcher,
327    cache: Option<Duration>,
328    path: &str,
329    headers: &HeaderMap,
330) -> Response<Body> {
331    let asset = match fetcher(path) {
332        Some(a) => a,
333        None => return not_found(),
334    };
335
336    if let (Some(client_tag), Some(asset_tag)) = (
337        headers
338            .get(axum::http::header::IF_NONE_MATCH)
339            .and_then(|v| v.to_str().ok()),
340        asset.etag.as_deref(),
341    ) {
342        if etag_matches(client_tag, asset_tag) {
343            let mut resp = Response::builder()
344                .status(StatusCode::NOT_MODIFIED)
345                .body(Body::empty())
346                .expect("304 body");
347            if let Some(d) = cache {
348                if let Ok(v) = HeaderValue::from_str(&format!("public, max-age={}", d.as_secs())) {
349                    resp.headers_mut().insert("cache-control", v);
350                }
351            }
352            return resp;
353        }
354    }
355
356    let mut builder = Response::builder()
357        .status(StatusCode::OK)
358        .header("content-type", asset.content_type.as_str())
359        .header("content-length", asset.data.len());
360    if let Some(tag) = asset.etag.as_deref() {
361        builder = builder.header("etag", quote_etag(tag));
362    }
363    if let Some(d) = cache {
364        builder = builder.header("cache-control", format!("public, max-age={}", d.as_secs()));
365    }
366    builder
367        .body(Body::from(asset.data.into_owned()))
368        .unwrap_or_else(|_| not_found())
369}
370
371fn not_found() -> Response<Body> {
372    Response::builder()
373        .status(StatusCode::NOT_FOUND)
374        .body(Body::from("not found"))
375        .expect("404 body")
376}
377
378fn quote_etag(raw: &str) -> String {
379    if raw.starts_with('"') {
380        raw.to_string()
381    } else {
382        format!("\"{raw}\"")
383    }
384}
385
386fn etag_matches(client: &str, server: &str) -> bool {
387    let normalize = |s: &str| -> String {
388        s.split(',')
389            .map(|t| {
390                t.trim()
391                    .trim_matches('"')
392                    .trim_start_matches("W/")
393                    .to_string()
394            })
395            .collect::<Vec<_>>()
396            .join(",")
397    };
398    let server_norm = normalize(server);
399    normalize(client)
400        .split(',')
401        .any(|tag| tag == server_norm || tag == "*")
402}
403
404// ─── Serve entry points ─────────────────────────────────────────────────────
405
406pub async fn serve(
407    router: AxumRouter,
408    cfg: &ServerConfig,
409    shutdown: tokio::sync::oneshot::Receiver<()>,
410) -> Result<(), Error> {
411    let addr: SocketAddr = cfg
412        .bind
413        .parse()
414        .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
415
416    tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
417
418    // If a redirect-HTTP listener is configured, spawn it alongside the main listener.
419    let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
420    let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
421    tokio::spawn(async move {
422        let _ = shutdown.await;
423        let _ = shutdown_main_tx.send(());
424        let _ = shutdown_redir_tx.send(());
425    });
426
427    let redirect_task = cfg.redirect_http.clone().map(|redir| {
428        let target_host = redir
429            .target_host
430            .clone()
431            .or_else(|| cfg.server_name.first().cloned());
432        let permanent = redir.permanent;
433        let bind = redir.bind.clone();
434        tokio::spawn(async move {
435            if let Err(e) =
436                serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
437            {
438                tracing::warn!(?e, "redirect_http listener exited with error");
439            }
440        })
441    });
442
443    let main_result = if let Some(tls) = &cfg.tls {
444        if tls.acme.is_some() {
445            serve_acme(
446                router,
447                addr,
448                tls,
449                cfg.limits.drain_timeout,
450                shutdown_main_rx,
451            )
452            .await
453        } else {
454            serve_tls(
455                router,
456                addr,
457                tls,
458                cfg.limits.drain_timeout,
459                shutdown_main_rx,
460            )
461            .await
462        }
463    } else {
464        serve_plain(router, addr, shutdown_main_rx).await
465    };
466
467    if let Some(task) = redirect_task {
468        task.abort();
469    }
470
471    main_result
472}
473
474/// Plain-HTTP listener that 30x-redirects every request to its `https://`
475/// equivalent. Used when TLS is on and `redirect_http` is configured.
476async fn serve_redirect_http(
477    bind: &str,
478    target_host: Option<String>,
479    permanent: bool,
480    shutdown: tokio::sync::oneshot::Receiver<()>,
481) -> Result<(), Error> {
482    let addr: SocketAddr = bind
483        .parse()
484        .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
485    tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
486
487    let target_host = Arc::new(target_host);
488    let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
489        let target_host = target_host.clone();
490        move |req: Request<Body>| {
491            let target_host = target_host.clone();
492            async move { http_redirect_handler(req, target_host, permanent).await }
493        }
494    }));
495
496    let listener = tokio::net::TcpListener::bind(addr).await?;
497    axum::serve(listener, router)
498        .with_graceful_shutdown(async move {
499            let _ = shutdown.await;
500        })
501        .await?;
502    Ok(())
503}
504
505async fn http_redirect_handler(
506    req: Request<Body>,
507    target_host: Arc<Option<String>>,
508    permanent: bool,
509) -> Response<Body> {
510    let host = target_host.as_ref().clone().unwrap_or_else(|| {
511        req.headers()
512            .get("host")
513            .and_then(|v| v.to_str().ok())
514            .map(String::from)
515            .unwrap_or_default()
516    });
517    let path_and_query = req
518        .uri()
519        .path_and_query()
520        .map(|p| p.as_str().to_string())
521        .unwrap_or_else(|| "/".to_string());
522    let location = format!("https://{host}{path_and_query}");
523
524    let status = if permanent {
525        StatusCode::MOVED_PERMANENTLY
526    } else {
527        StatusCode::FOUND
528    };
529    let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
530    *resp.status_mut() = status;
531    if let Ok(loc) = HeaderValue::from_str(&location) {
532        resp.headers_mut().insert("location", loc);
533    }
534    resp
535}
536
537fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
538    let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
539    let mut value = format!("max-age={}", max_age.as_secs());
540    if cfg.include_subdomains {
541        value.push_str("; includeSubDomains");
542    }
543    if cfg.preload {
544        value.push_str("; preload");
545    }
546    HeaderValue::from_str(&value).ok()
547}
548
549/// Reject requests whose Host header doesn't match any configured server_name.
550async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
551    let host = req
552        .headers()
553        .get("host")
554        .and_then(|v| v.to_str().ok())
555        .unwrap_or("")
556        .to_string();
557
558    // Strip port for matching: "example.com:8080" → "example.com".
559    let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
560
561    if matches_any(&host_no_port, &allowed) {
562        return next.run(req).await;
563    }
564
565    tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
566    let mut resp = Response::new(Body::from(format!(
567        "404 not found (unknown host: {host})\n"
568    )));
569    *resp.status_mut() = StatusCode::NOT_FOUND;
570    resp
571}
572
573fn matches_any(host: &str, patterns: &[String]) -> bool {
574    patterns.iter().any(|pat| matches_pattern(host, pat))
575}
576
577/// Match a host against a pattern. Supports exact match and `*.example.com`
578/// wildcards. The pattern is normalized to lowercase. A bare `*` matches any.
579fn matches_pattern(host: &str, pattern: &str) -> bool {
580    let pattern = pattern.to_ascii_lowercase();
581    if pattern == "*" {
582        return true;
583    }
584    if let Some(suffix) = pattern.strip_prefix("*.") {
585        // `*.foo.com` matches `bar.foo.com` but not `foo.com`.
586        return host.ends_with(&format!(".{suffix}"));
587    }
588    host == pattern
589}
590
591async fn serve_plain(
592    router: AxumRouter,
593    addr: SocketAddr,
594    shutdown: tokio::sync::oneshot::Receiver<()>,
595) -> Result<(), Error> {
596    let listener = tokio::net::TcpListener::bind(addr).await?;
597    axum::serve(
598        listener,
599        router.into_make_service_with_connect_info::<SocketAddr>(),
600    )
601    .with_graceful_shutdown(async move {
602        let _ = shutdown.await;
603    })
604    .await?;
605    Ok(())
606}
607
608async fn serve_tls(
609    router: AxumRouter,
610    addr: SocketAddr,
611    tls: &TlsConfig,
612    drain: Duration,
613    shutdown: tokio::sync::oneshot::Receiver<()>,
614) -> Result<(), Error> {
615    // Path A: single cert pair → fast happy path, no resolver indirection.
616    // Path B: `[[tls.certs]]` entries present → build a custom
617    // `ResolvesServerCert` that picks the cert by ClientHello SNI hostname,
618    // with the top-level cert as the default for unmatched names.
619    let config = if tls.additional_certs.is_empty() {
620        axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
621            .await
622            .map_err(|e| Error::Config(format!("tls load: {e}")))?
623    } else {
624        let resolver = build_sni_resolver(tls)
625            .map_err(|e| Error::Config(format!("tls multi-cert load: {e}")))?;
626        let server_config = rustls::ServerConfig::builder()
627            .with_no_client_auth()
628            .with_cert_resolver(Arc::new(resolver));
629        axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config))
630    };
631
632    // Cert hot-reload: spawn a notify watcher on the cert + key paths. On
633    // any change, re-read the PEM files and `reload_from_pem_file` on the
634    // shared RustlsConfig — new TLS handshakes pick up the new cert without
635    // a process restart, dropping the "swap cert → restart server" runbook
636    // ops normally need.
637    let watch_paths = [tls.cert.clone(), tls.key.clone()];
638    let config_for_watch = config.clone();
639    let cert_path = tls.cert.clone();
640    let key_path = tls.key.clone();
641    tokio::task::spawn_blocking(move || {
642        if let Err(e) = watch_tls_certs(config_for_watch, cert_path, key_path, watch_paths) {
643            tracing::warn!(error = %e, "cert hot-reload watcher exited");
644        }
645    });
646
647    let handle = axum_server::Handle::new();
648    let handle_for_shutdown = handle.clone();
649    tokio::spawn(async move {
650        let _ = shutdown.await;
651        handle_for_shutdown.graceful_shutdown(Some(drain));
652    });
653
654    axum_server::bind_rustls(addr, config)
655        .handle(handle)
656        .serve(router.into_make_service_with_connect_info::<SocketAddr>())
657        .await
658        .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
659    Ok(())
660}
661
662/// ACME-managed TLS serve path. Auto-obtains and rotates certs via Let's
663/// Encrypt (or any other ACME directory) using TLS-ALPN-01 in-process — no
664/// external certbot run required.
665///
666/// **Status:** the `[tls.acme]` schema parses today and apps written
667/// against it are forward-compatible. The runtime implementation is held
668/// back pending a focused PR that pins compatible `rustls-acme` /
669/// `rustls` / `axum-server` versions; the upstream `rustls-acme` 0.13
670/// release has build errors against the rustls version this workspace
671/// pins for the rest of TLS. Until that PR lands, ACME configs surface
672/// as a clear startup error rather than silently no-op'ing.
673async fn serve_acme(
674    _router: AxumRouter,
675    _addr: SocketAddr,
676    tls: &TlsConfig,
677    _drain: Duration,
678    _shutdown: tokio::sync::oneshot::Receiver<()>,
679) -> Result<(), Error> {
680    let acme = tls
681        .acme
682        .as_ref()
683        .expect("serve_acme called without [tls.acme]");
684    Err(Error::Config(format!(
685        "[tls.acme] is configured for {n} domain(s) but ACME runtime support \
686         is pending a follow-up PR (rustls-acme version pin). For now, use \
687         certbot in TLS-ALPN-01 mode and `[tls] cert`/`key` pointing at the \
688         certbot output; cert hot-reload handles renewals without restart.",
689        n = acme.domains.len(),
690    )))
691}
692
693/// SNI cert resolver — picks a `CertifiedKey` based on the ClientHello's
694/// SNI hostname. Falls back to the default cert when no entry matches.
695/// Matches `server_name` patterns the same way as the host-gating middleware:
696/// exact match, `*.example.com` wildcard for one-level subdomains, or `*` for
697/// any host.
698#[derive(Debug)]
699struct SniResolver {
700    /// Pre-compiled `(server_name, CertifiedKey)` pairs in declaration order.
701    /// First match wins, so put the most specific patterns first.
702    entries: Vec<(String, Arc<rustls::sign::CertifiedKey>)>,
703    default_key: Arc<rustls::sign::CertifiedKey>,
704}
705
706impl rustls::server::ResolvesServerCert for SniResolver {
707    fn resolve(
708        &self,
709        client_hello: rustls::server::ClientHello<'_>,
710    ) -> Option<Arc<rustls::sign::CertifiedKey>> {
711        let sni = client_hello
712            .server_name()
713            .unwrap_or("")
714            .to_ascii_lowercase();
715        for (pattern, key) in &self.entries {
716            if matches_pattern(&sni, pattern) {
717                return Some(key.clone());
718            }
719        }
720        Some(self.default_key.clone())
721    }
722}
723
724fn build_sni_resolver(tls: &TlsConfig) -> std::io::Result<SniResolver> {
725    let default_key = load_certified_key(&tls.cert, &tls.key)?;
726    let mut entries = Vec::with_capacity(tls.additional_certs.len());
727    for entry in &tls.additional_certs {
728        let key = load_certified_key(&entry.cert, &entry.key)?;
729        entries.push((entry.server_name.to_ascii_lowercase(), key));
730    }
731    tracing::info!(
732        default_cert = %tls.cert.display(),
733        additional = tls.additional_certs.len(),
734        "tls: SNI resolver active"
735    );
736    Ok(SniResolver {
737        entries,
738        default_key,
739    })
740}
741
742fn load_certified_key(
743    cert_path: &std::path::Path,
744    key_path: &std::path::Path,
745) -> std::io::Result<Arc<rustls::sign::CertifiedKey>> {
746    use std::io::BufReader;
747
748    let cert_file = std::fs::File::open(cert_path).map_err(|e| {
749        std::io::Error::new(
750            e.kind(),
751            format!("opening cert {}: {e}", cert_path.display()),
752        )
753    })?;
754    let mut cert_reader = BufReader::new(cert_file);
755    let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
756        rustls_pemfile::certs(&mut cert_reader).collect::<std::io::Result<_>>()?;
757    if certs.is_empty() {
758        return Err(std::io::Error::new(
759            std::io::ErrorKind::InvalidData,
760            format!("no certificates in {}", cert_path.display()),
761        ));
762    }
763
764    let key_file = std::fs::File::open(key_path).map_err(|e| {
765        std::io::Error::new(e.kind(), format!("opening key {}: {e}", key_path.display()))
766    })?;
767    let mut key_reader = BufReader::new(key_file);
768    let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
769        std::io::Error::new(
770            std::io::ErrorKind::InvalidData,
771            format!("no private key in {}", key_path.display()),
772        )
773    })?;
774
775    let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
776        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("sign: {e}")))?;
777
778    Ok(Arc::new(rustls::sign::CertifiedKey::new(
779        certs,
780        signing_key,
781    )))
782}
783
784/// Blocking-thread watcher that re-loads the rustls config whenever the cert
785/// or key file on disk changes. Uses `notify` (already a workspace dep for
786/// the dev file watcher) so we don't add anything new.
787///
788/// Coalescing: many editors write atomically by rename, which produces
789/// `Modify`+`Create` events in quick succession; we ignore that and just
790/// reload on any non-error event. The reload is itself cheap and idempotent.
791fn watch_tls_certs(
792    config: axum_server::tls_rustls::RustlsConfig,
793    cert: std::path::PathBuf,
794    key: std::path::PathBuf,
795    watch_paths: [std::path::PathBuf; 2],
796) -> std::io::Result<()> {
797    use notify::{RecursiveMode, Watcher};
798    use std::sync::mpsc::channel;
799
800    let (tx, rx) = channel::<notify::Result<notify::Event>>();
801    let mut watcher = notify::recommended_watcher(move |res| {
802        let _ = tx.send(res);
803    })
804    .map_err(|e| std::io::Error::other(format!("notify init: {e}")))?;
805
806    // Watch the parent directories — file-level watches don't survive
807    // editors that rename-on-write (vim, cargo, etc.), but directory watches
808    // catch the rename plus the new file's creation.
809    for p in &watch_paths {
810        if let Some(parent) = p.parent() {
811            watcher
812                .watch(parent, RecursiveMode::NonRecursive)
813                .map_err(|e| std::io::Error::other(format!("notify watch: {e}")))?;
814        }
815    }
816
817    let runtime = tokio::runtime::Handle::try_current().ok();
818    while let Ok(event) = rx.recv() {
819        let Ok(event) = event else { continue };
820        // Only react to events touching our cert/key files specifically — the
821        // directory watcher fires for any sibling file too.
822        let touches_us = event.paths.iter().any(|p| p == &cert || p == &key);
823        if !touches_us {
824            continue;
825        }
826        tracing::info!(
827            cert = %cert.display(),
828            key = %key.display(),
829            "tls cert change detected — reloading"
830        );
831        let cert = cert.clone();
832        let key = key.clone();
833        let config = config.clone();
834        let reload = async move {
835            if let Err(e) = config.reload_from_pem_file(&cert, &key).await {
836                tracing::warn!(error = %e, "tls reload failed");
837            } else {
838                tracing::info!("tls cert reloaded successfully");
839            }
840        };
841        if let Some(rt) = &runtime {
842            rt.spawn(reload);
843        } else {
844            // No current tokio runtime (e.g. unit test contexts) — best-effort
845            // fire-and-forget via a fresh single-threaded runtime.
846            std::thread::spawn(|| {
847                let rt = tokio::runtime::Builder::new_current_thread()
848                    .enable_all()
849                    .build();
850                if let Ok(rt) = rt {
851                    rt.block_on(reload);
852                }
853            });
854        }
855    }
856    Ok(())
857}
858
859// ─── Rate limiter (Moka-backed token bucket) ────────────────────────────────
860
861pub struct RateLimiter {
862    /// `bucket key` → `(window_end_instant, count_remaining)`.
863    state: moka::sync::Cache<String, (Instant, u32)>,
864    default_rule: Option<RateRule>,
865    route_rules: Vec<(MatchKey, RateRule)>,
866}
867
868#[derive(Clone, Copy)]
869struct RateRule {
870    count: u32,
871    window: Duration,
872}
873
874#[derive(Clone)]
875struct MatchKey {
876    method: Option<Method>,
877    path: String,
878}
879
880impl RateLimiter {
881    pub fn from_config(cfg: &RateLimitConfig) -> Self {
882        let default_rule = cfg.per_ip.as_deref().and_then(|s| {
883            crate::server_config::parse_rate(s)
884                .map(|(count, window)| RateRule { count, window })
885                .ok()
886        });
887        let route_rules = cfg
888            .routes
889            .iter()
890            .filter_map(|(spec, rate)| {
891                let (count, window) = crate::server_config::parse_rate(rate).ok()?;
892                let (method, path) = parse_route_spec(spec);
893                Some((MatchKey { method, path }, RateRule { count, window }))
894            })
895            .collect();
896
897        Self {
898            state: moka::sync::Cache::builder()
899                .max_capacity(10_000)
900                .time_to_idle(Duration::from_secs(600))
901                .build(),
902            default_rule,
903            route_rules,
904        }
905    }
906
907    fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
908        for (key, rule) in &self.route_rules {
909            if key.path == path && key.method.as_ref().is_none_or(|m| m == method) {
910                return Some(*rule);
911            }
912        }
913        self.default_rule
914    }
915
916    fn check(&self, bucket: &str, rule: RateRule) -> bool {
917        let now = Instant::now();
918        let mut allowed = true;
919        self.state
920            .entry(bucket.to_string())
921            .and_compute_with(|existing| match existing {
922                Some(entry) => {
923                    let (window_end, count) = entry.into_value();
924                    if now >= window_end {
925                        moka::ops::compute::Op::Put((
926                            now + rule.window,
927                            rule.count.saturating_sub(1),
928                        ))
929                    } else if count > 0 {
930                        moka::ops::compute::Op::Put((window_end, count - 1))
931                    } else {
932                        allowed = false;
933                        moka::ops::compute::Op::Put((window_end, 0))
934                    }
935                }
936                None => {
937                    moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
938                }
939            });
940        allowed
941    }
942}
943
944fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
945    let trimmed = spec.trim();
946    if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
947        let method = m.parse::<Method>().ok();
948        (method, p.trim().to_string())
949    } else {
950        (None, trimmed.to_string())
951    }
952}
953
954async fn rate_limit_mw(
955    limiter: Arc<RateLimiter>,
956    trusted: Arc<Vec<ipnet::IpNet>>,
957    req: Request<Body>,
958    next: Next,
959) -> Response<Body> {
960    let method = req.method().clone();
961    let path = req.uri().path().to_string();
962    let bucket = format!("{}|{}|{}", client_ip(&req, &trusted), method, path);
963
964    if let Some(rule) = limiter.rule_for(&method, &path) {
965        if !limiter.check(&bucket, rule) {
966            tracing::debug!(%method, %path, %bucket, "rate limited");
967            let mut resp = Response::new(Body::from("rate limit exceeded"));
968            *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
969            return resp;
970        }
971    }
972    next.run(req).await
973}
974
975/// Resolve the client's IP, honoring `X-Forwarded-For` **only** when the
976/// direct TCP peer is in `trusted` (set via `[trusted_proxies] ranges` in the
977/// server config).
978///
979/// Order of precedence:
980///   1. TCP peer (`ConnectInfo<SocketAddr>`) is in a trusted CIDR ⇒ take the
981///      first hop in XFF (the original client per RFC 7239).
982///   2. Otherwise ⇒ the TCP peer is the client. XFF from an untrusted peer
983///      is ignored, so rate limits and IP rules can't be spoofed by hostile
984///      clients setting their own XFF header.
985///   3. No `ConnectInfo` extension (rare — only if the router was started
986///      without `into_make_service_with_connect_info`) ⇒ `"unknown"`.
987fn client_ip(req: &Request<Body>, trusted: &[ipnet::IpNet]) -> String {
988    let peer: Option<std::net::IpAddr> = req
989        .extensions()
990        .get::<ConnectInfo<SocketAddr>>()
991        .map(|ci| ci.0.ip());
992
993    let peer_trusted = match peer {
994        Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
995        None => false,
996    };
997
998    if peer_trusted {
999        if let Some(v) = req.headers().get("x-forwarded-for") {
1000            if let Ok(s) = v.to_str() {
1001                if let Some(first) = s.split(',').next() {
1002                    let candidate = first.trim();
1003                    if !candidate.is_empty() {
1004                        return candidate.to_string();
1005                    }
1006                }
1007            }
1008        }
1009    }
1010
1011    peer.map(|a| a.to_string())
1012        .unwrap_or_else(|| "unknown".into())
1013}
1014
1015// ─── Access log ─────────────────────────────────────────────────────────────
1016
1017/// Apply a per-route timeout to a request whose path matches one of the
1018/// configured prefixes. First match wins. Requests that don't match any
1019/// rule fall through to the global `limits.request_timeout` (if any).
1020async fn route_timeout_mw(
1021    rules: Arc<Vec<RouteTimeoutRule>>,
1022    req: Request<Body>,
1023    next: Next,
1024) -> Response<Body> {
1025    let path = req.uri().path().to_string();
1026    let matching = rules.iter().find(|r| path.starts_with(&r.prefix));
1027    match matching {
1028        Some(rule) => {
1029            let timeout = rule.timeout;
1030            match tokio::time::timeout(timeout, next.run(req)).await {
1031                Ok(resp) => resp,
1032                Err(_) => {
1033                    tracing::debug!(
1034                        path,
1035                        timeout_ms = timeout.as_millis(),
1036                        "route timeout exceeded"
1037                    );
1038                    let mut resp = Response::new(Body::from("request timed out"));
1039                    *resp.status_mut() = StatusCode::REQUEST_TIMEOUT;
1040                    resp
1041                }
1042            }
1043        }
1044        None => next.run(req).await,
1045    }
1046}
1047
1048/// Generate or pass through `x-request-id`. If the inbound request carries
1049/// one, reuse it (lets upstream LBs / clients drive the trace ID); otherwise
1050/// mint a UUID v7 (time-ordered, sortable). Always echoed on the response so
1051/// log entries can be correlated by the same id from either side.
1052async fn request_id_mw(mut req: Request<Body>, next: Next) -> Response<Body> {
1053    const HEADER: &str = "x-request-id";
1054
1055    let inbound = req
1056        .headers()
1057        .get(HEADER)
1058        .and_then(|v| v.to_str().ok())
1059        .map(|s| s.to_string());
1060
1061    let request_id = inbound.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
1062
1063    if let Ok(v) = HeaderValue::from_str(&request_id) {
1064        req.headers_mut()
1065            .insert(HeaderName::from_static(HEADER), v.clone());
1066        // Run the handler with the id available in the request extensions so
1067        // downstream code (and the trace layer) can read it without parsing
1068        // headers again.
1069        req.extensions_mut().insert(RequestId(request_id.clone()));
1070
1071        let mut resp = next.run(req).await;
1072        resp.headers_mut()
1073            .insert(HeaderName::from_static(HEADER), v);
1074        resp
1075    } else {
1076        next.run(req).await
1077    }
1078}
1079
1080/// Typed wrapper for the per-request id, stored in `Request::extensions`
1081/// by the request-id middleware. Handlers can pull it via
1082/// `req.extensions().get::<RequestId>()` (or via the axum
1083/// `Extension<RequestId>` extractor).
1084#[derive(Debug, Clone)]
1085pub struct RequestId(pub String);
1086
1087async fn access_log_mw(
1088    format: AccessLogFormat,
1089    trusted: Arc<Vec<ipnet::IpNet>>,
1090    req: Request<Body>,
1091    next: Next,
1092) -> Response<Body> {
1093    let started = Instant::now();
1094    let method = req.method().clone();
1095    let path = req.uri().path().to_string();
1096    let host = req
1097        .headers()
1098        .get("host")
1099        .and_then(|v| v.to_str().ok())
1100        .unwrap_or("-")
1101        .to_string();
1102    let referer = req
1103        .headers()
1104        .get("referer")
1105        .and_then(|v| v.to_str().ok())
1106        .map(String::from);
1107    let ua = req
1108        .headers()
1109        .get("user-agent")
1110        .and_then(|v| v.to_str().ok())
1111        .map(String::from);
1112    let ip = client_ip(&req, &trusted);
1113    let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
1114
1115    let resp = next.run(req).await;
1116    let elapsed = started.elapsed();
1117    let status = resp.status().as_u16();
1118    let bytes = response_size(resp.headers()).unwrap_or(0);
1119
1120    match format {
1121        AccessLogFormat::Combined => {
1122            tracing::info!(
1123                target: "access_log",
1124                "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
1125                ip,
1126                method,
1127                path,
1128                status,
1129                bytes,
1130                referer.as_deref().unwrap_or("-"),
1131                ua.as_deref().unwrap_or("-"),
1132                elapsed.as_millis(),
1133            );
1134        }
1135        AccessLogFormat::Json => {
1136            tracing::info!(
1137                target: "access_log",
1138                json = %serde_json::json!({
1139                    "ip": ip,
1140                    "method": method.as_str(),
1141                    "path": path,
1142                    "host": host,
1143                    "status": status,
1144                    "bytes": bytes,
1145                    "referer": referer,
1146                    "user_agent": ua,
1147                    "duration_ms": elapsed.as_millis(),
1148                    "request_id": request_id,
1149                }),
1150                "request"
1151            );
1152        }
1153        AccessLogFormat::Off => {}
1154    }
1155    resp
1156}
1157
1158fn response_size(headers: &HeaderMap) -> Option<u64> {
1159    headers
1160        .get("content-length")
1161        .and_then(|v| v.to_str().ok())
1162        .and_then(|s| s.parse().ok())
1163}
1164
1165// ─── Rewrites ───────────────────────────────────────────────────────────────
1166
1167#[derive(Clone)]
1168struct CompiledRewrite {
1169    pattern: regex::Regex,
1170    to: String,
1171    status: Option<u16>,
1172    match_query: bool,
1173}
1174
1175struct CompiledRewrites {
1176    rules: Vec<CompiledRewrite>,
1177}
1178
1179impl CompiledRewrites {
1180    fn compile(rules: &[RewriteRule]) -> Self {
1181        let compiled = rules
1182            .iter()
1183            .filter_map(|r| match regex::Regex::new(&r.from) {
1184                Ok(pattern) => Some(CompiledRewrite {
1185                    pattern,
1186                    to: r.to.clone(),
1187                    status: r.status,
1188                    match_query: r.match_query,
1189                }),
1190                Err(e) => {
1191                    tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
1192                    None
1193                }
1194            })
1195            .collect();
1196        Self { rules: compiled }
1197    }
1198}
1199
1200async fn rewrite_mw(
1201    rules: Arc<CompiledRewrites>,
1202    mut req: Request<Body>,
1203    next: Next,
1204) -> Response<Body> {
1205    let path = req.uri().path().to_string();
1206    let path_and_query = req
1207        .uri()
1208        .path_and_query()
1209        .map(|p| p.as_str().to_string())
1210        .unwrap_or_else(|| path.clone());
1211
1212    let target_str_path = path.clone();
1213    let target_str_full = path_and_query.clone();
1214
1215    // First pass: apply path-only rules.
1216    let mut applied: Option<(String, Option<u16>)> = None;
1217    for rule in &rules.rules {
1218        let subject = if rule.match_query {
1219            &target_str_full
1220        } else {
1221            &target_str_path
1222        };
1223        if rule.pattern.is_match(subject) {
1224            let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
1225            applied = Some((replaced, rule.status));
1226            break;
1227        }
1228    }
1229
1230    let Some((new_target, status)) = applied else {
1231        return next.run(req).await;
1232    };
1233
1234    match status {
1235        Some(code @ (301 | 302 | 303 | 307 | 308)) => {
1236            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1237            *resp.status_mut() =
1238                StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
1239            if let Ok(loc) = HeaderValue::from_str(&new_target) {
1240                resp.headers_mut().insert("location", loc);
1241            }
1242            resp
1243        }
1244        _ => {
1245            // Internal rewrite: replace the URI's path-and-query.
1246            let mut parts = req.uri().clone().into_parts();
1247            if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1248                parts.path_and_query = Some(new_pq);
1249            }
1250            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1251                *req.uri_mut() = new_uri;
1252            }
1253            next.run(req).await
1254        }
1255    }
1256}
1257
1258// ─── Trailing slash ─────────────────────────────────────────────────────────
1259
1260async fn trailing_slash_mw(
1261    cfg: TrailingSlashConfig,
1262    mut req: Request<Body>,
1263    next: Next,
1264) -> Response<Body> {
1265    let path = req.uri().path().to_string();
1266    if path == "/" {
1267        return next.run(req).await;
1268    }
1269
1270    let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
1271    let has_slash = path.ends_with('/');
1272
1273    if want_slash == has_slash {
1274        return next.run(req).await;
1275    }
1276
1277    let new_path = if want_slash {
1278        format!("{path}/")
1279    } else {
1280        path.trim_end_matches('/').to_string()
1281    };
1282
1283    let query = req
1284        .uri()
1285        .query()
1286        .map(|q| format!("?{q}"))
1287        .unwrap_or_default();
1288    let new_target = format!("{new_path}{query}");
1289
1290    match cfg.action {
1291        TrailingSlashAction::Redirect => {
1292            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1293            *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
1294            if let Ok(loc) = HeaderValue::from_str(&new_target) {
1295                resp.headers_mut().insert("location", loc);
1296            }
1297            resp
1298        }
1299        TrailingSlashAction::Rewrite => {
1300            let mut parts = req.uri().clone().into_parts();
1301            if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1302                parts.path_and_query = Some(pq);
1303            }
1304            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1305                *req.uri_mut() = new_uri;
1306            }
1307            next.run(req).await
1308        }
1309    }
1310}
1311
1312// ─── Custom error pages ─────────────────────────────────────────────────────
1313
1314struct LoadedErrorPages {
1315    by_status: std::collections::HashMap<u16, (String, &'static str)>,
1316}
1317
1318fn load_error_pages(
1319    raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
1320) -> LoadedErrorPages {
1321    let mut by_status = std::collections::HashMap::new();
1322    for (key, path) in raw {
1323        let Ok(code) = key.parse::<u16>() else {
1324            tracing::warn!(key, "error_pages: invalid status code, skipping");
1325            continue;
1326        };
1327        let body = match std::fs::read_to_string(path) {
1328            Ok(s) => s,
1329            Err(e) => {
1330                tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
1331                continue;
1332            }
1333        };
1334        let content_type = guess_content_type(path);
1335        by_status.insert(code, (body, content_type));
1336    }
1337    LoadedErrorPages { by_status }
1338}
1339
1340fn guess_content_type(path: &std::path::Path) -> &'static str {
1341    match path.extension().and_then(|e| e.to_str()) {
1342        Some("html") | Some("htm") => "text/html; charset=utf-8",
1343        Some("json") => "application/json",
1344        Some("txt") => "text/plain; charset=utf-8",
1345        _ => "text/plain; charset=utf-8",
1346    }
1347}
1348
1349async fn error_pages_mw(
1350    pages: Arc<LoadedErrorPages>,
1351    req: Request<Body>,
1352    next: Next,
1353) -> Response<Body> {
1354    let resp = next.run(req).await;
1355    let status = resp.status().as_u16();
1356
1357    let Some((body, ctype)) = pages.by_status.get(&status) else {
1358        return resp;
1359    };
1360
1361    let mut out = Response::new(Body::from(body.clone()));
1362    *out.status_mut() = resp.status();
1363    if let Ok(ct) = HeaderValue::from_str(ctype) {
1364        out.headers_mut().insert("content-type", ct);
1365    }
1366    // Preserve a couple of useful headers from the original response.
1367    for h in ["cache-control", "x-request-id"] {
1368        if let Some(v) = resp.headers().get(h) {
1369            out.headers_mut().insert(h, v.clone());
1370        }
1371    }
1372    out
1373}
1374
1375// ─── Reverse proxy ──────────────────────────────────────────────────────────
1376
1377#[derive(Clone)]
1378struct CompiledProxy {
1379    prefix: String,
1380    upstream: String,
1381    strip_prefix: bool,
1382    preserve_host: bool,
1383    timeout: Duration,
1384    retries: u8,
1385}
1386
1387struct CompiledProxies {
1388    rules: Vec<CompiledProxy>,
1389    client: reqwest::Client,
1390}
1391
1392impl CompiledProxies {
1393    fn compile(rules: &[ProxyRule]) -> Self {
1394        let mut compiled: Vec<CompiledProxy> = rules
1395            .iter()
1396            .map(|r| CompiledProxy {
1397                prefix: r.prefix.clone(),
1398                upstream: r.upstream.trim_end_matches('/').to_string(),
1399                strip_prefix: r.strip_prefix,
1400                preserve_host: r.preserve_host,
1401                timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
1402                retries: r.retries,
1403            })
1404            .collect();
1405        // Longest prefix first so `/api/v2/users` beats `/api`.
1406        compiled.sort_by_key(|r| std::cmp::Reverse(r.prefix.len()));
1407
1408        let client = reqwest::Client::builder()
1409            .redirect(reqwest::redirect::Policy::none())
1410            .build()
1411            .unwrap_or_else(|_| reqwest::Client::new());
1412
1413        Self {
1414            rules: compiled,
1415            client,
1416        }
1417    }
1418
1419    fn matching(&self, path: &str) -> Option<&CompiledProxy> {
1420        self.rules.iter().find(|r| path.starts_with(&r.prefix))
1421    }
1422}
1423
1424async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
1425    let path = req.uri().path().to_string();
1426    let Some(rule) = proxies.matching(&path) else {
1427        return next.run(req).await;
1428    };
1429    let rule = rule.clone();
1430
1431    match proxy_forward(&proxies.client, &rule, req).await {
1432        Ok(resp) => resp,
1433        Err(e) => {
1434            tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
1435            let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
1436            *resp.status_mut() = StatusCode::BAD_GATEWAY;
1437            resp
1438        }
1439    }
1440}
1441
1442async fn proxy_forward(
1443    client: &reqwest::Client,
1444    rule: &CompiledProxy,
1445    req: Request<Body>,
1446) -> Result<Response<Body>, String> {
1447    let (parts, body) = req.into_parts();
1448    let body_bytes = axum::body::to_bytes(body, usize::MAX)
1449        .await
1450        .map_err(|e| format!("body read: {e}"))?;
1451
1452    let original_path = parts.uri.path();
1453    let upstream_path = if rule.strip_prefix {
1454        original_path
1455            .strip_prefix(&rule.prefix)
1456            .unwrap_or(original_path)
1457    } else {
1458        original_path
1459    };
1460    let upstream_path = if upstream_path.is_empty() {
1461        "/"
1462    } else {
1463        upstream_path
1464    };
1465    let query = parts
1466        .uri
1467        .query()
1468        .map(|q| format!("?{q}"))
1469        .unwrap_or_default();
1470    let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
1471
1472    let method = parts.method.clone();
1473    let mut last_err = String::new();
1474    for attempt in 0..=rule.retries {
1475        let mut request = client
1476            .request(
1477                reqwest::Method::from_bytes(method.as_str().as_bytes())
1478                    .unwrap_or(reqwest::Method::GET),
1479                &upstream_url,
1480            )
1481            .timeout(rule.timeout)
1482            .body(body_bytes.clone());
1483
1484        for (name, value) in parts.headers.iter() {
1485            // Hop-by-hop headers per RFC 7230 §6.1 — skip.
1486            let n = name.as_str().to_ascii_lowercase();
1487            if matches!(
1488                n.as_str(),
1489                "connection"
1490                    | "keep-alive"
1491                    | "proxy-authenticate"
1492                    | "proxy-authorization"
1493                    | "te"
1494                    | "trailers"
1495                    | "transfer-encoding"
1496                    | "upgrade"
1497                    | "content-length"
1498            ) {
1499                continue;
1500            }
1501            if !rule.preserve_host && n == "host" {
1502                continue;
1503            }
1504            if let Ok(v) = value.to_str() {
1505                request = request.header(name.as_str(), v);
1506            }
1507        }
1508
1509        // X-Forwarded-* headers — useful for upstreams.
1510        if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1511            request = request.header("x-forwarded-host", host);
1512        }
1513        request = request.header("x-forwarded-proto", "http");
1514
1515        match request.send().await {
1516            Ok(resp) => return upstream_to_axum(resp).await,
1517            Err(e) => {
1518                last_err = format!("attempt {} → {e}", attempt + 1);
1519                tracing::debug!(error = %e, attempt, "proxy retry");
1520                continue;
1521            }
1522        }
1523    }
1524    Err(last_err)
1525}
1526
1527// ─── CORS ───────────────────────────────────────────────────────────────────
1528
1529async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1530    let origin = req
1531        .headers()
1532        .get("origin")
1533        .and_then(|v| v.to_str().ok())
1534        .map(String::from);
1535
1536    let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1537        cfg.allow_origins
1538            .iter()
1539            .any(|allowed| allowed == "*" || allowed == o)
1540    });
1541
1542    // Preflight
1543    if req.method() == Method::OPTIONS && origin.is_some() {
1544        let mut resp = Response::new(Body::empty());
1545        *resp.status_mut() = StatusCode::NO_CONTENT;
1546        apply_cors_headers(
1547            resp.headers_mut(),
1548            &cfg,
1549            origin.as_deref(),
1550            is_allowed_origin,
1551        );
1552        return resp;
1553    }
1554
1555    let mut resp = next.run(req).await;
1556    apply_cors_headers(
1557        resp.headers_mut(),
1558        &cfg,
1559        origin.as_deref(),
1560        is_allowed_origin,
1561    );
1562    resp
1563}
1564
1565fn apply_cors_headers(
1566    headers: &mut HeaderMap,
1567    cfg: &CorsConfig,
1568    origin: Option<&str>,
1569    is_allowed_origin: bool,
1570) {
1571    if !is_allowed_origin {
1572        return;
1573    }
1574    if let Some(origin) = origin {
1575        if let Ok(v) = HeaderValue::from_str(origin) {
1576            headers.insert("access-control-allow-origin", v);
1577        }
1578        headers.insert("vary", HeaderValue::from_static("Origin"));
1579    } else if cfg.allow_origins.iter().any(|o| o == "*") {
1580        headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1581    }
1582
1583    let methods = if cfg.allow_methods.is_empty() {
1584        "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1585    } else {
1586        cfg.allow_methods.join(", ")
1587    };
1588    if let Ok(v) = HeaderValue::from_str(&methods) {
1589        headers.insert("access-control-allow-methods", v);
1590    }
1591
1592    let allow_headers = if cfg.allow_headers.is_empty() {
1593        "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1594    } else {
1595        cfg.allow_headers.join(", ")
1596    };
1597    if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1598        headers.insert("access-control-allow-headers", v);
1599    }
1600
1601    if !cfg.expose_headers.is_empty() {
1602        if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1603            headers.insert("access-control-expose-headers", v);
1604        }
1605    }
1606
1607    if cfg.allow_credentials {
1608        headers.insert(
1609            "access-control-allow-credentials",
1610            HeaderValue::from_static("true"),
1611        );
1612    }
1613
1614    if let Some(max_age) = cfg.max_age {
1615        if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1616            headers.insert("access-control-max-age", v);
1617        }
1618    }
1619}
1620
1621// ─── IP allow/deny ──────────────────────────────────────────────────────────
1622
1623async fn ip_rules_mw(
1624    rules: Arc<Vec<IpRule>>,
1625    trusted: Arc<Vec<ipnet::IpNet>>,
1626    req: Request<Body>,
1627    next: Next,
1628) -> Response<Body> {
1629    let path = req.uri().path().to_string();
1630    let ip_str = client_ip(&req, &trusted);
1631    let ip = ip_str.parse::<std::net::IpAddr>().ok();
1632
1633    for rule in rules.iter() {
1634        if !path.starts_with(&rule.prefix) {
1635            continue;
1636        }
1637        let matches_range = ip
1638            .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1639            .unwrap_or(false);
1640        let allowed = match rule.action {
1641            IpAction::Allow => matches_range,
1642            IpAction::Deny => !matches_range,
1643        };
1644        if !allowed {
1645            tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1646            let mut resp = Response::new(Body::from("forbidden"));
1647            *resp.status_mut() = StatusCode::FORBIDDEN;
1648            return resp;
1649        }
1650        // First matching prefix wins.
1651        break;
1652    }
1653
1654    next.run(req).await
1655}
1656
1657// ─── HTTP Basic Auth ────────────────────────────────────────────────────────
1658
1659use base64::engine::general_purpose::STANDARD as B64;
1660use base64::Engine as _;
1661
1662struct CompiledBasicAuth {
1663    rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1664}
1665
1666fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1667    let compiled = rules
1668        .iter()
1669        .map(|r| {
1670            let creds = r
1671                .credentials
1672                .iter()
1673                .filter_map(|c| {
1674                    c.split_once(':')
1675                        .map(|(u, p)| (u.to_string(), p.to_string()))
1676                })
1677                .collect();
1678            (r.clone(), creds)
1679        })
1680        .collect();
1681    CompiledBasicAuth { rules: compiled }
1682}
1683
1684async fn basic_auth_mw(
1685    rules: Arc<CompiledBasicAuth>,
1686    req: Request<Body>,
1687    next: Next,
1688) -> Response<Body> {
1689    let path = req.uri().path().to_string();
1690    for (rule, creds) in &rules.rules {
1691        if !path.starts_with(&rule.prefix) {
1692            continue;
1693        }
1694        let supplied = req
1695            .headers()
1696            .get("authorization")
1697            .and_then(|v| v.to_str().ok())
1698            .and_then(|s| s.strip_prefix("Basic "))
1699            .and_then(|b64| B64.decode(b64).ok())
1700            .and_then(|bytes| String::from_utf8(bytes).ok())
1701            .and_then(|pair| {
1702                pair.split_once(':')
1703                    .map(|(u, p)| (u.to_string(), p.to_string()))
1704            });
1705
1706        let ok = supplied
1707            .as_ref()
1708            .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1709            .unwrap_or(false);
1710
1711        if ok {
1712            return next.run(req).await;
1713        }
1714
1715        let challenge = format!("Basic realm=\"{}\"", rule.realm);
1716        let mut resp = Response::new(Body::from("authentication required"));
1717        *resp.status_mut() = StatusCode::UNAUTHORIZED;
1718        if let Ok(v) = HeaderValue::from_str(&challenge) {
1719            resp.headers_mut().insert("www-authenticate", v);
1720        }
1721        return resp;
1722    }
1723    next.run(req).await
1724}
1725
1726async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1727    let status = resp.status();
1728    let headers = resp.headers().clone();
1729    let bytes = resp
1730        .bytes()
1731        .await
1732        .map_err(|e| format!("upstream body: {e}"))?;
1733    let mut out = Response::new(Body::from(bytes));
1734    *out.status_mut() =
1735        axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1736    for (name, value) in headers.iter() {
1737        let n = name.as_str().to_ascii_lowercase();
1738        if matches!(
1739            n.as_str(),
1740            "connection"
1741                | "keep-alive"
1742                | "proxy-authenticate"
1743                | "proxy-authorization"
1744                | "te"
1745                | "trailers"
1746                | "transfer-encoding"
1747                | "upgrade"
1748        ) {
1749            continue;
1750        }
1751        if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1752            if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1753                out.headers_mut().append(name, v);
1754            }
1755        }
1756    }
1757    Ok(out)
1758}
1759
1760#[cfg(test)]
1761mod tests {
1762    use super::*;
1763    use std::net::{IpAddr, Ipv4Addr};
1764
1765    fn make_req(peer: Option<SocketAddr>, xff: Option<&str>) -> Request<Body> {
1766        let mut req = Request::builder();
1767        if let Some(v) = xff {
1768            req = req.header("x-forwarded-for", v);
1769        }
1770        let mut req = req.body(Body::empty()).unwrap();
1771        if let Some(addr) = peer {
1772            req.extensions_mut().insert(ConnectInfo(addr));
1773        }
1774        req
1775    }
1776
1777    fn cidr(s: &str) -> ipnet::IpNet {
1778        s.parse().unwrap()
1779    }
1780
1781    #[test]
1782    fn xff_ignored_when_peer_is_not_trusted() {
1783        // Hostile client direct-connects and sets their own XFF — must be ignored.
1784        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 12345);
1785        let req = make_req(Some(peer), Some("198.51.100.1"));
1786        let trusted = vec![cidr("10.0.0.0/8")];
1787
1788        assert_eq!(client_ip(&req, &trusted), "203.0.113.5");
1789    }
1790
1791    #[test]
1792    fn xff_honored_when_peer_is_a_trusted_proxy() {
1793        // The LB at 10.0.0.5 forwards the original client's IP.
1794        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1795        let req = make_req(Some(peer), Some("198.51.100.1, 10.0.0.5"));
1796        let trusted = vec![cidr("10.0.0.0/8")];
1797
1798        assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1799    }
1800
1801    #[test]
1802    fn empty_trusted_list_means_xff_is_never_honored() {
1803        // Even XFF from a private peer is ignored when no trusted ranges configured.
1804        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1805        let req = make_req(Some(peer), Some("198.51.100.1"));
1806        let trusted: Vec<ipnet::IpNet> = vec![];
1807
1808        assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1809    }
1810
1811    #[test]
1812    fn no_xff_falls_back_to_peer() {
1813        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1814        let req = make_req(Some(peer), None);
1815        let trusted = vec![cidr("10.0.0.0/8")];
1816
1817        assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1818    }
1819
1820    #[test]
1821    fn missing_connect_info_returns_unknown() {
1822        // Defensive: a router built without `into_make_service_with_connect_info`
1823        // won't have the extension; we shouldn't panic, and we shouldn't trust XFF.
1824        let req = make_req(None, Some("198.51.100.1"));
1825        let trusted = vec![cidr("10.0.0.0/8")];
1826
1827        assert_eq!(client_ip(&req, &trusted), "unknown");
1828    }
1829
1830    #[test]
1831    fn xff_with_whitespace_and_multiple_hops_picks_first() {
1832        let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1833        let req = make_req(Some(peer), Some("  198.51.100.1 ,10.0.0.5, 10.0.0.7"));
1834        let trusted = vec![cidr("10.0.0.0/8")];
1835
1836        assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1837    }
1838}