Skip to main content

anvil_core/
server.rs

1//! Production HTTP serving: applies the `ServerConfig` to an `axum::Router` and
2//! starts it on the configured bind addr, with optional TLS via `axum-server`.
3
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22    AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23    RateLimitConfig, RewriteRule, ServerConfig, StaticMount, TlsConfig, TrailingSlashAction,
24    TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27
28/// Apply every layer the server config calls for to the user's web router,
29/// then merge any static-file mounts. Returns a ready-to-serve `axum::Router`.
30pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
31    let mut router = web;
32
33    // Static file mounts run BEFORE wrapping with body/timeout/compression — they
34    // serve from disk and don't need request body parsing.
35    for (prefix, mount) in &cfg.static_files {
36        router = mount_static(router, prefix, mount);
37    }
38
39    // Compose request-side layers.
40    let body_max = cfg.limits.body_max as usize;
41    router = router
42        .layer(RequestBodyLimitLayer::new(body_max))
43        .layer(TraceLayer::new_for_http());
44
45    if let Some(timeout) = cfg.limits.request_timeout {
46        router = router.layer(TimeoutLayer::new(timeout));
47    }
48
49    // Virtual-host gating: only accept requests whose Host header matches a
50    // configured `server_name`. Empty `server_name` = match-all.
51    if !cfg.server_name.is_empty() {
52        let allowed = cfg.server_name.clone();
53        router = router.layer(axum::middleware::from_fn(
54            move |req: Request<Body>, next: Next| {
55                let allowed = allowed.clone();
56                async move { host_match_mw(allowed, req, next).await }
57            },
58        ));
59    }
60
61    // IP allow/deny + basic auth — apply first so unauthorized requests don't
62    // touch any other layer.
63    if !cfg.ip_rules.is_empty() {
64        let rules = Arc::new(cfg.ip_rules.clone());
65        let rules_clone = rules.clone();
66        router = router.layer(axum::middleware::from_fn(
67            move |req: Request<Body>, next: Next| {
68                let rules = rules_clone.clone();
69                async move { ip_rules_mw(rules, req, next).await }
70            },
71        ));
72    }
73    if !cfg.basic_auth.is_empty() {
74        let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
75        let rules_clone = rules.clone();
76        router = router.layer(axum::middleware::from_fn(
77            move |req: Request<Body>, next: Next| {
78                let rules = rules_clone.clone();
79                async move { basic_auth_mw(rules, req, next).await }
80            },
81        ));
82    }
83
84    // CORS — apply early. tower-http's CorsLayer would be cleaner, but we want
85    // full TOML control without depending on tower-http's CORS feature spec.
86    if cfg.cors.enabled {
87        let cors = Arc::new(cfg.cors.clone());
88        let cors_clone = cors.clone();
89        router = router.layer(axum::middleware::from_fn(
90            move |req: Request<Body>, next: Next| {
91                let cors = cors_clone.clone();
92                async move { cors_mw(cors, req, next).await }
93            },
94        ));
95    }
96
97    // Reverse-proxy rules — apply BEFORE rewrites so the user can rewrite
98    // upstream-bound requests too.
99    if !cfg.proxies.is_empty() {
100        let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
101        let proxies_clone = proxies.clone();
102        router = router.layer(axum::middleware::from_fn(
103            move |req: Request<Body>, next: Next| {
104                let proxies = proxies_clone.clone();
105                async move { proxy_mw(proxies, req, next).await }
106            },
107        ));
108    }
109
110    // Rewrites — apply early so they see the request before other layers.
111    if !cfg.rewrites.is_empty() {
112        let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
113        let compiled_clone = compiled.clone();
114        router = router.layer(axum::middleware::from_fn(
115            move |req: Request<Body>, next: Next| {
116                let rules = compiled_clone.clone();
117                async move { rewrite_mw(rules, req, next).await }
118            },
119        ));
120    }
121
122    // Trailing-slash policy.
123    if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
124        let ts = cfg.trailing_slash.clone();
125        router = router.layer(axum::middleware::from_fn(
126            move |req: Request<Body>, next: Next| {
127                let ts = ts.clone();
128                async move { trailing_slash_mw(ts, req, next).await }
129            },
130        ));
131    }
132
133    // Custom error pages: intercept responses with matching status codes and
134    // substitute the configured file contents.
135    if !cfg.error_pages.is_empty() {
136        let pages = Arc::new(load_error_pages(&cfg.error_pages));
137        let pages_clone = pages.clone();
138        router = router.layer(axum::middleware::from_fn(
139            move |req: Request<Body>, next: Next| {
140                let pages = pages_clone.clone();
141                async move { error_pages_mw(pages, req, next).await }
142            },
143        ));
144    }
145
146    // HSTS header for HTTPS responses.
147    if cfg.tls.is_some() && cfg.hsts.enabled {
148        if let Some(value) = build_hsts_header(&cfg.hsts) {
149            router = router.layer(SetResponseHeaderLayer::if_not_present(
150                HeaderName::from_static("strict-transport-security"),
151                value,
152            ));
153        }
154    }
155
156    if cfg.compression.enabled {
157        // tower-http's `CompressionLayer` selects the encoding based on the
158        // client's `Accept-Encoding` header; we just toggle the layer on and
159        // gate via the min-size predicate. Per-algorithm disable lives on the
160        // un-parameterized layer, so we apply it before `compress_when`.
161        let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
162        let mut layer = CompressionLayer::new();
163        if !cfg
164            .compression
165            .algorithms
166            .iter()
167            .any(|a| a.eq_ignore_ascii_case("gzip"))
168        {
169            layer = layer.no_gzip();
170        }
171        if !cfg
172            .compression
173            .algorithms
174            .iter()
175            .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
176        {
177            layer = layer.no_br();
178        }
179        if !cfg
180            .compression
181            .algorithms
182            .iter()
183            .any(|a| a.eq_ignore_ascii_case("deflate"))
184        {
185            layer = layer.no_deflate();
186        }
187        let layer = layer.compress_when(SizeAbove::new(min_size));
188        router = router.layer(layer);
189    }
190
191    if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
192        let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
193        let limiter_clone = limiter.clone();
194        router = router.layer(axum::middleware::from_fn(
195            move |req: Request<Body>, next: Next| {
196                let limiter = limiter_clone.clone();
197                async move { rate_limit_mw(limiter, req, next).await }
198            },
199        ));
200    }
201
202    if matches!(
203        cfg.access_log.format,
204        AccessLogFormat::Combined | AccessLogFormat::Json
205    ) {
206        let format = cfg.access_log.format;
207        router =
208            router.layer(axum::middleware::from_fn(
209                move |req: Request<Body>, next: Next| async move {
210                    access_log_mw(format, req, next).await
211                },
212            ));
213    }
214
215    router
216}
217
218fn mount_static(
219    router: AxumRouter<Container>,
220    prefix: &str,
221    mount: &StaticMount,
222) -> AxumRouter<Container> {
223    // Note: `ranges` is reserved for a future version of tower-http that exposes
224    // per-instance range toggling. For now ranges are always enabled.
225    let _ = mount.ranges;
226    let svc = ServeDir::new(&mount.dir);
227
228    let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
229    let nested = if let Some(cache) = mount.cache {
230        let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
231            .unwrap_or_else(|_| HeaderValue::from_static("public"));
232        nested.layer(SetResponseHeaderLayer::if_not_present(
233            HeaderName::from_static("cache-control"),
234            value,
235        ))
236    } else {
237        nested
238    };
239    router.merge(nested)
240}
241
242// ─── Serve entry points ─────────────────────────────────────────────────────
243
244pub async fn serve(
245    router: AxumRouter,
246    cfg: &ServerConfig,
247    shutdown: tokio::sync::oneshot::Receiver<()>,
248) -> Result<(), Error> {
249    let addr: SocketAddr = cfg
250        .bind
251        .parse()
252        .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
253
254    tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
255
256    // If a redirect-HTTP listener is configured, spawn it alongside the main listener.
257    let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
258    let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
259    tokio::spawn(async move {
260        let _ = shutdown.await;
261        let _ = shutdown_main_tx.send(());
262        let _ = shutdown_redir_tx.send(());
263    });
264
265    let redirect_task = cfg.redirect_http.clone().map(|redir| {
266        let target_host = redir
267            .target_host
268            .clone()
269            .or_else(|| cfg.server_name.first().cloned());
270        let permanent = redir.permanent;
271        let bind = redir.bind.clone();
272        tokio::spawn(async move {
273            if let Err(e) =
274                serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
275            {
276                tracing::warn!(?e, "redirect_http listener exited with error");
277            }
278        })
279    });
280
281    let main_result = if let Some(tls) = &cfg.tls {
282        serve_tls(router, addr, tls, shutdown_main_rx).await
283    } else {
284        serve_plain(router, addr, shutdown_main_rx).await
285    };
286
287    if let Some(task) = redirect_task {
288        task.abort();
289    }
290
291    main_result
292}
293
294/// Plain-HTTP listener that 30x-redirects every request to its `https://`
295/// equivalent. Used when TLS is on and `redirect_http` is configured.
296async fn serve_redirect_http(
297    bind: &str,
298    target_host: Option<String>,
299    permanent: bool,
300    shutdown: tokio::sync::oneshot::Receiver<()>,
301) -> Result<(), Error> {
302    let addr: SocketAddr = bind
303        .parse()
304        .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
305    tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
306
307    let target_host = Arc::new(target_host);
308    let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
309        let target_host = target_host.clone();
310        move |req: Request<Body>| {
311            let target_host = target_host.clone();
312            async move { http_redirect_handler(req, target_host, permanent).await }
313        }
314    }));
315
316    let listener = tokio::net::TcpListener::bind(addr).await?;
317    axum::serve(listener, router)
318        .with_graceful_shutdown(async move {
319            let _ = shutdown.await;
320        })
321        .await?;
322    Ok(())
323}
324
325async fn http_redirect_handler(
326    req: Request<Body>,
327    target_host: Arc<Option<String>>,
328    permanent: bool,
329) -> Response<Body> {
330    let host = target_host.as_ref().clone().unwrap_or_else(|| {
331        req.headers()
332            .get("host")
333            .and_then(|v| v.to_str().ok())
334            .map(String::from)
335            .unwrap_or_default()
336    });
337    let path_and_query = req
338        .uri()
339        .path_and_query()
340        .map(|p| p.as_str().to_string())
341        .unwrap_or_else(|| "/".to_string());
342    let location = format!("https://{host}{path_and_query}");
343
344    let status = if permanent {
345        StatusCode::MOVED_PERMANENTLY
346    } else {
347        StatusCode::FOUND
348    };
349    let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
350    *resp.status_mut() = status;
351    if let Ok(loc) = HeaderValue::from_str(&location) {
352        resp.headers_mut().insert("location", loc);
353    }
354    resp
355}
356
357fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
358    let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
359    let mut value = format!("max-age={}", max_age.as_secs());
360    if cfg.include_subdomains {
361        value.push_str("; includeSubDomains");
362    }
363    if cfg.preload {
364        value.push_str("; preload");
365    }
366    HeaderValue::from_str(&value).ok()
367}
368
369/// Reject requests whose Host header doesn't match any configured server_name.
370async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
371    let host = req
372        .headers()
373        .get("host")
374        .and_then(|v| v.to_str().ok())
375        .unwrap_or("")
376        .to_string();
377
378    // Strip port for matching: "example.com:8080" → "example.com".
379    let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
380
381    if matches_any(&host_no_port, &allowed) {
382        return next.run(req).await;
383    }
384
385    tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
386    let mut resp = Response::new(Body::from(format!(
387        "404 not found (unknown host: {host})\n"
388    )));
389    *resp.status_mut() = StatusCode::NOT_FOUND;
390    resp
391}
392
393fn matches_any(host: &str, patterns: &[String]) -> bool {
394    patterns.iter().any(|pat| matches_pattern(host, pat))
395}
396
397/// Match a host against a pattern. Supports exact match and `*.example.com`
398/// wildcards. The pattern is normalized to lowercase. A bare `*` matches any.
399fn matches_pattern(host: &str, pattern: &str) -> bool {
400    let pattern = pattern.to_ascii_lowercase();
401    if pattern == "*" {
402        return true;
403    }
404    if let Some(suffix) = pattern.strip_prefix("*.") {
405        // `*.foo.com` matches `bar.foo.com` but not `foo.com`.
406        return host.ends_with(&format!(".{suffix}"));
407    }
408    host == pattern
409}
410
411async fn serve_plain(
412    router: AxumRouter,
413    addr: SocketAddr,
414    shutdown: tokio::sync::oneshot::Receiver<()>,
415) -> Result<(), Error> {
416    let listener = tokio::net::TcpListener::bind(addr).await?;
417    axum::serve(listener, router)
418        .with_graceful_shutdown(async move {
419            let _ = shutdown.await;
420        })
421        .await?;
422    Ok(())
423}
424
425async fn serve_tls(
426    router: AxumRouter,
427    addr: SocketAddr,
428    tls: &TlsConfig,
429    shutdown: tokio::sync::oneshot::Receiver<()>,
430) -> Result<(), Error> {
431    let config = axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
432        .await
433        .map_err(|e| Error::Config(format!("tls load: {e}")))?;
434
435    let handle = axum_server::Handle::new();
436    let handle_for_shutdown = handle.clone();
437    tokio::spawn(async move {
438        let _ = shutdown.await;
439        handle_for_shutdown.graceful_shutdown(Some(Duration::from_secs(10)));
440    });
441
442    axum_server::bind_rustls(addr, config)
443        .handle(handle)
444        .serve(router.into_make_service())
445        .await
446        .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
447    Ok(())
448}
449
450// ─── Rate limiter (Moka-backed token bucket) ────────────────────────────────
451
452pub struct RateLimiter {
453    /// `bucket key` → `(window_end_instant, count_remaining)`.
454    state: moka::sync::Cache<String, (Instant, u32)>,
455    default_rule: Option<RateRule>,
456    route_rules: Vec<(MatchKey, RateRule)>,
457}
458
459#[derive(Clone, Copy)]
460struct RateRule {
461    count: u32,
462    window: Duration,
463}
464
465#[derive(Clone)]
466struct MatchKey {
467    method: Option<Method>,
468    path: String,
469}
470
471impl RateLimiter {
472    pub fn from_config(cfg: &RateLimitConfig) -> Self {
473        let default_rule = cfg.per_ip.as_deref().and_then(|s| {
474            crate::server_config::parse_rate(s)
475                .map(|(count, window)| RateRule { count, window })
476                .ok()
477        });
478        let route_rules = cfg
479            .routes
480            .iter()
481            .filter_map(|(spec, rate)| {
482                let (count, window) = crate::server_config::parse_rate(rate).ok()?;
483                let (method, path) = parse_route_spec(spec);
484                Some((MatchKey { method, path }, RateRule { count, window }))
485            })
486            .collect();
487
488        Self {
489            state: moka::sync::Cache::builder()
490                .max_capacity(10_000)
491                .time_to_idle(Duration::from_secs(600))
492                .build(),
493            default_rule,
494            route_rules,
495        }
496    }
497
498    fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
499        for (key, rule) in &self.route_rules {
500            if key.path == path && key.method.as_ref().map_or(true, |m| m == method) {
501                return Some(*rule);
502            }
503        }
504        self.default_rule
505    }
506
507    fn check(&self, bucket: &str, rule: RateRule) -> bool {
508        let now = Instant::now();
509        let mut allowed = true;
510        self.state
511            .entry(bucket.to_string())
512            .and_compute_with(|existing| match existing {
513                Some(entry) => {
514                    let (window_end, count) = entry.into_value();
515                    if now >= window_end {
516                        moka::ops::compute::Op::Put((
517                            now + rule.window,
518                            rule.count.saturating_sub(1),
519                        ))
520                    } else if count > 0 {
521                        moka::ops::compute::Op::Put((window_end, count - 1))
522                    } else {
523                        allowed = false;
524                        moka::ops::compute::Op::Put((window_end, 0))
525                    }
526                }
527                None => {
528                    moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
529                }
530            });
531        allowed
532    }
533}
534
535fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
536    let trimmed = spec.trim();
537    if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
538        let method = m.parse::<Method>().ok();
539        (method, p.trim().to_string())
540    } else {
541        (None, trimmed.to_string())
542    }
543}
544
545async fn rate_limit_mw(
546    limiter: Arc<RateLimiter>,
547    req: Request<Body>,
548    next: Next,
549) -> Response<Body> {
550    let method = req.method().clone();
551    let path = req.uri().path().to_string();
552    let bucket = format!("{}|{}|{}", client_ip(&req), method, path);
553
554    if let Some(rule) = limiter.rule_for(&method, &path) {
555        if !limiter.check(&bucket, rule) {
556            tracing::debug!(%method, %path, %bucket, "rate limited");
557            let mut resp = Response::new(Body::from("rate limit exceeded"));
558            *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
559            return resp;
560        }
561    }
562    next.run(req).await
563}
564
565fn client_ip(req: &Request<Body>) -> String {
566    // Prefer `X-Forwarded-For` if a value is present — trusted-proxy filtering
567    // is intentionally skipped in v1; apps behind untrusted LBs should disable
568    // rate limiting per-IP and rely on the LB.
569    if let Some(v) = req.headers().get("x-forwarded-for") {
570        if let Ok(s) = v.to_str() {
571            if let Some(first) = s.split(',').next() {
572                return first.trim().to_string();
573            }
574        }
575    }
576    // axum exposes the SocketAddr via ConnectInfo when configured. Without it
577    // we fall back to a single global bucket so the rate limit still applies.
578    "unknown".into()
579}
580
581// ─── Access log ─────────────────────────────────────────────────────────────
582
583async fn access_log_mw(format: AccessLogFormat, req: Request<Body>, next: Next) -> Response<Body> {
584    let started = Instant::now();
585    let method = req.method().clone();
586    let path = req.uri().path().to_string();
587    let host = req
588        .headers()
589        .get("host")
590        .and_then(|v| v.to_str().ok())
591        .unwrap_or("-")
592        .to_string();
593    let referer = req
594        .headers()
595        .get("referer")
596        .and_then(|v| v.to_str().ok())
597        .map(String::from);
598    let ua = req
599        .headers()
600        .get("user-agent")
601        .and_then(|v| v.to_str().ok())
602        .map(String::from);
603    let ip = client_ip(&req);
604
605    let resp = next.run(req).await;
606    let elapsed = started.elapsed();
607    let status = resp.status().as_u16();
608    let bytes = response_size(resp.headers()).unwrap_or(0);
609
610    match format {
611        AccessLogFormat::Combined => {
612            tracing::info!(
613                target: "access_log",
614                "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
615                ip,
616                method,
617                path,
618                status,
619                bytes,
620                referer.as_deref().unwrap_or("-"),
621                ua.as_deref().unwrap_or("-"),
622                elapsed.as_millis(),
623            );
624        }
625        AccessLogFormat::Json => {
626            tracing::info!(
627                target: "access_log",
628                json = %serde_json::json!({
629                    "ip": ip,
630                    "method": method.as_str(),
631                    "path": path,
632                    "host": host,
633                    "status": status,
634                    "bytes": bytes,
635                    "referer": referer,
636                    "user_agent": ua,
637                    "duration_ms": elapsed.as_millis(),
638                }),
639                "request"
640            );
641        }
642        AccessLogFormat::Off => {}
643    }
644    resp
645}
646
647fn response_size(headers: &HeaderMap) -> Option<u64> {
648    headers
649        .get("content-length")
650        .and_then(|v| v.to_str().ok())
651        .and_then(|s| s.parse().ok())
652}
653
654// ─── Rewrites ───────────────────────────────────────────────────────────────
655
656#[derive(Clone)]
657struct CompiledRewrite {
658    pattern: regex::Regex,
659    to: String,
660    status: Option<u16>,
661    match_query: bool,
662}
663
664struct CompiledRewrites {
665    rules: Vec<CompiledRewrite>,
666}
667
668impl CompiledRewrites {
669    fn compile(rules: &[RewriteRule]) -> Self {
670        let compiled = rules
671            .iter()
672            .filter_map(|r| match regex::Regex::new(&r.from) {
673                Ok(pattern) => Some(CompiledRewrite {
674                    pattern,
675                    to: r.to.clone(),
676                    status: r.status,
677                    match_query: r.match_query,
678                }),
679                Err(e) => {
680                    tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
681                    None
682                }
683            })
684            .collect();
685        Self { rules: compiled }
686    }
687}
688
689async fn rewrite_mw(
690    rules: Arc<CompiledRewrites>,
691    mut req: Request<Body>,
692    next: Next,
693) -> Response<Body> {
694    let path = req.uri().path().to_string();
695    let path_and_query = req
696        .uri()
697        .path_and_query()
698        .map(|p| p.as_str().to_string())
699        .unwrap_or_else(|| path.clone());
700
701    let target_str_path = path.clone();
702    let target_str_full = path_and_query.clone();
703
704    // First pass: apply path-only rules.
705    let mut applied: Option<(String, Option<u16>)> = None;
706    for rule in &rules.rules {
707        let subject = if rule.match_query {
708            &target_str_full
709        } else {
710            &target_str_path
711        };
712        if rule.pattern.is_match(subject) {
713            let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
714            applied = Some((replaced, rule.status));
715            break;
716        }
717    }
718
719    let Some((new_target, status)) = applied else {
720        return next.run(req).await;
721    };
722
723    match status {
724        Some(code @ (301 | 302 | 303 | 307 | 308)) => {
725            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
726            *resp.status_mut() =
727                StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
728            if let Ok(loc) = HeaderValue::from_str(&new_target) {
729                resp.headers_mut().insert("location", loc);
730            }
731            resp
732        }
733        _ => {
734            // Internal rewrite: replace the URI's path-and-query.
735            let mut parts = req.uri().clone().into_parts();
736            if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
737                parts.path_and_query = Some(new_pq);
738            }
739            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
740                *req.uri_mut() = new_uri;
741            }
742            next.run(req).await
743        }
744    }
745}
746
747// ─── Trailing slash ─────────────────────────────────────────────────────────
748
749async fn trailing_slash_mw(
750    cfg: TrailingSlashConfig,
751    mut req: Request<Body>,
752    next: Next,
753) -> Response<Body> {
754    let path = req.uri().path().to_string();
755    if path == "/" {
756        return next.run(req).await;
757    }
758
759    let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
760    let has_slash = path.ends_with('/');
761
762    if want_slash == has_slash {
763        return next.run(req).await;
764    }
765
766    let new_path = if want_slash {
767        format!("{path}/")
768    } else {
769        path.trim_end_matches('/').to_string()
770    };
771
772    let query = req
773        .uri()
774        .query()
775        .map(|q| format!("?{q}"))
776        .unwrap_or_default();
777    let new_target = format!("{new_path}{query}");
778
779    match cfg.action {
780        TrailingSlashAction::Redirect => {
781            let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
782            *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
783            if let Ok(loc) = HeaderValue::from_str(&new_target) {
784                resp.headers_mut().insert("location", loc);
785            }
786            resp
787        }
788        TrailingSlashAction::Rewrite => {
789            let mut parts = req.uri().clone().into_parts();
790            if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
791                parts.path_and_query = Some(pq);
792            }
793            if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
794                *req.uri_mut() = new_uri;
795            }
796            next.run(req).await
797        }
798    }
799}
800
801// ─── Custom error pages ─────────────────────────────────────────────────────
802
803struct LoadedErrorPages {
804    by_status: std::collections::HashMap<u16, (String, &'static str)>,
805}
806
807fn load_error_pages(
808    raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
809) -> LoadedErrorPages {
810    let mut by_status = std::collections::HashMap::new();
811    for (key, path) in raw {
812        let Ok(code) = key.parse::<u16>() else {
813            tracing::warn!(key, "error_pages: invalid status code, skipping");
814            continue;
815        };
816        let body = match std::fs::read_to_string(path) {
817            Ok(s) => s,
818            Err(e) => {
819                tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
820                continue;
821            }
822        };
823        let content_type = guess_content_type(path);
824        by_status.insert(code, (body, content_type));
825    }
826    LoadedErrorPages { by_status }
827}
828
829fn guess_content_type(path: &std::path::Path) -> &'static str {
830    match path.extension().and_then(|e| e.to_str()) {
831        Some("html") | Some("htm") => "text/html; charset=utf-8",
832        Some("json") => "application/json",
833        Some("txt") => "text/plain; charset=utf-8",
834        _ => "text/plain; charset=utf-8",
835    }
836}
837
838async fn error_pages_mw(
839    pages: Arc<LoadedErrorPages>,
840    req: Request<Body>,
841    next: Next,
842) -> Response<Body> {
843    let resp = next.run(req).await;
844    let status = resp.status().as_u16();
845
846    let Some((body, ctype)) = pages.by_status.get(&status) else {
847        return resp;
848    };
849
850    let mut out = Response::new(Body::from(body.clone()));
851    *out.status_mut() = resp.status();
852    if let Ok(ct) = HeaderValue::from_str(ctype) {
853        out.headers_mut().insert("content-type", ct);
854    }
855    // Preserve a couple of useful headers from the original response.
856    for h in ["cache-control", "x-request-id"] {
857        if let Some(v) = resp.headers().get(h) {
858            out.headers_mut().insert(h, v.clone());
859        }
860    }
861    out
862}
863
864// ─── Reverse proxy ──────────────────────────────────────────────────────────
865
866#[derive(Clone)]
867struct CompiledProxy {
868    prefix: String,
869    upstream: String,
870    strip_prefix: bool,
871    preserve_host: bool,
872    timeout: Duration,
873    retries: u8,
874}
875
876struct CompiledProxies {
877    rules: Vec<CompiledProxy>,
878    client: reqwest::Client,
879}
880
881impl CompiledProxies {
882    fn compile(rules: &[ProxyRule]) -> Self {
883        let mut compiled: Vec<CompiledProxy> = rules
884            .iter()
885            .map(|r| CompiledProxy {
886                prefix: r.prefix.clone(),
887                upstream: r.upstream.trim_end_matches('/').to_string(),
888                strip_prefix: r.strip_prefix,
889                preserve_host: r.preserve_host,
890                timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
891                retries: r.retries,
892            })
893            .collect();
894        // Longest prefix first so `/api/v2/users` beats `/api`.
895        compiled.sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len()));
896
897        let client = reqwest::Client::builder()
898            .redirect(reqwest::redirect::Policy::none())
899            .build()
900            .unwrap_or_else(|_| reqwest::Client::new());
901
902        Self {
903            rules: compiled,
904            client,
905        }
906    }
907
908    fn matching(&self, path: &str) -> Option<&CompiledProxy> {
909        self.rules.iter().find(|r| path.starts_with(&r.prefix))
910    }
911}
912
913async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
914    let path = req.uri().path().to_string();
915    let Some(rule) = proxies.matching(&path) else {
916        return next.run(req).await;
917    };
918    let rule = rule.clone();
919
920    match proxy_forward(&proxies.client, &rule, req).await {
921        Ok(resp) => resp,
922        Err(e) => {
923            tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
924            let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
925            *resp.status_mut() = StatusCode::BAD_GATEWAY;
926            resp
927        }
928    }
929}
930
931async fn proxy_forward(
932    client: &reqwest::Client,
933    rule: &CompiledProxy,
934    req: Request<Body>,
935) -> Result<Response<Body>, String> {
936    let (parts, body) = req.into_parts();
937    let body_bytes = axum::body::to_bytes(body, usize::MAX)
938        .await
939        .map_err(|e| format!("body read: {e}"))?;
940
941    let original_path = parts.uri.path();
942    let upstream_path = if rule.strip_prefix {
943        original_path
944            .strip_prefix(&rule.prefix)
945            .unwrap_or(original_path)
946    } else {
947        original_path
948    };
949    let upstream_path = if upstream_path.is_empty() {
950        "/"
951    } else {
952        upstream_path
953    };
954    let query = parts
955        .uri
956        .query()
957        .map(|q| format!("?{q}"))
958        .unwrap_or_default();
959    let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
960
961    let method = parts.method.clone();
962    let mut last_err = String::new();
963    for attempt in 0..=rule.retries {
964        let mut request = client
965            .request(
966                reqwest::Method::from_bytes(method.as_str().as_bytes())
967                    .unwrap_or(reqwest::Method::GET),
968                &upstream_url,
969            )
970            .timeout(rule.timeout)
971            .body(body_bytes.clone());
972
973        for (name, value) in parts.headers.iter() {
974            // Hop-by-hop headers per RFC 7230 §6.1 — skip.
975            let n = name.as_str().to_ascii_lowercase();
976            if matches!(
977                n.as_str(),
978                "connection"
979                    | "keep-alive"
980                    | "proxy-authenticate"
981                    | "proxy-authorization"
982                    | "te"
983                    | "trailers"
984                    | "transfer-encoding"
985                    | "upgrade"
986                    | "content-length"
987            ) {
988                continue;
989            }
990            if !rule.preserve_host && n == "host" {
991                continue;
992            }
993            if let Ok(v) = value.to_str() {
994                request = request.header(name.as_str(), v);
995            }
996        }
997
998        // X-Forwarded-* headers — useful for upstreams.
999        if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1000            request = request.header("x-forwarded-host", host);
1001        }
1002        request = request.header("x-forwarded-proto", "http");
1003
1004        match request.send().await {
1005            Ok(resp) => return upstream_to_axum(resp).await,
1006            Err(e) => {
1007                last_err = format!("attempt {} → {e}", attempt + 1);
1008                tracing::debug!(error = %e, attempt, "proxy retry");
1009                continue;
1010            }
1011        }
1012    }
1013    Err(last_err)
1014}
1015
1016// ─── CORS ───────────────────────────────────────────────────────────────────
1017
1018async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1019    let origin = req
1020        .headers()
1021        .get("origin")
1022        .and_then(|v| v.to_str().ok())
1023        .map(String::from);
1024
1025    let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1026        cfg.allow_origins
1027            .iter()
1028            .any(|allowed| allowed == "*" || allowed == o)
1029    });
1030
1031    // Preflight
1032    if req.method() == Method::OPTIONS && origin.is_some() {
1033        let mut resp = Response::new(Body::empty());
1034        *resp.status_mut() = StatusCode::NO_CONTENT;
1035        apply_cors_headers(
1036            resp.headers_mut(),
1037            &cfg,
1038            origin.as_deref(),
1039            is_allowed_origin,
1040        );
1041        return resp;
1042    }
1043
1044    let mut resp = next.run(req).await;
1045    apply_cors_headers(
1046        resp.headers_mut(),
1047        &cfg,
1048        origin.as_deref(),
1049        is_allowed_origin,
1050    );
1051    resp
1052}
1053
1054fn apply_cors_headers(
1055    headers: &mut HeaderMap,
1056    cfg: &CorsConfig,
1057    origin: Option<&str>,
1058    is_allowed_origin: bool,
1059) {
1060    if !is_allowed_origin {
1061        return;
1062    }
1063    if let Some(origin) = origin {
1064        if let Ok(v) = HeaderValue::from_str(origin) {
1065            headers.insert("access-control-allow-origin", v);
1066        }
1067        headers.insert("vary", HeaderValue::from_static("Origin"));
1068    } else if cfg.allow_origins.iter().any(|o| o == "*") {
1069        headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1070    }
1071
1072    let methods = if cfg.allow_methods.is_empty() {
1073        "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1074    } else {
1075        cfg.allow_methods.join(", ")
1076    };
1077    if let Ok(v) = HeaderValue::from_str(&methods) {
1078        headers.insert("access-control-allow-methods", v);
1079    }
1080
1081    let allow_headers = if cfg.allow_headers.is_empty() {
1082        "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1083    } else {
1084        cfg.allow_headers.join(", ")
1085    };
1086    if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1087        headers.insert("access-control-allow-headers", v);
1088    }
1089
1090    if !cfg.expose_headers.is_empty() {
1091        if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1092            headers.insert("access-control-expose-headers", v);
1093        }
1094    }
1095
1096    if cfg.allow_credentials {
1097        headers.insert(
1098            "access-control-allow-credentials",
1099            HeaderValue::from_static("true"),
1100        );
1101    }
1102
1103    if let Some(max_age) = cfg.max_age {
1104        if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1105            headers.insert("access-control-max-age", v);
1106        }
1107    }
1108}
1109
1110// ─── IP allow/deny ──────────────────────────────────────────────────────────
1111
1112async fn ip_rules_mw(rules: Arc<Vec<IpRule>>, req: Request<Body>, next: Next) -> Response<Body> {
1113    let path = req.uri().path().to_string();
1114    let ip_str = client_ip(&req);
1115    let ip = ip_str.parse::<std::net::IpAddr>().ok();
1116
1117    for rule in rules.iter() {
1118        if !path.starts_with(&rule.prefix) {
1119            continue;
1120        }
1121        let matches_range = ip
1122            .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1123            .unwrap_or(false);
1124        let allowed = match rule.action {
1125            IpAction::Allow => matches_range,
1126            IpAction::Deny => !matches_range,
1127        };
1128        if !allowed {
1129            tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1130            let mut resp = Response::new(Body::from("forbidden"));
1131            *resp.status_mut() = StatusCode::FORBIDDEN;
1132            return resp;
1133        }
1134        // First matching prefix wins.
1135        break;
1136    }
1137
1138    next.run(req).await
1139}
1140
1141// ─── HTTP Basic Auth ────────────────────────────────────────────────────────
1142
1143use base64::engine::general_purpose::STANDARD as B64;
1144use base64::Engine as _;
1145
1146struct CompiledBasicAuth {
1147    rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1148}
1149
1150fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1151    let compiled = rules
1152        .iter()
1153        .map(|r| {
1154            let creds = r
1155                .credentials
1156                .iter()
1157                .filter_map(|c| {
1158                    c.split_once(':')
1159                        .map(|(u, p)| (u.to_string(), p.to_string()))
1160                })
1161                .collect();
1162            (r.clone(), creds)
1163        })
1164        .collect();
1165    CompiledBasicAuth { rules: compiled }
1166}
1167
1168async fn basic_auth_mw(
1169    rules: Arc<CompiledBasicAuth>,
1170    req: Request<Body>,
1171    next: Next,
1172) -> Response<Body> {
1173    let path = req.uri().path().to_string();
1174    for (rule, creds) in &rules.rules {
1175        if !path.starts_with(&rule.prefix) {
1176            continue;
1177        }
1178        let supplied = req
1179            .headers()
1180            .get("authorization")
1181            .and_then(|v| v.to_str().ok())
1182            .and_then(|s| s.strip_prefix("Basic "))
1183            .and_then(|b64| B64.decode(b64).ok())
1184            .and_then(|bytes| String::from_utf8(bytes).ok())
1185            .and_then(|pair| {
1186                pair.split_once(':')
1187                    .map(|(u, p)| (u.to_string(), p.to_string()))
1188            });
1189
1190        let ok = supplied
1191            .as_ref()
1192            .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1193            .unwrap_or(false);
1194
1195        if ok {
1196            return next.run(req).await;
1197        }
1198
1199        let challenge = format!("Basic realm=\"{}\"", rule.realm);
1200        let mut resp = Response::new(Body::from("authentication required"));
1201        *resp.status_mut() = StatusCode::UNAUTHORIZED;
1202        if let Ok(v) = HeaderValue::from_str(&challenge) {
1203            resp.headers_mut().insert("www-authenticate", v);
1204        }
1205        return resp;
1206    }
1207    next.run(req).await
1208}
1209
1210async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1211    let status = resp.status();
1212    let headers = resp.headers().clone();
1213    let bytes = resp
1214        .bytes()
1215        .await
1216        .map_err(|e| format!("upstream body: {e}"))?;
1217    let mut out = Response::new(Body::from(bytes));
1218    *out.status_mut() =
1219        axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1220    for (name, value) in headers.iter() {
1221        let n = name.as_str().to_ascii_lowercase();
1222        if matches!(
1223            n.as_str(),
1224            "connection"
1225                | "keep-alive"
1226                | "proxy-authenticate"
1227                | "proxy-authorization"
1228                | "te"
1229                | "trailers"
1230                | "transfer-encoding"
1231                | "upgrade"
1232        ) {
1233            continue;
1234        }
1235        if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1236            if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1237                out.headers_mut().append(name, v);
1238            }
1239        }
1240    }
1241    Ok(out)
1242}