1use std::collections::BTreeMap;
8use std::path::PathBuf;
9use std::time::Duration;
10
11use serde::{Deserialize, Deserializer, Serialize};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(default, deny_unknown_fields)]
15pub struct ServerConfig {
16 #[serde(default = "default_bind")]
18 pub bind: String,
19
20 #[serde(default)]
23 pub server_name: Vec<String>,
24
25 pub tls: Option<TlsConfig>,
27
28 pub redirect_http: Option<RedirectHttpConfig>,
31
32 #[serde(default)]
34 pub hsts: HstsConfig,
35
36 #[serde(default)]
38 pub limits: LimitsConfig,
39
40 #[serde(default)]
42 pub compression: CompressionConfig,
43
44 #[serde(default)]
46 pub static_files: BTreeMap<String, StaticMount>,
47
48 #[serde(default)]
50 pub rate_limit: RateLimitConfig,
51
52 #[serde(default)]
55 pub trusted_proxies: TrustedProxiesConfig,
56
57 #[serde(default)]
59 pub access_log: AccessLogConfig,
60
61 #[serde(default)]
63 pub rewrites: Vec<RewriteRule>,
64
65 #[serde(default)]
67 pub error_pages: BTreeMap<String, std::path::PathBuf>,
68
69 #[serde(default)]
71 pub trailing_slash: TrailingSlashConfig,
72
73 #[serde(default, rename = "proxy")]
75 pub proxies: Vec<ProxyRule>,
76
77 #[serde(default)]
79 pub cors: CorsConfig,
80
81 #[serde(default)]
83 pub ip_rules: Vec<IpRule>,
84
85 #[serde(default, rename = "basic_auth")]
87 pub basic_auth: Vec<BasicAuthRule>,
88}
89
90fn default_bind() -> String {
91 "127.0.0.1:8080".to_string()
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub struct TlsConfig {
96 pub cert: PathBuf,
97 pub key: PathBuf,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101#[serde(deny_unknown_fields)]
102pub struct RedirectHttpConfig {
103 pub bind: String,
105
106 #[serde(default = "yes")]
108 pub permanent: bool,
109
110 pub target_host: Option<String>,
113}
114
115#[derive(Debug, Clone, Default, Serialize, Deserialize)]
116#[serde(default, deny_unknown_fields)]
117pub struct HstsConfig {
118 pub enabled: bool,
119
120 #[serde(deserialize_with = "deserialize_opt_duration", default)]
122 pub max_age: Option<Duration>,
123
124 pub include_subdomains: bool,
125 pub preload: bool,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129#[serde(default, deny_unknown_fields)]
130pub struct LimitsConfig {
131 #[serde(deserialize_with = "deserialize_size", default = "default_body_max")]
133 pub body_max: u64,
134
135 #[serde(deserialize_with = "deserialize_opt_duration", default)]
137 pub request_timeout: Option<Duration>,
138}
139
140impl Default for LimitsConfig {
141 fn default() -> Self {
142 Self {
143 body_max: default_body_max(),
144 request_timeout: None,
145 }
146 }
147}
148
149fn default_body_max() -> u64 {
150 16 * 1024 * 1024 }
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154#[serde(default, deny_unknown_fields)]
155pub struct CompressionConfig {
156 pub enabled: bool,
158
159 pub algorithms: Vec<String>,
162
163 #[serde(deserialize_with = "deserialize_size", default = "default_min_size")]
166 pub min_size: u64,
167}
168
169impl Default for CompressionConfig {
170 fn default() -> Self {
171 Self {
172 enabled: false,
173 algorithms: vec!["gzip".to_string()],
174 min_size: default_min_size(),
175 }
176 }
177}
178
179fn default_min_size() -> u64 {
180 1024
181}
182
183#[derive(Debug, Clone, Serialize, Deserialize)]
184#[serde(deny_unknown_fields)]
185pub struct StaticMount {
186 pub dir: PathBuf,
188
189 #[serde(deserialize_with = "deserialize_opt_duration", default)]
192 pub cache: Option<Duration>,
193
194 #[serde(default = "yes")]
196 pub ranges: bool,
197}
198
199fn yes() -> bool {
200 true
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
204#[serde(default, deny_unknown_fields)]
205pub struct RateLimitConfig {
206 pub per_ip: Option<String>,
208
209 #[serde(default)]
211 pub routes: BTreeMap<String, String>,
212}
213
214impl Default for RateLimitConfig {
215 fn default() -> Self {
216 Self {
217 per_ip: None,
218 routes: BTreeMap::new(),
219 }
220 }
221}
222
223#[derive(Debug, Clone, Default, Serialize, Deserialize)]
224#[serde(default, deny_unknown_fields)]
225pub struct TrustedProxiesConfig {
226 pub ranges: Vec<ipnet::IpNet>,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(deny_unknown_fields)]
232pub struct RewriteRule {
233 pub from: String,
235
236 pub to: String,
238
239 #[serde(default)]
243 pub status: Option<u16>,
244
245 #[serde(default)]
247 pub match_query: bool,
248}
249
250#[derive(Debug, Clone, Serialize, Deserialize)]
251#[serde(default, deny_unknown_fields)]
252pub struct TrailingSlashConfig {
253 pub mode: TrailingSlashMode,
257
258 pub action: TrailingSlashAction,
260}
261
262impl Default for TrailingSlashConfig {
263 fn default() -> Self {
264 Self {
265 mode: TrailingSlashMode::Ignore,
266 action: TrailingSlashAction::Redirect,
267 }
268 }
269}
270
271#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
272#[serde(rename_all = "lowercase")]
273pub enum TrailingSlashMode {
274 Always,
275 Never,
276 Ignore,
277}
278
279#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
280#[serde(rename_all = "lowercase")]
281pub enum TrailingSlashAction {
282 Redirect,
283 Rewrite,
284}
285
286#[derive(Debug, Clone, Default, Serialize, Deserialize)]
287#[serde(default, deny_unknown_fields)]
288pub struct CorsConfig {
289 pub enabled: bool,
290 pub allow_origins: Vec<String>,
292 pub allow_methods: Vec<String>,
294 pub allow_headers: Vec<String>,
296 pub expose_headers: Vec<String>,
298 pub allow_credentials: bool,
300 #[serde(deserialize_with = "deserialize_opt_duration", default)]
302 pub max_age: Option<Duration>,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
306#[serde(deny_unknown_fields)]
307pub struct IpRule {
308 pub prefix: String,
310 pub action: IpAction,
312 pub ranges: Vec<ipnet::IpNet>,
314}
315
316#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
317#[serde(rename_all = "lowercase")]
318pub enum IpAction {
319 Allow,
320 Deny,
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize)]
324#[serde(deny_unknown_fields)]
325pub struct BasicAuthRule {
326 pub prefix: String,
327 #[serde(default = "default_realm")]
329 pub realm: String,
330 pub credentials: Vec<String>,
332}
333
334fn default_realm() -> String {
335 "Restricted".to_string()
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
339#[serde(deny_unknown_fields)]
340pub struct ProxyRule {
341 pub prefix: String,
343
344 pub upstream: String,
346
347 #[serde(default)]
349 pub strip_prefix: bool,
350
351 #[serde(default)]
353 pub preserve_host: bool,
354
355 #[serde(deserialize_with = "deserialize_opt_duration", default)]
357 pub timeout: Option<Duration>,
358
359 #[serde(default)]
361 pub retries: u8,
362}
363
364#[derive(Debug, Clone, Serialize, Deserialize)]
365#[serde(default, deny_unknown_fields)]
366pub struct AccessLogConfig {
367 pub format: AccessLogFormat,
368 pub path: Option<PathBuf>,
369}
370
371impl Default for AccessLogConfig {
372 fn default() -> Self {
373 Self {
374 format: AccessLogFormat::Combined,
375 path: None,
376 }
377 }
378}
379
380#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
381#[serde(rename_all = "lowercase")]
382pub enum AccessLogFormat {
383 Combined,
385 Json,
387 Off,
389}
390
391impl ServerConfig {
392 pub fn from_file_or_default(path: impl AsRef<std::path::Path>) -> Self {
394 match Self::from_file(path.as_ref()) {
395 Ok(c) => c,
396 Err(crate::Error::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => Self::default(),
397 Err(e) => {
398 tracing::warn!(?e, path = %path.as_ref().display(), "failed to load server config; using defaults");
399 Self::default()
400 }
401 }
402 }
403
404 pub fn from_file(path: &std::path::Path) -> crate::Result<Self> {
405 let bytes = std::fs::read_to_string(path)?;
406 let cfg: Self = toml::from_str(&bytes)
407 .map_err(|e| crate::Error::Config(format!("toml parse {}: {e}", path.display())))?;
408 Ok(cfg.apply_env_overrides())
409 }
410
411 pub fn apply_env_overrides(mut self) -> Self {
414 if let Ok(v) = std::env::var("APP_ADDR") {
415 self.bind = v;
416 }
417 if let (Ok(cert), Ok(key)) = (std::env::var("TLS_CERT"), std::env::var("TLS_KEY")) {
418 self.tls = Some(TlsConfig {
419 cert: PathBuf::from(cert),
420 key: PathBuf::from(key),
421 });
422 }
423 self
424 }
425}
426
427fn deserialize_size<'de, D: Deserializer<'de>>(d: D) -> Result<u64, D::Error> {
430 use serde::de::Error;
431 let v = toml::Value::deserialize(d)?;
432 match v {
433 toml::Value::Integer(n) => Ok(n.max(0) as u64),
434 toml::Value::String(s) => parse_size(&s).map_err(D::Error::custom),
435 other => Err(D::Error::custom(format!(
436 "expected integer or size string, got {other:?}"
437 ))),
438 }
439}
440
441fn deserialize_opt_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Duration>, D::Error> {
442 use serde::de::Error;
443 let v = Option::<toml::Value>::deserialize(d)?;
444 match v {
445 None | Some(toml::Value::String(_)) if matches!(&v, Some(toml::Value::String(s)) if s.is_empty()) => {
446 Ok(None)
447 }
448 None => Ok(None),
449 Some(toml::Value::Integer(n)) => Ok(Some(Duration::from_secs(n.max(0) as u64))),
450 Some(toml::Value::String(s)) => parse_duration(&s).map(Some).map_err(D::Error::custom),
451 Some(other) => Err(D::Error::custom(format!(
452 "expected integer (seconds) or duration string, got {other:?}"
453 ))),
454 }
455}
456
457pub fn parse_size(s: &str) -> Result<u64, String> {
459 let s = s.trim();
460 if s.is_empty() {
461 return Err("empty size".into());
462 }
463 if let Ok(n) = s.parse::<u64>() {
464 return Ok(n);
465 }
466 let (num_part, unit_part) = split_num_unit(s);
467 let num: f64 = num_part
468 .parse()
469 .map_err(|e| format!("invalid size number `{num_part}`: {e}"))?;
470 let mult: u64 = match unit_part.trim().to_ascii_uppercase().as_str() {
471 "" | "B" => 1,
472 "K" | "KB" | "KIB" => 1024,
473 "M" | "MB" | "MIB" => 1024 * 1024,
474 "G" | "GB" | "GIB" => 1024 * 1024 * 1024,
475 other => return Err(format!("unknown size unit `{other}`")),
476 };
477 Ok((num * mult as f64) as u64)
478}
479
480pub fn parse_duration(s: &str) -> Result<Duration, String> {
484 let s = s.trim();
485 if s.is_empty() {
486 return Err("empty duration".into());
487 }
488 if let Ok(n) = s.parse::<u64>() {
489 return Ok(Duration::from_secs(n));
490 }
491 let (num_part, unit_part) = split_num_unit(s);
492 let num: u64 = if num_part.is_empty() {
493 1
494 } else {
495 num_part
496 .parse()
497 .map_err(|e| format!("invalid duration number `{num_part}`: {e}"))?
498 };
499 let secs: u64 = match unit_part.trim().to_ascii_lowercase().as_str() {
500 "s" | "sec" | "secs" | "second" | "seconds" => num,
501 "m" | "min" | "mins" | "minute" | "minutes" => num * 60,
502 "h" | "hr" | "hrs" | "hour" | "hours" => num * 3600,
503 "d" | "day" | "days" => num * 86400,
504 "w" | "wk" | "wks" | "week" | "weeks" => num * 86400 * 7,
505 "mo" | "month" | "months" => num * 86400 * 30,
506 "y" | "yr" | "yrs" | "year" | "years" => num * 86400 * 365,
507 other => return Err(format!("unknown duration unit `{other}`")),
508 };
509 Ok(Duration::from_secs(secs))
510}
511
512fn split_num_unit(s: &str) -> (&str, &str) {
513 let split = s
514 .find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-')
515 .unwrap_or(s.len());
516 (s[..split].trim(), s[split..].trim())
517}
518
519pub fn parse_rate(s: &str) -> Result<(u32, Duration), String> {
521 let (count, window) = s
522 .split_once('/')
523 .ok_or_else(|| format!("rate must be `<count>/<window>`: {s}"))?;
524 let count: u32 = count
525 .trim()
526 .parse()
527 .map_err(|e| format!("invalid count `{count}`: {e}"))?;
528 let dur = parse_duration(window.trim())?;
529 Ok((count, dur))
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn parses_sizes() {
538 assert_eq!(parse_size("10").unwrap(), 10);
539 assert_eq!(parse_size("10KB").unwrap(), 10 * 1024);
540 assert_eq!(parse_size("2MB").unwrap(), 2 * 1024 * 1024);
541 assert_eq!(parse_size("1GB").unwrap(), 1024 * 1024 * 1024);
542 assert_eq!(parse_size("1.5MB").unwrap(), (1.5 * 1024.0 * 1024.0) as u64);
543 assert!(parse_size("bad").is_err());
544 }
545
546 #[test]
547 fn parses_durations() {
548 assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
549 assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
550 assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
551 assert_eq!(parse_duration("1d").unwrap(), Duration::from_secs(86400));
552 assert_eq!(
553 parse_duration("1y").unwrap(),
554 Duration::from_secs(86400 * 365)
555 );
556 assert_eq!(parse_duration("42").unwrap(), Duration::from_secs(42));
557 assert!(parse_duration("bad").is_err());
558 }
559
560 #[test]
561 fn parses_rates() {
562 let (count, win) = parse_rate("60/minute").unwrap();
563 assert_eq!(count, 60);
564 assert_eq!(win, Duration::from_secs(60));
565 let (count, win) = parse_rate("5/m").unwrap();
566 assert_eq!(count, 5);
567 assert_eq!(win, Duration::from_secs(60));
568 }
569
570 #[test]
571 fn loads_vhost_and_security_toml() {
572 let toml = r#"
573 bind = "0.0.0.0:443"
574 server_name = ["example.com", "www.example.com", "*.example.com"]
575
576 [tls]
577 cert = "/etc/cert.pem"
578 key = "/etc/key.pem"
579
580 [redirect_http]
581 bind = "0.0.0.0:80"
582 permanent = true
583 target_host = "example.com"
584
585 [hsts]
586 enabled = true
587 max_age = "1y"
588 include_subdomains = true
589 preload = false
590
591 [cors]
592 enabled = true
593 allow_origins = ["*"]
594 allow_credentials = false
595 max_age = "1h"
596
597 [[ip_rules]]
598 prefix = "/admin"
599 action = "allow"
600 ranges = ["10.0.0.0/8"]
601
602 [[basic_auth]]
603 prefix = "/admin"
604 realm = "Admin"
605 credentials = ["alice:secret", "bob:second"]
606 "#;
607 let cfg: ServerConfig = toml::from_str(toml).unwrap();
608 assert_eq!(
609 cfg.server_name,
610 vec!["example.com", "www.example.com", "*.example.com"]
611 );
612 assert!(cfg.redirect_http.is_some());
613 assert_eq!(
614 cfg.redirect_http.as_ref().unwrap().target_host.as_deref(),
615 Some("example.com")
616 );
617 assert!(cfg.hsts.enabled);
618 assert_eq!(cfg.hsts.max_age, Some(Duration::from_secs(86400 * 365)));
619 assert!(cfg.cors.enabled);
620 assert_eq!(cfg.ip_rules.len(), 1);
621 assert_eq!(cfg.basic_auth.len(), 1);
622 assert_eq!(cfg.basic_auth[0].credentials.len(), 2);
623 }
624
625 #[test]
626 fn loads_rewrites_and_proxies_toml() {
627 let toml = r#"
628 [[rewrites]]
629 from = "^/old/(.*)$"
630 to = "/new/$1"
631 status = 301
632
633 [[rewrites]]
634 from = "^/legacy/(.*)$"
635 to = "/v2/$1"
636
637 [trailing_slash]
638 mode = "always"
639 action = "redirect"
640
641 [error_pages]
642 404 = "errors/404.html"
643 500 = "errors/500.html"
644
645 [[proxy]]
646 prefix = "/api/v2"
647 upstream = "http://api-v2.internal:8080"
648 strip_prefix = true
649 timeout = "10s"
650 retries = 3
651 "#;
652 let cfg: ServerConfig = toml::from_str(toml).unwrap();
653 assert_eq!(cfg.rewrites.len(), 2);
654 assert_eq!(cfg.rewrites[0].status, Some(301));
655 assert!(cfg.rewrites[1].status.is_none());
656 assert_eq!(cfg.trailing_slash.mode, TrailingSlashMode::Always);
657 assert_eq!(cfg.trailing_slash.action, TrailingSlashAction::Redirect);
658 assert_eq!(cfg.error_pages.len(), 2);
659 assert!(cfg.error_pages.contains_key("404"));
660 assert_eq!(cfg.proxies.len(), 1);
661 assert_eq!(cfg.proxies[0].upstream, "http://api-v2.internal:8080");
662 assert_eq!(cfg.proxies[0].retries, 3);
663 assert_eq!(cfg.proxies[0].timeout, Some(Duration::from_secs(10)));
664 }
665
666 #[test]
667 fn loads_full_toml() {
668 let toml = r#"
669 bind = "0.0.0.0:443"
670
671 [tls]
672 cert = "/etc/letsencrypt/live/example.com/fullchain.pem"
673 key = "/etc/letsencrypt/live/example.com/privkey.pem"
674
675 [limits]
676 body_max = "10MB"
677 request_timeout = "30s"
678
679 [compression]
680 enabled = true
681 algorithms = ["gzip", "br"]
682 min_size = "1KB"
683
684 [static_files."/assets"]
685 dir = "public/build"
686 cache = "1y"
687
688 [rate_limit]
689 per_ip = "60/minute"
690
691 [rate_limit.routes]
692 "POST /login" = "5/minute"
693
694 [trusted_proxies]
695 ranges = ["10.0.0.0/8", "127.0.0.1/32"]
696
697 [access_log]
698 format = "json"
699 path = "storage/logs/access.log"
700 "#;
701 let cfg: ServerConfig = toml::from_str(toml).unwrap();
702 assert_eq!(cfg.bind, "0.0.0.0:443");
703 assert!(cfg.tls.is_some());
704 assert_eq!(cfg.limits.body_max, 10 * 1024 * 1024);
705 assert_eq!(cfg.limits.request_timeout, Some(Duration::from_secs(30)));
706 assert!(cfg.compression.enabled);
707 assert_eq!(cfg.compression.algorithms, vec!["gzip", "br"]);
708 assert_eq!(cfg.compression.min_size, 1024);
709 assert!(cfg.static_files.contains_key("/assets"));
710 assert_eq!(
711 cfg.static_files["/assets"].cache,
712 Some(Duration::from_secs(86400 * 365))
713 );
714 assert_eq!(cfg.rate_limit.per_ip.as_deref(), Some("60/minute"));
715 assert_eq!(
716 cfg.rate_limit.routes.get("POST /login").map(String::as_str),
717 Some("5/minute")
718 );
719 assert_eq!(cfg.trusted_proxies.ranges.len(), 2);
720 assert_eq!(cfg.access_log.format, AccessLogFormat::Json);
721 }
722}