use std::path::PathBuf;
use serde::Deserialize;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum OneOrMany {
One(String),
Many(Vec<String>),
}
impl OneOrMany {
fn into_vec(self) -> Vec<String> {
match self {
OneOrMany::One(s) => vec![s],
OneOrMany::Many(v) => v,
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct TlsConfig {
pub cert: Option<String>,
pub key: Option<String>,
pub self_signed: Option<Vec<String>>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct CompressConfig {
#[serde(default = "yes")]
pub enabled: bool,
#[serde(default = "default_min_size")]
pub min_size: usize,
}
fn yes() -> bool {
true
}
fn default_min_size() -> usize {
256
}
impl Default for CompressConfig {
fn default() -> CompressConfig {
CompressConfig {
enabled: true,
min_size: default_min_size(),
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct AcmeFileConfig {
#[serde(default)]
pub accept_tos: bool,
pub email: Option<String>,
pub directory: Option<String>,
#[serde(default)]
pub staging: bool,
pub host_whitelist: Option<Vec<String>>,
pub cert_dir: Option<PathBuf>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct HstsConfig {
#[serde(default = "default_hsts_max_age")]
pub max_age: u64,
#[serde(default)]
pub include_subdomains: bool,
#[serde(default)]
pub preload: bool,
}
fn default_hsts_max_age() -> u64 {
31_536_000
}
impl HstsConfig {
pub fn header_value(&self) -> String {
let mut v = format!("max-age={}", self.max_age);
if self.include_subdomains {
v.push_str("; includeSubDomains");
}
if self.preload {
v.push_str("; preload");
}
v
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ServerConfig {
listen: OneOrMany,
pub root: Option<PathBuf>,
pub server_name: Option<String>,
pub workers: Option<usize>,
pub tls: Option<TlsConfig>,
pub compress: Option<CompressConfig>,
#[serde(default)]
pub allow_http: bool,
http_listen: Option<OneOrMany>,
pub acme: Option<AcmeFileConfig>,
pub hsts: Option<HstsConfig>,
}
impl ServerConfig {
pub fn from_toml_str(s: &str) -> Result<ServerConfig> {
toml::from_str(s).map_err(|e| Error::Config(e.to_string()))
}
pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<ServerConfig> {
let text = std::fs::read_to_string(path.as_ref())
.map_err(|e| Error::Config(format!("reading {}: {e}", path.as_ref().display())))?;
ServerConfig::from_toml_str(&text)
}
pub fn listen_addrs(&self) -> Vec<String> {
self.listen.clone().into_vec()
}
#[cfg(any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio"))]
pub fn into_server(self) -> Result<crate::rt::Server> {
let addrs = self.listen_addrs();
let first = addrs
.first()
.ok_or_else(|| Error::Config("no listen address".into()))?;
let mut server = crate::rt::Server::bind(first.as_str())?;
if let Some(root) = &self.root {
server = server.serve_dir(root.clone());
}
if let Some(workers) = self.workers {
server = server.workers(workers);
}
if self.server_name.is_some() {
server = server.server_name(self.server_name.clone());
}
if let Some(hsts) = &self.hsts {
server = server.hsts(Some(hsts.header_value()));
}
if self.allow_http {
server = server.allow_http(true);
}
if let Some(http) = &self.http_listen {
use std::net::ToSocketAddrs;
let mut resolved = Vec::new();
for a in http.clone().into_vec() {
resolved.extend(a.to_socket_addrs()?);
}
server = server.http_redirect(resolved.as_slice())?;
}
server = self.apply_tls(server)?;
server = self.apply_compress(server);
server = self.apply_acme(server)?;
Ok(server)
}
#[cfg(all(
feature = "acme",
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_acme(&self, server: crate::rt::Server) -> Result<crate::rt::Server> {
let Some(acme) = &self.acme else {
return Ok(server);
};
let directory = if acme.staging {
crate::acme::client::LETSENCRYPT_STAGING.to_owned()
} else {
acme.directory
.clone()
.unwrap_or_else(|| crate::acme::client::LETSENCRYPT_PRODUCTION.to_owned())
};
let whitelist = acme.host_whitelist.as_ref().map(|hosts| {
hosts
.iter()
.map(|h| h.trim().trim_end_matches('.').to_ascii_lowercase())
.collect()
});
let cfg = crate::acme::AcmeConfig {
directory_url: directory,
accept_tos: acme.accept_tos,
email: acme.email.clone(),
host_whitelist: whitelist,
cert_dir: acme.cert_dir.clone(),
};
Ok(server.acme(crate::acme::AcmeManager::new(cfg)?))
}
#[cfg(all(
not(feature = "acme"),
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_acme(&self, server: crate::rt::Server) -> Result<crate::rt::Server> {
if self.acme.is_some() {
return Err(Error::Config(
"[acme] configured but the `acme` feature is not enabled".into(),
));
}
Ok(server)
}
#[cfg(all(
feature = "tls",
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_tls(&self, server: crate::rt::Server) -> Result<crate::rt::Server> {
let Some(tls) = &self.tls else {
return Ok(server);
};
let acceptor = match (&tls.cert, &tls.key, &tls.self_signed) {
(Some(cert), Some(key), _) => crate::tls::TlsAcceptor::from_pem_files(cert, key)?,
(_, _, Some(names)) => {
let refs: Vec<&str> = names.iter().map(String::as_str).collect();
crate::tls::TlsAcceptor::self_signed(&refs)?
}
_ => {
return Err(Error::Config(
"[tls] requires either cert+key or self_signed".into(),
));
}
};
Ok(server.tls(acceptor))
}
#[cfg(all(
not(feature = "tls"),
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_tls(&self, server: crate::rt::Server) -> Result<crate::rt::Server> {
if self.tls.is_some() {
return Err(Error::Config(
"[tls] configured but the `tls` feature is not enabled".into(),
));
}
Ok(server)
}
#[cfg(all(
feature = "compress",
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_compress(&self, server: crate::rt::Server) -> crate::rt::Server {
let c = self.compress.clone().unwrap_or_default();
server.compression(crate::compress::Options {
enabled: c.enabled,
min_size: c.min_size,
})
}
#[cfg(all(
not(feature = "compress"),
any(feature = "rt-threadpool", feature = "rt-tokio", feature = "rt-mio")
))]
fn apply_compress(&self, server: crate::rt::Server) -> crate::rt::Server {
server
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_minimal() {
let cfg =
ServerConfig::from_toml_str("listen = \"127.0.0.1:8080\"\nroot = \"/srv\"\n").unwrap();
assert_eq!(cfg.listen_addrs(), vec!["127.0.0.1:8080"]);
assert_eq!(cfg.root, Some(PathBuf::from("/srv")));
}
#[test]
fn parses_full() {
let toml = r#"
listen = ["127.0.0.1:8443", "[::1]:8443"]
root = "/var/www"
workers = 16
[tls]
self_signed = ["localhost"]
[compress]
enabled = false
min_size = 1024
"#;
let cfg = ServerConfig::from_toml_str(toml).unwrap();
assert_eq!(cfg.listen_addrs().len(), 2);
assert_eq!(cfg.workers, Some(16));
assert!(cfg.tls.is_some());
assert!(!cfg.compress.as_ref().unwrap().enabled);
}
}