1use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22 AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23 RateLimitConfig, RewriteRule, RouteTimeoutRule, ServerConfig, StaticMount, TlsConfig,
24 TrailingSlashAction, TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27use axum::extract::ConnectInfo;
28
29pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
32 let mut router = web;
33
34 for (prefix, mount) in &cfg.static_files {
37 router = mount_static(router, prefix, mount);
38 }
39
40 let body_max = cfg.limits.body_max as usize;
42 router = router
43 .layer(RequestBodyLimitLayer::new(body_max))
44 .layer(TraceLayer::new_for_http());
45
46 if let Some(timeout) = cfg.limits.request_timeout {
47 router = router.layer(TimeoutLayer::new(timeout));
48 }
49
50 if let Some(max) = cfg.limits.max_concurrency {
57 let max = max.max(1) as usize;
58 let limiter = Arc::new(tokio::sync::Semaphore::new(max));
59 let limiter_clone = limiter.clone();
60 router = router.layer(axum::middleware::from_fn(
61 move |req: Request<Body>, next: Next| {
62 let limiter = limiter_clone.clone();
63 async move {
64 match limiter.try_acquire() {
65 Ok(_permit) => next.run(req).await,
66 Err(_) => {
67 let mut resp = Response::new(Body::from(
68 "service overloaded — too many concurrent requests",
69 ));
70 *resp.status_mut() = StatusCode::SERVICE_UNAVAILABLE;
71 resp
72 }
73 }
74 }
75 },
76 ));
77 }
78
79 if !cfg.route_timeouts.is_empty() {
84 let rules: Arc<Vec<RouteTimeoutRule>> = Arc::new(cfg.route_timeouts.clone());
85 let rules_clone = rules.clone();
86 router = router.layer(axum::middleware::from_fn(
87 move |req: Request<Body>, next: Next| {
88 let rules = rules_clone.clone();
89 async move { route_timeout_mw(rules, req, next).await }
90 },
91 ));
92 }
93
94 if !cfg.server_name.is_empty() {
97 let allowed = cfg.server_name.clone();
98 router = router.layer(axum::middleware::from_fn(
99 move |req: Request<Body>, next: Next| {
100 let allowed = allowed.clone();
101 async move { host_match_mw(allowed, req, next).await }
102 },
103 ));
104 }
105
106 let trusted: Arc<Vec<ipnet::IpNet>> = Arc::new(cfg.trusted_proxies.ranges.clone());
111
112 if !cfg.ip_rules.is_empty() {
115 let rules = Arc::new(cfg.ip_rules.clone());
116 let rules_clone = rules.clone();
117 let trusted_clone = trusted.clone();
118 router = router.layer(axum::middleware::from_fn(
119 move |req: Request<Body>, next: Next| {
120 let rules = rules_clone.clone();
121 let trusted = trusted_clone.clone();
122 async move { ip_rules_mw(rules, trusted, req, next).await }
123 },
124 ));
125 }
126 if !cfg.basic_auth.is_empty() {
127 let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
128 let rules_clone = rules.clone();
129 router = router.layer(axum::middleware::from_fn(
130 move |req: Request<Body>, next: Next| {
131 let rules = rules_clone.clone();
132 async move { basic_auth_mw(rules, req, next).await }
133 },
134 ));
135 }
136
137 if cfg.cors.enabled {
140 let cors = Arc::new(cfg.cors.clone());
141 let cors_clone = cors.clone();
142 router = router.layer(axum::middleware::from_fn(
143 move |req: Request<Body>, next: Next| {
144 let cors = cors_clone.clone();
145 async move { cors_mw(cors, req, next).await }
146 },
147 ));
148 }
149
150 if !cfg.proxies.is_empty() {
153 let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
154 let proxies_clone = proxies.clone();
155 router = router.layer(axum::middleware::from_fn(
156 move |req: Request<Body>, next: Next| {
157 let proxies = proxies_clone.clone();
158 async move { proxy_mw(proxies, req, next).await }
159 },
160 ));
161 }
162
163 if !cfg.rewrites.is_empty() {
165 let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
166 let compiled_clone = compiled.clone();
167 router = router.layer(axum::middleware::from_fn(
168 move |req: Request<Body>, next: Next| {
169 let rules = compiled_clone.clone();
170 async move { rewrite_mw(rules, req, next).await }
171 },
172 ));
173 }
174
175 if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
177 let ts = cfg.trailing_slash.clone();
178 router = router.layer(axum::middleware::from_fn(
179 move |req: Request<Body>, next: Next| {
180 let ts = ts.clone();
181 async move { trailing_slash_mw(ts, req, next).await }
182 },
183 ));
184 }
185
186 if !cfg.error_pages.is_empty() {
189 let pages = Arc::new(load_error_pages(&cfg.error_pages));
190 let pages_clone = pages.clone();
191 router = router.layer(axum::middleware::from_fn(
192 move |req: Request<Body>, next: Next| {
193 let pages = pages_clone.clone();
194 async move { error_pages_mw(pages, req, next).await }
195 },
196 ));
197 }
198
199 if cfg.tls.is_some() && cfg.hsts.enabled {
201 if let Some(value) = build_hsts_header(&cfg.hsts) {
202 router = router.layer(SetResponseHeaderLayer::if_not_present(
203 HeaderName::from_static("strict-transport-security"),
204 value,
205 ));
206 }
207 }
208
209 if cfg.compression.enabled {
210 let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
215 let mut layer = CompressionLayer::new();
216 if !cfg
217 .compression
218 .algorithms
219 .iter()
220 .any(|a| a.eq_ignore_ascii_case("gzip"))
221 {
222 layer = layer.no_gzip();
223 }
224 if !cfg
225 .compression
226 .algorithms
227 .iter()
228 .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
229 {
230 layer = layer.no_br();
231 }
232 if !cfg
233 .compression
234 .algorithms
235 .iter()
236 .any(|a| a.eq_ignore_ascii_case("deflate"))
237 {
238 layer = layer.no_deflate();
239 }
240 let layer = layer.compress_when(SizeAbove::new(min_size));
241 router = router.layer(layer);
242 }
243
244 if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
245 let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
246 let limiter_clone = limiter.clone();
247 let trusted_clone = trusted.clone();
248 router = router.layer(axum::middleware::from_fn(
249 move |req: Request<Body>, next: Next| {
250 let limiter = limiter_clone.clone();
251 let trusted = trusted_clone.clone();
252 async move { rate_limit_mw(limiter, trusted, req, next).await }
253 },
254 ));
255 }
256
257 if matches!(
258 cfg.access_log.format,
259 AccessLogFormat::Combined | AccessLogFormat::Json
260 ) {
261 let format = cfg.access_log.format;
262 let trusted_clone = trusted.clone();
263 router = router.layer(axum::middleware::from_fn(
264 move |req: Request<Body>, next: Next| {
265 let trusted = trusted_clone.clone();
266 async move { access_log_mw(format, trusted, req, next).await }
267 },
268 ));
269 }
270
271 router = router.layer(axum::middleware::from_fn(request_id_mw));
276
277 router
278}
279
280fn mount_static(
281 router: AxumRouter<Container>,
282 prefix: &str,
283 mount: &StaticMount,
284) -> AxumRouter<Container> {
285 let _ = mount.ranges;
288
289 if let Some(fetcher) = crate::embedded::lookup(prefix) {
292 let cache = mount.cache;
293 let route_pat = format!("{}/*path", prefix.trim_end_matches('/'));
294 let nested = AxumRouter::<Container>::new().route(
295 &route_pat,
296 axum::routing::get(
297 move |axum::extract::Path(path): axum::extract::Path<String>,
298 headers: HeaderMap| async move {
299 serve_embedded(fetcher, cache, &path, &headers)
300 },
301 ),
302 );
303 return router.merge(nested);
304 }
305
306 let svc = ServeDir::new(&mount.dir);
307
308 let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
309 let nested = if let Some(cache) = mount.cache {
310 let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
311 .unwrap_or_else(|_| HeaderValue::from_static("public"));
312 nested.layer(SetResponseHeaderLayer::if_not_present(
313 HeaderName::from_static("cache-control"),
314 value,
315 ))
316 } else {
317 nested
318 };
319 router.merge(nested)
320}
321
322fn serve_embedded(
326 fetcher: crate::embedded::EmbeddedAssetFetcher,
327 cache: Option<Duration>,
328 path: &str,
329 headers: &HeaderMap,
330) -> Response<Body> {
331 let asset = match fetcher(path) {
332 Some(a) => a,
333 None => return not_found(),
334 };
335
336 if let (Some(client_tag), Some(asset_tag)) = (
337 headers
338 .get(axum::http::header::IF_NONE_MATCH)
339 .and_then(|v| v.to_str().ok()),
340 asset.etag.as_deref(),
341 ) {
342 if etag_matches(client_tag, asset_tag) {
343 let mut resp = Response::builder()
344 .status(StatusCode::NOT_MODIFIED)
345 .body(Body::empty())
346 .expect("304 body");
347 if let Some(d) = cache {
348 if let Ok(v) = HeaderValue::from_str(&format!("public, max-age={}", d.as_secs())) {
349 resp.headers_mut().insert("cache-control", v);
350 }
351 }
352 return resp;
353 }
354 }
355
356 let mut builder = Response::builder()
357 .status(StatusCode::OK)
358 .header("content-type", asset.content_type.as_str())
359 .header("content-length", asset.data.len());
360 if let Some(tag) = asset.etag.as_deref() {
361 builder = builder.header("etag", quote_etag(tag));
362 }
363 if let Some(d) = cache {
364 builder = builder.header("cache-control", format!("public, max-age={}", d.as_secs()));
365 }
366 builder
367 .body(Body::from(asset.data.into_owned()))
368 .unwrap_or_else(|_| not_found())
369}
370
371fn not_found() -> Response<Body> {
372 Response::builder()
373 .status(StatusCode::NOT_FOUND)
374 .body(Body::from("not found"))
375 .expect("404 body")
376}
377
378fn quote_etag(raw: &str) -> String {
379 if raw.starts_with('"') {
380 raw.to_string()
381 } else {
382 format!("\"{raw}\"")
383 }
384}
385
386fn etag_matches(client: &str, server: &str) -> bool {
387 let normalize = |s: &str| -> String {
388 s.split(',')
389 .map(|t| {
390 t.trim()
391 .trim_matches('"')
392 .trim_start_matches("W/")
393 .to_string()
394 })
395 .collect::<Vec<_>>()
396 .join(",")
397 };
398 let server_norm = normalize(server);
399 normalize(client)
400 .split(',')
401 .any(|tag| tag == server_norm || tag == "*")
402}
403
404pub async fn serve(
407 router: AxumRouter,
408 cfg: &ServerConfig,
409 shutdown: tokio::sync::oneshot::Receiver<()>,
410) -> Result<(), Error> {
411 let addr: SocketAddr = cfg
412 .bind
413 .parse()
414 .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
415
416 tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
417
418 let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
420 let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
421 tokio::spawn(async move {
422 let _ = shutdown.await;
423 let _ = shutdown_main_tx.send(());
424 let _ = shutdown_redir_tx.send(());
425 });
426
427 let redirect_task = cfg.redirect_http.clone().map(|redir| {
428 let target_host = redir
429 .target_host
430 .clone()
431 .or_else(|| cfg.server_name.first().cloned());
432 let permanent = redir.permanent;
433 let bind = redir.bind.clone();
434 tokio::spawn(async move {
435 if let Err(e) =
436 serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
437 {
438 tracing::warn!(?e, "redirect_http listener exited with error");
439 }
440 })
441 });
442
443 let main_result = if let Some(tls) = &cfg.tls {
444 if tls.acme.is_some() {
445 serve_acme(
446 router,
447 addr,
448 tls,
449 cfg.limits.drain_timeout,
450 shutdown_main_rx,
451 )
452 .await
453 } else {
454 serve_tls(
455 router,
456 addr,
457 tls,
458 cfg.limits.drain_timeout,
459 shutdown_main_rx,
460 )
461 .await
462 }
463 } else {
464 serve_plain(router, addr, shutdown_main_rx).await
465 };
466
467 if let Some(task) = redirect_task {
468 task.abort();
469 }
470
471 main_result
472}
473
474async fn serve_redirect_http(
477 bind: &str,
478 target_host: Option<String>,
479 permanent: bool,
480 shutdown: tokio::sync::oneshot::Receiver<()>,
481) -> Result<(), Error> {
482 let addr: SocketAddr = bind
483 .parse()
484 .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
485 tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
486
487 let target_host = Arc::new(target_host);
488 let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
489 let target_host = target_host.clone();
490 move |req: Request<Body>| {
491 let target_host = target_host.clone();
492 async move { http_redirect_handler(req, target_host, permanent).await }
493 }
494 }));
495
496 let listener = tokio::net::TcpListener::bind(addr).await?;
497 axum::serve(listener, router)
498 .with_graceful_shutdown(async move {
499 let _ = shutdown.await;
500 })
501 .await?;
502 Ok(())
503}
504
505async fn http_redirect_handler(
506 req: Request<Body>,
507 target_host: Arc<Option<String>>,
508 permanent: bool,
509) -> Response<Body> {
510 let host = target_host.as_ref().clone().unwrap_or_else(|| {
511 req.headers()
512 .get("host")
513 .and_then(|v| v.to_str().ok())
514 .map(String::from)
515 .unwrap_or_default()
516 });
517 let path_and_query = req
518 .uri()
519 .path_and_query()
520 .map(|p| p.as_str().to_string())
521 .unwrap_or_else(|| "/".to_string());
522 let location = format!("https://{host}{path_and_query}");
523
524 let status = if permanent {
525 StatusCode::MOVED_PERMANENTLY
526 } else {
527 StatusCode::FOUND
528 };
529 let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
530 *resp.status_mut() = status;
531 if let Ok(loc) = HeaderValue::from_str(&location) {
532 resp.headers_mut().insert("location", loc);
533 }
534 resp
535}
536
537fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
538 let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
539 let mut value = format!("max-age={}", max_age.as_secs());
540 if cfg.include_subdomains {
541 value.push_str("; includeSubDomains");
542 }
543 if cfg.preload {
544 value.push_str("; preload");
545 }
546 HeaderValue::from_str(&value).ok()
547}
548
549async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
551 let host = req
552 .headers()
553 .get("host")
554 .and_then(|v| v.to_str().ok())
555 .unwrap_or("")
556 .to_string();
557
558 let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
560
561 if matches_any(&host_no_port, &allowed) {
562 return next.run(req).await;
563 }
564
565 tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
566 let mut resp = Response::new(Body::from(format!(
567 "404 not found (unknown host: {host})\n"
568 )));
569 *resp.status_mut() = StatusCode::NOT_FOUND;
570 resp
571}
572
573fn matches_any(host: &str, patterns: &[String]) -> bool {
574 patterns.iter().any(|pat| matches_pattern(host, pat))
575}
576
577fn matches_pattern(host: &str, pattern: &str) -> bool {
580 let pattern = pattern.to_ascii_lowercase();
581 if pattern == "*" {
582 return true;
583 }
584 if let Some(suffix) = pattern.strip_prefix("*.") {
585 return host.ends_with(&format!(".{suffix}"));
587 }
588 host == pattern
589}
590
591async fn serve_plain(
592 router: AxumRouter,
593 addr: SocketAddr,
594 shutdown: tokio::sync::oneshot::Receiver<()>,
595) -> Result<(), Error> {
596 let listener = tokio::net::TcpListener::bind(addr).await?;
597 axum::serve(
598 listener,
599 router.into_make_service_with_connect_info::<SocketAddr>(),
600 )
601 .with_graceful_shutdown(async move {
602 let _ = shutdown.await;
603 })
604 .await?;
605 Ok(())
606}
607
608async fn serve_tls(
609 router: AxumRouter,
610 addr: SocketAddr,
611 tls: &TlsConfig,
612 drain: Duration,
613 shutdown: tokio::sync::oneshot::Receiver<()>,
614) -> Result<(), Error> {
615 let config = if tls.additional_certs.is_empty() {
620 axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
621 .await
622 .map_err(|e| Error::Config(format!("tls load: {e}")))?
623 } else {
624 let resolver = build_sni_resolver(tls)
625 .map_err(|e| Error::Config(format!("tls multi-cert load: {e}")))?;
626 let server_config = rustls::ServerConfig::builder()
627 .with_no_client_auth()
628 .with_cert_resolver(Arc::new(resolver));
629 axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(server_config))
630 };
631
632 let watch_paths = [tls.cert.clone(), tls.key.clone()];
638 let config_for_watch = config.clone();
639 let cert_path = tls.cert.clone();
640 let key_path = tls.key.clone();
641 tokio::task::spawn_blocking(move || {
642 if let Err(e) = watch_tls_certs(config_for_watch, cert_path, key_path, watch_paths) {
643 tracing::warn!(error = %e, "cert hot-reload watcher exited");
644 }
645 });
646
647 let handle = axum_server::Handle::new();
648 let handle_for_shutdown = handle.clone();
649 tokio::spawn(async move {
650 let _ = shutdown.await;
651 handle_for_shutdown.graceful_shutdown(Some(drain));
652 });
653
654 axum_server::bind_rustls(addr, config)
655 .handle(handle)
656 .serve(router.into_make_service_with_connect_info::<SocketAddr>())
657 .await
658 .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
659 Ok(())
660}
661
662async fn serve_acme(
674 _router: AxumRouter,
675 _addr: SocketAddr,
676 tls: &TlsConfig,
677 _drain: Duration,
678 _shutdown: tokio::sync::oneshot::Receiver<()>,
679) -> Result<(), Error> {
680 let acme = tls
681 .acme
682 .as_ref()
683 .expect("serve_acme called without [tls.acme]");
684 Err(Error::Config(format!(
685 "[tls.acme] is configured for {n} domain(s) but ACME runtime support \
686 is pending a follow-up PR (rustls-acme version pin). For now, use \
687 certbot in TLS-ALPN-01 mode and `[tls] cert`/`key` pointing at the \
688 certbot output; cert hot-reload handles renewals without restart.",
689 n = acme.domains.len(),
690 )))
691}
692
693#[derive(Debug)]
699struct SniResolver {
700 entries: Vec<(String, Arc<rustls::sign::CertifiedKey>)>,
703 default_key: Arc<rustls::sign::CertifiedKey>,
704}
705
706impl rustls::server::ResolvesServerCert for SniResolver {
707 fn resolve(
708 &self,
709 client_hello: rustls::server::ClientHello<'_>,
710 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
711 let sni = client_hello
712 .server_name()
713 .unwrap_or("")
714 .to_ascii_lowercase();
715 for (pattern, key) in &self.entries {
716 if matches_pattern(&sni, pattern) {
717 return Some(key.clone());
718 }
719 }
720 Some(self.default_key.clone())
721 }
722}
723
724fn build_sni_resolver(tls: &TlsConfig) -> std::io::Result<SniResolver> {
725 let default_key = load_certified_key(&tls.cert, &tls.key)?;
726 let mut entries = Vec::with_capacity(tls.additional_certs.len());
727 for entry in &tls.additional_certs {
728 let key = load_certified_key(&entry.cert, &entry.key)?;
729 entries.push((entry.server_name.to_ascii_lowercase(), key));
730 }
731 tracing::info!(
732 default_cert = %tls.cert.display(),
733 additional = tls.additional_certs.len(),
734 "tls: SNI resolver active"
735 );
736 Ok(SniResolver {
737 entries,
738 default_key,
739 })
740}
741
742fn load_certified_key(
743 cert_path: &std::path::Path,
744 key_path: &std::path::Path,
745) -> std::io::Result<Arc<rustls::sign::CertifiedKey>> {
746 use std::io::BufReader;
747
748 let cert_file = std::fs::File::open(cert_path).map_err(|e| {
749 std::io::Error::new(
750 e.kind(),
751 format!("opening cert {}: {e}", cert_path.display()),
752 )
753 })?;
754 let mut cert_reader = BufReader::new(cert_file);
755 let certs: Vec<rustls::pki_types::CertificateDer<'static>> =
756 rustls_pemfile::certs(&mut cert_reader).collect::<std::io::Result<_>>()?;
757 if certs.is_empty() {
758 return Err(std::io::Error::new(
759 std::io::ErrorKind::InvalidData,
760 format!("no certificates in {}", cert_path.display()),
761 ));
762 }
763
764 let key_file = std::fs::File::open(key_path).map_err(|e| {
765 std::io::Error::new(e.kind(), format!("opening key {}: {e}", key_path.display()))
766 })?;
767 let mut key_reader = BufReader::new(key_file);
768 let key = rustls_pemfile::private_key(&mut key_reader)?.ok_or_else(|| {
769 std::io::Error::new(
770 std::io::ErrorKind::InvalidData,
771 format!("no private key in {}", key_path.display()),
772 )
773 })?;
774
775 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
776 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("sign: {e}")))?;
777
778 Ok(Arc::new(rustls::sign::CertifiedKey::new(
779 certs,
780 signing_key,
781 )))
782}
783
784fn watch_tls_certs(
792 config: axum_server::tls_rustls::RustlsConfig,
793 cert: std::path::PathBuf,
794 key: std::path::PathBuf,
795 watch_paths: [std::path::PathBuf; 2],
796) -> std::io::Result<()> {
797 use notify::{RecursiveMode, Watcher};
798 use std::sync::mpsc::channel;
799
800 let (tx, rx) = channel::<notify::Result<notify::Event>>();
801 let mut watcher = notify::recommended_watcher(move |res| {
802 let _ = tx.send(res);
803 })
804 .map_err(|e| std::io::Error::other(format!("notify init: {e}")))?;
805
806 for p in &watch_paths {
810 if let Some(parent) = p.parent() {
811 watcher
812 .watch(parent, RecursiveMode::NonRecursive)
813 .map_err(|e| std::io::Error::other(format!("notify watch: {e}")))?;
814 }
815 }
816
817 let runtime = tokio::runtime::Handle::try_current().ok();
818 while let Ok(event) = rx.recv() {
819 let Ok(event) = event else { continue };
820 let touches_us = event.paths.iter().any(|p| p == &cert || p == &key);
823 if !touches_us {
824 continue;
825 }
826 tracing::info!(
827 cert = %cert.display(),
828 key = %key.display(),
829 "tls cert change detected — reloading"
830 );
831 let cert = cert.clone();
832 let key = key.clone();
833 let config = config.clone();
834 let reload = async move {
835 if let Err(e) = config.reload_from_pem_file(&cert, &key).await {
836 tracing::warn!(error = %e, "tls reload failed");
837 } else {
838 tracing::info!("tls cert reloaded successfully");
839 }
840 };
841 if let Some(rt) = &runtime {
842 rt.spawn(reload);
843 } else {
844 std::thread::spawn(|| {
847 let rt = tokio::runtime::Builder::new_current_thread()
848 .enable_all()
849 .build();
850 if let Ok(rt) = rt {
851 rt.block_on(reload);
852 }
853 });
854 }
855 }
856 Ok(())
857}
858
859pub struct RateLimiter {
862 state: moka::sync::Cache<String, (Instant, u32)>,
864 default_rule: Option<RateRule>,
865 route_rules: Vec<(MatchKey, RateRule)>,
866}
867
868#[derive(Clone, Copy)]
869struct RateRule {
870 count: u32,
871 window: Duration,
872}
873
874#[derive(Clone)]
875struct MatchKey {
876 method: Option<Method>,
877 path: String,
878}
879
880impl RateLimiter {
881 pub fn from_config(cfg: &RateLimitConfig) -> Self {
882 let default_rule = cfg.per_ip.as_deref().and_then(|s| {
883 crate::server_config::parse_rate(s)
884 .map(|(count, window)| RateRule { count, window })
885 .ok()
886 });
887 let route_rules = cfg
888 .routes
889 .iter()
890 .filter_map(|(spec, rate)| {
891 let (count, window) = crate::server_config::parse_rate(rate).ok()?;
892 let (method, path) = parse_route_spec(spec);
893 Some((MatchKey { method, path }, RateRule { count, window }))
894 })
895 .collect();
896
897 Self {
898 state: moka::sync::Cache::builder()
899 .max_capacity(10_000)
900 .time_to_idle(Duration::from_secs(600))
901 .build(),
902 default_rule,
903 route_rules,
904 }
905 }
906
907 fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
908 for (key, rule) in &self.route_rules {
909 if key.path == path && key.method.as_ref().is_none_or(|m| m == method) {
910 return Some(*rule);
911 }
912 }
913 self.default_rule
914 }
915
916 fn check(&self, bucket: &str, rule: RateRule) -> bool {
917 let now = Instant::now();
918 let mut allowed = true;
919 self.state
920 .entry(bucket.to_string())
921 .and_compute_with(|existing| match existing {
922 Some(entry) => {
923 let (window_end, count) = entry.into_value();
924 if now >= window_end {
925 moka::ops::compute::Op::Put((
926 now + rule.window,
927 rule.count.saturating_sub(1),
928 ))
929 } else if count > 0 {
930 moka::ops::compute::Op::Put((window_end, count - 1))
931 } else {
932 allowed = false;
933 moka::ops::compute::Op::Put((window_end, 0))
934 }
935 }
936 None => {
937 moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
938 }
939 });
940 allowed
941 }
942}
943
944fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
945 let trimmed = spec.trim();
946 if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
947 let method = m.parse::<Method>().ok();
948 (method, p.trim().to_string())
949 } else {
950 (None, trimmed.to_string())
951 }
952}
953
954async fn rate_limit_mw(
955 limiter: Arc<RateLimiter>,
956 trusted: Arc<Vec<ipnet::IpNet>>,
957 req: Request<Body>,
958 next: Next,
959) -> Response<Body> {
960 let method = req.method().clone();
961 let path = req.uri().path().to_string();
962 let bucket = format!("{}|{}|{}", client_ip(&req, &trusted), method, path);
963
964 if let Some(rule) = limiter.rule_for(&method, &path) {
965 if !limiter.check(&bucket, rule) {
966 tracing::debug!(%method, %path, %bucket, "rate limited");
967 let mut resp = Response::new(Body::from("rate limit exceeded"));
968 *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
969 return resp;
970 }
971 }
972 next.run(req).await
973}
974
975fn client_ip(req: &Request<Body>, trusted: &[ipnet::IpNet]) -> String {
988 let peer: Option<std::net::IpAddr> = req
989 .extensions()
990 .get::<ConnectInfo<SocketAddr>>()
991 .map(|ci| ci.0.ip());
992
993 let peer_trusted = match peer {
994 Some(addr) => !trusted.is_empty() && trusted.iter().any(|net| net.contains(&addr)),
995 None => false,
996 };
997
998 if peer_trusted {
999 if let Some(v) = req.headers().get("x-forwarded-for") {
1000 if let Ok(s) = v.to_str() {
1001 if let Some(first) = s.split(',').next() {
1002 let candidate = first.trim();
1003 if !candidate.is_empty() {
1004 return candidate.to_string();
1005 }
1006 }
1007 }
1008 }
1009 }
1010
1011 peer.map(|a| a.to_string())
1012 .unwrap_or_else(|| "unknown".into())
1013}
1014
1015async fn route_timeout_mw(
1021 rules: Arc<Vec<RouteTimeoutRule>>,
1022 req: Request<Body>,
1023 next: Next,
1024) -> Response<Body> {
1025 let path = req.uri().path().to_string();
1026 let matching = rules.iter().find(|r| path.starts_with(&r.prefix));
1027 match matching {
1028 Some(rule) => {
1029 let timeout = rule.timeout;
1030 match tokio::time::timeout(timeout, next.run(req)).await {
1031 Ok(resp) => resp,
1032 Err(_) => {
1033 tracing::debug!(
1034 path,
1035 timeout_ms = timeout.as_millis(),
1036 "route timeout exceeded"
1037 );
1038 let mut resp = Response::new(Body::from("request timed out"));
1039 *resp.status_mut() = StatusCode::REQUEST_TIMEOUT;
1040 resp
1041 }
1042 }
1043 }
1044 None => next.run(req).await,
1045 }
1046}
1047
1048async fn request_id_mw(mut req: Request<Body>, next: Next) -> Response<Body> {
1053 const HEADER: &str = "x-request-id";
1054
1055 let inbound = req
1056 .headers()
1057 .get(HEADER)
1058 .and_then(|v| v.to_str().ok())
1059 .map(|s| s.to_string());
1060
1061 let request_id = inbound.unwrap_or_else(|| uuid::Uuid::now_v7().to_string());
1062
1063 if let Ok(v) = HeaderValue::from_str(&request_id) {
1064 req.headers_mut()
1065 .insert(HeaderName::from_static(HEADER), v.clone());
1066 req.extensions_mut().insert(RequestId(request_id.clone()));
1070
1071 let mut resp = next.run(req).await;
1072 resp.headers_mut()
1073 .insert(HeaderName::from_static(HEADER), v);
1074 resp
1075 } else {
1076 next.run(req).await
1077 }
1078}
1079
1080#[derive(Debug, Clone)]
1085pub struct RequestId(pub String);
1086
1087async fn access_log_mw(
1088 format: AccessLogFormat,
1089 trusted: Arc<Vec<ipnet::IpNet>>,
1090 req: Request<Body>,
1091 next: Next,
1092) -> Response<Body> {
1093 let started = Instant::now();
1094 let method = req.method().clone();
1095 let path = req.uri().path().to_string();
1096 let host = req
1097 .headers()
1098 .get("host")
1099 .and_then(|v| v.to_str().ok())
1100 .unwrap_or("-")
1101 .to_string();
1102 let referer = req
1103 .headers()
1104 .get("referer")
1105 .and_then(|v| v.to_str().ok())
1106 .map(String::from);
1107 let ua = req
1108 .headers()
1109 .get("user-agent")
1110 .and_then(|v| v.to_str().ok())
1111 .map(String::from);
1112 let ip = client_ip(&req, &trusted);
1113 let request_id = req.extensions().get::<RequestId>().map(|id| id.0.clone());
1114
1115 let resp = next.run(req).await;
1116 let elapsed = started.elapsed();
1117 let status = resp.status().as_u16();
1118 let bytes = response_size(resp.headers()).unwrap_or(0);
1119
1120 match format {
1121 AccessLogFormat::Combined => {
1122 tracing::info!(
1123 target: "access_log",
1124 "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
1125 ip,
1126 method,
1127 path,
1128 status,
1129 bytes,
1130 referer.as_deref().unwrap_or("-"),
1131 ua.as_deref().unwrap_or("-"),
1132 elapsed.as_millis(),
1133 );
1134 }
1135 AccessLogFormat::Json => {
1136 tracing::info!(
1137 target: "access_log",
1138 json = %serde_json::json!({
1139 "ip": ip,
1140 "method": method.as_str(),
1141 "path": path,
1142 "host": host,
1143 "status": status,
1144 "bytes": bytes,
1145 "referer": referer,
1146 "user_agent": ua,
1147 "duration_ms": elapsed.as_millis(),
1148 "request_id": request_id,
1149 }),
1150 "request"
1151 );
1152 }
1153 AccessLogFormat::Off => {}
1154 }
1155 resp
1156}
1157
1158fn response_size(headers: &HeaderMap) -> Option<u64> {
1159 headers
1160 .get("content-length")
1161 .and_then(|v| v.to_str().ok())
1162 .and_then(|s| s.parse().ok())
1163}
1164
1165#[derive(Clone)]
1168struct CompiledRewrite {
1169 pattern: regex::Regex,
1170 to: String,
1171 status: Option<u16>,
1172 match_query: bool,
1173}
1174
1175struct CompiledRewrites {
1176 rules: Vec<CompiledRewrite>,
1177}
1178
1179impl CompiledRewrites {
1180 fn compile(rules: &[RewriteRule]) -> Self {
1181 let compiled = rules
1182 .iter()
1183 .filter_map(|r| match regex::Regex::new(&r.from) {
1184 Ok(pattern) => Some(CompiledRewrite {
1185 pattern,
1186 to: r.to.clone(),
1187 status: r.status,
1188 match_query: r.match_query,
1189 }),
1190 Err(e) => {
1191 tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
1192 None
1193 }
1194 })
1195 .collect();
1196 Self { rules: compiled }
1197 }
1198}
1199
1200async fn rewrite_mw(
1201 rules: Arc<CompiledRewrites>,
1202 mut req: Request<Body>,
1203 next: Next,
1204) -> Response<Body> {
1205 let path = req.uri().path().to_string();
1206 let path_and_query = req
1207 .uri()
1208 .path_and_query()
1209 .map(|p| p.as_str().to_string())
1210 .unwrap_or_else(|| path.clone());
1211
1212 let target_str_path = path.clone();
1213 let target_str_full = path_and_query.clone();
1214
1215 let mut applied: Option<(String, Option<u16>)> = None;
1217 for rule in &rules.rules {
1218 let subject = if rule.match_query {
1219 &target_str_full
1220 } else {
1221 &target_str_path
1222 };
1223 if rule.pattern.is_match(subject) {
1224 let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
1225 applied = Some((replaced, rule.status));
1226 break;
1227 }
1228 }
1229
1230 let Some((new_target, status)) = applied else {
1231 return next.run(req).await;
1232 };
1233
1234 match status {
1235 Some(code @ (301 | 302 | 303 | 307 | 308)) => {
1236 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1237 *resp.status_mut() =
1238 StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
1239 if let Ok(loc) = HeaderValue::from_str(&new_target) {
1240 resp.headers_mut().insert("location", loc);
1241 }
1242 resp
1243 }
1244 _ => {
1245 let mut parts = req.uri().clone().into_parts();
1247 if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1248 parts.path_and_query = Some(new_pq);
1249 }
1250 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1251 *req.uri_mut() = new_uri;
1252 }
1253 next.run(req).await
1254 }
1255 }
1256}
1257
1258async fn trailing_slash_mw(
1261 cfg: TrailingSlashConfig,
1262 mut req: Request<Body>,
1263 next: Next,
1264) -> Response<Body> {
1265 let path = req.uri().path().to_string();
1266 if path == "/" {
1267 return next.run(req).await;
1268 }
1269
1270 let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
1271 let has_slash = path.ends_with('/');
1272
1273 if want_slash == has_slash {
1274 return next.run(req).await;
1275 }
1276
1277 let new_path = if want_slash {
1278 format!("{path}/")
1279 } else {
1280 path.trim_end_matches('/').to_string()
1281 };
1282
1283 let query = req
1284 .uri()
1285 .query()
1286 .map(|q| format!("?{q}"))
1287 .unwrap_or_default();
1288 let new_target = format!("{new_path}{query}");
1289
1290 match cfg.action {
1291 TrailingSlashAction::Redirect => {
1292 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
1293 *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
1294 if let Ok(loc) = HeaderValue::from_str(&new_target) {
1295 resp.headers_mut().insert("location", loc);
1296 }
1297 resp
1298 }
1299 TrailingSlashAction::Rewrite => {
1300 let mut parts = req.uri().clone().into_parts();
1301 if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
1302 parts.path_and_query = Some(pq);
1303 }
1304 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
1305 *req.uri_mut() = new_uri;
1306 }
1307 next.run(req).await
1308 }
1309 }
1310}
1311
1312struct LoadedErrorPages {
1315 by_status: std::collections::HashMap<u16, (String, &'static str)>,
1316}
1317
1318fn load_error_pages(
1319 raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
1320) -> LoadedErrorPages {
1321 let mut by_status = std::collections::HashMap::new();
1322 for (key, path) in raw {
1323 let Ok(code) = key.parse::<u16>() else {
1324 tracing::warn!(key, "error_pages: invalid status code, skipping");
1325 continue;
1326 };
1327 let body = match std::fs::read_to_string(path) {
1328 Ok(s) => s,
1329 Err(e) => {
1330 tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
1331 continue;
1332 }
1333 };
1334 let content_type = guess_content_type(path);
1335 by_status.insert(code, (body, content_type));
1336 }
1337 LoadedErrorPages { by_status }
1338}
1339
1340fn guess_content_type(path: &std::path::Path) -> &'static str {
1341 match path.extension().and_then(|e| e.to_str()) {
1342 Some("html") | Some("htm") => "text/html; charset=utf-8",
1343 Some("json") => "application/json",
1344 Some("txt") => "text/plain; charset=utf-8",
1345 _ => "text/plain; charset=utf-8",
1346 }
1347}
1348
1349async fn error_pages_mw(
1350 pages: Arc<LoadedErrorPages>,
1351 req: Request<Body>,
1352 next: Next,
1353) -> Response<Body> {
1354 let resp = next.run(req).await;
1355 let status = resp.status().as_u16();
1356
1357 let Some((body, ctype)) = pages.by_status.get(&status) else {
1358 return resp;
1359 };
1360
1361 let mut out = Response::new(Body::from(body.clone()));
1362 *out.status_mut() = resp.status();
1363 if let Ok(ct) = HeaderValue::from_str(ctype) {
1364 out.headers_mut().insert("content-type", ct);
1365 }
1366 for h in ["cache-control", "x-request-id"] {
1368 if let Some(v) = resp.headers().get(h) {
1369 out.headers_mut().insert(h, v.clone());
1370 }
1371 }
1372 out
1373}
1374
1375#[derive(Clone)]
1378struct CompiledProxy {
1379 prefix: String,
1380 upstream: String,
1381 strip_prefix: bool,
1382 preserve_host: bool,
1383 timeout: Duration,
1384 retries: u8,
1385}
1386
1387struct CompiledProxies {
1388 rules: Vec<CompiledProxy>,
1389 client: reqwest::Client,
1390}
1391
1392impl CompiledProxies {
1393 fn compile(rules: &[ProxyRule]) -> Self {
1394 let mut compiled: Vec<CompiledProxy> = rules
1395 .iter()
1396 .map(|r| CompiledProxy {
1397 prefix: r.prefix.clone(),
1398 upstream: r.upstream.trim_end_matches('/').to_string(),
1399 strip_prefix: r.strip_prefix,
1400 preserve_host: r.preserve_host,
1401 timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
1402 retries: r.retries,
1403 })
1404 .collect();
1405 compiled.sort_by_key(|r| std::cmp::Reverse(r.prefix.len()));
1407
1408 let client = reqwest::Client::builder()
1409 .redirect(reqwest::redirect::Policy::none())
1410 .build()
1411 .unwrap_or_else(|_| reqwest::Client::new());
1412
1413 Self {
1414 rules: compiled,
1415 client,
1416 }
1417 }
1418
1419 fn matching(&self, path: &str) -> Option<&CompiledProxy> {
1420 self.rules.iter().find(|r| path.starts_with(&r.prefix))
1421 }
1422}
1423
1424async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
1425 let path = req.uri().path().to_string();
1426 let Some(rule) = proxies.matching(&path) else {
1427 return next.run(req).await;
1428 };
1429 let rule = rule.clone();
1430
1431 match proxy_forward(&proxies.client, &rule, req).await {
1432 Ok(resp) => resp,
1433 Err(e) => {
1434 tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
1435 let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
1436 *resp.status_mut() = StatusCode::BAD_GATEWAY;
1437 resp
1438 }
1439 }
1440}
1441
1442async fn proxy_forward(
1443 client: &reqwest::Client,
1444 rule: &CompiledProxy,
1445 req: Request<Body>,
1446) -> Result<Response<Body>, String> {
1447 let (parts, body) = req.into_parts();
1448 let body_bytes = axum::body::to_bytes(body, usize::MAX)
1449 .await
1450 .map_err(|e| format!("body read: {e}"))?;
1451
1452 let original_path = parts.uri.path();
1453 let upstream_path = if rule.strip_prefix {
1454 original_path
1455 .strip_prefix(&rule.prefix)
1456 .unwrap_or(original_path)
1457 } else {
1458 original_path
1459 };
1460 let upstream_path = if upstream_path.is_empty() {
1461 "/"
1462 } else {
1463 upstream_path
1464 };
1465 let query = parts
1466 .uri
1467 .query()
1468 .map(|q| format!("?{q}"))
1469 .unwrap_or_default();
1470 let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
1471
1472 let method = parts.method.clone();
1473 let mut last_err = String::new();
1474 for attempt in 0..=rule.retries {
1475 let mut request = client
1476 .request(
1477 reqwest::Method::from_bytes(method.as_str().as_bytes())
1478 .unwrap_or(reqwest::Method::GET),
1479 &upstream_url,
1480 )
1481 .timeout(rule.timeout)
1482 .body(body_bytes.clone());
1483
1484 for (name, value) in parts.headers.iter() {
1485 let n = name.as_str().to_ascii_lowercase();
1487 if matches!(
1488 n.as_str(),
1489 "connection"
1490 | "keep-alive"
1491 | "proxy-authenticate"
1492 | "proxy-authorization"
1493 | "te"
1494 | "trailers"
1495 | "transfer-encoding"
1496 | "upgrade"
1497 | "content-length"
1498 ) {
1499 continue;
1500 }
1501 if !rule.preserve_host && n == "host" {
1502 continue;
1503 }
1504 if let Ok(v) = value.to_str() {
1505 request = request.header(name.as_str(), v);
1506 }
1507 }
1508
1509 if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1511 request = request.header("x-forwarded-host", host);
1512 }
1513 request = request.header("x-forwarded-proto", "http");
1514
1515 match request.send().await {
1516 Ok(resp) => return upstream_to_axum(resp).await,
1517 Err(e) => {
1518 last_err = format!("attempt {} → {e}", attempt + 1);
1519 tracing::debug!(error = %e, attempt, "proxy retry");
1520 continue;
1521 }
1522 }
1523 }
1524 Err(last_err)
1525}
1526
1527async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1530 let origin = req
1531 .headers()
1532 .get("origin")
1533 .and_then(|v| v.to_str().ok())
1534 .map(String::from);
1535
1536 let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1537 cfg.allow_origins
1538 .iter()
1539 .any(|allowed| allowed == "*" || allowed == o)
1540 });
1541
1542 if req.method() == Method::OPTIONS && origin.is_some() {
1544 let mut resp = Response::new(Body::empty());
1545 *resp.status_mut() = StatusCode::NO_CONTENT;
1546 apply_cors_headers(
1547 resp.headers_mut(),
1548 &cfg,
1549 origin.as_deref(),
1550 is_allowed_origin,
1551 );
1552 return resp;
1553 }
1554
1555 let mut resp = next.run(req).await;
1556 apply_cors_headers(
1557 resp.headers_mut(),
1558 &cfg,
1559 origin.as_deref(),
1560 is_allowed_origin,
1561 );
1562 resp
1563}
1564
1565fn apply_cors_headers(
1566 headers: &mut HeaderMap,
1567 cfg: &CorsConfig,
1568 origin: Option<&str>,
1569 is_allowed_origin: bool,
1570) {
1571 if !is_allowed_origin {
1572 return;
1573 }
1574 if let Some(origin) = origin {
1575 if let Ok(v) = HeaderValue::from_str(origin) {
1576 headers.insert("access-control-allow-origin", v);
1577 }
1578 headers.insert("vary", HeaderValue::from_static("Origin"));
1579 } else if cfg.allow_origins.iter().any(|o| o == "*") {
1580 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1581 }
1582
1583 let methods = if cfg.allow_methods.is_empty() {
1584 "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1585 } else {
1586 cfg.allow_methods.join(", ")
1587 };
1588 if let Ok(v) = HeaderValue::from_str(&methods) {
1589 headers.insert("access-control-allow-methods", v);
1590 }
1591
1592 let allow_headers = if cfg.allow_headers.is_empty() {
1593 "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1594 } else {
1595 cfg.allow_headers.join(", ")
1596 };
1597 if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1598 headers.insert("access-control-allow-headers", v);
1599 }
1600
1601 if !cfg.expose_headers.is_empty() {
1602 if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1603 headers.insert("access-control-expose-headers", v);
1604 }
1605 }
1606
1607 if cfg.allow_credentials {
1608 headers.insert(
1609 "access-control-allow-credentials",
1610 HeaderValue::from_static("true"),
1611 );
1612 }
1613
1614 if let Some(max_age) = cfg.max_age {
1615 if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1616 headers.insert("access-control-max-age", v);
1617 }
1618 }
1619}
1620
1621async fn ip_rules_mw(
1624 rules: Arc<Vec<IpRule>>,
1625 trusted: Arc<Vec<ipnet::IpNet>>,
1626 req: Request<Body>,
1627 next: Next,
1628) -> Response<Body> {
1629 let path = req.uri().path().to_string();
1630 let ip_str = client_ip(&req, &trusted);
1631 let ip = ip_str.parse::<std::net::IpAddr>().ok();
1632
1633 for rule in rules.iter() {
1634 if !path.starts_with(&rule.prefix) {
1635 continue;
1636 }
1637 let matches_range = ip
1638 .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1639 .unwrap_or(false);
1640 let allowed = match rule.action {
1641 IpAction::Allow => matches_range,
1642 IpAction::Deny => !matches_range,
1643 };
1644 if !allowed {
1645 tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1646 let mut resp = Response::new(Body::from("forbidden"));
1647 *resp.status_mut() = StatusCode::FORBIDDEN;
1648 return resp;
1649 }
1650 break;
1652 }
1653
1654 next.run(req).await
1655}
1656
1657use base64::engine::general_purpose::STANDARD as B64;
1660use base64::Engine as _;
1661
1662struct CompiledBasicAuth {
1663 rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1664}
1665
1666fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1667 let compiled = rules
1668 .iter()
1669 .map(|r| {
1670 let creds = r
1671 .credentials
1672 .iter()
1673 .filter_map(|c| {
1674 c.split_once(':')
1675 .map(|(u, p)| (u.to_string(), p.to_string()))
1676 })
1677 .collect();
1678 (r.clone(), creds)
1679 })
1680 .collect();
1681 CompiledBasicAuth { rules: compiled }
1682}
1683
1684async fn basic_auth_mw(
1685 rules: Arc<CompiledBasicAuth>,
1686 req: Request<Body>,
1687 next: Next,
1688) -> Response<Body> {
1689 let path = req.uri().path().to_string();
1690 for (rule, creds) in &rules.rules {
1691 if !path.starts_with(&rule.prefix) {
1692 continue;
1693 }
1694 let supplied = req
1695 .headers()
1696 .get("authorization")
1697 .and_then(|v| v.to_str().ok())
1698 .and_then(|s| s.strip_prefix("Basic "))
1699 .and_then(|b64| B64.decode(b64).ok())
1700 .and_then(|bytes| String::from_utf8(bytes).ok())
1701 .and_then(|pair| {
1702 pair.split_once(':')
1703 .map(|(u, p)| (u.to_string(), p.to_string()))
1704 });
1705
1706 let ok = supplied
1707 .as_ref()
1708 .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1709 .unwrap_or(false);
1710
1711 if ok {
1712 return next.run(req).await;
1713 }
1714
1715 let challenge = format!("Basic realm=\"{}\"", rule.realm);
1716 let mut resp = Response::new(Body::from("authentication required"));
1717 *resp.status_mut() = StatusCode::UNAUTHORIZED;
1718 if let Ok(v) = HeaderValue::from_str(&challenge) {
1719 resp.headers_mut().insert("www-authenticate", v);
1720 }
1721 return resp;
1722 }
1723 next.run(req).await
1724}
1725
1726async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1727 let status = resp.status();
1728 let headers = resp.headers().clone();
1729 let bytes = resp
1730 .bytes()
1731 .await
1732 .map_err(|e| format!("upstream body: {e}"))?;
1733 let mut out = Response::new(Body::from(bytes));
1734 *out.status_mut() =
1735 axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1736 for (name, value) in headers.iter() {
1737 let n = name.as_str().to_ascii_lowercase();
1738 if matches!(
1739 n.as_str(),
1740 "connection"
1741 | "keep-alive"
1742 | "proxy-authenticate"
1743 | "proxy-authorization"
1744 | "te"
1745 | "trailers"
1746 | "transfer-encoding"
1747 | "upgrade"
1748 ) {
1749 continue;
1750 }
1751 if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1752 if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1753 out.headers_mut().append(name, v);
1754 }
1755 }
1756 }
1757 Ok(out)
1758}
1759
1760#[cfg(test)]
1761mod tests {
1762 use super::*;
1763 use std::net::{IpAddr, Ipv4Addr};
1764
1765 fn make_req(peer: Option<SocketAddr>, xff: Option<&str>) -> Request<Body> {
1766 let mut req = Request::builder();
1767 if let Some(v) = xff {
1768 req = req.header("x-forwarded-for", v);
1769 }
1770 let mut req = req.body(Body::empty()).unwrap();
1771 if let Some(addr) = peer {
1772 req.extensions_mut().insert(ConnectInfo(addr));
1773 }
1774 req
1775 }
1776
1777 fn cidr(s: &str) -> ipnet::IpNet {
1778 s.parse().unwrap()
1779 }
1780
1781 #[test]
1782 fn xff_ignored_when_peer_is_not_trusted() {
1783 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(203, 0, 113, 5)), 12345);
1785 let req = make_req(Some(peer), Some("198.51.100.1"));
1786 let trusted = vec![cidr("10.0.0.0/8")];
1787
1788 assert_eq!(client_ip(&req, &trusted), "203.0.113.5");
1789 }
1790
1791 #[test]
1792 fn xff_honored_when_peer_is_a_trusted_proxy() {
1793 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1795 let req = make_req(Some(peer), Some("198.51.100.1, 10.0.0.5"));
1796 let trusted = vec![cidr("10.0.0.0/8")];
1797
1798 assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1799 }
1800
1801 #[test]
1802 fn empty_trusted_list_means_xff_is_never_honored() {
1803 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1805 let req = make_req(Some(peer), Some("198.51.100.1"));
1806 let trusted: Vec<ipnet::IpNet> = vec![];
1807
1808 assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1809 }
1810
1811 #[test]
1812 fn no_xff_falls_back_to_peer() {
1813 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1814 let req = make_req(Some(peer), None);
1815 let trusted = vec![cidr("10.0.0.0/8")];
1816
1817 assert_eq!(client_ip(&req, &trusted), "10.0.0.5");
1818 }
1819
1820 #[test]
1821 fn missing_connect_info_returns_unknown() {
1822 let req = make_req(None, Some("198.51.100.1"));
1825 let trusted = vec![cidr("10.0.0.0/8")];
1826
1827 assert_eq!(client_ip(&req, &trusted), "unknown");
1828 }
1829
1830 #[test]
1831 fn xff_with_whitespace_and_multiple_hops_picks_first() {
1832 let peer = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 5)), 443);
1833 let req = make_req(Some(peer), Some(" 198.51.100.1 ,10.0.0.5, 10.0.0.7"));
1834 let trusted = vec![cidr("10.0.0.0/8")];
1835
1836 assert_eq!(client_ip(&req, &trusted), "198.51.100.1");
1837 }
1838}