use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::env;
use std::time::Duration;
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct Config {
pub server: ServerCfg,
pub auth: AuthCfg,
pub ratelimit: RateLimitCfg,
pub validation: ValidationCfg,
pub headers: HeadersCfg,
pub tls: TlsCfg,
pub waf: WafCfg,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct ServerCfg {
pub port: u16,
pub app_port: u16,
pub upstream: String,
pub trust_forwarded_for: bool,
pub admin_port: u16,
pub admin_addr: String,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct AuthCfg {
pub mode: String,
pub realm: String,
pub users: BTreeMap<String, String>,
pub api_keys: Vec<String>,
pub api_key_header: String,
pub jwt: JwtCfg,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct JwtCfg {
pub algorithm: String,
pub secret: String,
pub public_key_pem: String,
pub jwks_url: String,
pub jwks_cache_secs: u64,
pub issuer: String,
pub audience: String,
pub leeway_secs: u64,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct RateLimitCfg {
pub enabled: bool,
pub rate: String,
pub burst: u32,
pub routes: Vec<RouteRateLimit>,
pub per_key: PerKeyRateLimit,
pub store: String,
pub redis_url: String,
pub redis_prefix: String,
pub fail_open: bool,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct RouteRateLimit {
pub path: String,
pub rate: String,
pub burst: u32,
}
impl Default for RouteRateLimit {
fn default() -> Self {
RouteRateLimit {
path: String::new(),
rate: "60/min".into(),
burst: 20,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct PerKeyRateLimit {
pub enabled: bool,
pub rate: String,
pub burst: u32,
}
impl Default for PerKeyRateLimit {
fn default() -> Self {
PerKeyRateLimit {
enabled: false,
rate: "1000/hour".into(),
burst: 100,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct ValidationCfg {
pub max_body: String,
pub max_response_body: String,
pub upstream_timeout: String,
pub max_header_bytes: String,
pub allow_methods: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct HeadersCfg {
pub hsts: bool,
pub csp: String,
pub csp_report_only: bool,
pub csp_report_uri: String,
pub referrer_policy: String,
pub permissions_policy: String,
pub frame_options: String,
pub force_secure_cookies: bool,
pub strip: Vec<String>,
}
impl Default for ServerCfg {
fn default() -> Self {
ServerCfg {
port: 8080,
app_port: 3000,
upstream: String::new(),
trust_forwarded_for: false,
admin_port: 0,
admin_addr: "127.0.0.1".into(),
}
}
}
impl Default for AuthCfg {
fn default() -> Self {
AuthCfg {
mode: "none".into(),
realm: "EdgeGuard".into(),
users: BTreeMap::new(),
api_keys: vec![],
api_key_header: "X-API-Key".into(),
jwt: JwtCfg::default(),
}
}
}
impl Default for JwtCfg {
fn default() -> Self {
JwtCfg {
algorithm: "HS256".into(),
secret: String::new(),
public_key_pem: String::new(),
jwks_url: String::new(),
jwks_cache_secs: 300,
issuer: String::new(),
audience: String::new(),
leeway_secs: 60,
}
}
}
impl Default for RateLimitCfg {
fn default() -> Self {
RateLimitCfg {
enabled: true,
rate: "60/min".into(),
burst: 20,
routes: vec![],
per_key: PerKeyRateLimit::default(),
store: "local".into(),
redis_url: "redis://127.0.0.1:6379".into(),
redis_prefix: "edgeguard".into(),
fail_open: false,
}
}
}
impl Default for ValidationCfg {
fn default() -> Self {
ValidationCfg {
max_body: "2MiB".into(),
max_response_body: "0".into(),
upstream_timeout: "30s".into(),
max_header_bytes: "0".into(),
allow_methods: vec![],
}
}
}
impl Default for HeadersCfg {
fn default() -> Self {
HeadersCfg {
hsts: true,
csp: "default-src 'self'".into(),
csp_report_only: false,
csp_report_uri: String::new(),
referrer_policy: "no-referrer".into(),
permissions_policy: "geolocation=(), microphone=(), camera=()".into(),
frame_options: "DENY".into(),
force_secure_cookies: true,
strip: vec!["Server".into(), "X-Powered-By".into()],
}
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[serde(default)]
pub struct TlsCfg {
pub enabled: bool,
pub cert_path: String,
pub key_path: String,
pub acme: AcmeCfg,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct AcmeCfg {
pub enabled: bool,
pub domains: Vec<String>,
pub email: String,
pub directory_url: String,
pub cache_dir: String,
pub accept_tos: bool,
}
impl Default for AcmeCfg {
fn default() -> Self {
AcmeCfg {
enabled: false,
domains: vec![],
email: String::new(),
directory_url: "https://acme-staging-v02.api.letsencrypt.org/directory".into(),
cache_dir: "./acme".into(),
accept_tos: false,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct WafCfg {
pub mode: String,
pub sqli: bool,
pub xss: bool,
pub path_traversal: bool,
pub inspect_path: bool,
pub inspect_headers: bool,
pub inspect_body: bool,
pub rules: Vec<WafRule>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(default)]
pub struct WafRule {
pub id: String,
pub pattern: String,
pub target: String,
}
impl Default for WafCfg {
fn default() -> Self {
WafCfg {
mode: "off".into(),
sqli: true,
xss: true,
path_traversal: true,
inspect_path: true,
inspect_headers: false,
inspect_body: false,
rules: vec![],
}
}
}
impl Default for WafRule {
fn default() -> Self {
WafRule {
id: String::new(),
pattern: String::new(),
target: "path".into(),
}
}
}
impl Config {
pub fn load(path: Option<&str>) -> Result<Config> {
let mut cfg = if let Some(p) = path {
let raw =
std::fs::read_to_string(p).with_context(|| format!("reading config file {p}"))?;
toml::from_str::<Config>(&raw).with_context(|| format!("parsing config file {p}"))?
} else {
Config::default()
};
if let Ok(p) = env::var("PORT") {
if let Ok(v) = p.parse() {
cfg.server.port = v;
}
}
if let Ok(p) = env::var("APP_PORT") {
if let Ok(v) = p.parse() {
cfg.server.app_port = v;
}
}
if let Ok(p) = env::var("ADMIN_PORT") {
if let Ok(v) = p.parse() {
cfg.server.admin_port = v;
}
}
if let Ok(u) = env::var("UPSTREAM") {
if !u.is_empty() {
cfg.server.upstream = u;
}
}
if let Ok(s) = env::var("EDGEGUARD_JWT_SECRET") {
if !s.is_empty() {
cfg.auth.jwt.secret = s;
}
}
if let Ok(u) = env::var("EDGEGUARD_REDIS_URL") {
if !u.is_empty() {
cfg.ratelimit.redis_url = u;
}
}
if let Ok(keys) = env::var("EDGEGUARD_API_KEYS") {
let keys: Vec<String> = keys
.split(',')
.map(|k| k.trim().to_string())
.filter(|k| !k.is_empty())
.collect();
if !keys.is_empty() {
cfg.auth.api_keys = keys;
}
}
Ok(cfg)
}
pub fn upstream_base(&self) -> String {
if self.server.upstream.is_empty() {
format!("http://127.0.0.1:{}", self.server.app_port)
} else {
self.server.upstream.trim_end_matches('/').to_string()
}
}
pub fn upstream_probe_addr(&self) -> Option<(String, u16)> {
if self.server.upstream.is_empty() {
Some(("127.0.0.1".to_string(), self.server.app_port))
} else {
parse_host_port(&self.server.upstream)
}
}
}
fn parse_host_port(url: &str) -> Option<(String, u16)> {
let (default_port, rest) = if let Some(r) = url.strip_prefix("http://") {
(80u16, r)
} else if let Some(r) = url.strip_prefix("https://") {
(443u16, r)
} else {
(80u16, url)
};
let authority = rest.split('/').next().unwrap_or(rest);
let authority = authority.rsplit('@').next().unwrap_or(authority);
if authority.is_empty() {
return None;
}
if let Some(after) = authority.strip_prefix('[') {
let (host, tail) = after.split_once(']')?;
let port = match tail.strip_prefix(':') {
Some(p) => p.parse().ok()?,
None => default_port,
};
return Some((host.to_string(), port));
}
match authority.rsplit_once(':') {
Some((host, port)) if !host.is_empty() => Some((host.to_string(), port.parse().ok()?)),
Some(_) => None,
None => Some((authority.to_string(), default_port)),
}
}
pub fn parse_size(s: &str) -> Result<usize> {
let s = s.trim();
let (num, mult): (&str, usize) = if let Some(n) = s.strip_suffix("GiB") {
(n, 1024 * 1024 * 1024)
} else if let Some(n) = s.strip_suffix("MiB") {
(n, 1024 * 1024)
} else if let Some(n) = s.strip_suffix("KiB") {
(n, 1024)
} else if let Some(n) = s.strip_suffix("GB") {
(n, 1_000_000_000)
} else if let Some(n) = s.strip_suffix("MB") {
(n, 1_000_000)
} else if let Some(n) = s.strip_suffix("KB") {
(n, 1_000)
} else if let Some(n) = s.strip_suffix('B') {
(n, 1)
} else {
(s, 1)
};
let n: usize = num
.trim()
.parse()
.with_context(|| format!("invalid size: {s}"))?;
n.checked_mul(mult)
.with_context(|| format!("size too large: {s}"))
}
pub fn parse_rate(s: &str) -> Result<(u32, Duration)> {
let (n, unit) = s
.split_once('/')
.with_context(|| format!("invalid rate (expected N/unit): {s}"))?;
let count: u32 = n
.trim()
.parse()
.with_context(|| format!("invalid rate count: {s}"))?;
let period = match unit.trim() {
"s" | "sec" | "second" => Duration::from_secs(1),
"m" | "min" | "minute" => Duration::from_secs(60),
"h" | "hour" => Duration::from_secs(3600),
other => anyhow::bail!("unsupported rate unit: {other}"),
};
Ok((count, period))
}
pub fn parse_duration(s: &str) -> Result<Duration> {
let s = s.trim();
if let Some(n) = s.strip_suffix("ms") {
let ms: u64 = n
.trim()
.parse()
.with_context(|| format!("invalid duration: {s}"))?;
Ok(Duration::from_millis(ms))
} else if let Some(n) = s.strip_suffix('s') {
let secs: u64 = n
.trim()
.parse()
.with_context(|| format!("invalid duration: {s}"))?;
Ok(Duration::from_secs(secs))
} else if let Some(n) = s.strip_suffix('m') {
let mins: u64 = n
.trim()
.parse()
.with_context(|| format!("invalid duration: {s}"))?;
let secs = mins
.checked_mul(60)
.with_context(|| format!("duration too large: {s}"))?;
Ok(Duration::from_secs(secs))
} else {
let secs: u64 = s
.parse()
.with_context(|| format!("invalid duration: {s}"))?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_size_units_and_plain_bytes() {
assert_eq!(parse_size("0").unwrap(), 0);
assert_eq!(parse_size("1048576").unwrap(), 1_048_576);
assert_eq!(parse_size("512B").unwrap(), 512);
assert_eq!(parse_size("1KB").unwrap(), 1_000);
assert_eq!(parse_size("1KiB").unwrap(), 1_024);
assert_eq!(parse_size("2MiB").unwrap(), 2 * 1024 * 1024);
assert_eq!(parse_size("16MiB").unwrap(), 16 * 1024 * 1024);
assert_eq!(parse_size("1GiB").unwrap(), 1024 * 1024 * 1024);
assert_eq!(parse_size(" 4 MiB ").unwrap(), 4 * 1024 * 1024);
}
#[test]
fn parse_size_rejects_garbage_and_overflow() {
assert!(parse_size("abc").is_err());
assert!(parse_size("MiB").is_err());
assert!(parse_size("99999999999999999999GiB").is_err());
}
#[test]
fn parse_rate_counts_and_units() {
assert_eq!(parse_rate("60/min").unwrap(), (60, Duration::from_secs(60)));
assert_eq!(parse_rate("10/sec").unwrap(), (10, Duration::from_secs(1)));
assert_eq!(
parse_rate("1000/hour").unwrap(),
(1000, Duration::from_secs(3600))
);
assert_eq!(parse_rate(" 5 / m ").unwrap(), (5, Duration::from_secs(60)));
}
#[test]
fn parse_rate_rejects_garbage() {
assert!(parse_rate("60").is_err()); assert!(parse_rate("x/min").is_err()); assert!(parse_rate("60/year").is_err()); }
#[test]
fn probe_addr_defaults_to_app_port_in_coprocess_mode() {
let cfg = Config::default();
assert_eq!(
cfg.upstream_probe_addr(),
Some(("127.0.0.1".to_string(), cfg.server.app_port))
);
}
#[test]
fn parse_host_port_handles_schemes_paths_and_ipv6() {
assert_eq!(
parse_host_port("http://127.0.0.1:3000"),
Some(("127.0.0.1".to_string(), 3000))
);
assert_eq!(
parse_host_port("http://app.internal:8080/health"),
Some(("app.internal".to_string(), 8080))
);
assert_eq!(
parse_host_port("https://example.com"),
Some(("example.com".to_string(), 443))
);
assert_eq!(
parse_host_port("http://example.com"),
Some(("example.com".to_string(), 80))
);
assert_eq!(
parse_host_port("http://[::1]:3000"),
Some(("::1".to_string(), 3000))
);
assert_eq!(
parse_host_port("http://[2001:db8::1]"),
Some(("2001:db8::1".to_string(), 80))
);
}
#[test]
fn parse_host_port_rejects_empty_or_unusable_host() {
assert_eq!(parse_host_port("http://:3000"), None);
assert_eq!(parse_host_port("http://host:notaport"), None);
}
#[test]
fn parse_duration_units_and_bare_seconds() {
assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
assert_eq!(parse_duration("500ms").unwrap(), Duration::from_millis(500));
assert_eq!(parse_duration("2m").unwrap(), Duration::from_secs(120));
assert_eq!(parse_duration("45").unwrap(), Duration::from_secs(45));
assert_eq!(parse_duration("0").unwrap(), Duration::ZERO);
assert_eq!(parse_duration(" 10s ").unwrap(), Duration::from_secs(10));
}
#[test]
fn parse_duration_rejects_garbage() {
assert!(parse_duration("abc").is_err());
assert!(parse_duration("10x").is_err());
assert!(parse_duration("s").is_err());
}
}