use std::fmt;
use std::path::Path;
use lettre::Address;
use serde::Deserialize;
use thiserror::Error;
#[derive(Clone, Deserialize)]
#[serde(transparent)]
pub struct SecretString(String);
impl SecretString {
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
pub fn expose(&self) -> &str {
&self.0
}
}
impl fmt::Debug for SecretString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[REDACTED]")
}
}
impl fmt::Display for SecretString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("[REDACTED]")
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct AppConfig {
pub server: ServerConfig,
pub security: SecurityConfig,
pub mail: MailConfig,
pub smtp: SmtpConfig,
pub rate_limit: RateLimitConfig,
pub logging: LoggingConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ServerConfig {
pub bind_address: String,
#[serde(default = "default_max_request_body_bytes")]
pub max_request_body_bytes: usize,
#[serde(default = "default_request_timeout_seconds")]
pub request_timeout_seconds: u64,
#[serde(default = "default_shutdown_timeout_seconds")]
pub shutdown_timeout_seconds: u64,
#[serde(default)]
pub concurrency_limit: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SecurityConfig {
#[serde(default = "default_true")]
pub require_auth: bool,
#[serde(default)]
pub trust_proxy_headers: bool,
#[serde(default)]
pub trusted_source_cidrs: Vec<String>,
#[serde(default)]
pub allowed_source_cidrs: Vec<String>,
#[serde(default)]
pub api_keys: Vec<ApiKeyConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiKeyConfig {
pub id: String,
pub secret: SecretString,
#[serde(default = "default_true")]
pub enabled: bool,
pub description: Option<String>,
#[serde(default)]
pub allowed_recipient_domains: Vec<String>,
#[serde(default)]
pub allowed_recipients: Vec<String>,
pub rate_limit_per_min: Option<u32>,
#[serde(default)]
pub burst: u32,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MailConfig {
pub default_from: String,
pub default_from_name: Option<String>,
#[serde(default)]
pub allowed_recipient_domains: Vec<String>,
#[serde(default = "default_max_subject_chars")]
pub max_subject_chars: usize,
#[serde(default = "default_max_body_bytes")]
pub max_body_bytes: usize,
#[serde(default = "default_max_recipients")]
pub max_recipients: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct SmtpConfig {
#[serde(default = "default_smtp_mode")]
pub mode: String,
#[serde(default = "default_smtp_host")]
pub host: String,
#[serde(default = "default_smtp_port")]
pub port: u16,
#[serde(default = "default_connect_timeout_seconds")]
pub connect_timeout_seconds: u64,
#[serde(default = "default_submission_timeout_seconds")]
pub submission_timeout_seconds: u64,
pub auth_user: Option<String>,
pub auth_password: Option<SecretString>,
#[serde(default = "default_pipe_command")]
pub pipe_command: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct RateLimitConfig {
#[serde(default = "default_global_per_min")]
pub global_per_min: u32,
#[serde(default = "default_per_ip_per_min")]
pub per_ip_per_min: u32,
#[serde(default = "default_per_key_per_min")]
pub per_key_per_min: u32,
#[serde(default = "default_global_burst")]
pub global_burst: u32,
#[serde(default = "default_per_ip_burst")]
pub per_ip_burst: u32,
#[serde(default = "default_per_key_burst")]
pub per_key_burst: u32,
#[serde(default)]
pub burst_size: u32,
#[serde(default = "default_ip_table_size")]
pub ip_table_size: usize,
}
impl RateLimitConfig {
pub fn effective_global_burst(&self) -> u32 {
if self.global_burst > 0 { self.global_burst }
else if self.burst_size > 0 { self.burst_size }
else { default_global_burst() }
}
pub fn effective_per_ip_burst(&self) -> u32 {
if self.per_ip_burst > 0 { self.per_ip_burst }
else if self.burst_size > 0 { self.burst_size }
else { default_per_ip_burst() }
}
pub fn effective_per_key_burst(&self) -> u32 {
if self.per_key_burst > 0 { self.per_key_burst }
else if self.burst_size > 0 { self.burst_size }
else { default_per_key_burst() }
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct LoggingConfig {
#[serde(default = "default_log_format")]
pub format: String,
#[serde(default = "default_log_level")]
pub level: String,
#[serde(default)]
pub mask_recipient: bool,
}
fn default_max_request_body_bytes() -> usize { 1_048_576 }
fn default_request_timeout_seconds() -> u64 { 30 }
fn default_shutdown_timeout_seconds() -> u64 { 30 }
fn default_true() -> bool { true }
fn default_max_subject_chars() -> usize { 255 }
fn default_max_body_bytes() -> usize { 65_536 }
fn default_smtp_mode() -> String { "smtp".into() }
fn default_smtp_host() -> String { "127.0.0.1".into() }
fn default_smtp_port() -> u16 { 25 }
fn default_connect_timeout_seconds() -> u64 { 5 }
fn default_submission_timeout_seconds() -> u64 { 30 }
fn default_global_per_min() -> u32 { 60 }
fn default_per_ip_per_min() -> u32 { 20 }
#[allow(dead_code)]
fn default_burst_size() -> u32 { 5 }
fn default_max_recipients() -> usize { 10 }
fn default_pipe_command() -> String { "/usr/sbin/sendmail".into() }
fn default_global_burst() -> u32 { 10 }
fn default_per_ip_burst() -> u32 { 5 }
fn default_per_key_burst() -> u32 { 5 }
fn default_per_key_per_min() -> u32 { 30 }
fn default_ip_table_size() -> usize { 10_000 }
fn default_log_format() -> String { "text".into() }
fn default_log_level() -> String { "info".into() }
#[derive(Debug, Error)]
pub enum ConfigError {
#[error("cannot read config file: {0}")]
Io(#[from] std::io::Error),
#[error("config parse error: {0}")]
Parse(#[from] toml::de::Error),
#[error("invalid server.bind_address: must be host:port (e.g. 127.0.0.1:8080)")]
InvalidBindAddress,
#[error("invalid mail.default_from: must be a valid email address")]
InvalidDefaultFrom,
#[error("security.require_auth is true but no api_keys are defined")]
NoApiKeys,
#[error("no api_keys entries have enabled = true")]
NoEnabledApiKeys,
#[error("invalid CIDR: {0}")]
InvalidCidr(String),
#[error("configuration error: {0}")]
Validation(String),
#[error("invalid smtp.port: must be 1-65535")]
InvalidSmtpPort,
#[error("invalid rate_limit values: all per_min values must be > 0")]
InvalidRateLimit,
#[error("invalid logging.level: must be trace, debug, info, warn, or error")]
InvalidLogLevel,
#[error("invalid logging.format: must be 'text' or 'json'")]
InvalidLogFormat,
}
pub fn load(path: &Path) -> Result<AppConfig, ConfigError> {
let text = std::fs::read_to_string(path)?;
let config: AppConfig = toml::from_str(&text)?;
validate(&config)?;
Ok(config)
}
fn validate(config: &AppConfig) -> Result<(), ConfigError> {
config
.server
.bind_address
.parse::<std::net::SocketAddr>()
.map_err(|_| ConfigError::InvalidBindAddress)?;
config
.mail
.default_from
.parse::<Address>()
.map_err(|_| ConfigError::InvalidDefaultFrom)?;
if config.security.require_auth && config.security.api_keys.is_empty() {
return Err(ConfigError::NoApiKeys);
}
if config.security.require_auth
&& !config.security.api_keys.iter().any(|k| k.enabled)
{
return Err(ConfigError::NoEnabledApiKeys);
}
for cidr in config.security.trusted_source_cidrs.iter()
.chain(config.security.allowed_source_cidrs.iter())
{
cidr.parse::<ipnet::IpNet>()
.map_err(|_| ConfigError::InvalidCidr(cidr.clone()))?;
}
if config.smtp.port == 0 {
return Err(ConfigError::InvalidSmtpPort);
}
match (&config.smtp.auth_user, &config.smtp.auth_password) {
(Some(_), None) | (None, Some(_)) => {
return Err(ConfigError::Validation(
"smtp.auth_user and smtp.auth_password must both be set or both absent".into(),
));
}
_ => {}
}
if config.smtp.mode == "pipe"
&& (config.smtp.auth_user.is_some() || config.smtp.auth_password.is_some())
{
return Err(ConfigError::Validation(
r#"smtp.auth_user/auth_password are not applicable when smtp.mode = "pipe""#.into(),
));
}
if config.rate_limit.global_per_min == 0 || config.rate_limit.per_ip_per_min == 0 {
return Err(ConfigError::InvalidRateLimit);
}
let valid_levels = ["trace", "debug", "info", "warn", "error"];
if !valid_levels.contains(&config.logging.level.as_str()) {
return Err(ConfigError::InvalidLogLevel);
}
let valid_formats = ["text", "json"];
if !valid_formats.contains(&config.logging.format.as_str()) {
return Err(ConfigError::InvalidLogFormat);
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_config_str() -> String {
r#"
[server]
bind_address = "127.0.0.1:8080"
[security]
require_auth = false
[rate_limit]
[mail]
default_from = "noreply@example.com"
[smtp]
[logging]
"#
.into()
}
#[test]
fn valid_config_parses() {
let config: AppConfig = toml::from_str(&minimal_config_str()).unwrap();
assert!(validate(&config).is_ok());
}
#[test]
fn invalid_bind_address() {
let text = minimal_config_str().replace("127.0.0.1:8080", "notanaddress");
let config: AppConfig = toml::from_str(&text).unwrap();
assert!(matches!(validate(&config), Err(ConfigError::InvalidBindAddress)));
}
#[test]
fn invalid_default_from() {
let text = minimal_config_str().replace("noreply@example.com", "notanemail");
let config: AppConfig = toml::from_str(&text).unwrap();
assert!(matches!(validate(&config), Err(ConfigError::InvalidDefaultFrom)));
}
#[test]
fn require_auth_no_keys() {
let text = minimal_config_str().replace("require_auth = false", "require_auth = true");
let config: AppConfig = toml::from_str(&text).unwrap();
assert!(matches!(validate(&config), Err(ConfigError::NoApiKeys)));
}
#[test]
fn secret_string_is_redacted_in_debug() {
let s = SecretString::new("very-secret");
assert!(!format!("{:?}", s).contains("very-secret"));
assert!(!format!("{}", s).contains("very-secret"));
assert_eq!(s.expose(), "very-secret");
}
#[test]
fn defaults_are_sensible() {
let config: AppConfig = toml::from_str(&minimal_config_str()).unwrap();
assert_eq!(config.server.max_request_body_bytes, 1_048_576);
assert_eq!(config.smtp.port, 25);
assert_eq!(config.rate_limit.global_per_min, 60);
assert_eq!(config.logging.format, "text");
}
}