1use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use axum::body::Body;
9use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode};
10use axum::middleware::Next;
11use axum::Router as AxumRouter;
12use tower_http::compression::predicate::SizeAbove;
13use tower_http::compression::CompressionLayer;
14use tower_http::limit::RequestBodyLimitLayer;
15use tower_http::services::ServeDir;
16use tower_http::set_header::SetResponseHeaderLayer;
17use tower_http::timeout::TimeoutLayer;
18use tower_http::trace::TraceLayer;
19
20use crate::container::Container;
21use crate::server_config::{
22 AccessLogFormat, BasicAuthRule, CorsConfig, HstsConfig, IpAction, IpRule, ProxyRule,
23 RateLimitConfig, RewriteRule, ServerConfig, StaticMount, TlsConfig, TrailingSlashAction,
24 TrailingSlashConfig, TrailingSlashMode,
25};
26use crate::Error;
27
28pub fn apply_layers(web: AxumRouter<Container>, cfg: &ServerConfig) -> AxumRouter<Container> {
31 let mut router = web;
32
33 for (prefix, mount) in &cfg.static_files {
36 router = mount_static(router, prefix, mount);
37 }
38
39 let body_max = cfg.limits.body_max as usize;
41 router = router
42 .layer(RequestBodyLimitLayer::new(body_max))
43 .layer(TraceLayer::new_for_http());
44
45 if let Some(timeout) = cfg.limits.request_timeout {
46 router = router.layer(TimeoutLayer::new(timeout));
47 }
48
49 if !cfg.server_name.is_empty() {
52 let allowed = cfg.server_name.clone();
53 router = router.layer(axum::middleware::from_fn(
54 move |req: Request<Body>, next: Next| {
55 let allowed = allowed.clone();
56 async move { host_match_mw(allowed, req, next).await }
57 },
58 ));
59 }
60
61 if !cfg.ip_rules.is_empty() {
64 let rules = Arc::new(cfg.ip_rules.clone());
65 let rules_clone = rules.clone();
66 router = router.layer(axum::middleware::from_fn(
67 move |req: Request<Body>, next: Next| {
68 let rules = rules_clone.clone();
69 async move { ip_rules_mw(rules, req, next).await }
70 },
71 ));
72 }
73 if !cfg.basic_auth.is_empty() {
74 let rules = Arc::new(compile_basic_auth(&cfg.basic_auth));
75 let rules_clone = rules.clone();
76 router = router.layer(axum::middleware::from_fn(
77 move |req: Request<Body>, next: Next| {
78 let rules = rules_clone.clone();
79 async move { basic_auth_mw(rules, req, next).await }
80 },
81 ));
82 }
83
84 if cfg.cors.enabled {
87 let cors = Arc::new(cfg.cors.clone());
88 let cors_clone = cors.clone();
89 router = router.layer(axum::middleware::from_fn(
90 move |req: Request<Body>, next: Next| {
91 let cors = cors_clone.clone();
92 async move { cors_mw(cors, req, next).await }
93 },
94 ));
95 }
96
97 if !cfg.proxies.is_empty() {
100 let proxies = Arc::new(CompiledProxies::compile(&cfg.proxies));
101 let proxies_clone = proxies.clone();
102 router = router.layer(axum::middleware::from_fn(
103 move |req: Request<Body>, next: Next| {
104 let proxies = proxies_clone.clone();
105 async move { proxy_mw(proxies, req, next).await }
106 },
107 ));
108 }
109
110 if !cfg.rewrites.is_empty() {
112 let compiled = Arc::new(CompiledRewrites::compile(&cfg.rewrites));
113 let compiled_clone = compiled.clone();
114 router = router.layer(axum::middleware::from_fn(
115 move |req: Request<Body>, next: Next| {
116 let rules = compiled_clone.clone();
117 async move { rewrite_mw(rules, req, next).await }
118 },
119 ));
120 }
121
122 if cfg.trailing_slash.mode != TrailingSlashMode::Ignore {
124 let ts = cfg.trailing_slash.clone();
125 router = router.layer(axum::middleware::from_fn(
126 move |req: Request<Body>, next: Next| {
127 let ts = ts.clone();
128 async move { trailing_slash_mw(ts, req, next).await }
129 },
130 ));
131 }
132
133 if !cfg.error_pages.is_empty() {
136 let pages = Arc::new(load_error_pages(&cfg.error_pages));
137 let pages_clone = pages.clone();
138 router = router.layer(axum::middleware::from_fn(
139 move |req: Request<Body>, next: Next| {
140 let pages = pages_clone.clone();
141 async move { error_pages_mw(pages, req, next).await }
142 },
143 ));
144 }
145
146 if cfg.tls.is_some() && cfg.hsts.enabled {
148 if let Some(value) = build_hsts_header(&cfg.hsts) {
149 router = router.layer(SetResponseHeaderLayer::if_not_present(
150 HeaderName::from_static("strict-transport-security"),
151 value,
152 ));
153 }
154 }
155
156 if cfg.compression.enabled {
157 let min_size = u16::try_from(cfg.compression.min_size).unwrap_or(u16::MAX);
162 let mut layer = CompressionLayer::new();
163 if !cfg
164 .compression
165 .algorithms
166 .iter()
167 .any(|a| a.eq_ignore_ascii_case("gzip"))
168 {
169 layer = layer.no_gzip();
170 }
171 if !cfg
172 .compression
173 .algorithms
174 .iter()
175 .any(|a| a.eq_ignore_ascii_case("br") || a.eq_ignore_ascii_case("brotli"))
176 {
177 layer = layer.no_br();
178 }
179 if !cfg
180 .compression
181 .algorithms
182 .iter()
183 .any(|a| a.eq_ignore_ascii_case("deflate"))
184 {
185 layer = layer.no_deflate();
186 }
187 let layer = layer.compress_when(SizeAbove::new(min_size));
188 router = router.layer(layer);
189 }
190
191 if cfg.rate_limit.per_ip.is_some() || !cfg.rate_limit.routes.is_empty() {
192 let limiter = Arc::new(RateLimiter::from_config(&cfg.rate_limit));
193 let limiter_clone = limiter.clone();
194 router = router.layer(axum::middleware::from_fn(
195 move |req: Request<Body>, next: Next| {
196 let limiter = limiter_clone.clone();
197 async move { rate_limit_mw(limiter, req, next).await }
198 },
199 ));
200 }
201
202 if matches!(
203 cfg.access_log.format,
204 AccessLogFormat::Combined | AccessLogFormat::Json
205 ) {
206 let format = cfg.access_log.format;
207 router =
208 router.layer(axum::middleware::from_fn(
209 move |req: Request<Body>, next: Next| async move {
210 access_log_mw(format, req, next).await
211 },
212 ));
213 }
214
215 router
216}
217
218fn mount_static(
219 router: AxumRouter<Container>,
220 prefix: &str,
221 mount: &StaticMount,
222) -> AxumRouter<Container> {
223 let _ = mount.ranges;
226 let svc = ServeDir::new(&mount.dir);
227
228 let nested = AxumRouter::<Container>::new().nest_service(prefix, svc);
229 let nested = if let Some(cache) = mount.cache {
230 let value = HeaderValue::from_str(&format!("public, max-age={}", cache.as_secs()))
231 .unwrap_or_else(|_| HeaderValue::from_static("public"));
232 nested.layer(SetResponseHeaderLayer::if_not_present(
233 HeaderName::from_static("cache-control"),
234 value,
235 ))
236 } else {
237 nested
238 };
239 router.merge(nested)
240}
241
242pub async fn serve(
245 router: AxumRouter,
246 cfg: &ServerConfig,
247 shutdown: tokio::sync::oneshot::Receiver<()>,
248) -> Result<(), Error> {
249 let addr: SocketAddr = cfg
250 .bind
251 .parse()
252 .map_err(|e| Error::Config(format!("invalid bind addr `{}`: {e}", cfg.bind)))?;
253
254 tracing::info!(%addr, tls = cfg.tls.is_some(), server_name = ?cfg.server_name, "anvil server starting");
255
256 let (shutdown_main_tx, shutdown_main_rx) = tokio::sync::oneshot::channel::<()>();
258 let (shutdown_redir_tx, shutdown_redir_rx) = tokio::sync::oneshot::channel::<()>();
259 tokio::spawn(async move {
260 let _ = shutdown.await;
261 let _ = shutdown_main_tx.send(());
262 let _ = shutdown_redir_tx.send(());
263 });
264
265 let redirect_task = cfg.redirect_http.clone().map(|redir| {
266 let target_host = redir
267 .target_host
268 .clone()
269 .or_else(|| cfg.server_name.first().cloned());
270 let permanent = redir.permanent;
271 let bind = redir.bind.clone();
272 tokio::spawn(async move {
273 if let Err(e) =
274 serve_redirect_http(&bind, target_host, permanent, shutdown_redir_rx).await
275 {
276 tracing::warn!(?e, "redirect_http listener exited with error");
277 }
278 })
279 });
280
281 let main_result = if let Some(tls) = &cfg.tls {
282 serve_tls(router, addr, tls, shutdown_main_rx).await
283 } else {
284 serve_plain(router, addr, shutdown_main_rx).await
285 };
286
287 if let Some(task) = redirect_task {
288 task.abort();
289 }
290
291 main_result
292}
293
294async fn serve_redirect_http(
297 bind: &str,
298 target_host: Option<String>,
299 permanent: bool,
300 shutdown: tokio::sync::oneshot::Receiver<()>,
301) -> Result<(), Error> {
302 let addr: SocketAddr = bind
303 .parse()
304 .map_err(|e| Error::Config(format!("invalid redirect_http bind `{bind}`: {e}")))?;
305 tracing::info!(%addr, target_host = ?target_host, permanent, "http→https redirect listener");
306
307 let target_host = Arc::new(target_host);
308 let router: AxumRouter = AxumRouter::new().fallback(axum::routing::any({
309 let target_host = target_host.clone();
310 move |req: Request<Body>| {
311 let target_host = target_host.clone();
312 async move { http_redirect_handler(req, target_host, permanent).await }
313 }
314 }));
315
316 let listener = tokio::net::TcpListener::bind(addr).await?;
317 axum::serve(listener, router)
318 .with_graceful_shutdown(async move {
319 let _ = shutdown.await;
320 })
321 .await?;
322 Ok(())
323}
324
325async fn http_redirect_handler(
326 req: Request<Body>,
327 target_host: Arc<Option<String>>,
328 permanent: bool,
329) -> Response<Body> {
330 let host = target_host.as_ref().clone().unwrap_or_else(|| {
331 req.headers()
332 .get("host")
333 .and_then(|v| v.to_str().ok())
334 .map(String::from)
335 .unwrap_or_default()
336 });
337 let path_and_query = req
338 .uri()
339 .path_and_query()
340 .map(|p| p.as_str().to_string())
341 .unwrap_or_else(|| "/".to_string());
342 let location = format!("https://{host}{path_and_query}");
343
344 let status = if permanent {
345 StatusCode::MOVED_PERMANENTLY
346 } else {
347 StatusCode::FOUND
348 };
349 let mut resp = Response::new(Body::from(format!("Redirecting to {location}\n")));
350 *resp.status_mut() = status;
351 if let Ok(loc) = HeaderValue::from_str(&location) {
352 resp.headers_mut().insert("location", loc);
353 }
354 resp
355}
356
357fn build_hsts_header(cfg: &HstsConfig) -> Option<HeaderValue> {
358 let max_age = cfg.max_age.unwrap_or(Duration::from_secs(86400 * 365));
359 let mut value = format!("max-age={}", max_age.as_secs());
360 if cfg.include_subdomains {
361 value.push_str("; includeSubDomains");
362 }
363 if cfg.preload {
364 value.push_str("; preload");
365 }
366 HeaderValue::from_str(&value).ok()
367}
368
369async fn host_match_mw(allowed: Vec<String>, req: Request<Body>, next: Next) -> Response<Body> {
371 let host = req
372 .headers()
373 .get("host")
374 .and_then(|v| v.to_str().ok())
375 .unwrap_or("")
376 .to_string();
377
378 let host_no_port = host.split(':').next().unwrap_or("").to_ascii_lowercase();
380
381 if matches_any(&host_no_port, &allowed) {
382 return next.run(req).await;
383 }
384
385 tracing::debug!(host, allowed = ?allowed, "rejected host: no server_name match");
386 let mut resp = Response::new(Body::from(format!(
387 "404 not found (unknown host: {host})\n"
388 )));
389 *resp.status_mut() = StatusCode::NOT_FOUND;
390 resp
391}
392
393fn matches_any(host: &str, patterns: &[String]) -> bool {
394 patterns.iter().any(|pat| matches_pattern(host, pat))
395}
396
397fn matches_pattern(host: &str, pattern: &str) -> bool {
400 let pattern = pattern.to_ascii_lowercase();
401 if pattern == "*" {
402 return true;
403 }
404 if let Some(suffix) = pattern.strip_prefix("*.") {
405 return host.ends_with(&format!(".{suffix}"));
407 }
408 host == pattern
409}
410
411async fn serve_plain(
412 router: AxumRouter,
413 addr: SocketAddr,
414 shutdown: tokio::sync::oneshot::Receiver<()>,
415) -> Result<(), Error> {
416 let listener = tokio::net::TcpListener::bind(addr).await?;
417 axum::serve(listener, router)
418 .with_graceful_shutdown(async move {
419 let _ = shutdown.await;
420 })
421 .await?;
422 Ok(())
423}
424
425async fn serve_tls(
426 router: AxumRouter,
427 addr: SocketAddr,
428 tls: &TlsConfig,
429 shutdown: tokio::sync::oneshot::Receiver<()>,
430) -> Result<(), Error> {
431 let config = axum_server::tls_rustls::RustlsConfig::from_pem_file(&tls.cert, &tls.key)
432 .await
433 .map_err(|e| Error::Config(format!("tls load: {e}")))?;
434
435 let handle = axum_server::Handle::new();
436 let handle_for_shutdown = handle.clone();
437 tokio::spawn(async move {
438 let _ = shutdown.await;
439 handle_for_shutdown.graceful_shutdown(Some(Duration::from_secs(10)));
440 });
441
442 axum_server::bind_rustls(addr, config)
443 .handle(handle)
444 .serve(router.into_make_service())
445 .await
446 .map_err(|e| Error::Internal(format!("tls serve: {e}")))?;
447 Ok(())
448}
449
450pub struct RateLimiter {
453 state: moka::sync::Cache<String, (Instant, u32)>,
455 default_rule: Option<RateRule>,
456 route_rules: Vec<(MatchKey, RateRule)>,
457}
458
459#[derive(Clone, Copy)]
460struct RateRule {
461 count: u32,
462 window: Duration,
463}
464
465#[derive(Clone)]
466struct MatchKey {
467 method: Option<Method>,
468 path: String,
469}
470
471impl RateLimiter {
472 pub fn from_config(cfg: &RateLimitConfig) -> Self {
473 let default_rule = cfg.per_ip.as_deref().and_then(|s| {
474 crate::server_config::parse_rate(s)
475 .map(|(count, window)| RateRule { count, window })
476 .ok()
477 });
478 let route_rules = cfg
479 .routes
480 .iter()
481 .filter_map(|(spec, rate)| {
482 let (count, window) = crate::server_config::parse_rate(rate).ok()?;
483 let (method, path) = parse_route_spec(spec);
484 Some((MatchKey { method, path }, RateRule { count, window }))
485 })
486 .collect();
487
488 Self {
489 state: moka::sync::Cache::builder()
490 .max_capacity(10_000)
491 .time_to_idle(Duration::from_secs(600))
492 .build(),
493 default_rule,
494 route_rules,
495 }
496 }
497
498 fn rule_for(&self, method: &Method, path: &str) -> Option<RateRule> {
499 for (key, rule) in &self.route_rules {
500 if key.path == path && key.method.as_ref().map_or(true, |m| m == method) {
501 return Some(*rule);
502 }
503 }
504 self.default_rule
505 }
506
507 fn check(&self, bucket: &str, rule: RateRule) -> bool {
508 let now = Instant::now();
509 let mut allowed = true;
510 self.state
511 .entry(bucket.to_string())
512 .and_compute_with(|existing| match existing {
513 Some(entry) => {
514 let (window_end, count) = entry.into_value();
515 if now >= window_end {
516 moka::ops::compute::Op::Put((
517 now + rule.window,
518 rule.count.saturating_sub(1),
519 ))
520 } else if count > 0 {
521 moka::ops::compute::Op::Put((window_end, count - 1))
522 } else {
523 allowed = false;
524 moka::ops::compute::Op::Put((window_end, 0))
525 }
526 }
527 None => {
528 moka::ops::compute::Op::Put((now + rule.window, rule.count.saturating_sub(1)))
529 }
530 });
531 allowed
532 }
533}
534
535fn parse_route_spec(spec: &str) -> (Option<Method>, String) {
536 let trimmed = spec.trim();
537 if let Some((m, p)) = trimmed.split_once(char::is_whitespace) {
538 let method = m.parse::<Method>().ok();
539 (method, p.trim().to_string())
540 } else {
541 (None, trimmed.to_string())
542 }
543}
544
545async fn rate_limit_mw(
546 limiter: Arc<RateLimiter>,
547 req: Request<Body>,
548 next: Next,
549) -> Response<Body> {
550 let method = req.method().clone();
551 let path = req.uri().path().to_string();
552 let bucket = format!("{}|{}|{}", client_ip(&req), method, path);
553
554 if let Some(rule) = limiter.rule_for(&method, &path) {
555 if !limiter.check(&bucket, rule) {
556 tracing::debug!(%method, %path, %bucket, "rate limited");
557 let mut resp = Response::new(Body::from("rate limit exceeded"));
558 *resp.status_mut() = StatusCode::TOO_MANY_REQUESTS;
559 return resp;
560 }
561 }
562 next.run(req).await
563}
564
565fn client_ip(req: &Request<Body>) -> String {
566 if let Some(v) = req.headers().get("x-forwarded-for") {
570 if let Ok(s) = v.to_str() {
571 if let Some(first) = s.split(',').next() {
572 return first.trim().to_string();
573 }
574 }
575 }
576 "unknown".into()
579}
580
581async fn access_log_mw(format: AccessLogFormat, req: Request<Body>, next: Next) -> Response<Body> {
584 let started = Instant::now();
585 let method = req.method().clone();
586 let path = req.uri().path().to_string();
587 let host = req
588 .headers()
589 .get("host")
590 .and_then(|v| v.to_str().ok())
591 .unwrap_or("-")
592 .to_string();
593 let referer = req
594 .headers()
595 .get("referer")
596 .and_then(|v| v.to_str().ok())
597 .map(String::from);
598 let ua = req
599 .headers()
600 .get("user-agent")
601 .and_then(|v| v.to_str().ok())
602 .map(String::from);
603 let ip = client_ip(&req);
604
605 let resp = next.run(req).await;
606 let elapsed = started.elapsed();
607 let status = resp.status().as_u16();
608 let bytes = response_size(resp.headers()).unwrap_or(0);
609
610 match format {
611 AccessLogFormat::Combined => {
612 tracing::info!(
613 target: "access_log",
614 "{} - - \"{} {} HTTP/1.1\" {} {} \"{}\" \"{}\" {}ms",
615 ip,
616 method,
617 path,
618 status,
619 bytes,
620 referer.as_deref().unwrap_or("-"),
621 ua.as_deref().unwrap_or("-"),
622 elapsed.as_millis(),
623 );
624 }
625 AccessLogFormat::Json => {
626 tracing::info!(
627 target: "access_log",
628 json = %serde_json::json!({
629 "ip": ip,
630 "method": method.as_str(),
631 "path": path,
632 "host": host,
633 "status": status,
634 "bytes": bytes,
635 "referer": referer,
636 "user_agent": ua,
637 "duration_ms": elapsed.as_millis(),
638 }),
639 "request"
640 );
641 }
642 AccessLogFormat::Off => {}
643 }
644 resp
645}
646
647fn response_size(headers: &HeaderMap) -> Option<u64> {
648 headers
649 .get("content-length")
650 .and_then(|v| v.to_str().ok())
651 .and_then(|s| s.parse().ok())
652}
653
654#[derive(Clone)]
657struct CompiledRewrite {
658 pattern: regex::Regex,
659 to: String,
660 status: Option<u16>,
661 match_query: bool,
662}
663
664struct CompiledRewrites {
665 rules: Vec<CompiledRewrite>,
666}
667
668impl CompiledRewrites {
669 fn compile(rules: &[RewriteRule]) -> Self {
670 let compiled = rules
671 .iter()
672 .filter_map(|r| match regex::Regex::new(&r.from) {
673 Ok(pattern) => Some(CompiledRewrite {
674 pattern,
675 to: r.to.clone(),
676 status: r.status,
677 match_query: r.match_query,
678 }),
679 Err(e) => {
680 tracing::warn!(rule = %r.from, error = %e, "invalid rewrite regex, skipping");
681 None
682 }
683 })
684 .collect();
685 Self { rules: compiled }
686 }
687}
688
689async fn rewrite_mw(
690 rules: Arc<CompiledRewrites>,
691 mut req: Request<Body>,
692 next: Next,
693) -> Response<Body> {
694 let path = req.uri().path().to_string();
695 let path_and_query = req
696 .uri()
697 .path_and_query()
698 .map(|p| p.as_str().to_string())
699 .unwrap_or_else(|| path.clone());
700
701 let target_str_path = path.clone();
702 let target_str_full = path_and_query.clone();
703
704 let mut applied: Option<(String, Option<u16>)> = None;
706 for rule in &rules.rules {
707 let subject = if rule.match_query {
708 &target_str_full
709 } else {
710 &target_str_path
711 };
712 if rule.pattern.is_match(subject) {
713 let replaced = rule.pattern.replace(subject, rule.to.as_str()).to_string();
714 applied = Some((replaced, rule.status));
715 break;
716 }
717 }
718
719 let Some((new_target, status)) = applied else {
720 return next.run(req).await;
721 };
722
723 match status {
724 Some(code @ (301 | 302 | 303 | 307 | 308)) => {
725 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
726 *resp.status_mut() =
727 StatusCode::from_u16(code).unwrap_or(StatusCode::MOVED_PERMANENTLY);
728 if let Ok(loc) = HeaderValue::from_str(&new_target) {
729 resp.headers_mut().insert("location", loc);
730 }
731 resp
732 }
733 _ => {
734 let mut parts = req.uri().clone().into_parts();
736 if let Ok(new_pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
737 parts.path_and_query = Some(new_pq);
738 }
739 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
740 *req.uri_mut() = new_uri;
741 }
742 next.run(req).await
743 }
744 }
745}
746
747async fn trailing_slash_mw(
750 cfg: TrailingSlashConfig,
751 mut req: Request<Body>,
752 next: Next,
753) -> Response<Body> {
754 let path = req.uri().path().to_string();
755 if path == "/" {
756 return next.run(req).await;
757 }
758
759 let want_slash = matches!(cfg.mode, TrailingSlashMode::Always);
760 let has_slash = path.ends_with('/');
761
762 if want_slash == has_slash {
763 return next.run(req).await;
764 }
765
766 let new_path = if want_slash {
767 format!("{path}/")
768 } else {
769 path.trim_end_matches('/').to_string()
770 };
771
772 let query = req
773 .uri()
774 .query()
775 .map(|q| format!("?{q}"))
776 .unwrap_or_default();
777 let new_target = format!("{new_path}{query}");
778
779 match cfg.action {
780 TrailingSlashAction::Redirect => {
781 let mut resp = Response::new(Body::from(format!("Redirecting to {new_target}\n")));
782 *resp.status_mut() = StatusCode::MOVED_PERMANENTLY;
783 if let Ok(loc) = HeaderValue::from_str(&new_target) {
784 resp.headers_mut().insert("location", loc);
785 }
786 resp
787 }
788 TrailingSlashAction::Rewrite => {
789 let mut parts = req.uri().clone().into_parts();
790 if let Ok(pq) = new_target.parse::<axum::http::uri::PathAndQuery>() {
791 parts.path_and_query = Some(pq);
792 }
793 if let Ok(new_uri) = axum::http::Uri::from_parts(parts) {
794 *req.uri_mut() = new_uri;
795 }
796 next.run(req).await
797 }
798 }
799}
800
801struct LoadedErrorPages {
804 by_status: std::collections::HashMap<u16, (String, &'static str)>,
805}
806
807fn load_error_pages(
808 raw: &std::collections::BTreeMap<String, std::path::PathBuf>,
809) -> LoadedErrorPages {
810 let mut by_status = std::collections::HashMap::new();
811 for (key, path) in raw {
812 let Ok(code) = key.parse::<u16>() else {
813 tracing::warn!(key, "error_pages: invalid status code, skipping");
814 continue;
815 };
816 let body = match std::fs::read_to_string(path) {
817 Ok(s) => s,
818 Err(e) => {
819 tracing::warn!(?path, ?e, "error_pages: failed to read file, skipping");
820 continue;
821 }
822 };
823 let content_type = guess_content_type(path);
824 by_status.insert(code, (body, content_type));
825 }
826 LoadedErrorPages { by_status }
827}
828
829fn guess_content_type(path: &std::path::Path) -> &'static str {
830 match path.extension().and_then(|e| e.to_str()) {
831 Some("html") | Some("htm") => "text/html; charset=utf-8",
832 Some("json") => "application/json",
833 Some("txt") => "text/plain; charset=utf-8",
834 _ => "text/plain; charset=utf-8",
835 }
836}
837
838async fn error_pages_mw(
839 pages: Arc<LoadedErrorPages>,
840 req: Request<Body>,
841 next: Next,
842) -> Response<Body> {
843 let resp = next.run(req).await;
844 let status = resp.status().as_u16();
845
846 let Some((body, ctype)) = pages.by_status.get(&status) else {
847 return resp;
848 };
849
850 let mut out = Response::new(Body::from(body.clone()));
851 *out.status_mut() = resp.status();
852 if let Ok(ct) = HeaderValue::from_str(ctype) {
853 out.headers_mut().insert("content-type", ct);
854 }
855 for h in ["cache-control", "x-request-id"] {
857 if let Some(v) = resp.headers().get(h) {
858 out.headers_mut().insert(h, v.clone());
859 }
860 }
861 out
862}
863
864#[derive(Clone)]
867struct CompiledProxy {
868 prefix: String,
869 upstream: String,
870 strip_prefix: bool,
871 preserve_host: bool,
872 timeout: Duration,
873 retries: u8,
874}
875
876struct CompiledProxies {
877 rules: Vec<CompiledProxy>,
878 client: reqwest::Client,
879}
880
881impl CompiledProxies {
882 fn compile(rules: &[ProxyRule]) -> Self {
883 let mut compiled: Vec<CompiledProxy> = rules
884 .iter()
885 .map(|r| CompiledProxy {
886 prefix: r.prefix.clone(),
887 upstream: r.upstream.trim_end_matches('/').to_string(),
888 strip_prefix: r.strip_prefix,
889 preserve_host: r.preserve_host,
890 timeout: r.timeout.unwrap_or(Duration::from_secs(30)),
891 retries: r.retries,
892 })
893 .collect();
894 compiled.sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len()));
896
897 let client = reqwest::Client::builder()
898 .redirect(reqwest::redirect::Policy::none())
899 .build()
900 .unwrap_or_else(|_| reqwest::Client::new());
901
902 Self {
903 rules: compiled,
904 client,
905 }
906 }
907
908 fn matching(&self, path: &str) -> Option<&CompiledProxy> {
909 self.rules.iter().find(|r| path.starts_with(&r.prefix))
910 }
911}
912
913async fn proxy_mw(proxies: Arc<CompiledProxies>, req: Request<Body>, next: Next) -> Response<Body> {
914 let path = req.uri().path().to_string();
915 let Some(rule) = proxies.matching(&path) else {
916 return next.run(req).await;
917 };
918 let rule = rule.clone();
919
920 match proxy_forward(&proxies.client, &rule, req).await {
921 Ok(resp) => resp,
922 Err(e) => {
923 tracing::warn!(?e, prefix = %rule.prefix, upstream = %rule.upstream, "proxy error");
924 let mut resp = Response::new(Body::from(format!("upstream error: {e}")));
925 *resp.status_mut() = StatusCode::BAD_GATEWAY;
926 resp
927 }
928 }
929}
930
931async fn proxy_forward(
932 client: &reqwest::Client,
933 rule: &CompiledProxy,
934 req: Request<Body>,
935) -> Result<Response<Body>, String> {
936 let (parts, body) = req.into_parts();
937 let body_bytes = axum::body::to_bytes(body, usize::MAX)
938 .await
939 .map_err(|e| format!("body read: {e}"))?;
940
941 let original_path = parts.uri.path();
942 let upstream_path = if rule.strip_prefix {
943 original_path
944 .strip_prefix(&rule.prefix)
945 .unwrap_or(original_path)
946 } else {
947 original_path
948 };
949 let upstream_path = if upstream_path.is_empty() {
950 "/"
951 } else {
952 upstream_path
953 };
954 let query = parts
955 .uri
956 .query()
957 .map(|q| format!("?{q}"))
958 .unwrap_or_default();
959 let upstream_url = format!("{}{}{}", rule.upstream, upstream_path, query);
960
961 let method = parts.method.clone();
962 let mut last_err = String::new();
963 for attempt in 0..=rule.retries {
964 let mut request = client
965 .request(
966 reqwest::Method::from_bytes(method.as_str().as_bytes())
967 .unwrap_or(reqwest::Method::GET),
968 &upstream_url,
969 )
970 .timeout(rule.timeout)
971 .body(body_bytes.clone());
972
973 for (name, value) in parts.headers.iter() {
974 let n = name.as_str().to_ascii_lowercase();
976 if matches!(
977 n.as_str(),
978 "connection"
979 | "keep-alive"
980 | "proxy-authenticate"
981 | "proxy-authorization"
982 | "te"
983 | "trailers"
984 | "transfer-encoding"
985 | "upgrade"
986 | "content-length"
987 ) {
988 continue;
989 }
990 if !rule.preserve_host && n == "host" {
991 continue;
992 }
993 if let Ok(v) = value.to_str() {
994 request = request.header(name.as_str(), v);
995 }
996 }
997
998 if let Some(host) = parts.headers.get("host").and_then(|v| v.to_str().ok()) {
1000 request = request.header("x-forwarded-host", host);
1001 }
1002 request = request.header("x-forwarded-proto", "http");
1003
1004 match request.send().await {
1005 Ok(resp) => return upstream_to_axum(resp).await,
1006 Err(e) => {
1007 last_err = format!("attempt {} → {e}", attempt + 1);
1008 tracing::debug!(error = %e, attempt, "proxy retry");
1009 continue;
1010 }
1011 }
1012 }
1013 Err(last_err)
1014}
1015
1016async fn cors_mw(cfg: Arc<CorsConfig>, req: Request<Body>, next: Next) -> Response<Body> {
1019 let origin = req
1020 .headers()
1021 .get("origin")
1022 .and_then(|v| v.to_str().ok())
1023 .map(String::from);
1024
1025 let is_allowed_origin = origin.as_deref().is_some_and(|o| {
1026 cfg.allow_origins
1027 .iter()
1028 .any(|allowed| allowed == "*" || allowed == o)
1029 });
1030
1031 if req.method() == Method::OPTIONS && origin.is_some() {
1033 let mut resp = Response::new(Body::empty());
1034 *resp.status_mut() = StatusCode::NO_CONTENT;
1035 apply_cors_headers(
1036 resp.headers_mut(),
1037 &cfg,
1038 origin.as_deref(),
1039 is_allowed_origin,
1040 );
1041 return resp;
1042 }
1043
1044 let mut resp = next.run(req).await;
1045 apply_cors_headers(
1046 resp.headers_mut(),
1047 &cfg,
1048 origin.as_deref(),
1049 is_allowed_origin,
1050 );
1051 resp
1052}
1053
1054fn apply_cors_headers(
1055 headers: &mut HeaderMap,
1056 cfg: &CorsConfig,
1057 origin: Option<&str>,
1058 is_allowed_origin: bool,
1059) {
1060 if !is_allowed_origin {
1061 return;
1062 }
1063 if let Some(origin) = origin {
1064 if let Ok(v) = HeaderValue::from_str(origin) {
1065 headers.insert("access-control-allow-origin", v);
1066 }
1067 headers.insert("vary", HeaderValue::from_static("Origin"));
1068 } else if cfg.allow_origins.iter().any(|o| o == "*") {
1069 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
1070 }
1071
1072 let methods = if cfg.allow_methods.is_empty() {
1073 "GET, POST, PUT, PATCH, DELETE, OPTIONS".to_string()
1074 } else {
1075 cfg.allow_methods.join(", ")
1076 };
1077 if let Ok(v) = HeaderValue::from_str(&methods) {
1078 headers.insert("access-control-allow-methods", v);
1079 }
1080
1081 let allow_headers = if cfg.allow_headers.is_empty() {
1082 "Content-Type, Authorization, X-CSRF-TOKEN, X-Requested-With".to_string()
1083 } else {
1084 cfg.allow_headers.join(", ")
1085 };
1086 if let Ok(v) = HeaderValue::from_str(&allow_headers) {
1087 headers.insert("access-control-allow-headers", v);
1088 }
1089
1090 if !cfg.expose_headers.is_empty() {
1091 if let Ok(v) = HeaderValue::from_str(&cfg.expose_headers.join(", ")) {
1092 headers.insert("access-control-expose-headers", v);
1093 }
1094 }
1095
1096 if cfg.allow_credentials {
1097 headers.insert(
1098 "access-control-allow-credentials",
1099 HeaderValue::from_static("true"),
1100 );
1101 }
1102
1103 if let Some(max_age) = cfg.max_age {
1104 if let Ok(v) = HeaderValue::from_str(&max_age.as_secs().to_string()) {
1105 headers.insert("access-control-max-age", v);
1106 }
1107 }
1108}
1109
1110async fn ip_rules_mw(rules: Arc<Vec<IpRule>>, req: Request<Body>, next: Next) -> Response<Body> {
1113 let path = req.uri().path().to_string();
1114 let ip_str = client_ip(&req);
1115 let ip = ip_str.parse::<std::net::IpAddr>().ok();
1116
1117 for rule in rules.iter() {
1118 if !path.starts_with(&rule.prefix) {
1119 continue;
1120 }
1121 let matches_range = ip
1122 .map(|addr| rule.ranges.iter().any(|net| net.contains(&addr)))
1123 .unwrap_or(false);
1124 let allowed = match rule.action {
1125 IpAction::Allow => matches_range,
1126 IpAction::Deny => !matches_range,
1127 };
1128 if !allowed {
1129 tracing::debug!(path, ip = %ip_str, "ip rule denied request");
1130 let mut resp = Response::new(Body::from("forbidden"));
1131 *resp.status_mut() = StatusCode::FORBIDDEN;
1132 return resp;
1133 }
1134 break;
1136 }
1137
1138 next.run(req).await
1139}
1140
1141use base64::engine::general_purpose::STANDARD as B64;
1144use base64::Engine as _;
1145
1146struct CompiledBasicAuth {
1147 rules: Vec<(BasicAuthRule, Vec<(String, String)>)>,
1148}
1149
1150fn compile_basic_auth(rules: &[BasicAuthRule]) -> CompiledBasicAuth {
1151 let compiled = rules
1152 .iter()
1153 .map(|r| {
1154 let creds = r
1155 .credentials
1156 .iter()
1157 .filter_map(|c| {
1158 c.split_once(':')
1159 .map(|(u, p)| (u.to_string(), p.to_string()))
1160 })
1161 .collect();
1162 (r.clone(), creds)
1163 })
1164 .collect();
1165 CompiledBasicAuth { rules: compiled }
1166}
1167
1168async fn basic_auth_mw(
1169 rules: Arc<CompiledBasicAuth>,
1170 req: Request<Body>,
1171 next: Next,
1172) -> Response<Body> {
1173 let path = req.uri().path().to_string();
1174 for (rule, creds) in &rules.rules {
1175 if !path.starts_with(&rule.prefix) {
1176 continue;
1177 }
1178 let supplied = req
1179 .headers()
1180 .get("authorization")
1181 .and_then(|v| v.to_str().ok())
1182 .and_then(|s| s.strip_prefix("Basic "))
1183 .and_then(|b64| B64.decode(b64).ok())
1184 .and_then(|bytes| String::from_utf8(bytes).ok())
1185 .and_then(|pair| {
1186 pair.split_once(':')
1187 .map(|(u, p)| (u.to_string(), p.to_string()))
1188 });
1189
1190 let ok = supplied
1191 .as_ref()
1192 .map(|(u, p)| creds.iter().any(|(cu, cp)| cu == u && cp == p))
1193 .unwrap_or(false);
1194
1195 if ok {
1196 return next.run(req).await;
1197 }
1198
1199 let challenge = format!("Basic realm=\"{}\"", rule.realm);
1200 let mut resp = Response::new(Body::from("authentication required"));
1201 *resp.status_mut() = StatusCode::UNAUTHORIZED;
1202 if let Ok(v) = HeaderValue::from_str(&challenge) {
1203 resp.headers_mut().insert("www-authenticate", v);
1204 }
1205 return resp;
1206 }
1207 next.run(req).await
1208}
1209
1210async fn upstream_to_axum(resp: reqwest::Response) -> Result<Response<Body>, String> {
1211 let status = resp.status();
1212 let headers = resp.headers().clone();
1213 let bytes = resp
1214 .bytes()
1215 .await
1216 .map_err(|e| format!("upstream body: {e}"))?;
1217 let mut out = Response::new(Body::from(bytes));
1218 *out.status_mut() =
1219 axum::http::StatusCode::from_u16(status.as_u16()).unwrap_or(axum::http::StatusCode::OK);
1220 for (name, value) in headers.iter() {
1221 let n = name.as_str().to_ascii_lowercase();
1222 if matches!(
1223 n.as_str(),
1224 "connection"
1225 | "keep-alive"
1226 | "proxy-authenticate"
1227 | "proxy-authorization"
1228 | "te"
1229 | "trailers"
1230 | "transfer-encoding"
1231 | "upgrade"
1232 ) {
1233 continue;
1234 }
1235 if let Ok(v) = HeaderValue::from_bytes(value.as_bytes()) {
1236 if let Ok(name) = HeaderName::from_bytes(name.as_str().as_bytes()) {
1237 out.headers_mut().append(name, v);
1238 }
1239 }
1240 }
1241 Ok(out)
1242}