1use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22 AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23 RateLimitConfig, RewriteRule, RouteTimeoutRule, ServerConfig, StaticMount, TlsConfig,
24 TrailingSlashAction, TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27use axum::extract::ConnectInfo;
28
29pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
32 let mut router = web;
33
34 for (prefix, mount) in &cfg.static_files {
37 router = mount_static(router, prefix, mount);
38 }
39
40 let body_max = cfg.limits.body_max as usize;
42 router = router
43 .layer(RequestBodyLimitLayer::new(body_max))
44 .layer(TraceLayer::new_for_http());
45
46 if let Some(timeout) = cfg.limits.request_timeout {
47 router = router.layer(TimeoutLayer::new(timeout));
48 }
49
50 if let Some(max) = cfg.limits.max_concurrency {
57 let max = max.max(1) as usize;
58 let limiter = Arc::new(tokio::sync::Semaphore::new(max));
59 let limiter_clone = limiter.clone();
60 router = router.layer(axum::middleware::from_fn(
61 move |req: Request<Body>, next: Next| {
62 let limiter = limiter_clone.clone();
63 async move {
64 match limiter.try_acquire() {
65 Ok(_permit) => next.run(req).await,
66 Err(_) => {
67 let mut resp = Response::new(Body::from(
68 "service overloaded — too many concurrent requests",
69 ));
70 *resp.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
71 resp
72 }
73 }
74 }
75 },
76 ));
77 }
78
79 if !cfg.route_timeouts.is_empty() {
84 let rules: Arc<Vec<RouteTimeoutRule>> = Arc::new(cfg.route_timeouts.clone());
85 let rules_clone = rules.clone();
86 router = router.layer(axum::middleware::from_fn(
87 move |req: Request<Body>, next: Next| {
88 let rules = rules_clone.clone();
89 async move { route_timeout_mw(rules, req, next).await }
90 },
91 ));
92 }
93
94 if !cfg.server_name.is_empty() {
97 let allowed = cfg.server_name.clone();
98 router = router.layer(axum::middleware::from_fn(
99 move |req: Request<Body>, next: Next| {
100 let allowed = allowed.clone();
101 async move { host_match_mw(allowed, req, next).await }
102 },
103 ));
104 }
105
106 let trusted: Arc<Vec<ipnet::IpNet>> = Arc::new(cfg.trusted_proxies.ranges.clone());
111
112 if !cfg.ip_rules.is_empty() {
115 let rules = Arc::new(cfg.ip_rules.clone());
116 let rules_clone = rules.clone();
117 let trusted_clone = trusted.clone();
118 router = router.layer(axum::middleware::from_fn(
119 move |req: Request<Body>, next: Next| {
120 let rules = rules_clone.clone();
121 let trusted = trusted_clone.clone();
122 async move { ip_rules_mw(rules, trusted, req, next).await }
123 },
124 ));
125 }
126 if !cfg.basic_auth.is_empty() {
127 let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
128 let rules_clone = rules.clone();
129 router = router.layer(axum::middleware::from_fn(
130 move |req: Request<Body>, next: Next| {
131 let rules = rules_clone.clone();
132 async move { basic_auth_mw(rules, req, next).await }
133 },
134 ));
135 }
136
137 if cfg.cors.enabled {
140 let cors = Arc::new(cfg.cors.clone());
141 let cors_clone = cors.clone();
142 router = router.layer(axum::middleware::from_fn(
143 move |req: Request<Body>, next: Next| {
144 let cors = cors_clone.clone();
145 async move { cors_mw(cors, req, next).await }
146 },
147 ));
148 }
149
150 if !cfg.proxies.is_empty() {
153 let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
154 let proxies_clone = proxies.clone();
155 router = router.layer(axum::middleware::from_fn(
156 move |req: Request<Body>, next: Next| {
157 let proxies = proxies_clone.clone();
158 async move { proxy_mw(proxies, req, next).await }
159 },
160 ));
161 }
162
163 if !cfg.rewrites.is_empty() {
165 let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
166 let compiled_clone = compiled.clone();
167 router = router.layer(axum::middleware::from_fn(
168 move |req: Request<Body>, next: Next| {
169 let rules = compiled_clone.clone();
170 async move { rewrite_mw(rules, req, next).await }
171 },
172 ));
173 }
174
175 if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
177 let ts = cfg.trailing_slash.clone();
178 router = router.layer(axum::middleware::from_fn(
179 move |req: Request<Body>, next: Next| {
180 let ts = ts.clone();
181 async move { trailing_slash_mw(ts, req, next).await }
182 },
183 ));
184 }
185
186 if !cfg.error_pages.is_empty() {
189 let pages = Arc::new(load_error_pages(&cfg.error_pages));
190 let pages_clone = pages.clone();
191 router = router.layer(axum::middleware::from_fn(
192 move |req: Request<Body>, next: Next| {
193 let pages = pages_clone.clone();
194 async move { error_pages_mw(pages, req, next).await }
195 },
196 ));
197 }
198
199 if cfg.tls.is_some() && cfg.hsts.enabled {
201 if let Some(value) = build_hsts_header(&cfg.hsts) {
202 router = router.layer(SetResponseHeaderLayer::if_not_present(
203 HeaderName::from_static("strict-transport-security"),
204 value,
205 ));
206 }
207 }
208
209 if cfg.compression.enabled {
210 let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
215 let mut layer = CompressionLayer::new();
216 if !cfg
217 .compression
218 .algorithms
219 .iter()
220 .any(|a| a.eq_ignore_ascii_case("gzip"))
221 {
222 layer = layer.no_gzip();
223 }
224 if !cfg
225 .compression
226 .algorithms
227 .iter()
228 .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
229 {
230 layer = layer.no_br();
231 }
232 if !cfg
233 .compression
234 .algorithms
235 .iter()
236 .any(|a| a.eq_ignore_ascii_case("deflate"))
237 {
238 layer = layer.no_deflate();
239 }
240 let layer = layer.compress_when(SizeAbove::new(min_size));
241 router = router.layer(layer);
242 }
243
244 if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
245 let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
246 let limiter_clone = limiter.clone();
247 let trusted_clone = trusted.clone();
248 router = router.layer(axum::middleware::from_fn(
249 move |req: Request<Body>, next: Next| {
250 let limiter = limiter_clone.clone();
251 let trusted = trusted_clone.clone();
252 async move { rate_limit_mw(limiter, trusted, req, next).await }
253 },
254 ));
255 }
256
257 if matches!(
258 cfg.access_log.format,
259 AccessLogFormat::Combined | AccessLogFormat::Json
260 ) {
261 let format = cfg.access_log.format;
262 let trusted_clone = trusted.clone();
263 router = router.layer(axum::middleware::from_fn(
264 move |req: Request<Body>, next: Next| {
265 let trusted = trusted_clone.clone();
266 async move { access_log_mw(format, trusted, req, next).await }
267 },
268 ));
269 }
270
271 router = router.layer(axum::middleware::from_fn(request_id_mw));
276
277 let trusted_for_strip = trusted.clone();
287 router = router.layer(axum::middleware::from_fn(
288 move |req: Request<Body>, next: Next| {
289 let trusted = trusted_for_strip.clone();
290 async move { strip_untrusted_headers_mw(trusted, req, next).await }
291 },
292 ));
293
294 router
295}
296
297const SENSITIVE_PROXY_HEADERS: &[&str] = &[
305 "x-forwarded-for",
306 "x-forwarded-proto",
307 "x-forwarded-host",
308 "x-forwarded-port",
309 "x-forwarded-prefix",
310 "x-real-ip",
311 "forwarded",
312 "x-tls-spki-sha256",
313 "x-client-cert",
314 "x-ssl-client-cert",
315 "x-ssl-client-verify",
316 "x-ssl-client-s-dn",
317];
318
319async fn strip_untrusted_headers_mw(
320 trusted: Arc<Vec<ipnet::IpNet>>,
321 mut req: Request<Body>,
322 next: Next,
323) -> Response<Body> {
324 let peer: Option<std::net::IpAddr> = req
325 .extensions()
326 .get::<ConnectInfo<SocketAddr>>()
327 .map(|ci| ci.0.ip());
328
329 let peer_trusted = match peer {
330 Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
331 None => false,
332 };
333
334 if !peer_trusted {
335 let headers = req.headers_mut();
336 for &name in SENSITIVE_PROXY_HEADERS {
337 headers.remove(name);
338 }
339 }
340
341 next.run(req).await
342}
343
344fn mount_static(
345 router: AxumRouter<Container>,
346 prefix: &str,
347 mount: &StaticMount,
348) -> AxumRouter<Container> {
349 let _ = mount.ranges;
352
353 if let Some(fetcher) = crate::embedded::lookup(prefix) {
356 let cache = mount.cache;
357 let route_pat = format!("{}/*path", prefix.trim_end_matches('/'));
358 let nested = AxumRouter::<Container>::new().route(
359 &route_pat,
360 axum::routing::get(
361 move |axum::extract::Path(path): axum::extract::Path<String>,
362 headers: HeaderMap| async move {
363 serve_embedded(fetcher, cache, &path, &headers)
364 },
365 ),
366 );
367 return router.merge(nested);
368 }
369
370 let svc = ServeDir::new(&mount.dir);
371
372 let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
373 let nested = if let Some(cache) = mount.cache {
374 let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
375 .unwrap_or_else(|_| HeaderValue::from_static("public"));
376 nested.layer(SetResponseHeaderLayer::if_not_present(
377 HeaderName::from_static("cache-control"),
378 value,
379 ))
380 } else {
381 nested
382 };
383 router.merge(nested)
384}
385
386fn serve_embedded(
390 fetcher: crate::embedded::EmbeddedAssetFetcher,
391 cache: Option<Duration>,
392 path: &str,
393 headers: &HeaderMap,
394) -> Response<Body> {
395 let asset = match fetcher(path) {
396 Some(a) => a,
397 None => return not_found(),
398 };
399
400 if let (Some(client_tag), Some(asset_tag)) = (
401 headers
402 .get(axum::http::header::IF_NONE_MATCH)
403 .and_then(|v| v.to_str().ok()),
404 asset.etag.as_deref(),
405 ) {
406 if etag_matches(client_tag, asset_tag) {
407 let mut resp = Response::builder()
408 .status(StatusCode::NOT_MODIFIED)
409 .body(Body::empty())
410 .expect("304 body");
411 if let Some(d) = cache {
412 if let Ok(v) = HeaderValue::from_str(&format!("public, max-age={}", d.as_secs())) {
413 resp.headers_mut().insert("cache-control", v);
414 }
415 }
416 return resp;
417 }
418 }
419
420 let mut builder = Response::builder()
421 .status(StatusCode::OK)
422 .header("content-type", asset.content_type.as_str())
423 .header("content-length", asset.data.len());
424 if let Some(tag) = asset.etag.as_deref() {
425 builder = builder.header("etag", quote_etag(tag));
426 }
427 if let Some(d) = cache {
428 builder = builder.header("cache-control", format!("public, max-age={}", d.as_secs()));
429 }
430 builder
431 .body(Body::from(asset.data.into_owned()))
432 .unwrap_or_else(|_| not_found())
433}
434
435fn not_found() -> Response<Body> {
436 Response::builder()
437 .status(StatusCode::NOT_FOUND)
438 .body(Body::from("not found"))
439 .expect("404 body")
440}
441
442fn quote_etag(raw: &str) -> String {
443 if raw.starts_with('"') {
444 raw.to_string()
445 } else {
446 format!("\"{raw}\"")
447 }
448}
449
450fn etag_matches(client: &str, server: &str) -> bool {
451 let normalize = |s: &str| -> String {
452 s.split(',')
453 .map(|t| {
454 t.trim()
455 .trim_matches('"')
456 .trim_start_matches("W/")
457 .to_string()
458 })
459 .collect::<Vec<_>>()
460 .join(",")
461 };
462 let server_norm = normalize(server);
463 normalize(client)
464 .split(',')
465 .any(|tag| tag == server_norm || tag == "*")
466}
467
468pub async fn serve(
471 router: AxumRouter,
472 cfg: &ServerConfig,
473 shutdown: tokio::sync::oneshot::Receiver<()>,
474) -> Result<(), Error> {
475 let addr: SocketAddr = cfg
476 .bind
477 .parse()
478 .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
479
480 tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
481
482 let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
484 let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
485 tokio::spawn(async move {
486 let _ = shutdown.await;
487 let _ = shutdown_main_tx.send(());
488 let _ = shutdown_redir_tx.send(());
489 });
490
491 let redirect_task = cfg.redirect_http.clone().map(|redir| {
492 let target_host = redir
493 .target_host
494 .clone()
495 .or_else(|| cfg.server_name.first().cloned());
496 let permanent = redir.permanent;
497 let bind = redir.bind.clone();
498 tokio::spawn(async move {
499 if let Err(e) =
500 serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
501 {
502 tracing::warn!(?e, "redirect_http listener exited with error");
503 }
504 })
505 });
506
507 let main_result = if let Some(tls) = &cfg.tls {
508 if tls.acme.is_some() {
509 serve_acme(
510 router,
511 addr,
512 tls,
513 cfg.limits.drain_timeout,
514 shutdown_main_rx,
515 )
516 .await
517 } else {
518 serve_tls(
519 router,
520 addr,
521 tls,
522 cfg.limits.drain_timeout,
523 shutdown_main_rx,
524 )
525 .await
526 }
527 } else {
528 serve_plain(router, addr, shutdown_main_rx).await
529 };
530
531 if let Some(task) = redirect_task {
532 task.abort();
533 }
534
535 main_result
536}
537
538async fn serve_redirect_http(
541 bind: &str,
542 target_host: Option<String>,
543 permanent: bool,
544 shutdown: tokio::sync::oneshot::Receiver<()>,
545) -> Result<(), Error> {
546 let addr: SocketAddr = bind
547 .parse()
548 .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
549 tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
550
551 let target_host = Arc::new(target_host);
552 let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
553 let target_host = target_host.clone();
554 move |req: Request<Body>| {
555 let target_host = target_host.clone();
556 async move { http_redirect_handler(req, target_host, permanent).await }
557 }
558 }));
559
560 let listener = tokio::net::TcpListener::bind(addr).await?;
561 axum::serve(listener, router)
562 .with_graceful_shutdown(async move {
563 let _ = shutdown.await;
564 })
565 .await?;
566 Ok(())
567}
568
569async fn http_redirect_handler(
570 req: Request<Body>,
571 target_host: Arc<Option<String>>,
572 permanent: bool,
573) -> Response<Body> {
574 let host = target_host.as_ref().clone().unwrap_or_else(|| {
575 req.headers()
576 .get("host")
577 .and_then(|v| v.to_str().ok())
578 .map(String::from)
579 .unwrap_or_default()
580 });
581 let path_and_query = req
582 .uri()
583 .path_and_query()
584 .map(|p| p.as_str().to_string())
585 .unwrap_or_else(|| "/".to_string());
586 let location = format!("https://{host}{path_and_query}");
587
588 let status = if permanent {
589 StatusCode::MOVED_PERMANENTLY
590 } else {
591 StatusCode::FOUND
592 };
593 let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
594 *resp.status_mut() = status;
595 if let Ok(loc) = HeaderValue::from_str(&location) {
596 resp.headers_mut().insert("location", loc);
597 }
598 resp
599}
600
601fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
602 let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
603 let mut value = format!("max-age={}", max_age.as_secs());
604 if cfg.include_subdomains {
605 value.push_str("; includeSubDomains");
606 }
607 if cfg.preload {
608 value.push_str("; preload");
609 }
610 HeaderValue::from_str(&value).ok()
611}
612
613async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
615 let host = req
616 .headers()
617 .get("host")
618 .and_then(|v| v.to_str().ok())
619 .unwrap_or("")
620 .to_string();
621
622 let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
624
625 if matches_any(&host_no_port, &allowed) {
626 return next.run(req).await;
627 }
628
629 tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
630 let mut resp = Response::new(Body::from(format!(
631 "404 not found (unknown host: {host})\n"
632 )));
633 *resp.status_mut() = StatusCode::NOT_FOUND;
634 resp
635}
636
637fn matches_any(host: &str, patterns: &[String]) -> bool {
638 patterns.iter().any(|pat| matches_pattern(host, pat))
639}
640
641fn matches_pattern(host: &str, pattern: &str) -> bool {
644 let pattern = pattern.to_ascii_lowercase();
645 if pattern == "*" {
646 return true;
647 }
648 if let Some(suffix) = pattern.strip_prefix("*.") {
649 return host.ends_with(&format!(".{suffix}"));
651 }
652 host == pattern
653}
654
655async fn serve_plain(
656 router: AxumRouter,
657 addr: SocketAddr,
658 shutdown: tokio::sync::oneshot::Receiver<()>,
659) -> Result<(), Error> {
660 let listener = tokio::net::TcpListener::bind(addr).await?;
661 axum::serve(
662 listener,
663 router.into_make_service_with_connect_info::<SocketAddr>(),
664 )
665 .with_graceful_shutdown(async move {
666 let _ = shutdown.await;
667 })
668 .await?;
669 Ok(())
670}
671
672async fn serve_tls(
673 router: AxumRouter,
674 addr: SocketAddr,
675 tls: &TlsConfig,
676 drain: Duration,
677 shutdown: tokio::sync::oneshot::Receiver<()>,
678) -> Result<(), Error> {
679 let config = if tls.additional_certs.is_empty() {
684 axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
685 .await
686 .map_err(|e| Error::Config(format!("tls load: {e}")))?
687 } else {
688 let resolver = build_sni_resolver(tls)
689 .map_err(|e| Error::Config(format!("tls multi-cert load: {e}")))?;
690 let server_config = rustls::ServerConfig::builder()
691 .with_no_client_auth()
692 .with_cert_resolver(Arc::new(resolver));
693 axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config))
694 };
695
696 let watch_paths = [tls.cert.clone(), tls.key.clone()];
702 let config_for_watch = config.clone();
703 let cert_path = tls.cert.clone();
704 let key_path = tls.key.clone();
705 tokio::task::spawn_blocking(move || {
706 if let Err(e) = watch_tls_certs(config_for_watch, cert_path, key_path, watch_paths) {
707 tracing::warn!(error = %e, "cert hot-reload watcher exited");
708 }
709 });
710
711 let handle = axum_server::Handle::new();
712 let handle_for_shutdown = handle.clone();
713 tokio::spawn(async move {
714 let _ = shutdown.await;
715 handle_for_shutdown.graceful_shutdown(Some(drain));
716 });
717
718 axum_server::bind_rustls(addr, config)
719 .handle(handle)
720 .serve(router.into_make_service_with_connect_info::<SocketAddr>())
721 .await
722 .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
723 Ok(())
724}
725
726async fn serve_acme(
738 _router: AxumRouter,
739 _addr: SocketAddr,
740 tls: &TlsConfig,
741 _drain: Duration,
742 _shutdown: tokio::sync::oneshot::Receiver<()>,
743) -> Result<(), Error> {
744 let acme = tls
745 .acme
746 .as_ref()
747 .expect("serve_acme called without [tls.acme]");
748 Err(Error::Config(format!(
749 "[tls.acme] is configured for {n} domain(s) but ACME runtime support \
750 is pending a follow-up PR (rustls-acme version pin). For now, use \
751 certbot in TLS-ALPN-01 mode and `[tls] cert`/`key` pointing at the \
752 certbot output; cert hot-reload handles renewals without restart.",
753 n = acme.domains.len(),
754 )))
755}
756
757#[derive(Debug)]
763struct SniResolver {
764 entries: Vec<(String, Arc<rustls::sign::CertifiedKey>)>,
767 default_key: Arc<rustls::sign::CertifiedKey>,
768}
769
770impl rustls::server::ResolvesServerCert for SniResolver {
771 fn resolve(
772 &self,
773 client_hello: rustls::server::ClientHello<'_>,
774 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
775 let sni = client_hello
776 .server_name()
777 .unwrap_or("")
778 .to_ascii_lowercase();
779 for (pattern, key) in &self.entries {
780 if matches_pattern(&sni, pattern) {
781 return Some(key.clone());
782 }
783 }
784 Some(self.default_key.clone())
785 }
786}
787
788fn build_sni_resolver(tls: &TlsConfig) -> std::io::Result<SniResolver> {
789 let default_key = load_certified_key(&tls.cert, &tls.key)?;
790 let mut entries = Vec::with_capacity(tls.additional_certs.len());
791 for entry in &tls.additional_certs {
792 let key = load_certified_key(&entry.cert, &entry.key)?;
793 entries.push((entry.server_name.to_ascii_lowercase(), key));
794 }
795 tracing::info!(
796 default_cert = %tls.cert.display(),
797 additional = tls.additional_certs.len(),
798 "tls: SNI resolver active"
799 );
800 Ok(SniResolver {
801 entries,
802 default_key,
803 })
804}
805
806fn load_certified_key(
807 cert_path: &std::path::Path,
808 key_path: &std::path::Path,
809) -> std::io::Result<Arc<rustls::sign::CertifiedKey>> {
810 use std::io::BufReader;
811
812 let cert_file = std::fs::File::open(cert_path).map_err(|e| {
813 std::io::Error::new(
814 e.kind(),
815 format!("opening cert {}: {e}", cert_path.display()),
816 )
817 })?;
818 let mut cert_reader = BufReader::new(cert_file);
819 let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
820 rustls_pemfile::certs(&mut cert_reader).collect::<std::io::Result<_>>()?;
821 if certs.is_empty() {
822 return Err(std::io::Error::new(
823 std::io::ErrorKind::InvalidData,
824 format!("no certificates in {}", cert_path.display()),
825 ));
826 }
827
828 let key_file = std::fs::File::open(key_path).map_err(|e| {
829 std::io::Error::new(e.kind(), format!("opening key {}: {e}", key_path.display()))
830 })?;
831 let mut key_reader = BufReader::new(key_file);
832 let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
833 std::io::Error::new(
834 std::io::ErrorKind::InvalidData,
835 format!("no private key in {}", key_path.display()),
836 )
837 })?;
838
839 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
840 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("sign: {e}")))?;
841
842 Ok(Arc::new(rustls::sign::CertifiedKey::new(
843 certs,
844 signing_key,
845 )))
846}
847
848fn watch_tls_certs(
856 config: axum_server::tls_rustls::RustlsConfig,
857 cert: std::path::PathBuf,
858 key: std::path::PathBuf,
859 watch_paths: [std::path::PathBuf; 2],
860) -> std::io::Result<()> {
861 use notify::{RecursiveMode, Watcher};
862 use std::sync::mpsc::channel;
863
864 let (tx, rx) = channel::<notify::Result<notify::Event>>();
865 let mut watcher = notify::recommended_watcher(move |res| {
866 let _ = tx.send(res);
867 })
868 .map_err(|e| std::io::Error::other(format!("notify init: {e}")))?;
869
870 for p in &watch_paths {
874 if let Some(parent) = p.parent() {
875 watcher
876 .watch(parent, RecursiveMode::NonRecursive)
877 .map_err(|e| std::io::Error::other(format!("notify watch: {e}")))?;
878 }
879 }
880
881 let runtime = tokio::runtime::Handle::try_current().ok();
882 while let Ok(event) = rx.recv() {
883 let Ok(event) = event else { continue };
884 let touches_us = event.paths.iter().any(|p| p == &cert || p == &key);
887 if !touches_us {
888 continue;
889 }
890 tracing::info!(
891 cert = %cert.display(),
892 key = %key.display(),
893 "tls cert change detected — reloading"
894 );
895 let cert = cert.clone();
896 let key = key.clone();
897 let config = config.clone();
898 let reload = async move {
899 if let Err(e) = config.reload_from_pem_file(&cert, &key).await {
900 tracing::warn!(error = %e, "tls reload failed");
901 } else {
902 tracing::info!("tls cert reloaded successfully");
903 }
904 };
905 if let Some(rt) = &runtime {
906 rt.spawn(reload);
907 } else {
908 std::thread::spawn(|| {
911 let rt = tokio::runtime::Builder::new_current_thread()
912 .enable_all()
913 .build();
914 if let Ok(rt) = rt {
915 rt.block_on(reload);
916 }
917 });
918 }
919 }
920 Ok(())
921}
922
923pub struct RateLimiter {
926 state: moka::sync::Cache<String, (Instant, u32)>,
928 default_rule: Option<RateRule>,
929 route_rules: Vec<(MatchKey, RateRule)>,
930}
931
932#[derive(Clone, Copy)]
933struct RateRule {
934 count: u32,
935 window: Duration,
936}
937
938#[derive(Clone)]
939struct MatchKey {
940 method: Option<Method>,
941 path: String,
942}
943
944impl RateLimiter {
945 pub fn from_config(cfg: &RateLimitConfig) -> Self {
946 let default_rule = cfg.per_ip.as_deref().and_then(|s| {
947 crate::server_config::parse_rate(s)
948 .map(|(count, window)| RateRule { count, window })
949 .ok()
950 });
951 let route_rules = cfg
952 .routes
953 .iter()
954 .filter_map(|(spec, rate)| {
955 let (count, window) = crate::server_config::parse_rate(rate).ok()?;
956 let (method, path) = parse_route_spec(spec);
957 Some((MatchKey { method, path }, RateRule { count, window }))
958 })
959 .collect();
960
961 Self {
962 state: moka::sync::Cache::builder()
963 .max_capacity(10_000)
964 .time_to_idle(Duration::from_secs(600))
965 .build(),
966 default_rule,
967 route_rules,
968 }
969 }
970
971 fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
972 for (key, rule) in &self.route_rules {
973 if key.path == path && key.method.as_ref().is_none_or(|m| m == method) {
974 return Some(*rule);
975 }
976 }
977 self.default_rule
978 }
979
980 fn check(&self, bucket: &str, rule: RateRule) -> bool {
981 let now = Instant::now();
982 let mut allowed = true;
983 self.state
984 .entry(bucket.to_string())
985 .and_compute_with(|existing| match existing {
986 Some(entry) => {
987 let (window_end, count) = entry.into_value();
988 if now >= window_end {
989 moka::ops::compute::Op::Put((
990 now + rule.window,
991 rule.count.saturating_sub(1),
992 ))
993 } else if count > 0 {
994 moka::ops::compute::Op::Put((window_end, count - 1))
995 } else {
996 allowed = false;
997 moka::ops::compute::Op::Put((window_end, 0))
998 }
999 }
1000 None => {
1001 moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
1002 }
1003 });
1004 allowed
1005 }
1006}
1007
1008fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
1009 let trimmed = spec.trim();
1010 if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
1011 let method = m.parse::<Method>().ok();
1012 (method, p.trim().to_string())
1013 } else {
1014 (None, trimmed.to_string())
1015 }
1016}
1017
1018async fn rate_limit_mw(
1019 limiter: Arc<RateLimiter>,
1020 trusted: Arc<Vec<ipnet::IpNet>>,
1021 req: Request<Body>,
1022 next: Next,
1023) -> Response<Body> {
1024 let method = req.method().clone();
1025 let path = req.uri().path().to_string();
1026 let bucket = format!("{}|{}|{}", client_ip(&req, &trusted), method, path);
1027
1028 if let Some(rule) = limiter.rule_for(&method, &path) {
1029 if !limiter.check(&bucket, rule) {
1030 tracing::debug!(%method, %path, %bucket, "rate limited");
1031 let mut resp = Response::new(Body::from("rate limit exceeded"));
1032 *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
1033 return resp;
1034 }
1035 }
1036 next.run(req).await
1037}
1038
1039fn client_ip(req: &Request<Body>, trusted: &[ipnet::IpNet]) -> String {
1052 let peer: Option<std::net::IpAddr> = req
1053 .extensions()
1054 .get::<ConnectInfo<SocketAddr>>()
1055 .map(|ci| ci.0.ip());
1056
1057 let peer_trusted = match peer {
1058 Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
1059 None => false,
1060 };
1061
1062 if peer_trusted {
1063 if let Some(v) = req.headers().get("x-forwarded-for") {
1064 if let Ok(s) = v.to_str() {
1065 if let Some(first) = s.split(',').next() {
1066 let candidate = first.trim();
1067 if !candidate.is_empty() {
1068 return candidate.to_string();
1069 }
1070 }
1071 }
1072 }
1073 }
1074
1075 peer.map(|a| a.to_string())
1076 .unwrap_or_else(|| "unknown".into())
1077}
1078
1079async fn route_timeout_mw(
1085 rules: Arc<Vec<RouteTimeoutRule>>,
1086 req: Request<Body>,
1087 next: Next,
1088) -> Response<Body> {
1089 let path = req.uri().path().to_string();
1090 let matching = rules.iter().find(|r| path.starts_with(&r.prefix));
1091 match matching {
1092 Some(rule) => {
1093 let timeout = rule.timeout;
1094 match tokio::time::timeout(timeout, next.run(req)).await {
1095 Ok(resp) => resp,
1096 Err(_) => {
1097 tracing::debug!(
1098 path,
1099 timeout_ms = timeout.as_millis(),
1100 "route timeout exceeded"
1101 );
1102 let mut resp = Response::new(Body::from("request timed out"));
1103 *resp.status_mut() = StatusCode::REQUEST_TIMEOUT;
1104 resp
1105 }
1106 }
1107 }
1108 None => next.run(req).await,
1109 }
1110}
1111
1112async fn request_id_mw(mut req: Request<Body>, next: Next) -> Response<Body> {
1117 const HEADER: &str = "x-request-id";
1118
1119 let inbound = req
1120 .headers()
1121 .get(HEADER)
1122 .and_then(|v| v.to_str().ok())
1123 .map(|s| s.to_string());
1124
1125 let request_id = inbound.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
1126
1127 if let Ok(v) = HeaderValue::from_str(&request_id) {
1128 req.headers_mut()
1129 .insert(HeaderName::from_static(HEADER), v.clone());
1130 req.extensions_mut().insert(RequestId(request_id.clone()));
1134
1135 let mut resp = next.run(req).await;
1136 resp.headers_mut()
1137 .insert(HeaderName::from_static(HEADER), v);
1138 resp
1139 } else {
1140 next.run(req).await
1141 }
1142}
1143
1144#[derive(Debug, Clone)]
1149pub struct RequestId(pub String);
1150
1151async fn access_log_mw(
1152 format: AccessLogFormat,
1153 trusted: Arc<Vec<ipnet::IpNet>>,
1154 req: Request<Body>,
1155 next: Next,
1156) -> Response<Body> {
1157 let started = Instant::now();
1158 let method = req.method().clone();
1159 let path = req.uri().path().to_string();
1160 let host = req
1161 .headers()
1162 .get("host")
1163 .and_then(|v| v.to_str().ok())
1164 .unwrap_or("-")
1165 .to_string();
1166 let referer = req
1167 .headers()
1168 .get("referer")
1169 .and_then(|v| v.to_str().ok())
1170 .map(String::from);
1171 let ua = req
1172 .headers()
1173 .get("user-agent")
1174 .and_then(|v| v.to_str().ok())
1175 .map(String::from);
1176 let ip = client_ip(&req, &trusted);
1177 let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
1178
1179 let resp = next.run(req).await;
1180 let elapsed = started.elapsed();
1181 let status = resp.status().as_u16();
1182 let bytes = response_size(resp.headers()).unwrap_or(0);
1183
1184 match format {
1185 AccessLogFormat::Combined => {
1186 tracing::info!(
1187 target: "access_log",
1188 "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
1189 ip,
1190 method,
1191 path,
1192 status,
1193 bytes,
1194 referer.as_deref().unwrap_or("-"),
1195 ua.as_deref().unwrap_or("-"),
1196 elapsed.as_millis(),
1197 );
1198 }
1199 AccessLogFormat::Json => {
1200 tracing::info!(
1201 target: "access_log",
1202 json = %serde_json::json!({
1203 "ip": ip,
1204 "method": method.as_str(),
1205 "path": path,
1206 "host": host,
1207 "status": status,
1208 "bytes": bytes,
1209 "referer": referer,
1210 "user_agent": ua,
1211 "duration_ms": elapsed.as_millis(),
1212 "request_id": request_id,
1213 }),
1214 "request"
1215 );
1216 }
1217 AccessLogFormat::Off => {}
1218 }
1219 resp
1220}
1221
1222fn response_size(headers: &HeaderMap) -> Option<u64> {
1223 headers
1224 .get("content-length")
1225 .and_then(|v| v.to_str().ok())
1226 .and_then(|s| s.parse().ok())
1227}
1228
1229#[derive(Clone)]
1232struct CompiledRewrite {
1233 pattern: regex::Regex,
1234 to: String,
1235 status: Option<u16>,
1236 match_query: bool,
1237}
1238
1239struct CompiledRewrites {
1240 rules: Vec<CompiledRewrite>,
1241}
1242
1243impl CompiledRewrites {
1244 fn compile(rules: &[RewriteRule]) -> Self {
1245 let compiled = rules
1246 .iter()
1247 .filter_map(|r| match regex::Regex::new(&r.from) {
1248 Ok(pattern) => Some(CompiledRewrite {
1249 pattern,
1250 to: r.to.clone(),
1251 status: r.status,
1252 match_query: r.match_query,
1253 }),
1254 Err(e) => {
1255 tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
1256 None
1257 }
1258 })
1259 .collect();
1260 Self { rules: compiled }
1261 }
1262}
1263
1264async fn rewrite_mw(
1265 rules: Arc<CompiledRewrites>,
1266 mut req: Request<Body>,
1267 next: Next,
1268) -> Response<Body> {
1269 let path = req.uri().path().to_string();
1270 let path_and_query = req
1271 .uri()
1272 .path_and_query()
1273 .map(|p| p.as_str().to_string())
1274 .unwrap_or_else(|| path.clone());
1275
1276 let target_str_path = path.clone();
1277 let target_str_full = path_and_query.clone();
1278
1279 let mut applied: Option<(String, Option<u16>)> = None;
1281 for rule in &rules.rules {
1282 let subject = if rule.match_query {
1283 &target_str_full
1284 } else {
1285 &target_str_path
1286 };
1287 if rule.pattern.is_match(subject) {
1288 let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
1289 applied = Some((replaced, rule.status));
1290 break;
1291 }
1292 }
1293
1294 let Some((new_target, status)) = applied else {
1295 return next.run(req).await;
1296 };
1297
1298 match status {
1299 Some(code @ (301 | 302 | 303 | 307 | 308)) => {
1300 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1301 *resp.status_mut() =
1302 StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
1303 if let Ok(loc) = HeaderValue::from_str(&new_target) {
1304 resp.headers_mut().insert("location", loc);
1305 }
1306 resp
1307 }
1308 _ => {
1309 let mut parts = req.uri().clone().into_parts();
1311 if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1312 parts.path_and_query = Some(new_pq);
1313 }
1314 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1315 *req.uri_mut() = new_uri;
1316 }
1317 next.run(req).await
1318 }
1319 }
1320}
1321
1322async fn trailing_slash_mw(
1325 cfg: TrailingSlashConfig,
1326 mut req: Request<Body>,
1327 next: Next,
1328) -> Response<Body> {
1329 let path = req.uri().path().to_string();
1330 if path == "/" {
1331 return next.run(req).await;
1332 }
1333
1334 let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
1335 let has_slash = path.ends_with('/');
1336
1337 if want_slash == has_slash {
1338 return next.run(req).await;
1339 }
1340
1341 let new_path = if want_slash {
1342 format!("{path}/")
1343 } else {
1344 path.trim_end_matches('/').to_string()
1345 };
1346
1347 let query = req
1348 .uri()
1349 .query()
1350 .map(|q| format!("?{q}"))
1351 .unwrap_or_default();
1352 let new_target = format!("{new_path}{query}");
1353
1354 match cfg.action {
1355 TrailingSlashAction::Redirect => {
1356 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1357 *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
1358 if let Ok(loc) = HeaderValue::from_str(&new_target) {
1359 resp.headers_mut().insert("location", loc);
1360 }
1361 resp
1362 }
1363 TrailingSlashAction::Rewrite => {
1364 let mut parts = req.uri().clone().into_parts();
1365 if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1366 parts.path_and_query = Some(pq);
1367 }
1368 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1369 *req.uri_mut() = new_uri;
1370 }
1371 next.run(req).await
1372 }
1373 }
1374}
1375
1376struct LoadedErrorPages {
1379 by_status: std::collections::HashMap<u16, (String, &'static str)>,
1380}
1381
1382fn load_error_pages(
1383 raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
1384) -> LoadedErrorPages {
1385 let mut by_status = std::collections::HashMap::new();
1386 for (key, path) in raw {
1387 let Ok(code) = key.parse::<u16>() else {
1388 tracing::warn!(key, "error_pages: invalid status code, skipping");
1389 continue;
1390 };
1391 let body = match std::fs::read_to_string(path) {
1392 Ok(s) => s,
1393 Err(e) => {
1394 tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
1395 continue;
1396 }
1397 };
1398 let content_type = guess_content_type(path);
1399 by_status.insert(code, (body, content_type));
1400 }
1401 LoadedErrorPages { by_status }
1402}
1403
1404fn guess_content_type(path: &std::path::Path) -> &'static str {
1405 match path.extension().and_then(|e| e.to_str()) {
1406 Some("html") | Some("htm") => "text/html; charset=utf-8",
1407 Some("json") => "application/json",
1408 Some("txt") => "text/plain; charset=utf-8",
1409 _ => "text/plain; charset=utf-8",
1410 }
1411}
1412
1413async fn error_pages_mw(
1414 pages: Arc<LoadedErrorPages>,
1415 req: Request<Body>,
1416 next: Next,
1417) -> Response<Body> {
1418 let resp = next.run(req).await;
1419 let status = resp.status().as_u16();
1420
1421 let Some((body, ctype)) = pages.by_status.get(&status) else {
1422 return resp;
1423 };
1424
1425 let mut out = Response::new(Body::from(body.clone()));
1426 *out.status_mut() = resp.status();
1427 if let Ok(ct) = HeaderValue::from_str(ctype) {
1428 out.headers_mut().insert("content-type", ct);
1429 }
1430 for h in ["cache-control", "x-request-id"] {
1432 if let Some(v) = resp.headers().get(h) {
1433 out.headers_mut().insert(h, v.clone());
1434 }
1435 }
1436 out
1437}
1438
1439#[derive(Clone)]
1442struct CompiledProxy {
1443 prefix: String,
1444 upstream: String,
1445 strip_prefix: bool,
1446 preserve_host: bool,
1447 timeout: Duration,
1448 retries: u8,
1449}
1450
1451struct CompiledProxies {
1452 rules: Vec<CompiledProxy>,
1453 client: reqwest::Client,
1454}
1455
1456impl CompiledProxies {
1457 fn compile(rules: &[ProxyRule]) -> Self {
1458 let mut compiled: Vec<CompiledProxy> = rules
1459 .iter()
1460 .map(|r| CompiledProxy {
1461 prefix: r.prefix.clone(),
1462 upstream: r.upstream.trim_end_matches('/').to_string(),
1463 strip_prefix: r.strip_prefix,
1464 preserve_host: r.preserve_host,
1465 timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
1466 retries: r.retries,
1467 })
1468 .collect();
1469 compiled.sort_by_key(|r| std::cmp::Reverse(r.prefix.len()));
1471
1472 let client = reqwest::Client::builder()
1473 .redirect(reqwest::redirect::Policy::none())
1474 .build()
1475 .unwrap_or_else(|_| reqwest::Client::new());
1476
1477 Self {
1478 rules: compiled,
1479 client,
1480 }
1481 }
1482
1483 fn matching(&self, path: &str) -> Option<&CompiledProxy> {
1484 self.rules.iter().find(|r| path.starts_with(&r.prefix))
1485 }
1486}
1487
1488async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
1489 let path = req.uri().path().to_string();
1490 let Some(rule) = proxies.matching(&path) else {
1491 return next.run(req).await;
1492 };
1493 let rule = rule.clone();
1494
1495 match proxy_forward(&proxies.client, &rule, req).await {
1496 Ok(resp) => resp,
1497 Err(e) => {
1498 tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
1499 let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
1500 *resp.status_mut() = StatusCode::BAD_GATEWAY;
1501 resp
1502 }
1503 }
1504}
1505
1506async fn proxy_forward(
1507 client: &reqwest::Client,
1508 rule: &CompiledProxy,
1509 req: Request<Body>,
1510) -> Result<Response<Body>, String> {
1511 let (parts, body) = req.into_parts();
1512 let body_bytes = axum::body::to_bytes(body, usize::MAX)
1513 .await
1514 .map_err(|e| format!("body read: {e}"))?;
1515
1516 let original_path = parts.uri.path();
1517 let upstream_path = if rule.strip_prefix {
1518 original_path
1519 .strip_prefix(&rule.prefix)
1520 .unwrap_or(original_path)
1521 } else {
1522 original_path
1523 };
1524 let upstream_path = if upstream_path.is_empty() {
1525 "/"
1526 } else {
1527 upstream_path
1528 };
1529 let query = parts
1530 .uri
1531 .query()
1532 .map(|q| format!("?{q}"))
1533 .unwrap_or_default();
1534 let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
1535
1536 let method = parts.method.clone();
1537 let mut last_err = String::new();
1538 for attempt in 0..=rule.retries {
1539 let mut request = client
1540 .request(
1541 reqwest::Method::from_bytes(method.as_str().as_bytes())
1542 .unwrap_or(reqwest::Method::GET),
1543 &upstream_url,
1544 )
1545 .timeout(rule.timeout)
1546 .body(body_bytes.clone());
1547
1548 for (name, value) in parts.headers.iter() {
1549 let n = name.as_str().to_ascii_lowercase();
1551 if matches!(
1552 n.as_str(),
1553 "connection"
1554 | "keep-alive"
1555 | "proxy-authenticate"
1556 | "proxy-authorization"
1557 | "te"
1558 | "trailers"
1559 | "transfer-encoding"
1560 | "upgrade"
1561 | "content-length"
1562 ) {
1563 continue;
1564 }
1565 if !rule.preserve_host && n == "host" {
1566 continue;
1567 }
1568 if let Ok(v) = value.to_str() {
1569 request = request.header(name.as_str(), v);
1570 }
1571 }
1572
1573 if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1575 request = request.header("x-forwarded-host", host);
1576 }
1577 request = request.header("x-forwarded-proto", "http");
1578
1579 match request.send().await {
1580 Ok(resp) => return upstream_to_axum(resp).await,
1581 Err(e) => {
1582 last_err = format!("attempt {} → {e}", attempt + 1);
1583 tracing::debug!(error = %e, attempt, "proxy retry");
1584 continue;
1585 }
1586 }
1587 }
1588 Err(last_err)
1589}
1590
1591async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1594 let origin = req
1595 .headers()
1596 .get("origin")
1597 .and_then(|v| v.to_str().ok())
1598 .map(String::from);
1599
1600 let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1601 cfg.allow_origins
1602 .iter()
1603 .any(|allowed| allowed == "*" || allowed == o)
1604 });
1605
1606 if req.method() == Method::OPTIONS && origin.is_some() {
1608 let mut resp = Response::new(Body::empty());
1609 *resp.status_mut() = StatusCode::NO_CONTENT;
1610 apply_cors_headers(
1611 resp.headers_mut(),
1612 &cfg,
1613 origin.as_deref(),
1614 is_allowed_origin,
1615 );
1616 return resp;
1617 }
1618
1619 let mut resp = next.run(req).await;
1620 apply_cors_headers(
1621 resp.headers_mut(),
1622 &cfg,
1623 origin.as_deref(),
1624 is_allowed_origin,
1625 );
1626 resp
1627}
1628
1629fn apply_cors_headers(
1630 headers: &mut HeaderMap,
1631 cfg: &CorsConfig,
1632 origin: Option<&str>,
1633 is_allowed_origin: bool,
1634) {
1635 if !is_allowed_origin {
1636 return;
1637 }
1638 if let Some(origin) = origin {
1639 if let Ok(v) = HeaderValue::from_str(origin) {
1640 headers.insert("access-control-allow-origin", v);
1641 }
1642 headers.insert("vary", HeaderValue::from_static("Origin"));
1643 } else if cfg.allow_origins.iter().any(|o| o == "*") {
1644 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1645 }
1646
1647 let methods = if cfg.allow_methods.is_empty() {
1648 "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1649 } else {
1650 cfg.allow_methods.join(", ")
1651 };
1652 if let Ok(v) = HeaderValue::from_str(&methods) {
1653 headers.insert("access-control-allow-methods", v);
1654 }
1655
1656 let allow_headers = if cfg.allow_headers.is_empty() {
1657 "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1658 } else {
1659 cfg.allow_headers.join(", ")
1660 };
1661 if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1662 headers.insert("access-control-allow-headers", v);
1663 }
1664
1665 if !cfg.expose_headers.is_empty() {
1666 if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1667 headers.insert("access-control-expose-headers", v);
1668 }
1669 }
1670
1671 if cfg.allow_credentials {
1672 headers.insert(
1673 "access-control-allow-credentials",
1674 HeaderValue::from_static("true"),
1675 );
1676 }
1677
1678 if let Some(max_age) = cfg.max_age {
1679 if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1680 headers.insert("access-control-max-age", v);
1681 }
1682 }
1683}
1684
1685async fn ip_rules_mw(
1688 rules: Arc<Vec<IpRule>>,
1689 trusted: Arc<Vec<ipnet::IpNet>>,
1690 req: Request<Body>,
1691 next: Next,
1692) -> Response<Body> {
1693 let path = req.uri().path().to_string();
1694 let ip_str = client_ip(&req, &trusted);
1695 let ip = ip_str.parse::<std::net::IpAddr>().ok();
1696
1697 for rule in rules.iter() {
1698 if !path.starts_with(&rule.prefix) {
1699 continue;
1700 }
1701 let matches_range = ip
1702 .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1703 .unwrap_or(false);
1704 let allowed = match rule.action {
1705 IpAction::Allow => matches_range,
1706 IpAction::Deny => !matches_range,
1707 };
1708 if !allowed {
1709 tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1710 let mut resp = Response::new(Body::from("forbidden"));
1711 *resp.status_mut() = StatusCode::FORBIDDEN;
1712 return resp;
1713 }
1714 break;
1716 }
1717
1718 next.run(req).await
1719}
1720
1721use base64::engine::general_purpose::STANDARD as B64;
1724use base64::Engine as _;
1725
1726struct CompiledBasicAuth {
1727 rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1728}
1729
1730fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1731 let compiled = rules
1732 .iter()
1733 .map(|r| {
1734 let creds = r
1735 .credentials
1736 .iter()
1737 .filter_map(|c| {
1738 c.split_once(':')
1739 .map(|(u, p)| (u.to_string(), p.to_string()))
1740 })
1741 .collect();
1742 (r.clone(), creds)
1743 })
1744 .collect();
1745 CompiledBasicAuth { rules: compiled }
1746}
1747
1748async fn basic_auth_mw(
1749 rules: Arc<CompiledBasicAuth>,
1750 req: Request<Body>,
1751 next: Next,
1752) -> Response<Body> {
1753 let path = req.uri().path().to_string();
1754 for (rule, creds) in &rules.rules {
1755 if !path.starts_with(&rule.prefix) {
1756 continue;
1757 }
1758 let supplied = req
1759 .headers()
1760 .get("authorization")
1761 .and_then(|v| v.to_str().ok())
1762 .and_then(|s| s.strip_prefix("Basic "))
1763 .and_then(|b64| B64.decode(b64).ok())
1764 .and_then(|bytes| String::from_utf8(bytes).ok())
1765 .and_then(|pair| {
1766 pair.split_once(':')
1767 .map(|(u, p)| (u.to_string(), p.to_string()))
1768 });
1769
1770 let ok = supplied
1771 .as_ref()
1772 .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1773 .unwrap_or(false);
1774
1775 if ok {
1776 return next.run(req).await;
1777 }
1778
1779 let challenge = format!("Basic realm=\"{}\"", rule.realm);
1780 let mut resp = Response::new(Body::from("authentication required"));
1781 *resp.status_mut() = StatusCode::UNAUTHORIZED;
1782 if let Ok(v) = HeaderValue::from_str(&challenge) {
1783 resp.headers_mut().insert("www-authenticate", v);
1784 }
1785 return resp;
1786 }
1787 next.run(req).await
1788}
1789
1790async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1791 let status = resp.status();
1792 let headers = resp.headers().clone();
1793 let bytes = resp
1794 .bytes()
1795 .await
1796 .map_err(|e| format!("upstream body: {e}"))?;
1797 let mut out = Response::new(Body::from(bytes));
1798 *out.status_mut() =
1799 axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1800 for (name, value) in headers.iter() {
1801 let n = name.as_str().to_ascii_lowercase();
1802 if matches!(
1803 n.as_str(),
1804 "connection"
1805 | "keep-alive"
1806 | "proxy-authenticate"
1807 | "proxy-authorization"
1808 | "te"
1809 | "trailers"
1810 | "transfer-encoding"
1811 | "upgrade"
1812 ) {
1813 continue;
1814 }
1815 if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1816 if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1817 out.headers_mut().append(name, v);
1818 }
1819 }
1820 }
1821 Ok(out)
1822}
1823
1824#[cfg(test)]
1825mod tests {
1826 use super::*;
1827 use std::net::{IpAddr, Ipv4Addr};
1828
1829 fn make_req(peer: Option<SocketAddr>, xff: Option<&str>) -> Request<Body> {
1830 let mut req = Request::builder();
1831 if let Some(v) = xff {
1832 req = req.header("x-forwarded-for", v);
1833 }
1834 let mut req = req.body(Body::empty()).unwrap();
1835 if let Some(addr) = peer {
1836 req.extensions_mut().insert(ConnectInfo(addr));
1837 }
1838 req
1839 }
1840
1841 fn cidr(s: &str) -> ipnet::IpNet {
1842 s.parse().unwrap()
1843 }
1844
1845 #[test]
1846 fn xff_ignored_when_peer_is_not_trusted() {
1847 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 12345);
1849 let req = make_req(Some(peer), Some("198.51.100.1"));
1850 let trusted = vec![cidr("10.0.0.0/8")];
1851
1852 assert_eq!(client_ip(&req, &trusted), "203.0.113.5");
1853 }
1854
1855 #[test]
1856 fn xff_honored_when_peer_is_a_trusted_proxy() {
1857 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1859 let req = make_req(Some(peer), Some("198.51.100.1, 10.0.0.5"));
1860 let trusted = vec![cidr("10.0.0.0/8")];
1861
1862 assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1863 }
1864
1865 #[test]
1866 fn empty_trusted_list_means_xff_is_never_honored() {
1867 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1869 let req = make_req(Some(peer), Some("198.51.100.1"));
1870 let trusted: Vec<ipnet::IpNet> = vec![];
1871
1872 assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1873 }
1874
1875 #[test]
1876 fn no_xff_falls_back_to_peer() {
1877 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1878 let req = make_req(Some(peer), None);
1879 let trusted = vec![cidr("10.0.0.0/8")];
1880
1881 assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1882 }
1883
1884 #[test]
1885 fn missing_connect_info_returns_unknown() {
1886 let req = make_req(None, Some("198.51.100.1"));
1889 let trusted = vec![cidr("10.0.0.0/8")];
1890
1891 assert_eq!(client_ip(&req, &trusted), "unknown");
1892 }
1893
1894 #[test]
1895 fn xff_with_whitespace_and_multiple_hops_picks_first() {
1896 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1897 let req = make_req(Some(peer), Some(" 198.51.100.1 ,10.0.0.5, 10.0.0.7"));
1898 let trusted = vec![cidr("10.0.0.0/8")];
1899
1900 assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1901 }
1902}