use std::collections::BTreeMap;
use std::path::PathBuf;
use std::time::Duration;
use serde::{Deserialize, Deserializer, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct ServerConfig {
#[serde(default = "default_bind")]
pub bind: String,
#[serde(default)]
pub server_name: Vec<String>,
pub tls: Option<TlsConfig>,
pub redirect_http: Option<RedirectHttpConfig>,
#[serde(default)]
pub hsts: HstsConfig,
#[serde(default)]
pub limits: LimitsConfig,
#[serde(default)]
pub compression: CompressionConfig,
#[serde(default)]
pub static_files: BTreeMap<String, StaticMount>,
#[serde(default)]
pub rate_limit: RateLimitConfig,
#[serde(default)]
pub trusted_proxies: TrustedProxiesConfig,
#[serde(default, rename = "route_timeout")]
pub route_timeouts: Vec<RouteTimeoutRule>,
#[serde(default)]
pub access_log: AccessLogConfig,
#[serde(default)]
pub rewrites: Vec<RewriteRule>,
#[serde(default)]
pub error_pages: BTreeMap<String, std::path::PathBuf>,
#[serde(default)]
pub trailing_slash: TrailingSlashConfig,
#[serde(default, rename = "proxy")]
pub proxies: Vec<ProxyRule>,
#[serde(default)]
pub cors: CorsConfig,
#[serde(default)]
pub ip_rules: Vec<IpRule>,
#[serde(default, rename = "basic_auth")]
pub basic_auth: Vec<BasicAuthRule>,
}
fn default_bind() -> String {
"127.0.0.1:8080".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub cert: PathBuf,
pub key: PathBuf,
#[serde(default)]
pub acme: Option<AcmeConfig>,
#[serde(default, rename = "certs")]
pub additional_certs: Vec<SniCertEntry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct AcmeConfig {
pub domains: Vec<String>,
#[serde(default)]
pub contact: Option<String>,
#[serde(default = "default_acme_cache")]
pub cache_dir: PathBuf,
#[serde(default = "default_acme_directory")]
pub directory: String,
}
fn default_acme_cache() -> PathBuf {
PathBuf::from("./database/acme-cache")
}
fn default_acme_directory() -> String {
"https://acme-v02.api.letsencrypt.org/directory".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct SniCertEntry {
pub server_name: String,
pub cert: PathBuf,
pub key: PathBuf,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RedirectHttpConfig {
pub bind: String,
#[serde(default = "yes")]
pub permanent: bool,
pub target_host: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct HstsConfig {
pub enabled: bool,
#[serde(deserialize_with = "deserialize_opt_duration", default)]
pub max_age: Option<Duration>,
pub include_subdomains: bool,
pub preload: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct LimitsConfig {
#[serde(deserialize_with = "deserialize_size", default = "default_body_max")]
pub body_max: u64,
#[serde(deserialize_with = "deserialize_opt_duration", default)]
pub request_timeout: Option<Duration>,
#[serde(
deserialize_with = "deserialize_duration",
default = "default_drain_timeout"
)]
pub drain_timeout: Duration,
pub max_concurrency: Option<u32>,
}
impl Default for LimitsConfig {
fn default() -> Self {
Self {
body_max: default_body_max(),
request_timeout: None,
drain_timeout: default_drain_timeout(),
max_concurrency: None,
}
}
}
fn default_drain_timeout() -> Duration {
Duration::from_secs(10)
}
fn default_body_max() -> u64 {
16 * 1024 * 1024 }
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct CompressionConfig {
pub enabled: bool,
pub algorithms: Vec<String>,
#[serde(deserialize_with = "deserialize_size", default = "default_min_size")]
pub min_size: u64,
}
impl Default for CompressionConfig {
fn default() -> Self {
Self {
enabled: false,
algorithms: vec!["gzip".to_string()],
min_size: default_min_size(),
}
}
}
fn default_min_size() -> u64 {
1024
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct StaticMount {
pub dir: PathBuf,
#[serde(deserialize_with = "deserialize_opt_duration", default)]
pub cache: Option<Duration>,
#[serde(default = "yes")]
pub ranges: bool,
}
fn yes() -> bool {
true
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct RateLimitConfig {
pub per_ip: Option<String>,
#[serde(default)]
pub routes: BTreeMap<String, String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct TrustedProxiesConfig {
pub ranges: Vec<ipnet::IpNet>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RouteTimeoutRule {
pub prefix: String,
#[serde(deserialize_with = "deserialize_duration")]
pub timeout: Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct RewriteRule {
pub from: String,
pub to: String,
#[serde(default)]
pub status: Option<u16>,
#[serde(default)]
pub match_query: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct TrailingSlashConfig {
pub mode: TrailingSlashMode,
pub action: TrailingSlashAction,
}
impl Default for TrailingSlashConfig {
fn default() -> Self {
Self {
mode: TrailingSlashMode::Ignore,
action: TrailingSlashAction::Redirect,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum TrailingSlashMode {
Always,
Never,
Ignore,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum TrailingSlashAction {
Redirect,
Rewrite,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct CorsConfig {
pub enabled: bool,
pub allow_origins: Vec<String>,
pub allow_methods: Vec<String>,
pub allow_headers: Vec<String>,
pub expose_headers: Vec<String>,
pub allow_credentials: bool,
#[serde(deserialize_with = "deserialize_opt_duration", default)]
pub max_age: Option<Duration>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct IpRule {
pub prefix: String,
pub action: IpAction,
pub ranges: Vec<ipnet::IpNet>,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum IpAction {
Allow,
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct BasicAuthRule {
pub prefix: String,
#[serde(default = "default_realm")]
pub realm: String,
pub credentials: Vec<String>,
}
fn default_realm() -> String {
"Restricted".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProxyRule {
pub prefix: String,
pub upstream: String,
#[serde(default)]
pub strip_prefix: bool,
#[serde(default)]
pub preserve_host: bool,
#[serde(deserialize_with = "deserialize_opt_duration", default)]
pub timeout: Option<Duration>,
#[serde(default)]
pub retries: u8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default, deny_unknown_fields)]
pub struct AccessLogConfig {
pub format: AccessLogFormat,
pub path: Option<PathBuf>,
}
impl Default for AccessLogConfig {
fn default() -> Self {
Self {
format: AccessLogFormat::Combined,
path: None,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum AccessLogFormat {
Combined,
Json,
Off,
}
impl ServerConfig {
pub fn from_file_or_default(path: impl AsRef<std::path::Path>) -> Self {
match Self::from_file(path.as_ref()) {
Ok(c) => c,
Err(crate::Error::Io(e)) if e.kind() == std::io::ErrorKind::NotFound => Self::default(),
Err(e) => {
tracing::warn!(?e, path = %path.as_ref().display(), "failed to load server config; using defaults");
Self::default()
}
}
}
pub fn from_file(path: &std::path::Path) -> crate::Result<Self> {
let bytes = std::fs::read_to_string(path)?;
let cfg: Self = toml::from_str(&bytes)
.map_err(|e| crate::Error::Config(format!("toml parse {}: {e}", path.display())))?;
Ok(cfg.apply_env_overrides())
}
pub fn apply_env_overrides(mut self) -> Self {
if let Ok(v) = std::env::var("APP_ADDR") {
self.bind = v;
}
if let (Ok(cert), Ok(key)) = (std::env::var("TLS_CERT"), std::env::var("TLS_KEY")) {
self.tls = Some(TlsConfig {
cert: PathBuf::from(cert),
key: PathBuf::from(key),
acme: None,
additional_certs: Vec::new(),
});
}
self
}
}
fn deserialize_size<'de, D: Deserializer<'de>>(d: D) -> Result<u64, D::Error> {
use serde::de::Error;
let v = toml::Value::deserialize(d)?;
match v {
toml::Value::Integer(n) => Ok(n.max(0) as u64),
toml::Value::String(s) => parse_size(&s).map_err(D::Error::custom),
other => Err(D::Error::custom(format!(
"expected integer or size string, got {other:?}"
))),
}
}
fn deserialize_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Duration, D::Error> {
use serde::de::Error;
let v = toml::Value::deserialize(d)?;
match v {
toml::Value::Integer(n) => Ok(Duration::from_secs(n.max(0) as u64)),
toml::Value::String(s) => parse_duration(&s).map_err(D::Error::custom),
other => Err(D::Error::custom(format!(
"expected integer (seconds) or duration string, got {other:?}"
))),
}
}
fn deserialize_opt_duration<'de, D: Deserializer<'de>>(d: D) -> Result<Option<Duration>, D::Error> {
use serde::de::Error;
let v = Option::<toml::Value>::deserialize(d)?;
match v {
None | Some(toml::Value::String(_)) if matches!(&v, Some(toml::Value::String(s)) if s.is_empty()) => {
Ok(None)
}
None => Ok(None),
Some(toml::Value::Integer(n)) => Ok(Some(Duration::from_secs(n.max(0) as u64))),
Some(toml::Value::String(s)) => parse_duration(&s).map(Some).map_err(D::Error::custom),
Some(other) => Err(D::Error::custom(format!(
"expected integer (seconds) or duration string, got {other:?}"
))),
}
}
pub fn parse_size(s: &str) -> Result<u64, String> {
let s = s.trim();
if s.is_empty() {
return Err("empty size".into());
}
if let Ok(n) = s.parse::<u64>() {
return Ok(n);
}
let (num_part, unit_part) = split_num_unit(s);
let num: f64 = num_part
.parse()
.map_err(|e| format!("invalid size number `{num_part}`: {e}"))?;
let mult: u64 = match unit_part.trim().to_ascii_uppercase().as_str() {
"" | "B" => 1,
"K" | "KB" | "KIB" => 1024,
"M" | "MB" | "MIB" => 1024 * 1024,
"G" | "GB" | "GIB" => 1024 * 1024 * 1024,
other => return Err(format!("unknown size unit `{other}`")),
};
Ok((num * mult as f64) as u64)
}
pub fn parse_duration(s: &str) -> Result<Duration, String> {
let s = s.trim();
if s.is_empty() {
return Err("empty duration".into());
}
if let Ok(n) = s.parse::<u64>() {
return Ok(Duration::from_secs(n));
}
let (num_part, unit_part) = split_num_unit(s);
let num: u64 = if num_part.is_empty() {
1
} else {
num_part
.parse()
.map_err(|e| format!("invalid duration number `{num_part}`: {e}"))?
};
let secs: u64 = match unit_part.trim().to_ascii_lowercase().as_str() {
"s" | "sec" | "secs" | "second" | "seconds" => num,
"m" | "min" | "mins" | "minute" | "minutes" => num * 60,
"h" | "hr" | "hrs" | "hour" | "hours" => num * 3600,
"d" | "day" | "days" => num * 86400,
"w" | "wk" | "wks" | "week" | "weeks" => num * 86400 * 7,
"mo" | "month" | "months" => num * 86400 * 30,
"y" | "yr" | "yrs" | "year" | "years" => num * 86400 * 365,
other => return Err(format!("unknown duration unit `{other}`")),
};
Ok(Duration::from_secs(secs))
}
fn split_num_unit(s: &str) -> (&str, &str) {
let split = s
.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-')
.unwrap_or(s.len());
(s[..split].trim(), s[split..].trim())
}
pub fn parse_rate(s: &str) -> Result<(u32, Duration), String> {
let (count, window) = s
.split_once('/')
.ok_or_else(|| format!("rate must be `<count>/<window>`: {s}"))?;
let count: u32 = count
.trim()
.parse()
.map_err(|e| format!("invalid count `{count}`: {e}"))?;
let dur = parse_duration(window.trim())?;
Ok((count, dur))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_sizes() {
assert_eq!(parse_size("10").unwrap(), 10);
assert_eq!(parse_size("10KB").unwrap(), 10 * 1024);
assert_eq!(parse_size("2MB").unwrap(), 2 * 1024 * 1024);
assert_eq!(parse_size("1GB").unwrap(), 1024 * 1024 * 1024);
assert_eq!(parse_size("1.5MB").unwrap(), (1.5 * 1024.0 * 1024.0) as u64);
assert!(parse_size("bad").is_err());
}
#[test]
fn parses_durations() {
assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
assert_eq!(parse_duration("1h").unwrap(), Duration::from_secs(3600));
assert_eq!(parse_duration("1d").unwrap(), Duration::from_secs(86400));
assert_eq!(
parse_duration("1y").unwrap(),
Duration::from_secs(86400 * 365)
);
assert_eq!(parse_duration("42").unwrap(), Duration::from_secs(42));
assert!(parse_duration("bad").is_err());
}
#[test]
fn parses_rates() {
let (count, win) = parse_rate("60/minute").unwrap();
assert_eq!(count, 60);
assert_eq!(win, Duration::from_secs(60));
let (count, win) = parse_rate("5/m").unwrap();
assert_eq!(count, 5);
assert_eq!(win, Duration::from_secs(60));
}
#[test]
fn loads_vhost_and_security_toml() {
let toml = r#"
bind = "0.0.0.0:443"
server_name = ["example.com", "www.example.com", "*.example.com"]
[tls]
cert = "/etc/cert.pem"
key = "/etc/key.pem"
[redirect_http]
bind = "0.0.0.0:80"
permanent = true
target_host = "example.com"
[hsts]
enabled = true
max_age = "1y"
include_subdomains = true
preload = false
[cors]
enabled = true
allow_origins = ["*"]
allow_credentials = false
max_age = "1h"
[[ip_rules]]
prefix = "/admin"
action = "allow"
ranges = ["10.0.0.0/8"]
[[basic_auth]]
prefix = "/admin"
realm = "Admin"
credentials = ["alice:secret", "bob:second"]
"#;
let cfg: ServerConfig = toml::from_str(toml).unwrap();
assert_eq!(
cfg.server_name,
vec!["example.com", "www.example.com", "*.example.com"]
);
assert!(cfg.redirect_http.is_some());
assert_eq!(
cfg.redirect_http.as_ref().unwrap().target_host.as_deref(),
Some("example.com")
);
assert!(cfg.hsts.enabled);
assert_eq!(cfg.hsts.max_age, Some(Duration::from_secs(86400 * 365)));
assert!(cfg.cors.enabled);
assert_eq!(cfg.ip_rules.len(), 1);
assert_eq!(cfg.basic_auth.len(), 1);
assert_eq!(cfg.basic_auth[0].credentials.len(), 2);
}
#[test]
fn loads_rewrites_and_proxies_toml() {
let toml = r#"
[[rewrites]]
from = "^/old/(.*)$"
to = "/new/$1"
status = 301
[[rewrites]]
from = "^/legacy/(.*)$"
to = "/v2/$1"
[trailing_slash]
mode = "always"
action = "redirect"
[error_pages]
404 = "errors/404.html"
500 = "errors/500.html"
[[proxy]]
prefix = "/api/v2"
upstream = "http://api-v2.internal:8080"
strip_prefix = true
timeout = "10s"
retries = 3
"#;
let cfg: ServerConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.rewrites.len(), 2);
assert_eq!(cfg.rewrites[0].status, Some(301));
assert!(cfg.rewrites[1].status.is_none());
assert_eq!(cfg.trailing_slash.mode, TrailingSlashMode::Always);
assert_eq!(cfg.trailing_slash.action, TrailingSlashAction::Redirect);
assert_eq!(cfg.error_pages.len(), 2);
assert!(cfg.error_pages.contains_key("404"));
assert_eq!(cfg.proxies.len(), 1);
assert_eq!(cfg.proxies[0].upstream, "http://api-v2.internal:8080");
assert_eq!(cfg.proxies[0].retries, 3);
assert_eq!(cfg.proxies[0].timeout, Some(Duration::from_secs(10)));
}
#[test]
fn loads_full_toml() {
let toml = r#"
bind = "0.0.0.0:443"
[tls]
cert = "/etc/letsencrypt/live/example.com/fullchain.pem"
key = "/etc/letsencrypt/live/example.com/privkey.pem"
[limits]
body_max = "10MB"
request_timeout = "30s"
[compression]
enabled = true
algorithms = ["gzip", "br"]
min_size = "1KB"
[static_files."/assets"]
dir = "public/build"
cache = "1y"
[rate_limit]
per_ip = "60/minute"
[rate_limit.routes]
"POST /login" = "5/minute"
[trusted_proxies]
ranges = ["10.0.0.0/8", "127.0.0.1/32"]
[access_log]
format = "json"
path = "storage/logs/access.log"
"#;
let cfg: ServerConfig = toml::from_str(toml).unwrap();
assert_eq!(cfg.bind, "0.0.0.0:443");
assert!(cfg.tls.is_some());
assert_eq!(cfg.limits.body_max, 10 * 1024 * 1024);
assert_eq!(cfg.limits.request_timeout, Some(Duration::from_secs(30)));
assert!(cfg.compression.enabled);
assert_eq!(cfg.compression.algorithms, vec!["gzip", "br"]);
assert_eq!(cfg.compression.min_size, 1024);
assert!(cfg.static_files.contains_key("/assets"));
assert_eq!(
cfg.static_files["/assets"].cache,
Some(Duration::from_secs(86400 * 365))
);
assert_eq!(cfg.rate_limit.per_ip.as_deref(), Some("60/minute"));
assert_eq!(
cfg.rate_limit.routes.get("POST /login").map(String::as_str),
Some("5/minute")
);
assert_eq!(cfg.trusted_proxies.ranges.len(), 2);
assert_eq!(cfg.access_log.format, AccessLogFormat::Json);
}
}