use crate::shadow::ShadowMirrorConfig;
use crate::trap::TrapConfig;
use crate::vhost::SiteConfig;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::fmt;
use std::fs;
use std::path::Path;
use tracing::{debug, info, warn};
const MAX_CONFIG_SIZE: u64 = 10 * 1024 * 1024;
#[derive(Clone, Serialize, Deserialize, JsonSchema)]
pub struct GlobalConfig {
#[serde(default = "default_http_addr")]
pub http_addr: String,
#[serde(default = "default_https_addr")]
pub https_addr: String,
#[serde(default)]
pub workers: usize,
#[serde(default = "default_shutdown_timeout")]
pub shutdown_timeout_secs: u64,
#[serde(default = "default_waf_threshold")]
pub waf_threshold: u8,
#[serde(default = "default_true")]
pub waf_enabled: bool,
#[serde(default = "default_log_level")]
pub log_level: String,
#[serde(default)]
pub admin_api_key: Option<String>,
#[serde(default)]
pub trap_config: Option<TrapConfig>,
#[serde(default = "default_waf_regex_timeout_ms")]
pub waf_regex_timeout_ms: u64,
}
fn default_waf_regex_timeout_ms() -> u64 {
100 }
fn default_http_addr() -> String {
"0.0.0.0:80".to_string()
}
fn default_https_addr() -> String {
"0.0.0.0:443".to_string()
}
fn default_shutdown_timeout() -> u64 {
30
}
fn default_waf_threshold() -> u8 {
70
}
fn default_true() -> bool {
true
}
fn default_log_level() -> String {
"info".to_string()
}
impl fmt::Debug for GlobalConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GlobalConfig")
.field("http_addr", &self.http_addr)
.field("https_addr", &self.https_addr)
.field("workers", &self.workers)
.field("shutdown_timeout_secs", &self.shutdown_timeout_secs)
.field("waf_threshold", &self.waf_threshold)
.field("waf_enabled", &self.waf_enabled)
.field("log_level", &self.log_level)
.field(
"admin_api_key",
&self.admin_api_key.as_ref().map(|_| "[REDACTED]"),
)
.field("trap_config", &self.trap_config)
.field("waf_regex_timeout_ms", &self.waf_regex_timeout_ms)
.finish()
}
}
impl Default for GlobalConfig {
fn default() -> Self {
Self {
http_addr: default_http_addr(),
https_addr: default_https_addr(),
workers: 0,
shutdown_timeout_secs: default_shutdown_timeout(),
waf_threshold: default_waf_threshold(),
waf_enabled: true,
log_level: default_log_level(),
admin_api_key: None,
trap_config: None,
waf_regex_timeout_ms: default_waf_regex_timeout_ms(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct RateLimitConfig {
pub rps: u32,
#[serde(default = "default_true")]
pub enabled: bool,
pub burst: Option<u32>,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
rps: 10000,
enabled: true,
burst: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct UpstreamConfig {
pub host: String,
pub port: u16,
#[serde(default = "default_weight")]
pub weight: u32,
#[serde(skip)]
pub healthy: bool,
}
fn default_weight() -> u32 {
1
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
#[serde(default = "default_min_tls")]
pub min_version: String,
}
fn default_min_tls() -> String {
"1.2".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)]
pub struct AccessControlConfig {
#[serde(default)]
pub allow: Vec<String>,
#[serde(default)]
pub deny: Vec<String>,
#[serde(default)]
pub default_action: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)]
pub struct HeaderConfig {
#[serde(default)]
pub request: HeaderOps,
#[serde(default)]
pub response: HeaderOps,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)]
pub struct HeaderOps {
#[serde(default)]
pub add: std::collections::HashMap<String, String>,
#[serde(default)]
pub set: std::collections::HashMap<String, String>,
#[serde(default)]
pub remove: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SiteWafConfig {
#[serde(default = "default_true")]
pub enabled: bool,
pub threshold: Option<u8>,
#[serde(default)]
pub rule_overrides: std::collections::HashMap<String, String>,
}
impl Default for SiteWafConfig {
fn default() -> Self {
Self {
enabled: true,
threshold: None,
rule_overrides: std::collections::HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct SiteYamlConfig {
pub hostname: String,
pub upstreams: Vec<UpstreamConfig>,
pub tls: Option<TlsConfig>,
pub waf: Option<SiteWafConfig>,
pub rate_limit: Option<RateLimitConfig>,
pub access_control: Option<AccessControlConfig>,
pub headers: Option<HeaderConfig>,
#[serde(default)]
pub shadow_mirror: Option<ShadowMirrorConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ProfilerConfig {
#[serde(default = "default_true")]
pub enabled: bool,
#[serde(default = "default_max_profiles")]
pub max_profiles: usize,
#[serde(default = "default_max_schemas")]
pub max_schemas: usize,
#[serde(default = "default_min_samples")]
pub min_samples_for_validation: u32,
#[serde(default = "default_payload_z_threshold")]
pub payload_z_threshold: f64,
#[serde(default = "default_param_z_threshold")]
pub param_z_threshold: f64,
#[serde(default = "default_response_z_threshold")]
pub response_z_threshold: f64,
#[serde(default = "default_min_stddev")]
pub min_stddev: f64,
#[serde(default = "default_type_ratio_threshold")]
pub type_ratio_threshold: f64,
#[serde(default = "default_max_type_counts")]
pub max_type_counts: usize,
#[serde(default = "default_true")]
pub redact_pii: bool,
#[serde(default)]
pub freeze_after_samples: u32,
}
fn default_max_profiles() -> usize {
1000
}
fn default_max_schemas() -> usize {
500
}
fn default_min_samples() -> u32 {
100
}
fn default_payload_z_threshold() -> f64 {
3.0
}
fn default_param_z_threshold() -> f64 {
4.0
}
fn default_response_z_threshold() -> f64 {
4.0
}
fn default_min_stddev() -> f64 {
0.01
}
fn default_type_ratio_threshold() -> f64 {
0.9
}
fn default_max_type_counts() -> usize {
10
}
impl Default for ProfilerConfig {
fn default() -> Self {
Self {
enabled: true,
max_profiles: default_max_profiles(),
max_schemas: default_max_schemas(),
min_samples_for_validation: default_min_samples(),
payload_z_threshold: default_payload_z_threshold(),
param_z_threshold: default_param_z_threshold(),
response_z_threshold: default_response_z_threshold(),
min_stddev: default_min_stddev(),
type_ratio_threshold: default_type_ratio_threshold(),
max_type_counts: default_max_type_counts(),
redact_pii: true,
freeze_after_samples: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct ConfigFile {
#[serde(default)]
pub server: GlobalConfig,
pub sites: Vec<SiteYamlConfig>,
#[serde(default)]
pub rate_limit: RateLimitConfig,
#[serde(default)]
pub profiler: ProfilerConfig,
}
#[derive(Debug, thiserror::Error)]
pub enum ConfigError {
#[error("configuration file not found: {path} (check the path or mount the file into the container)")]
NotFound { path: String },
#[error("configuration file too large: {size} bytes (max {max} bytes). Reduce size or split the configuration")]
FileTooLarge { size: u64, max: u64 },
#[error("failed to read configuration: {0}")]
IoError(#[from] std::io::Error),
#[error("failed to parse configuration: {0}")]
ParseError(#[from] serde_yaml::Error),
#[error("validation error: {0}")]
ValidationError(String),
#[error("TLS certificate not found: {path} (set tls.cert_path to a valid PEM file)")]
CertNotFound { path: String },
#[error("TLS key not found: {path} (set tls.key_path to a valid PEM file)")]
KeyNotFound { path: String },
#[error("duplicate hostname: {hostname} (hostnames must be unique; consider a wildcard like '*.example.com')")]
DuplicateHostname { hostname: String },
#[error("invalid TLS version: {version} (set min_version to '1.2' or '1.3')")]
InvalidTlsVersion { version: String },
#[error("path traversal detected in: {path} (remove '..' or encoded traversal sequences)")]
PathTraversal { path: String },
}
fn contains_path_traversal(path: &str) -> bool {
if path.contains("..") {
return true;
}
let path_lower = path.to_lowercase();
if path_lower.contains("%2e%2e") || path_lower.contains("%2e.") || path_lower.contains(".%2e") {
return true;
}
if path_lower.contains("%252e") {
return true;
}
if path.contains('\0') || path_lower.contains("%00") {
return true;
}
false
}
pub struct ConfigLoader;
impl ConfigLoader {
pub fn load<P: AsRef<Path>>(path: P) -> Result<ConfigFile, ConfigError> {
let path = path.as_ref();
info!("Loading configuration from: {}", path.display());
if !path.exists() {
return Err(ConfigError::NotFound {
path: path.display().to_string(),
});
}
let metadata = fs::metadata(path)?;
if metadata.len() > MAX_CONFIG_SIZE {
return Err(ConfigError::FileTooLarge {
size: metadata.len(),
max: MAX_CONFIG_SIZE,
});
}
let contents = fs::read_to_string(path)?;
let config: ConfigFile = serde_yaml::from_str(&contents)?;
Self::validate(&config)?;
info!("Loaded configuration with {} sites", config.sites.len());
Ok(config)
}
fn validate(config: &ConfigFile) -> Result<(), ConfigError> {
let mut hostnames = HashSet::new();
for site in &config.sites {
let normalized = site.hostname.to_lowercase();
if !hostnames.insert(normalized.clone()) {
return Err(ConfigError::DuplicateHostname {
hostname: site.hostname.clone(),
});
}
if site.upstreams.is_empty() {
return Err(ConfigError::ValidationError(format!(
"site '{}' has no upstreams configured; add at least one upstream with host and port",
site.hostname
)));
}
if let Some(tls) = &site.tls {
Self::validate_tls(tls)?;
}
if let Some(waf) = &site.waf {
if !waf.enabled {
warn!(
site = %site.hostname,
"WAF protection DISABLED for site - backend may be exposed to attacks"
);
}
if let Some(threshold) = waf.threshold {
if threshold == 0 {
return Err(ConfigError::ValidationError(format!(
"site '{}' has WAF threshold of 0, which effectively disables protection. \
Use waf.enabled: false to disable the WAF, or set threshold between 1-100",
site.hostname
)));
}
if threshold > 100 {
return Err(ConfigError::ValidationError(format!(
"site '{}' has invalid WAF threshold {} (must be 1-100); \
use waf.enabled: false to disable or set a valid threshold",
site.hostname, threshold
)));
}
}
}
if let Some(rl) = &site.rate_limit {
if rl.rps == 0 && rl.enabled {
warn!(
"Site '{}' has rate limiting enabled with 0 RPS; set rps > 0 or disable rate limiting",
site.hostname
);
}
if rl.rps > 1_000_000 {
return Err(ConfigError::ValidationError(format!(
"site '{}' has extreme RPS limit {} (max 1,000,000)",
site.hostname, rl.rps
)));
}
}
if let Some(shadow) = &site.shadow_mirror {
if let Err(e) = shadow.validate() {
return Err(ConfigError::ValidationError(format!(
"site '{}' has invalid shadow_mirror config: {}. Fix shadow_mirror settings or remove the block",
site.hostname,
e
)));
}
}
}
if config.server.workers > 1024 {
return Err(ConfigError::ValidationError(format!(
"extreme worker count {} (max 1024)",
config.server.workers
)));
}
if config.server.shutdown_timeout_secs > 3600 {
return Err(ConfigError::ValidationError(format!(
"extreme shutdown timeout {}s (max 3600)",
config.server.shutdown_timeout_secs
)));
}
if config.server.waf_regex_timeout_ms > 500 {
return Err(ConfigError::ValidationError(format!(
"extreme WAF regex timeout {}ms (max 500)",
config.server.waf_regex_timeout_ms
)));
}
if !config.server.waf_enabled {
warn!(
"Global WAF protection DISABLED - all sites may be exposed to attacks unless individually configured"
);
}
if config.server.waf_threshold == 0 {
return Err(ConfigError::ValidationError(
"global WAF threshold of 0 effectively disables protection. \
Use waf_enabled: false to disable globally, or set waf_threshold between 1-100"
.to_string(),
));
}
if config.server.waf_threshold > 100 {
return Err(ConfigError::ValidationError(format!(
"global WAF threshold {} is invalid (must be 1-100); set waf_threshold between 1-100",
config.server.waf_threshold
)));
}
Ok(())
}
fn validate_tls(tls: &TlsConfig) -> Result<(), ConfigError> {
if contains_path_traversal(&tls.cert_path) {
return Err(ConfigError::PathTraversal {
path: tls.cert_path.clone(),
});
}
if contains_path_traversal(&tls.key_path) {
return Err(ConfigError::PathTraversal {
path: tls.key_path.clone(),
});
}
if !Path::new(&tls.cert_path).exists() {
return Err(ConfigError::CertNotFound {
path: tls.cert_path.clone(),
});
}
if !Path::new(&tls.key_path).exists() {
return Err(ConfigError::KeyNotFound {
path: tls.key_path.clone(),
});
}
match tls.min_version.as_str() {
"1.2" | "1.3" => {}
_ => {
return Err(ConfigError::InvalidTlsVersion {
version: tls.min_version.clone(),
});
}
}
debug!(
"Validated TLS config: cert={}, key=[REDACTED]",
tls.cert_path
);
Ok(())
}
pub fn to_site_configs(config: &ConfigFile) -> Vec<SiteConfig> {
config
.sites
.iter()
.map(|site| SiteConfig {
hostname: site.hostname.clone(),
upstreams: site
.upstreams
.iter()
.map(|u| format!("{}:{}", u.host, u.port))
.collect(),
tls_enabled: site.tls.is_some(),
tls_cert: site.tls.as_ref().map(|t| t.cert_path.clone()),
tls_key: site.tls.as_ref().map(|t| t.key_path.clone()),
waf_threshold: site.waf.as_ref().and_then(|w| w.threshold),
waf_enabled: site.waf.as_ref().map(|w| w.enabled).unwrap_or(true),
access_control: site.access_control.clone(),
headers: site.headers.as_ref().map(|headers| headers.compile()),
shadow_mirror: site.shadow_mirror.clone(),
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn create_temp_config(content: &str) -> NamedTempFile {
let mut file = NamedTempFile::new().unwrap();
file.write_all(content.as_bytes()).unwrap();
file
}
#[test]
fn test_load_minimal_config() {
let yaml = r#"
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.sites.len(), 1);
assert_eq!(config.sites[0].hostname, "example.com");
}
#[test]
fn test_load_full_config() {
let yaml = r#"
server:
http_addr: "0.0.0.0:8080"
https_addr: "0.0.0.0:8443"
workers: 4
waf_threshold: 80
log_level: debug
rate_limit:
rps: 5000
enabled: true
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
weight: 2
waf:
enabled: true
threshold: 60
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.server.http_addr, "0.0.0.0:8080");
assert_eq!(config.server.workers, 4);
assert_eq!(config.rate_limit.rps, 5000);
assert_eq!(config.sites[0].waf.as_ref().unwrap().threshold, Some(60));
}
#[test]
fn test_duplicate_hostname() {
let yaml = r#"
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8081
"#;
let file = create_temp_config(yaml);
let result = ConfigLoader::load(file.path());
assert!(matches!(result, Err(ConfigError::DuplicateHostname { .. })));
}
#[test]
fn test_no_upstreams() {
let yaml = r#"
sites:
- hostname: example.com
upstreams: []
"#;
let file = create_temp_config(yaml);
let result = ConfigLoader::load(file.path());
assert!(matches!(result, Err(ConfigError::ValidationError(_))));
}
#[test]
fn test_invalid_waf_threshold() {
let yaml = r#"
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
waf:
threshold: 150
"#;
let file = create_temp_config(yaml);
let result = ConfigLoader::load(file.path());
assert!(matches!(result, Err(ConfigError::ValidationError(_))));
}
#[test]
fn test_file_not_found() {
let result = ConfigLoader::load("/nonexistent/config.yaml");
assert!(matches!(result, Err(ConfigError::NotFound { .. })));
}
#[test]
fn test_default_values() {
let config = GlobalConfig::default();
assert_eq!(config.http_addr, "0.0.0.0:80");
assert_eq!(config.https_addr, "0.0.0.0:443");
assert_eq!(config.waf_threshold, 70);
assert!(config.waf_enabled);
assert_eq!(config.waf_regex_timeout_ms, 100); }
#[test]
fn test_debug_redacts_admin_api_key() {
let mut config = GlobalConfig::default();
config.admin_api_key = Some("super-secret-key-12345".to_string());
let debug_output = format!("{:?}", config);
assert!(!debug_output.contains("super-secret-key-12345"));
assert!(debug_output.contains("[REDACTED]"));
assert!(debug_output.contains("0.0.0.0:80"));
}
#[test]
fn test_debug_shows_none_when_no_key() {
let config = GlobalConfig::default();
let debug_output = format!("{:?}", config);
assert!(debug_output.contains("None"));
assert!(!debug_output.contains("[REDACTED]"));
}
#[test]
fn test_waf_regex_timeout_config() {
let yaml = r#"
server:
waf_regex_timeout_ms: 200
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.server.waf_regex_timeout_ms, 200);
}
#[test]
fn test_waf_regex_timeout_default() {
let yaml = r#"
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.server.waf_regex_timeout_ms, 100);
}
#[test]
fn test_to_site_configs() {
let yaml = r#"
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
waf:
enabled: true
threshold: 80
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
let sites = ConfigLoader::to_site_configs(&config);
assert_eq!(sites.len(), 1);
assert_eq!(sites[0].hostname, "example.com");
assert_eq!(sites[0].waf_threshold, Some(80));
assert!(sites[0].waf_enabled);
}
#[test]
fn test_yaml_with_unknown_fields_passes() {
let yaml = r#"
server:
http_addr: "0.0.0.0:9090"
unknown_field: "should be ignored"
another_mystery: 42
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
extra_site_field: true
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.server.http_addr, "0.0.0.0:9090");
assert_eq!(config.sites.len(), 1);
assert_eq!(config.sites[0].hostname, "example.com");
}
#[test]
fn test_yaml_with_unknown_top_level_field_passes() {
let yaml = r#"
some_future_feature:
enabled: true
sites:
- hostname: example.com
upstreams:
- host: 127.0.0.1
port: 8080
"#;
let file = create_temp_config(yaml);
let config = ConfigLoader::load(file.path()).unwrap();
assert_eq!(config.sites.len(), 1);
}
#[test]
fn test_path_traversal_detection() {
use super::contains_path_traversal;
assert!(contains_path_traversal(".."));
assert!(contains_path_traversal("../etc/passwd"));
assert!(contains_path_traversal("/path/../secret"));
assert!(contains_path_traversal("path/to/../../root"));
assert!(contains_path_traversal("%2e%2e"));
assert!(contains_path_traversal("%2E%2E"));
assert!(contains_path_traversal("%2e."));
assert!(contains_path_traversal(".%2e"));
assert!(contains_path_traversal("%252e%252e"));
assert!(contains_path_traversal("path/%252e%252e/file"));
assert!(contains_path_traversal("\x00"));
assert!(contains_path_traversal("path\x00/file"));
assert!(contains_path_traversal("%00"));
assert!(contains_path_traversal("path/%00/file"));
assert!(!contains_path_traversal("/path/to/file"));
assert!(!contains_path_traversal("certs/server.pem"));
assert!(!contains_path_traversal("/etc/nginx/ssl/cert.pem"));
assert!(!contains_path_traversal("./relative/path")); }
}