1use std::collections::BTreeMap;
8use std::path::PathBuf;
9use std::time::Duration;
10
11use serde::{Deserialize, Deserializer, Serialize};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(default, deny_unknown_fields)]
15pub struct ServerConfig {
16 #[serde(default = "default_bind")]
18 pub bind: String,
19
20 #[serde(default)]
23 pub server_name: Vec<String>,
24
25 pub tls: Option<TlsConfig>,
27
28 pub redirect_http: Option<RedirectHttpConfig>,
31
32 #[serde(default)]
34 pub hsts: HstsConfig,
35
36 #[serde(default)]
38 pub limits: LimitsConfig,
39
40 #[serde(default)]
42 pub compression: CompressionConfig,
43
44 #[serde(default)]
46 pub static_files: BTreeMap<String, StaticMount>,
47
48 #[serde(default)]
50 pub rate_limit: RateLimitConfig,
51
52 #[serde(default)]
55 pub trusted_proxies: TrustedProxiesConfig,
56
57 #[serde(default, rename = "route_timeout")]
63 pub route_timeouts: Vec<RouteTimeoutRule>,
64
65 #[serde(default)]
67 pub access_log: AccessLogConfig,
68
69 #[serde(default)]
71 pub rewrites: Vec<RewriteRule>,
72
73 #[serde(default)]
75 pub error_pages: BTreeMap<String, std::path::PathBuf>,
76
77 #[serde(default)]
79 pub trailing_slash: TrailingSlashConfig,
80
81 #[serde(default, rename = "proxy")]
83 pub proxies: Vec<ProxyRule>,
84
85 #[serde(default)]
87 pub cors: CorsConfig,
88
89 #[serde(default)]
91 pub ip_rules: Vec<IpRule>,
92
93 #[serde(default, rename = "basic_auth")]
95 pub basic_auth: Vec<BasicAuthRule>,
96}
97
98fn default_bind() -> String {
99 "127.0.0.1:8080".to_string()
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct TlsConfig {
104 pub cert: PathBuf,
105 pub key: PathBuf,
106
107 #[serde(default)]
112 pub acme: Option<AcmeConfig>,
113
114 #[serde(default, rename = "certs")]
137 pub additional_certs: Vec<SniCertEntry>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(deny_unknown_fields)]
142pub struct AcmeConfig {
143 pub domains: Vec<String>,
146
147 #[serde(default)]
150 pub contact: Option<String>,
151
152 #[serde(default = "default_acme_cache")]
156 pub cache_dir: PathBuf,
157
158 #[serde(default = "default_acme_directory")]
163 pub directory: String,
164}
165
166fn default_acme_cache() -> PathBuf {
167 PathBuf::from("./database/acme-cache")
168}
169
170fn default_acme_directory() -> String {
171 "https://acme-v02.api.letsencrypt.org/directory".to_string()
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175#[serde(deny_unknown_fields)]
176pub struct SniCertEntry {
177 pub server_name: String,
181 pub cert: PathBuf,
182 pub key: PathBuf,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(deny_unknown_fields)]
187pub struct RedirectHttpConfig {
188 pub bind: String,
190
191 #[serde(default = "yes")]
193 pub permanent: bool,
194
195 pub target_host: Option<String>,
198}
199
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default, deny_unknown_fields)]
202pub struct HstsConfig {
203 pub enabled: bool,
204
205 #[serde(deserialize_with = "deserialize_opt_duration", default)]
207 pub max_age: Option<Duration>,
208
209 pub include_subdomains: bool,
210 pub preload: bool,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214#[serde(default, deny_unknown_fields)]
215pub struct LimitsConfig {
216 #[serde(deserialize_with = "deserialize_size", default = "default_body_max")]
218 pub body_max: u64,
219
220 #[serde(deserialize_with = "deserialize_opt_duration", default)]
222 pub request_timeout: Option<Duration>,
223
224 #[serde(
230 deserialize_with = "deserialize_duration",
231 default = "default_drain_timeout"
232 )]
233 pub drain_timeout: Duration,
234
235 pub max_concurrency: Option<u32>,
246}
247
248impl Default for LimitsConfig {
249 fn default() -> Self {
250 Self {
251 body_max: default_body_max(),
252 request_timeout: None,
253 drain_timeout: default_drain_timeout(),
254 max_concurrency: None,
255 }
256 }
257}
258
259fn default_drain_timeout() -> Duration {
260 Duration::from_secs(10)
261}
262
263fn default_body_max() -> u64 {
264 16 * 1024 * 1024 }
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
268#[serde(default, deny_unknown_fields)]
269pub struct CompressionConfig {
270 pub enabled: bool,
272
273 pub algorithms: Vec<String>,
276
277 #[serde(deserialize_with = "deserialize_size", default = "default_min_size")]
280 pub min_size: u64,
281}
282
283impl Default for CompressionConfig {
284 fn default() -> Self {
285 Self {
286 enabled: false,
287 algorithms: vec!["gzip".to_string()],
288 min_size: default_min_size(),
289 }
290 }
291}
292
293fn default_min_size() -> u64 {
294 1024
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(deny_unknown_fields)]
299pub struct StaticMount {
300 pub dir: PathBuf,
302
303 #[serde(deserialize_with = "deserialize_opt_duration", default)]
306 pub cache: Option<Duration>,
307
308 #[serde(default = "yes")]
310 pub ranges: bool,
311}
312
313fn yes() -> bool {
314 true
315}
316
317#[derive(Debug, Clone, Default, Serialize, Deserialize)]
318#[serde(default, deny_unknown_fields)]
319pub struct RateLimitConfig {
320 pub per_ip: Option<String>,
322
323 #[serde(default)]
325 pub routes: BTreeMap<String, String>,
326}
327
328#[derive(Debug, Clone, Default, Serialize, Deserialize)]
329#[serde(default, deny_unknown_fields)]
330pub struct TrustedProxiesConfig {
331 pub ranges: Vec<ipnet::IpNet>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
336#[serde(deny_unknown_fields)]
337pub struct RouteTimeoutRule {
338 pub prefix: String,
340
341 #[serde(deserialize_with = "deserialize_duration")]
344 pub timeout: Duration,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
348#[serde(deny_unknown_fields)]
349pub struct RewriteRule {
350 pub from: String,
352
353 pub to: String,
355
356 #[serde(default)]
360 pub status: Option<u16>,
361
362 #[serde(default)]
364 pub match_query: bool,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
368#[serde(default, deny_unknown_fields)]
369pub struct TrailingSlashConfig {
370 pub mode: TrailingSlashMode,
374
375 pub action: TrailingSlashAction,
377}
378
379impl Default for TrailingSlashConfig {
380 fn default() -> Self {
381 Self {
382 mode: TrailingSlashMode::Ignore,
383 action: TrailingSlashAction::Redirect,
384 }
385 }
386}
387
388#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
389#[serde(rename_all = "lowercase")]
390pub enum TrailingSlashMode {
391 Always,
392 Never,
393 Ignore,
394}
395
396#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
397#[serde(rename_all = "lowercase")]
398pub enum TrailingSlashAction {
399 Redirect,
400 Rewrite,
401}
402
403#[derive(Debug, Clone, Default, Serialize, Deserialize)]
404#[serde(default, deny_unknown_fields)]
405pub struct CorsConfig {
406 pub enabled: bool,
407 pub allow_origins: Vec<String>,
409 pub allow_methods: Vec<String>,
411 pub allow_headers: Vec<String>,
413 pub expose_headers: Vec<String>,
415 pub allow_credentials: bool,
417 #[serde(deserialize_with = "deserialize_opt_duration", default)]
419 pub max_age: Option<Duration>,
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize)]
423#[serde(deny_unknown_fields)]
424pub struct IpRule {
425 pub prefix: String,
427 pub action: IpAction,
429 pub ranges: Vec<ipnet::IpNet>,
431}
432
433#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
434#[serde(rename_all = "lowercase")]
435pub enum IpAction {
436 Allow,
437 Deny,
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
441#[serde(deny_unknown_fields)]
442pub struct BasicAuthRule {
443 pub prefix: String,
444 #[serde(default = "default_realm")]
446 pub realm: String,
447 pub credentials: Vec<String>,
449}
450
451fn default_realm() -> String {
452 "Restricted".to_string()
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
456#[serde(deny_unknown_fields)]
457pub struct ProxyRule {
458 pub prefix: String,
460
461 pub upstream: String,
463
464 #[serde(default)]
466 pub strip_prefix: bool,
467
468 #[serde(default)]
470 pub preserve_host: bool,
471
472 #[serde(deserialize_with = "deserialize_opt_duration", default)]
474 pub timeout: Option<Duration>,
475
476 #[serde(default)]
478 pub retries: u8,
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
482#[serde(default, deny_unknown_fields)]
483pub struct AccessLogConfig {
484 pub format: AccessLogFormat,
485 pub path: Option<PathBuf>,
486}
487
488impl Default for AccessLogConfig {
489 fn default() -> Self {
490 Self {
491 format: AccessLogFormat::Combined,
492 path: None,
493 }
494 }
495}
496
497#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
498#[serde(rename_all = "lowercase")]
499pub enum AccessLogFormat {
500 Combined,
502 Json,
504 Off,
506}
507
508impl ServerConfig {
509 pub fn from_file_or_default(path: impl AsRef<std::path::Path>) -> Self {
511 match Self::from_file(path.as_ref()) {
512 Ok(c) => c,
513 Err(crate::Error::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => Self::default(),
514 Err(e) => {
515 tracing::warn!(?e, path = %path.as_ref().display(), "failed to load server config; using defaults");
516 Self::default()
517 }
518 }
519 }
520
521 pub fn from_file(path: &std::path::Path) -> crate::Result<Self> {
522 let bytes = std::fs::read_to_string(path)?;
523 let cfg: Self = toml::from_str(&bytes)
524 .map_err(|e| crate::Error::Config(format!("toml parse {}: {e}", path.display())))?;
525 Ok(cfg.apply_env_overrides())
526 }
527
528 pub fn apply_env_overrides(mut self) -> Self {
531 if let Ok(v) = std::env::var("APP_ADDR") {
532 self.bind = v;
533 }
534 if let (Ok(cert), Ok(key)) = (std::env::var("TLS_CERT"), std::env::var("TLS_KEY")) {
535 self.tls = Some(TlsConfig {
536 cert: PathBuf::from(cert),
537 key: PathBuf::from(key),
538 acme: None,
539 additional_certs: Vec::new(),
540 });
541 }
542 self
543 }
544}
545
546fn deserialize_size<'de, D: Deserializer<'de>>(d: D) -> Result<u64, D::Error> {
549 use serde::de::Error;
550 let v = toml::Value::deserialize(d)?;
551 match v {
552 toml::Value::Integer(n) => Ok(n.max(0) as u64),
553 toml::Value::String(s) => parse_size(&s).map_err(D::Error::custom),
554 other => Err(D::Error::custom(format!(
555 "expected integer or size string, got {other:?}"
556 ))),
557 }
558}
559
560fn deserialize_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
561 use serde::de::Error;
562 let v = toml::Value::deserialize(d)?;
563 match v {
564 toml::Value::Integer(n) => Ok(Duration::from_secs(n.max(0) as u64)),
565 toml::Value::String(s) => parse_duration(&s).map_err(D::Error::custom),
566 other => Err(D::Error::custom(format!(
567 "expected integer (seconds) or duration string, got {other:?}"
568 ))),
569 }
570}
571
572fn deserialize_opt_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Duration>, D::Error> {
573 use serde::de::Error;
574 let v = Option::<toml::Value>::deserialize(d)?;
575 match v {
576 None | Some(toml::Value::String(_)) if matches!(&v, Some(toml::Value::String(s)) if s.is_empty()) => {
577 Ok(None)
578 }
579 None => Ok(None),
580 Some(toml::Value::Integer(n)) => Ok(Some(Duration::from_secs(n.max(0) as u64))),
581 Some(toml::Value::String(s)) => parse_duration(&s).map(Some).map_err(D::Error::custom),
582 Some(other) => Err(D::Error::custom(format!(
583 "expected integer (seconds) or duration string, got {other:?}"
584 ))),
585 }
586}
587
588pub fn parse_size(s: &str) -> Result<u64, String> {
590 let s = s.trim();
591 if s.is_empty() {
592 return Err("empty size".into());
593 }
594 if let Ok(n) = s.parse::<u64>() {
595 return Ok(n);
596 }
597 let (num_part, unit_part) = split_num_unit(s);
598 let num: f64 = num_part
599 .parse()
600 .map_err(|e| format!("invalid size number `{num_part}`: {e}"))?;
601 let mult: u64 = match unit_part.trim().to_ascii_uppercase().as_str() {
602 "" | "B" => 1,
603 "K" | "KB" | "KIB" => 1024,
604 "M" | "MB" | "MIB" => 1024 * 1024,
605 "G" | "GB" | "GIB" => 1024 * 1024 * 1024,
606 other => return Err(format!("unknown size unit `{other}`")),
607 };
608 Ok((num * mult as f64) as u64)
609}
610
611pub fn parse_duration(s: &str) -> Result<Duration, String> {
615 let s = s.trim();
616 if s.is_empty() {
617 return Err("empty duration".into());
618 }
619 if let Ok(n) = s.parse::<u64>() {
620 return Ok(Duration::from_secs(n));
621 }
622 let (num_part, unit_part) = split_num_unit(s);
623 let num: u64 = if num_part.is_empty() {
624 1
625 } else {
626 num_part
627 .parse()
628 .map_err(|e| format!("invalid duration number `{num_part}`: {e}"))?
629 };
630 let secs: u64 = match unit_part.trim().to_ascii_lowercase().as_str() {
631 "s" | "sec" | "secs" | "second" | "seconds" => num,
632 "m" | "min" | "mins" | "minute" | "minutes" => num * 60,
633 "h" | "hr" | "hrs" | "hour" | "hours" => num * 3600,
634 "d" | "day" | "days" => num * 86400,
635 "w" | "wk" | "wks" | "week" | "weeks" => num * 86400 * 7,
636 "mo" | "month" | "months" => num * 86400 * 30,
637 "y" | "yr" | "yrs" | "year" | "years" => num * 86400 * 365,
638 other => return Err(format!("unknown duration unit `{other}`")),
639 };
640 Ok(Duration::from_secs(secs))
641}
642
643fn split_num_unit(s: &str) -> (&str, &str) {
644 let split = s
645 .find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-')
646 .unwrap_or(s.len());
647 (s[..split].trim(), s[split..].trim())
648}
649
650pub fn parse_rate(s: &str) -> Result<(u32, Duration), String> {
652 let (count, window) = s
653 .split_once('/')
654 .ok_or_else(|| format!("rate must be `<count>/<window>`: {s}"))?;
655 let count: u32 = count
656 .trim()
657 .parse()
658 .map_err(|e| format!("invalid count `{count}`: {e}"))?;
659 let dur = parse_duration(window.trim())?;
660 Ok((count, dur))
661}
662
663#[cfg(test)]
664mod tests {
665 use super::*;
666
667 #[test]
668 fn parses_sizes() {
669 assert_eq!(parse_size("10").unwrap(), 10);
670 assert_eq!(parse_size("10KB").unwrap(), 10 * 1024);
671 assert_eq!(parse_size("2MB").unwrap(), 2 * 1024 * 1024);
672 assert_eq!(parse_size("1GB").unwrap(), 1024 * 1024 * 1024);
673 assert_eq!(parse_size("1.5MB").unwrap(), (1.5 * 1024.0 * 1024.0) as u64);
674 assert!(parse_size("bad").is_err());
675 }
676
677 #[test]
678 fn parses_durations() {
679 assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
680 assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
681 assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
682 assert_eq!(parse_duration("1d").unwrap(), Duration::from_secs(86400));
683 assert_eq!(
684 parse_duration("1y").unwrap(),
685 Duration::from_secs(86400 * 365)
686 );
687 assert_eq!(parse_duration("42").unwrap(), Duration::from_secs(42));
688 assert!(parse_duration("bad").is_err());
689 }
690
691 #[test]
692 fn parses_rates() {
693 let (count, win) = parse_rate("60/minute").unwrap();
694 assert_eq!(count, 60);
695 assert_eq!(win, Duration::from_secs(60));
696 let (count, win) = parse_rate("5/m").unwrap();
697 assert_eq!(count, 5);
698 assert_eq!(win, Duration::from_secs(60));
699 }
700
701 #[test]
702 fn loads_vhost_and_security_toml() {
703 let toml = r#"
704 bind = "0.0.0.0:443"
705 server_name = ["example.com", "www.example.com", "*.example.com"]
706
707 [tls]
708 cert = "/etc/cert.pem"
709 key = "/etc/key.pem"
710
711 [redirect_http]
712 bind = "0.0.0.0:80"
713 permanent = true
714 target_host = "example.com"
715
716 [hsts]
717 enabled = true
718 max_age = "1y"
719 include_subdomains = true
720 preload = false
721
722 [cors]
723 enabled = true
724 allow_origins = ["*"]
725 allow_credentials = false
726 max_age = "1h"
727
728 [[ip_rules]]
729 prefix = "/admin"
730 action = "allow"
731 ranges = ["10.0.0.0/8"]
732
733 [[basic_auth]]
734 prefix = "/admin"
735 realm = "Admin"
736 credentials = ["alice:secret", "bob:second"]
737 "#;
738 let cfg: ServerConfig = toml::from_str(toml).unwrap();
739 assert_eq!(
740 cfg.server_name,
741 vec!["example.com", "www.example.com", "*.example.com"]
742 );
743 assert!(cfg.redirect_http.is_some());
744 assert_eq!(
745 cfg.redirect_http.as_ref().unwrap().target_host.as_deref(),
746 Some("example.com")
747 );
748 assert!(cfg.hsts.enabled);
749 assert_eq!(cfg.hsts.max_age, Some(Duration::from_secs(86400 * 365)));
750 assert!(cfg.cors.enabled);
751 assert_eq!(cfg.ip_rules.len(), 1);
752 assert_eq!(cfg.basic_auth.len(), 1);
753 assert_eq!(cfg.basic_auth[0].credentials.len(), 2);
754 }
755
756 #[test]
757 fn loads_rewrites_and_proxies_toml() {
758 let toml = r#"
759 [[rewrites]]
760 from = "^/old/(.*)$"
761 to = "/new/$1"
762 status = 301
763
764 [[rewrites]]
765 from = "^/legacy/(.*)$"
766 to = "/v2/$1"
767
768 [trailing_slash]
769 mode = "always"
770 action = "redirect"
771
772 [error_pages]
773 404 = "errors/404.html"
774 500 = "errors/500.html"
775
776 [[proxy]]
777 prefix = "/api/v2"
778 upstream = "http://api-v2.internal:8080"
779 strip_prefix = true
780 timeout = "10s"
781 retries = 3
782 "#;
783 let cfg: ServerConfig = toml::from_str(toml).unwrap();
784 assert_eq!(cfg.rewrites.len(), 2);
785 assert_eq!(cfg.rewrites[0].status, Some(301));
786 assert!(cfg.rewrites[1].status.is_none());
787 assert_eq!(cfg.trailing_slash.mode, TrailingSlashMode::Always);
788 assert_eq!(cfg.trailing_slash.action, TrailingSlashAction::Redirect);
789 assert_eq!(cfg.error_pages.len(), 2);
790 assert!(cfg.error_pages.contains_key("404"));
791 assert_eq!(cfg.proxies.len(), 1);
792 assert_eq!(cfg.proxies[0].upstream, "http://api-v2.internal:8080");
793 assert_eq!(cfg.proxies[0].retries, 3);
794 assert_eq!(cfg.proxies[0].timeout, Some(Duration::from_secs(10)));
795 }
796
797 #[test]
798 fn loads_full_toml() {
799 let toml = r#"
800 bind = "0.0.0.0:443"
801
802 [tls]
803 cert = "/etc/letsencrypt/live/example.com/fullchain.pem"
804 key = "/etc/letsencrypt/live/example.com/privkey.pem"
805
806 [limits]
807 body_max = "10MB"
808 request_timeout = "30s"
809
810 [compression]
811 enabled = true
812 algorithms = ["gzip", "br"]
813 min_size = "1KB"
814
815 [static_files."/assets"]
816 dir = "public/build"
817 cache = "1y"
818
819 [rate_limit]
820 per_ip = "60/minute"
821
822 [rate_limit.routes]
823 "POST /login" = "5/minute"
824
825 [trusted_proxies]
826 ranges = ["10.0.0.0/8", "127.0.0.1/32"]
827
828 [access_log]
829 format = "json"
830 path = "storage/logs/access.log"
831 "#;
832 let cfg: ServerConfig = toml::from_str(toml).unwrap();
833 assert_eq!(cfg.bind, "0.0.0.0:443");
834 assert!(cfg.tls.is_some());
835 assert_eq!(cfg.limits.body_max, 10 * 1024 * 1024);
836 assert_eq!(cfg.limits.request_timeout, Some(Duration::from_secs(30)));
837 assert!(cfg.compression.enabled);
838 assert_eq!(cfg.compression.algorithms, vec!["gzip", "br"]);
839 assert_eq!(cfg.compression.min_size, 1024);
840 assert!(cfg.static_files.contains_key("/assets"));
841 assert_eq!(
842 cfg.static_files["/assets"].cache,
843 Some(Duration::from_secs(86400 * 365))
844 );
845 assert_eq!(cfg.rate_limit.per_ip.as_deref(), Some("60/minute"));
846 assert_eq!(
847 cfg.rate_limit.routes.get("POST /login").map(String::as_str),
848 Some("5/minute")
849 );
850 assert_eq!(cfg.trusted_proxies.ranges.len(), 2);
851 assert_eq!(cfg.access_log.format, AccessLogFormat::Json);
852 }
853}