use std::net::SocketAddr;
use std::path::Path;
use std::time::Duration;
use regex::Regex;
use serde::{Deserialize, Serialize};
use crate::{ProxyError, Result};
pub const DEFAULT_MAX_BODY_SIZE: u64 = 10 * 1024 * 1024;
pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
pub const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
pub const DEFAULT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(60);
pub const DEFAULT_POOL_MAX_IDLE_PER_HOST: usize = 32;
pub const DEFAULT_MAX_CONCURRENT_REQUESTS: usize = 1000;
pub const DEFAULT_UPSTREAM_WEIGHT: u32 = 1;
pub const DEFAULT_LISTEN_ADDR: &str = "127.0.0.1:8100";
pub const DEFAULT_FAILURE_THRESHOLD: u32 = 3;
pub const DEFAULT_HEALTHY_THRESHOLD: u32 = 1;
pub const DEFAULT_HEALTH_CHECK_COOLDOWN: Duration = Duration::from_secs(30);
pub const DEFAULT_HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(10);
pub const DEFAULT_HEALTH_CHECK_PATH: &str = "/health";
pub const DEFAULT_HEALTH_CHECK_TIMEOUT: Duration = Duration::from_secs(3);
pub const DEFAULT_RATE_LIMIT_RPS: u32 = 100;
pub const DEFAULT_RATE_LIMIT_BURST: u32 = 50;
pub const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Config {
#[serde(default)]
pub listen: Option<String>,
#[serde(default)]
pub upstreams: Vec<UpstreamConfig>,
#[serde(default)]
pub blocked_headers: Vec<String>,
#[serde(default)]
pub blocked_params: Vec<String>,
#[serde(default)]
pub masked_params: Vec<String>,
#[serde(default)]
pub max_body_size: Option<u64>,
#[serde(default)]
pub strip_response_headers: Vec<String>,
#[serde(default)]
pub max_concurrent_requests: Option<usize>,
#[serde(default)]
pub timeouts: TimeoutsConfig,
#[serde(default)]
pub pool: PoolConfig,
#[serde(default)]
pub tls: Option<TlsConfig>,
#[serde(default)]
pub health_check: Option<HealthCheckConfig>,
#[serde(default)]
pub rate_limit: Option<RateLimitConfig>,
#[serde(default)]
pub shutdown_timeout: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TimeoutsConfig {
#[serde(default = "default_connect_timeout_secs")]
pub connect: u64,
#[serde(default = "default_request_timeout_secs")]
pub request: u64,
#[serde(default = "default_idle_timeout_secs")]
pub idle: u64,
}
fn default_connect_timeout_secs() -> u64 {
DEFAULT_CONNECT_TIMEOUT.as_secs()
}
fn default_request_timeout_secs() -> u64 {
DEFAULT_REQUEST_TIMEOUT.as_secs()
}
fn default_idle_timeout_secs() -> u64 {
DEFAULT_POOL_IDLE_TIMEOUT.as_secs()
}
impl Default for TimeoutsConfig {
fn default() -> Self {
Self {
connect: default_connect_timeout_secs(),
request: default_request_timeout_secs(),
idle: default_idle_timeout_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct PoolConfig {
#[serde(default = "default_pool_idle_timeout_secs")]
pub idle_timeout: u64,
#[serde(default = "default_pool_max_idle_per_host")]
pub max_idle_per_host: usize,
}
fn default_pool_idle_timeout_secs() -> u64 {
DEFAULT_POOL_IDLE_TIMEOUT.as_secs()
}
fn default_pool_max_idle_per_host() -> usize {
DEFAULT_POOL_MAX_IDLE_PER_HOST
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
idle_timeout: default_pool_idle_timeout_secs(),
max_idle_per_host: default_pool_max_idle_per_host(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UpstreamConfig {
pub address: String,
#[serde(default = "default_weight")]
pub weight: u32,
}
fn default_weight() -> u32 {
DEFAULT_UPSTREAM_WEIGHT
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct HealthCheckConfig {
#[serde(default = "default_health_path")]
pub path: String,
#[serde(default = "default_health_interval_secs")]
pub interval: u64,
#[serde(default = "default_failure_threshold")]
pub unhealthy_threshold: u32,
#[serde(default = "default_healthy_threshold")]
pub healthy_threshold: u32,
#[serde(default = "default_cooldown_secs")]
pub cooldown: u64,
#[serde(default = "default_health_timeout_secs")]
pub timeout: u64,
}
fn default_health_path() -> String {
DEFAULT_HEALTH_CHECK_PATH.into()
}
fn default_health_interval_secs() -> u64 {
DEFAULT_HEALTH_CHECK_INTERVAL.as_secs()
}
fn default_failure_threshold() -> u32 {
DEFAULT_FAILURE_THRESHOLD
}
fn default_healthy_threshold() -> u32 {
DEFAULT_HEALTHY_THRESHOLD
}
fn default_cooldown_secs() -> u64 {
DEFAULT_HEALTH_CHECK_COOLDOWN.as_secs()
}
fn default_health_timeout_secs() -> u64 {
DEFAULT_HEALTH_CHECK_TIMEOUT.as_secs()
}
impl Default for HealthCheckConfig {
fn default() -> Self {
Self {
path: default_health_path(),
interval: default_health_interval_secs(),
unhealthy_threshold: default_failure_threshold(),
healthy_threshold: default_healthy_threshold(),
cooldown: default_cooldown_secs(),
timeout: default_health_timeout_secs(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RateLimitConfig {
#[serde(default = "default_rate_limit_rps")]
pub requests_per_second: u32,
#[serde(default = "default_rate_limit_burst")]
pub burst: u32,
}
fn default_rate_limit_rps() -> u32 {
DEFAULT_RATE_LIMIT_RPS
}
fn default_rate_limit_burst() -> u32 {
DEFAULT_RATE_LIMIT_BURST
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_second: default_rate_limit_rps(),
burst: default_rate_limit_burst(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TlsConfig {
pub cert_path: String,
pub key_path: String,
}
#[derive(Debug, Clone)]
pub struct ValidatedUpstream {
pub uri: hyper::Uri,
pub weight: u32,
}
#[derive(Debug)]
pub struct RuntimeConfig {
pub listen: SocketAddr,
pub upstreams: Vec<ValidatedUpstream>,
pub blocked_headers: Vec<String>,
pub blocked_params: Vec<String>,
pub mask_rules: Vec<MaskRule>,
pub max_body_size: u64,
pub strip_response_headers: Vec<String>,
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub pool_idle_timeout: Duration,
pub pool_max_idle_per_host: usize,
pub max_concurrent_requests: usize,
pub tls: Option<TlsConfig>,
pub health_check: Option<HealthCheckConfig>,
pub failure_threshold: u32,
pub healthy_threshold: u32,
pub health_check_cooldown: Duration,
pub rate_limit: Option<RateLimitConfig>,
pub shutdown_timeout: Duration,
}
#[derive(Debug, Clone)]
pub struct MaskRule {
pub param: String,
pub pattern: Regex,
}
fn validate_upstream(address: &str, weight: u32) -> Result<ValidatedUpstream> {
if address.is_empty() {
return Err(ProxyError::InvalidUpstream(
"upstream address must not be empty".into(),
));
}
let uri = address
.parse::<hyper::Uri>()
.map_err(|e| ProxyError::InvalidUpstream(format!("{e}")))?;
uri.authority().ok_or_else(|| {
ProxyError::InvalidUpstream(format!("upstream URI has no authority: {address}"))
})?;
if weight == 0 {
return Err(ProxyError::Config(format!(
"upstream weight must be positive: {address}"
)));
}
Ok(ValidatedUpstream { uri, weight })
}
impl Config {
pub fn load_from_file(file_path: &(impl AsRef<Path> + ?Sized)) -> Result<Self> {
let file = std::fs::File::open(file_path).map_err(|e| {
ProxyError::Config(format!(
"failed to open {}: {e}",
file_path.as_ref().display()
))
})?;
serde_yaml::from_reader(file)
.map_err(|e| ProxyError::Config(format!("failed to parse config: {e}")))
}
pub fn into_runtime(self) -> Result<RuntimeConfig> {
if self.upstreams.is_empty() {
return Err(ProxyError::Config(
"at least one upstream must be configured".into(),
));
}
let listen_str = self.listen.as_deref().unwrap_or(DEFAULT_LISTEN_ADDR);
let listen = listen_str.parse::<SocketAddr>().map_err(|e| {
ProxyError::Config(format!("invalid listen address \"{listen_str}\": {e}"))
})?;
let upstreams = self
.upstreams
.iter()
.map(|u| validate_upstream(&u.address, u.weight))
.collect::<Result<Vec<_>>>()?;
let blocked_headers = self
.blocked_headers
.into_iter()
.map(|h| h.to_ascii_lowercase())
.collect();
let mask_rules = self
.masked_params
.iter()
.map(|param| {
let escaped = regex::escape(param);
Regex::new(&format!("{escaped}=([^&]+)"))
.map(|pattern| MaskRule {
param: param.clone(),
pattern,
})
.map_err(|e| {
ProxyError::Config(format!("invalid mask pattern for {param}: {e}"))
})
})
.collect::<Result<Vec<_>>>()?;
let max_body_size = self.max_body_size.unwrap_or(DEFAULT_MAX_BODY_SIZE);
let strip_response_headers = self
.strip_response_headers
.into_iter()
.map(|h| h.to_ascii_lowercase())
.collect();
let connect_timeout = Duration::from_secs(self.timeouts.connect);
let request_timeout = Duration::from_secs(self.timeouts.request);
let pool_idle_timeout = Duration::from_secs(self.pool.idle_timeout);
let pool_max_idle_per_host = self.pool.max_idle_per_host;
let max_concurrent_requests = self
.max_concurrent_requests
.unwrap_or(DEFAULT_MAX_CONCURRENT_REQUESTS);
let failure_threshold = self
.health_check
.as_ref()
.map_or(DEFAULT_FAILURE_THRESHOLD, |hc| hc.unhealthy_threshold);
let healthy_threshold = self
.health_check
.as_ref()
.map_or(DEFAULT_HEALTHY_THRESHOLD, |hc| hc.healthy_threshold);
let health_check_cooldown = self
.health_check
.as_ref()
.map_or(DEFAULT_HEALTH_CHECK_COOLDOWN, |hc| {
Duration::from_secs(hc.cooldown)
});
let shutdown_timeout = self
.shutdown_timeout
.map_or(DEFAULT_SHUTDOWN_TIMEOUT, Duration::from_secs);
Ok(RuntimeConfig {
listen,
upstreams,
blocked_headers,
blocked_params: self.blocked_params,
mask_rules,
max_body_size,
strip_response_headers,
connect_timeout,
request_timeout,
pool_idle_timeout,
pool_max_idle_per_host,
max_concurrent_requests,
tls: self.tls,
health_check: self.health_check,
failure_threshold,
healthy_threshold,
health_check_cooldown,
rate_limit: self.rate_limit,
shutdown_timeout,
})
}
}
impl RuntimeConfig {
pub fn has_https_upstream(&self) -> bool {
self.upstreams.iter().any(|u| {
u.uri
.scheme_str()
.is_some_and(|s| s.eq_ignore_ascii_case("https"))
})
}
pub fn mask_sensitive_data(&self, data: &str) -> String {
self.mask_rules.iter().fold(data.to_owned(), |acc, rule| {
rule.pattern
.replace_all(&acc, format!("{}=****", rule.param))
.into_owned()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn single_upstream(addr: &str) -> Vec<UpstreamConfig> {
vec![UpstreamConfig {
address: addr.into(),
weight: 1,
}]
}
#[test]
fn loads_config_from_file() {
let config = Config::load_from_file("./Config.example.yml")
.expect("Config.example.yml should be loadable");
assert_eq!(config.listen, Some("127.0.0.1:8100".into()));
assert_eq!(config.upstreams.len(), 1);
assert_eq!(config.upstreams[0].address, "http://localhost:3000");
assert_eq!(
config.blocked_headers,
vec!["X-Debug-Token", "X-Internal-Auth"]
);
assert_eq!(config.blocked_params, vec!["access_token", "secret_key"]);
assert_eq!(config.masked_params, vec!["password", "ssn", "credit_card"]);
assert_eq!(config.timeouts.connect, 5);
assert_eq!(config.timeouts.request, 30);
assert_eq!(config.timeouts.idle, 60);
assert_eq!(config.pool.idle_timeout, 60);
assert_eq!(config.pool.max_idle_per_host, 32);
assert_eq!(config.max_concurrent_requests, Some(1000));
assert_eq!(
config.rate_limit,
Some(RateLimitConfig {
requests_per_second: 100,
burst: 50,
})
);
}
#[test]
fn into_runtime_rejects_empty_upstreams() {
let config = Config::default();
assert!(config.into_runtime().is_err());
}
#[test]
fn into_runtime_rejects_malformed_upstream() {
let config = Config {
upstreams: vec![UpstreamConfig {
address: "not a valid uri %%".into(),
weight: 1,
}],
..Default::default()
};
assert!(config.into_runtime().is_err());
}
#[test]
fn into_runtime_lowercases_blocked_headers() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
blocked_headers: vec!["X-Custom-Header".into()],
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
assert_eq!(rt.blocked_headers, vec!["x-custom-header"]);
}
#[test]
fn into_runtime_validates_upstreams() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
assert_eq!(rt.upstreams.len(), 1);
assert_eq!(
rt.upstreams[0].uri,
"http://localhost:3000".parse::<hyper::Uri>().unwrap()
);
assert_eq!(rt.upstreams[0].weight, 1);
}
#[test]
fn into_runtime_handles_multiple_upstreams() {
let config = Config {
upstreams: vec![
UpstreamConfig {
address: "http://backend1:3000".into(),
weight: 3,
},
UpstreamConfig {
address: "http://backend2:3000".into(),
weight: 1,
},
],
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
assert_eq!(rt.upstreams.len(), 2);
assert_eq!(
rt.upstreams[0].uri,
"http://backend1:3000".parse::<hyper::Uri>().unwrap()
);
assert_eq!(rt.upstreams[0].weight, 3);
assert_eq!(rt.upstreams[1].weight, 1);
}
#[test]
fn into_runtime_rejects_zero_weight() {
let config = Config {
upstreams: vec![UpstreamConfig {
address: "http://localhost:3000".into(),
weight: 0,
}],
..Default::default()
};
assert!(config.into_runtime().is_err());
}
#[test]
fn has_https_upstream_detects_scheme() {
let config = Config {
upstreams: vec![
UpstreamConfig {
address: "http://backend1:3000".into(),
weight: 1,
},
UpstreamConfig {
address: "https://backend2:3000".into(),
weight: 1,
},
],
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert!(rt.has_https_upstream());
}
#[test]
fn mask_sensitive_data_replaces_values() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
masked_params: vec!["password".into(), "token".into()],
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
let input = "username=john&password=secret&token=1234567890";
let masked = rt.mask_sensitive_data(input);
assert_eq!(masked, "username=john&password=****&token=****");
}
#[test]
fn mask_sensitive_data_leaves_unmatched_text_intact() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
masked_params: vec!["password".into()],
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
let input = "username=john&role=admin";
assert_eq!(rt.mask_sensitive_data(input), input);
}
#[test]
fn mask_handles_regex_special_characters_in_param_name() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
masked_params: vec!["user.password".into()],
..Default::default()
};
let rt = config.into_runtime().expect("valid config");
let input = "user.password=secret123&other=value";
assert_eq!(
rt.mask_sensitive_data(input),
"user.password=****&other=value"
);
}
#[test]
fn into_runtime_defaults_listen_address() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(
rt.listen,
DEFAULT_LISTEN_ADDR.parse::<SocketAddr>().unwrap()
);
}
#[test]
fn into_runtime_parses_custom_listen_address() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
listen: Some("0.0.0.0:9090".into()),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.listen, "0.0.0.0:9090".parse::<SocketAddr>().unwrap());
}
#[test]
fn into_runtime_rejects_invalid_listen_address() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
listen: Some("not-an-address".into()),
..Default::default()
};
assert!(config.into_runtime().is_err());
}
#[test]
fn health_check_config_uses_defaults() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
health_check: Some(HealthCheckConfig::default()),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.failure_threshold, DEFAULT_FAILURE_THRESHOLD);
assert_eq!(rt.healthy_threshold, DEFAULT_HEALTHY_THRESHOLD);
assert_eq!(rt.health_check_cooldown, DEFAULT_HEALTH_CHECK_COOLDOWN);
assert!(rt.health_check.is_some());
let hc = rt.health_check.as_ref().unwrap();
assert_eq!(hc.timeout, DEFAULT_HEALTH_CHECK_TIMEOUT.as_secs());
}
#[test]
fn timeouts_config_uses_defaults() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.connect_timeout, DEFAULT_CONNECT_TIMEOUT);
assert_eq!(rt.request_timeout, DEFAULT_REQUEST_TIMEOUT);
assert_eq!(rt.pool_idle_timeout, DEFAULT_POOL_IDLE_TIMEOUT);
assert_eq!(rt.pool_max_idle_per_host, DEFAULT_POOL_MAX_IDLE_PER_HOST);
}
#[test]
fn custom_timeouts_propagate() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
timeouts: TimeoutsConfig {
connect: 2,
request: 10,
idle: 120,
},
pool: PoolConfig {
idle_timeout: 90,
max_idle_per_host: 16,
},
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.connect_timeout, Duration::from_secs(2));
assert_eq!(rt.request_timeout, Duration::from_secs(10));
assert_eq!(rt.pool_idle_timeout, Duration::from_secs(90));
assert_eq!(rt.pool_max_idle_per_host, 16);
}
#[test]
fn shutdown_timeout_defaults() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.shutdown_timeout, DEFAULT_SHUTDOWN_TIMEOUT);
}
#[test]
fn custom_shutdown_timeout() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
shutdown_timeout: Some(10),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.shutdown_timeout, Duration::from_secs(10));
}
#[test]
fn healthy_threshold_propagates() {
let config = Config {
upstreams: single_upstream("http://localhost:3000"),
health_check: Some(HealthCheckConfig {
healthy_threshold: 5,
..Default::default()
}),
..Default::default()
};
let rt = config.into_runtime().unwrap();
assert_eq!(rt.healthy_threshold, 5);
}
}