Skip to main content

anvil_core/
server_config.rs

1//! Production HTTP serving configuration — the NGINX-equivalent surface.
2//!
3//! Apps can load this from `config/anvil.toml` via `ServerConfig::from_file`,
4//! or build it programmatically via the typed structs. Env vars override file
5//! values where applicable (Laravel-style precedence).
6
7use std::collections::BTreeMap;
8use std::path::PathBuf;
9use std::time::Duration;
10
11use serde::{Deserialize, Deserializer, Serialize};
12
13#[derive(Debug, Clone, Default, Serialize, Deserialize)]
14#[serde(default, deny_unknown_fields)]
15pub struct ServerConfig {
16    /// Bind address. Default: `127.0.0.1:8080` (set in `from_env`).
17    #[serde(default = "default_bind")]
18    pub bind: String,
19
20    /// Virtual host names this server answers to. Empty = match all hosts.
21    /// Supports wildcard prefixes: `"*.example.com"` matches any subdomain.
22    #[serde(default)]
23    pub server_name: Vec<String>,
24
25    /// Optional TLS config. If present, the server runs HTTPS.
26    pub tls: Option<TlsConfig>,
27
28    /// Optional HTTP-to-HTTPS auto-redirect listener. Typically binds :80 and
29    /// 301-redirects every request to the equivalent `https://` URL.
30    pub redirect_http: Option<RedirectHttpConfig>,
31
32    /// HTTP Strict Transport Security (HSTS) header. Off by default.
33    #[serde(default)]
34    pub hsts: HstsConfig,
35
36    /// Body/timeout limits.
37    #[serde(default)]
38    pub limits: LimitsConfig,
39
40    /// Compression layer config.
41    #[serde(default)]
42    pub compression: CompressionConfig,
43
44    /// Static file mounts — map of URL prefix → on-disk dir + cache policy.
45    #[serde(default)]
46    pub static_files: BTreeMap<String, StaticMount>,
47
48    /// Rate limiting rules.
49    #[serde(default)]
50    pub rate_limit: RateLimitConfig,
51
52    /// Trusted reverse-proxy ranges. Forwarded headers from outside these
53    /// CIDRs are ignored.
54    #[serde(default)]
55    pub trusted_proxies: TrustedProxiesConfig,
56
57    /// Per-route timeout overrides. Each entry matches by path prefix and
58    /// applies its own timeout to requests under that prefix, overriding
59    /// `limits.request_timeout`. First matching prefix wins. Useful for
60    /// slow endpoints (large uploads, long polls) that don't want the
61    /// global timeout raised.
62    #[serde(default, rename = "route_timeout")]
63    pub route_timeouts: Vec<RouteTimeoutRule>,
64
65    /// Access log config.
66    #[serde(default)]
67    pub access_log: AccessLogConfig,
68
69    /// URL rewrite rules (regex `from` → `to`, optionally as a redirect).
70    #[serde(default)]
71    pub rewrites: Vec<RewriteRule>,
72
73    /// Custom error pages — map of status code (as a string key) → file path.
74    #[serde(default)]
75    pub error_pages: BTreeMap<String, std::path::PathBuf>,
76
77    /// Trailing-slash policy.
78    #[serde(default)]
79    pub trailing_slash: TrailingSlashConfig,
80
81    /// Reverse-proxy rules — path prefix → upstream URL.
82    #[serde(default, rename = "proxy")]
83    pub proxies: Vec<ProxyRule>,
84
85    /// CORS configuration.
86    #[serde(default)]
87    pub cors: CorsConfig,
88
89    /// Path-prefixed IP allow/deny rules.
90    #[serde(default)]
91    pub ip_rules: Vec<IpRule>,
92
93    /// Path-prefixed HTTP Basic Auth blocks.
94    #[serde(default, rename = "basic_auth")]
95    pub basic_auth: Vec<BasicAuthRule>,
96}
97
98fn default_bind() -> String {
99    "127.0.0.1:8080".to_string()
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct TlsConfig {
104    pub cert: PathBuf,
105    pub key: PathBuf,
106
107    /// Optional ACME (Let's Encrypt) auto-cert. When present and the
108    /// framework was built with `--features acme`, the server obtains and
109    /// rotates certs in-process via TLS-ALPN-01 — `cert`/`key` above act
110    /// as the cache locations rather than pre-existing files.
111    #[serde(default)]
112    pub acme: Option<AcmeConfig>,
113
114    /// Additional `(server_name, cert, key)` triples for SNI-based cert
115    /// selection. The top-level `cert`/`key` above acts as the default
116    /// when no SNI hostname matches an entry here. Empty list = single-cert
117    /// behaviour (backward compatible).
118    ///
119    /// Example TOML:
120    ///
121    /// ```toml
122    /// [tls]
123    /// cert = "/etc/letsencrypt/live/example.com/fullchain.pem"
124    /// key  = "/etc/letsencrypt/live/example.com/privkey.pem"
125    ///
126    /// [[tls.certs]]
127    /// server_name = "api.example.com"
128    /// cert        = "/etc/letsencrypt/live/api.example.com/fullchain.pem"
129    /// key         = "/etc/letsencrypt/live/api.example.com/privkey.pem"
130    ///
131    /// [[tls.certs]]
132    /// server_name = "admin.example.com"
133    /// cert        = "/etc/letsencrypt/live/admin.example.com/fullchain.pem"
134    /// key         = "/etc/letsencrypt/live/admin.example.com/privkey.pem"
135    /// ```
136    #[serde(default, rename = "certs")]
137    pub additional_certs: Vec<SniCertEntry>,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(deny_unknown_fields)]
142pub struct AcmeConfig {
143    /// Domains to request certs for. The first is the primary SAN; the
144    /// rest are added as alternative names on the same cert.
145    pub domains: Vec<String>,
146
147    /// Contact email passed to the ACME directory (Let's Encrypt sends
148    /// expiry notices here). Optional but strongly recommended.
149    #[serde(default)]
150    pub contact: Option<String>,
151
152    /// On-disk cache directory for issued certs + the ACME account key.
153    /// Avoids hammering Let's Encrypt's rate limits on restart. Default
154    /// `./database/acme-cache` — created on first run.
155    #[serde(default = "default_acme_cache")]
156    pub cache_dir: PathBuf,
157
158    /// ACME directory URL. Default = Let's Encrypt production. Override
159    /// with `https://acme-staging-v02.api.letsencrypt.org/directory` for
160    /// staging while wiring this up — production has aggressive rate
161    /// limits on certificate issuance.
162    #[serde(default = "default_acme_directory")]
163    pub directory: String,
164}
165
166fn default_acme_cache() -> PathBuf {
167    PathBuf::from("./database/acme-cache")
168}
169
170fn default_acme_directory() -> String {
171    "https://acme-v02.api.letsencrypt.org/directory".to_string()
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
175#[serde(deny_unknown_fields)]
176pub struct SniCertEntry {
177    /// Hostname this cert should be served for. Wildcards work the
178    /// same as `server_name` matching: `*.example.com` matches any
179    /// subdomain but not the apex.
180    pub server_name: String,
181    pub cert: PathBuf,
182    pub key: PathBuf,
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
186#[serde(deny_unknown_fields)]
187pub struct RedirectHttpConfig {
188    /// Plain-HTTP listener address (typically `"0.0.0.0:80"`).
189    pub bind: String,
190
191    /// 301 (permanent) when `true`, 302 (temporary) when `false`. Default: 301.
192    #[serde(default = "yes")]
193    pub permanent: bool,
194
195    /// Target host for the redirect. If unset, the request's Host header is
196    /// reused (with the scheme flipped to `https`).
197    pub target_host: Option<String>,
198}
199
200#[derive(Debug, Clone, Default, Serialize, Deserialize)]
201#[serde(default, deny_unknown_fields)]
202pub struct HstsConfig {
203    pub enabled: bool,
204
205    /// `max-age=<seconds>`. Defaults to `1y` when HSTS is enabled.
206    #[serde(deserialize_with = "deserialize_opt_duration", default)]
207    pub max_age: Option<Duration>,
208
209    pub include_subdomains: bool,
210    pub preload: bool,
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214#[serde(default, deny_unknown_fields)]
215pub struct LimitsConfig {
216    /// Max request body size. Accepts `"10MB"`, `"500KB"`, `"2GB"`, raw byte count.
217    #[serde(deserialize_with = "deserialize_size", default = "default_body_max")]
218    pub body_max: u64,
219
220    /// Per-request timeout for the handler. `None` = no timeout.
221    #[serde(deserialize_with = "deserialize_opt_duration", default)]
222    pub request_timeout: Option<Duration>,
223
224    /// Graceful shutdown timeout — on `SIGTERM` the server stops accepting
225    /// new connections, then waits up to this long for in-flight requests
226    /// to complete before dropping them. Default 10s matches SystemD's
227    /// typical drain window. Raise for long-poll / large-upload workloads,
228    /// lower for fast-iteration deploys.
229    #[serde(
230        deserialize_with = "deserialize_duration",
231        default = "default_drain_timeout"
232    )]
233    pub drain_timeout: Duration,
234
235    /// Maximum concurrent in-flight requests across the whole server.
236    /// `None` = unlimited (Tokio's default). When set, requests above the
237    /// cap immediately return HTTP 503 instead of queueing — this protects
238    /// against thundering-herd overload and lets a load balancer steer
239    /// traffic to healthy peers.
240    ///
241    /// Set in `config/anvil.toml` as `[limits] max_concurrency = 500`.
242    /// Pick a value that matches the size of your DB pool × expected
243    /// per-request DB ops; over-provisioning here just shifts the
244    /// bottleneck elsewhere.
245    pub max_concurrency: Option<u32>,
246}
247
248impl Default for LimitsConfig {
249    fn default() -> Self {
250        Self {
251            body_max: default_body_max(),
252            request_timeout: None,
253            drain_timeout: default_drain_timeout(),
254            max_concurrency: None,
255        }
256    }
257}
258
259fn default_drain_timeout() -> Duration {
260    Duration::from_secs(10)
261}
262
263fn default_body_max() -> u64 {
264    16 * 1024 * 1024 // 16MB
265}
266
267#[derive(Debug, Clone, Serialize, Deserialize)]
268#[serde(default, deny_unknown_fields)]
269pub struct CompressionConfig {
270    /// Enable compression. Off by default — flip via config or env.
271    pub enabled: bool,
272
273    /// Algorithms to advertise via `Accept-Encoding` matching. Order matters.
274    /// Accepts `"gzip"`, `"br"`, `"deflate"`.
275    pub algorithms: Vec<String>,
276
277    /// Minimum response size (bytes) below which compression is skipped.
278    /// Accepts `"1KB"`, raw bytes.
279    #[serde(deserialize_with = "deserialize_size", default = "default_min_size")]
280    pub min_size: u64,
281}
282
283impl Default for CompressionConfig {
284    fn default() -> Self {
285        Self {
286            enabled: false,
287            algorithms: vec!["gzip".to_string()],
288            min_size: default_min_size(),
289        }
290    }
291}
292
293fn default_min_size() -> u64 {
294    1024
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
298#[serde(deny_unknown_fields)]
299pub struct StaticMount {
300    /// On-disk directory served at this URL prefix.
301    pub dir: PathBuf,
302
303    /// `Cache-Control: max-age=<seconds>` value. Accepts `"1y"`, `"30d"`, `"3600"`.
304    /// Default: no Cache-Control header is set.
305    #[serde(deserialize_with = "deserialize_opt_duration", default)]
306    pub cache: Option<Duration>,
307
308    /// Whether to enable byte-range requests (default: true).
309    #[serde(default = "yes")]
310    pub ranges: bool,
311}
312
313fn yes() -> bool {
314    true
315}
316
317#[derive(Debug, Clone, Default, Serialize, Deserialize)]
318#[serde(default, deny_unknown_fields)]
319pub struct RateLimitConfig {
320    /// Default per-IP rate (e.g. `"60/minute"`). `None` disables the default rate.
321    pub per_ip: Option<String>,
322
323    /// Per-route overrides: `{"POST /login" = "5/minute"}`.
324    #[serde(default)]
325    pub routes: BTreeMap<String, String>,
326}
327
328#[derive(Debug, Clone, Default, Serialize, Deserialize)]
329#[serde(default, deny_unknown_fields)]
330pub struct TrustedProxiesConfig {
331    /// CIDR ranges from which X-Forwarded-* headers will be honored.
332    pub ranges: Vec<ipnet::IpNet>,
333}
334
335#[derive(Debug, Clone, Serialize, Deserialize)]
336#[serde(deny_unknown_fields)]
337pub struct RouteTimeoutRule {
338    /// URL path prefix this rule matches against. E.g. `/api/v2/uploads`.
339    pub prefix: String,
340
341    /// Timeout applied to matching requests. Accepts `"30s"`, `"2m"`,
342    /// `"500ms"`, or a bare integer (seconds).
343    #[serde(deserialize_with = "deserialize_duration")]
344    pub timeout: Duration,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
348#[serde(deny_unknown_fields)]
349pub struct RewriteRule {
350    /// Regex applied to the request path (or full path+query, when `match_query` is true).
351    pub from: String,
352
353    /// Replacement template. Capture groups available as `$1`, `$2`, etc.
354    pub to: String,
355
356    /// HTTP status to return. `301`/`302`/`307`/`308` send a redirect. Any other
357    /// value (or unset) does an in-place internal rewrite — the request URI is
358    /// rewritten before reaching the handler.
359    #[serde(default)]
360    pub status: Option<u16>,
361
362    /// If true, the regex is applied to `path?query` instead of just `path`.
363    #[serde(default)]
364    pub match_query: bool,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
368#[serde(default, deny_unknown_fields)]
369pub struct TrailingSlashConfig {
370    /// `"always"` — append `/` to paths missing one (redirect or rewrite).
371    /// `"never"` — strip trailing `/`.
372    /// `"ignore"` (default) — leave alone.
373    pub mode: TrailingSlashMode,
374
375    /// `"redirect"` (default) returns a 301; `"rewrite"` modifies the URI in place.
376    pub action: TrailingSlashAction,
377}
378
379impl Default for TrailingSlashConfig {
380    fn default() -> Self {
381        Self {
382            mode: TrailingSlashMode::Ignore,
383            action: TrailingSlashAction::Redirect,
384        }
385    }
386}
387
388#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
389#[serde(rename_all = "lowercase")]
390pub enum TrailingSlashMode {
391    Always,
392    Never,
393    Ignore,
394}
395
396#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
397#[serde(rename_all = "lowercase")]
398pub enum TrailingSlashAction {
399    Redirect,
400    Rewrite,
401}
402
403#[derive(Debug, Clone, Default, Serialize, Deserialize)]
404#[serde(default, deny_unknown_fields)]
405pub struct CorsConfig {
406    pub enabled: bool,
407    /// Allowed origins. `["*"]` allows any. Default: empty.
408    pub allow_origins: Vec<String>,
409    /// Allowed methods. Default: `["GET", "POST", "OPTIONS"]` when enabled.
410    pub allow_methods: Vec<String>,
411    /// Allowed headers. Default: a reasonable set when enabled.
412    pub allow_headers: Vec<String>,
413    /// Expose these response headers to the JS layer.
414    pub expose_headers: Vec<String>,
415    /// Whether credentials (cookies, auth headers) are allowed cross-origin.
416    pub allow_credentials: bool,
417    /// `Access-Control-Max-Age` for preflight cache.
418    #[serde(deserialize_with = "deserialize_opt_duration", default)]
419    pub max_age: Option<Duration>,
420}
421
422#[derive(Debug, Clone, Serialize, Deserialize)]
423#[serde(deny_unknown_fields)]
424pub struct IpRule {
425    /// Path prefix this rule applies to.
426    pub prefix: String,
427    /// `"allow"` or `"deny"`.
428    pub action: IpAction,
429    /// CIDR ranges (or single IPs) covered by this rule.
430    pub ranges: Vec<ipnet::IpNet>,
431}
432
433#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
434#[serde(rename_all = "lowercase")]
435pub enum IpAction {
436    Allow,
437    Deny,
438}
439
440#[derive(Debug, Clone, Serialize, Deserialize)]
441#[serde(deny_unknown_fields)]
442pub struct BasicAuthRule {
443    pub prefix: String,
444    /// `realm` shown in the browser's auth prompt.
445    #[serde(default = "default_realm")]
446    pub realm: String,
447    /// Inline credentials as `user:password` pairs.
448    pub credentials: Vec<String>,
449}
450
451fn default_realm() -> String {
452    "Restricted".to_string()
453}
454
455#[derive(Debug, Clone, Serialize, Deserialize)]
456#[serde(deny_unknown_fields)]
457pub struct ProxyRule {
458    /// Path prefix that triggers this proxy (e.g. `"/api/v2"`).
459    pub prefix: String,
460
461    /// Upstream base URL (e.g. `"http://api-v2.internal:8080"`).
462    pub upstream: String,
463
464    /// Strip the prefix from the request path before forwarding. Default: false.
465    #[serde(default)]
466    pub strip_prefix: bool,
467
468    /// Keep the original Host header instead of using the upstream host. Default: false.
469    #[serde(default)]
470    pub preserve_host: bool,
471
472    /// Per-request timeout. Defaults to 30s.
473    #[serde(deserialize_with = "deserialize_opt_duration", default)]
474    pub timeout: Option<Duration>,
475
476    /// How many times to retry on connection failure. Default: 0.
477    #[serde(default)]
478    pub retries: u8,
479}
480
481#[derive(Debug, Clone, Serialize, Deserialize)]
482#[serde(default, deny_unknown_fields)]
483pub struct AccessLogConfig {
484    pub format: AccessLogFormat,
485    pub path: Option<PathBuf>,
486}
487
488impl Default for AccessLogConfig {
489    fn default() -> Self {
490        Self {
491            format: AccessLogFormat::Combined,
492            path: None,
493        }
494    }
495}
496
497#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
498#[serde(rename_all = "lowercase")]
499pub enum AccessLogFormat {
500    /// Apache "combined" format: `host - - [time] "method path proto" status bytes`
501    Combined,
502    /// Newline-delimited JSON, one object per request.
503    Json,
504    /// Off — only the framework's TraceLayer fires.
505    Off,
506}
507
508impl ServerConfig {
509    /// Load from `config/anvil.toml` if present, otherwise return defaults.
510    pub fn from_file_or_default(path: impl AsRef<std::path::Path>) -> Self {
511        match Self::from_file(path.as_ref()) {
512            Ok(c) => c,
513            Err(crate::Error::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => Self::default(),
514            Err(e) => {
515                tracing::warn!(?e, path = %path.as_ref().display(), "failed to load server config; using defaults");
516                Self::default()
517            }
518        }
519    }
520
521    pub fn from_file(path: &std::path::Path) -> crate::Result<Self> {
522        let bytes = std::fs::read_to_string(path)?;
523        let cfg: Self = toml::from_str(&bytes)
524            .map_err(|e| crate::Error::Config(format!("toml parse {}: {e}", path.display())))?;
525        Ok(cfg.apply_env_overrides())
526    }
527
528    /// Apply env-var overrides for the most common keys, mirroring Laravel's
529    /// `config(...)` + `.env` precedence.
530    pub fn apply_env_overrides(mut self) -> Self {
531        if let Ok(v) = std::env::var("APP_ADDR") {
532            self.bind = v;
533        }
534        if let (Ok(cert), Ok(key)) = (std::env::var("TLS_CERT"), std::env::var("TLS_KEY")) {
535            self.tls = Some(TlsConfig {
536                cert: PathBuf::from(cert),
537                key: PathBuf::from(key),
538                acme: None,
539                additional_certs: Vec::new(),
540            });
541        }
542        self
543    }
544}
545
546// ─── Helpers: parse human-readable sizes / durations ────────────────────────
547
548fn deserialize_size<'de, D: Deserializer<'de>>(d: D) -> Result<u64, D::Error> {
549    use serde::de::Error;
550    let v = toml::Value::deserialize(d)?;
551    match v {
552        toml::Value::Integer(n) => Ok(n.max(0) as u64),
553        toml::Value::String(s) => parse_size(&s).map_err(D::Error::custom),
554        other => Err(D::Error::custom(format!(
555            "expected integer or size string, got {other:?}"
556        ))),
557    }
558}
559
560fn deserialize_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
561    use serde::de::Error;
562    let v = toml::Value::deserialize(d)?;
563    match v {
564        toml::Value::Integer(n) => Ok(Duration::from_secs(n.max(0) as u64)),
565        toml::Value::String(s) => parse_duration(&s).map_err(D::Error::custom),
566        other => Err(D::Error::custom(format!(
567            "expected integer (seconds) or duration string, got {other:?}"
568        ))),
569    }
570}
571
572fn deserialize_opt_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Duration>, D::Error> {
573    use serde::de::Error;
574    let v = Option::<toml::Value>::deserialize(d)?;
575    match v {
576        None | Some(toml::Value::String(_)) if matches!(&v, Some(toml::Value::String(s)) if s.is_empty()) => {
577            Ok(None)
578        }
579        None => Ok(None),
580        Some(toml::Value::Integer(n)) => Ok(Some(Duration::from_secs(n.max(0) as u64))),
581        Some(toml::Value::String(s)) => parse_duration(&s).map(Some).map_err(D::Error::custom),
582        Some(other) => Err(D::Error::custom(format!(
583            "expected integer (seconds) or duration string, got {other:?}"
584        ))),
585    }
586}
587
588/// Parse `"10MB"`, `"500KB"`, `"2GB"`, or a bare integer (bytes).
589pub fn parse_size(s: &str) -> Result<u64, String> {
590    let s = s.trim();
591    if s.is_empty() {
592        return Err("empty size".into());
593    }
594    if let Ok(n) = s.parse::<u64>() {
595        return Ok(n);
596    }
597    let (num_part, unit_part) = split_num_unit(s);
598    let num: f64 = num_part
599        .parse()
600        .map_err(|e| format!("invalid size number `{num_part}`: {e}"))?;
601    let mult: u64 = match unit_part.trim().to_ascii_uppercase().as_str() {
602        "" | "B" => 1,
603        "K" | "KB" | "KIB" => 1024,
604        "M" | "MB" | "MIB" => 1024 * 1024,
605        "G" | "GB" | "GIB" => 1024 * 1024 * 1024,
606        other => return Err(format!("unknown size unit `{other}`")),
607    };
608    Ok((num * mult as f64) as u64)
609}
610
611/// Parse `"30s"`, `"5m"`, `"1h"`, `"1d"`, `"1y"`, or a bare integer (seconds).
612/// Bare unit strings like `"m"` (without a count) are interpreted as `"1m"` so
613/// rate-limit specs like `"5/m"` parse cleanly.
614pub fn parse_duration(s: &str) -> Result<Duration, String> {
615    let s = s.trim();
616    if s.is_empty() {
617        return Err("empty duration".into());
618    }
619    if let Ok(n) = s.parse::<u64>() {
620        return Ok(Duration::from_secs(n));
621    }
622    let (num_part, unit_part) = split_num_unit(s);
623    let num: u64 = if num_part.is_empty() {
624        1
625    } else {
626        num_part
627            .parse()
628            .map_err(|e| format!("invalid duration number `{num_part}`: {e}"))?
629    };
630    let secs: u64 = match unit_part.trim().to_ascii_lowercase().as_str() {
631        "s" | "sec" | "secs" | "second" | "seconds" => num,
632        "m" | "min" | "mins" | "minute" | "minutes" => num * 60,
633        "h" | "hr" | "hrs" | "hour" | "hours" => num * 3600,
634        "d" | "day" | "days" => num * 86400,
635        "w" | "wk" | "wks" | "week" | "weeks" => num * 86400 * 7,
636        "mo" | "month" | "months" => num * 86400 * 30,
637        "y" | "yr" | "yrs" | "year" | "years" => num * 86400 * 365,
638        other => return Err(format!("unknown duration unit `{other}`")),
639    };
640    Ok(Duration::from_secs(secs))
641}
642
643fn split_num_unit(s: &str) -> (&str, &str) {
644    let split = s
645        .find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-')
646        .unwrap_or(s.len());
647    (s[..split].trim(), s[split..].trim())
648}
649
650/// Parse `"60/minute"` → (count, window).
651pub fn parse_rate(s: &str) -> Result<(u32, Duration), String> {
652    let (count, window) = s
653        .split_once('/')
654        .ok_or_else(|| format!("rate must be `<count>/<window>`: {s}"))?;
655    let count: u32 = count
656        .trim()
657        .parse()
658        .map_err(|e| format!("invalid count `{count}`: {e}"))?;
659    let dur = parse_duration(window.trim())?;
660    Ok((count, dur))
661}
662
663#[cfg(test)]
664mod tests {
665    use super::*;
666
667    #[test]
668    fn parses_sizes() {
669        assert_eq!(parse_size("10").unwrap(), 10);
670        assert_eq!(parse_size("10KB").unwrap(), 10 * 1024);
671        assert_eq!(parse_size("2MB").unwrap(), 2 * 1024 * 1024);
672        assert_eq!(parse_size("1GB").unwrap(), 1024 * 1024 * 1024);
673        assert_eq!(parse_size("1.5MB").unwrap(), (1.5 * 1024.0 * 1024.0) as u64);
674        assert!(parse_size("bad").is_err());
675    }
676
677    #[test]
678    fn parses_durations() {
679        assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
680        assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
681        assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
682        assert_eq!(parse_duration("1d").unwrap(), Duration::from_secs(86400));
683        assert_eq!(
684            parse_duration("1y").unwrap(),
685            Duration::from_secs(86400 * 365)
686        );
687        assert_eq!(parse_duration("42").unwrap(), Duration::from_secs(42));
688        assert!(parse_duration("bad").is_err());
689    }
690
691    #[test]
692    fn parses_rates() {
693        let (count, win) = parse_rate("60/minute").unwrap();
694        assert_eq!(count, 60);
695        assert_eq!(win, Duration::from_secs(60));
696        let (count, win) = parse_rate("5/m").unwrap();
697        assert_eq!(count, 5);
698        assert_eq!(win, Duration::from_secs(60));
699    }
700
701    #[test]
702    fn loads_vhost_and_security_toml() {
703        let toml = r#"
704            bind = "0.0.0.0:443"
705            server_name = ["example.com", "www.example.com", "*.example.com"]
706
707            [tls]
708            cert = "/etc/cert.pem"
709            key  = "/etc/key.pem"
710
711            [redirect_http]
712            bind = "0.0.0.0:80"
713            permanent = true
714            target_host = "example.com"
715
716            [hsts]
717            enabled = true
718            max_age = "1y"
719            include_subdomains = true
720            preload = false
721
722            [cors]
723            enabled = true
724            allow_origins = ["*"]
725            allow_credentials = false
726            max_age = "1h"
727
728            [[ip_rules]]
729            prefix = "/admin"
730            action = "allow"
731            ranges = ["10.0.0.0/8"]
732
733            [[basic_auth]]
734            prefix = "/admin"
735            realm = "Admin"
736            credentials = ["alice:secret", "bob:second"]
737        "#;
738        let cfg: ServerConfig = toml::from_str(toml).unwrap();
739        assert_eq!(
740            cfg.server_name,
741            vec!["example.com", "www.example.com", "*.example.com"]
742        );
743        assert!(cfg.redirect_http.is_some());
744        assert_eq!(
745            cfg.redirect_http.as_ref().unwrap().target_host.as_deref(),
746            Some("example.com")
747        );
748        assert!(cfg.hsts.enabled);
749        assert_eq!(cfg.hsts.max_age, Some(Duration::from_secs(86400 * 365)));
750        assert!(cfg.cors.enabled);
751        assert_eq!(cfg.ip_rules.len(), 1);
752        assert_eq!(cfg.basic_auth.len(), 1);
753        assert_eq!(cfg.basic_auth[0].credentials.len(), 2);
754    }
755
756    #[test]
757    fn loads_rewrites_and_proxies_toml() {
758        let toml = r#"
759            [[rewrites]]
760            from = "^/old/(.*)$"
761            to = "/new/$1"
762            status = 301
763
764            [[rewrites]]
765            from = "^/legacy/(.*)$"
766            to = "/v2/$1"
767
768            [trailing_slash]
769            mode = "always"
770            action = "redirect"
771
772            [error_pages]
773            404 = "errors/404.html"
774            500 = "errors/500.html"
775
776            [[proxy]]
777            prefix = "/api/v2"
778            upstream = "http://api-v2.internal:8080"
779            strip_prefix = true
780            timeout = "10s"
781            retries = 3
782        "#;
783        let cfg: ServerConfig = toml::from_str(toml).unwrap();
784        assert_eq!(cfg.rewrites.len(), 2);
785        assert_eq!(cfg.rewrites[0].status, Some(301));
786        assert!(cfg.rewrites[1].status.is_none());
787        assert_eq!(cfg.trailing_slash.mode, TrailingSlashMode::Always);
788        assert_eq!(cfg.trailing_slash.action, TrailingSlashAction::Redirect);
789        assert_eq!(cfg.error_pages.len(), 2);
790        assert!(cfg.error_pages.contains_key("404"));
791        assert_eq!(cfg.proxies.len(), 1);
792        assert_eq!(cfg.proxies[0].upstream, "http://api-v2.internal:8080");
793        assert_eq!(cfg.proxies[0].retries, 3);
794        assert_eq!(cfg.proxies[0].timeout, Some(Duration::from_secs(10)));
795    }
796
797    #[test]
798    fn loads_full_toml() {
799        let toml = r#"
800            bind = "0.0.0.0:443"
801
802            [tls]
803            cert = "/etc/letsencrypt/live/example.com/fullchain.pem"
804            key = "/etc/letsencrypt/live/example.com/privkey.pem"
805
806            [limits]
807            body_max = "10MB"
808            request_timeout = "30s"
809
810            [compression]
811            enabled = true
812            algorithms = ["gzip", "br"]
813            min_size = "1KB"
814
815            [static_files."/assets"]
816            dir = "public/build"
817            cache = "1y"
818
819            [rate_limit]
820            per_ip = "60/minute"
821
822            [rate_limit.routes]
823            "POST /login" = "5/minute"
824
825            [trusted_proxies]
826            ranges = ["10.0.0.0/8", "127.0.0.1/32"]
827
828            [access_log]
829            format = "json"
830            path = "storage/logs/access.log"
831        "#;
832        let cfg: ServerConfig = toml::from_str(toml).unwrap();
833        assert_eq!(cfg.bind, "0.0.0.0:443");
834        assert!(cfg.tls.is_some());
835        assert_eq!(cfg.limits.body_max, 10 * 1024 * 1024);
836        assert_eq!(cfg.limits.request_timeout, Some(Duration::from_secs(30)));
837        assert!(cfg.compression.enabled);
838        assert_eq!(cfg.compression.algorithms, vec!["gzip", "br"]);
839        assert_eq!(cfg.compression.min_size, 1024);
840        assert!(cfg.static_files.contains_key("/assets"));
841        assert_eq!(
842            cfg.static_files["/assets"].cache,
843            Some(Duration::from_secs(86400 * 365))
844        );
845        assert_eq!(cfg.rate_limit.per_ip.as_deref(), Some("60/minute"));
846        assert_eq!(
847            cfg.rate_limit.routes.get("POST /login").map(String::as_str),
848            Some("5/minute")
849        );
850        assert_eq!(cfg.trusted_proxies.ranges.len(), 2);
851        assert_eq!(cfg.access_log.format, AccessLogFormat::Json);
852    }
853}