use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::FailureMode;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FilterConfig {
pub id: String,
#[serde(flatten)]
pub filter: Filter,
}
impl FilterConfig {
pub fn new(id: impl Into<String>, filter: Filter) -> Self {
Self {
id: id.into(),
filter,
}
}
pub fn phase(&self) -> FilterPhase {
self.filter.phase()
}
pub fn filter_type(&self) -> &'static str {
self.filter.type_name()
}
pub fn validate(&self, available_agents: &[String]) -> Result<(), String> {
self.filter.validate(available_agents)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum FilterPhase {
#[default]
Request,
Response,
Both,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum Filter {
RateLimit(RateLimitFilter),
Headers(HeadersFilter),
Compress(CompressFilter),
Cors(CorsFilter),
Timeout(TimeoutFilter),
Log(LogFilter),
Geo(GeoFilter),
Agent(AgentFilter),
Redirect(RedirectFilter),
UrlRewrite(UrlRewriteFilter),
}
impl Filter {
pub fn phase(&self) -> FilterPhase {
match self {
Filter::RateLimit(_) => FilterPhase::Request,
Filter::Headers(h) => h.phase,
Filter::Compress(_) => FilterPhase::Response,
Filter::Cors(_) => FilterPhase::Both,
Filter::Timeout(_) => FilterPhase::Request,
Filter::Log(l) => {
match (l.log_request, l.log_response) {
(true, true) => FilterPhase::Both,
(true, false) => FilterPhase::Request,
(false, true) => FilterPhase::Response,
(false, false) => FilterPhase::Request, }
}
Filter::Geo(_) => FilterPhase::Request,
Filter::Agent(a) => a.phase.unwrap_or(FilterPhase::Request),
Filter::Redirect(_) => FilterPhase::Request,
Filter::UrlRewrite(_) => FilterPhase::Request,
}
}
pub fn type_name(&self) -> &'static str {
match self {
Filter::RateLimit(_) => "rate-limit",
Filter::Headers(_) => "headers",
Filter::Compress(_) => "compress",
Filter::Cors(_) => "cors",
Filter::Timeout(_) => "timeout",
Filter::Log(_) => "log",
Filter::Geo(_) => "geo",
Filter::Agent(_) => "agent",
Filter::Redirect(_) => "redirect",
Filter::UrlRewrite(_) => "url-rewrite",
}
}
pub fn runs_on_request(&self) -> bool {
matches!(self.phase(), FilterPhase::Request | FilterPhase::Both)
}
pub fn runs_on_response(&self) -> bool {
matches!(self.phase(), FilterPhase::Response | FilterPhase::Both)
}
pub fn validate(&self, available_agents: &[String]) -> Result<(), String> {
match self {
Filter::RateLimit(r) => {
if r.max_rps == 0 {
return Err("rate-limit max-rps must be > 0".into());
}
}
Filter::Compress(c) => {
if c.algorithms.is_empty() {
return Err("compress filter requires at least one algorithm".into());
}
}
Filter::Geo(g) => {
if g.database_path.is_empty() {
return Err("geo filter requires 'database-path'".into());
}
for code in &g.countries {
if code.len() != 2 || !code.chars().all(|c| c.is_ascii_uppercase()) {
return Err(format!(
"geo filter: invalid country code '{}' (expected ISO 3166-1 alpha-2 like 'US', 'CN')",
code
));
}
}
}
Filter::Agent(a) => {
if !available_agents.contains(&a.agent) {
return Err(format!(
"agent filter references unknown agent '{}'. Available: {:?}",
a.agent, available_agents
));
}
}
_ => {}
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitFilter {
#[serde(rename = "max-rps")]
pub max_rps: u32,
#[serde(default = "default_burst")]
pub burst: u32,
#[serde(default)]
pub key: RateLimitKey,
#[serde(default, rename = "on-limit")]
pub on_limit: RateLimitAction,
#[serde(default = "default_limit_status", rename = "status-code")]
pub status_code: u16,
#[serde(rename = "limit-message")]
pub limit_message: Option<String>,
#[serde(default)]
pub backend: RateLimitBackend,
#[serde(default = "default_max_delay_ms", rename = "max-delay-ms")]
pub max_delay_ms: u64,
}
fn default_max_delay_ms() -> u64 {
5000 }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct GlobalRateLimitConfig {
#[serde(default, rename = "default-rps")]
pub default_rps: Option<u32>,
#[serde(default, rename = "default-burst")]
pub default_burst: Option<u32>,
#[serde(default)]
pub key: RateLimitKey,
#[serde(default)]
pub global: Option<GlobalLimitConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalLimitConfig {
#[serde(rename = "max-rps")]
pub max_rps: u32,
#[serde(default = "default_burst")]
pub burst: u32,
#[serde(default)]
pub key: RateLimitKey,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum RateLimitBackend {
#[default]
Local,
Redis(RedisBackendConfig),
Memcached(MemcachedBackendConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RedisBackendConfig {
pub url: String,
#[serde(default = "default_redis_prefix", rename = "key-prefix")]
pub key_prefix: String,
#[serde(default = "default_redis_pool_size", rename = "pool-size")]
pub pool_size: u32,
#[serde(default = "default_redis_timeout_ms", rename = "timeout-ms")]
pub timeout_ms: u64,
#[serde(default = "default_true", rename = "fallback-local")]
pub fallback_local: bool,
}
impl Default for RedisBackendConfig {
fn default() -> Self {
Self {
url: "redis://127.0.0.1:6379".to_string(),
key_prefix: default_redis_prefix(),
pool_size: default_redis_pool_size(),
timeout_ms: default_redis_timeout_ms(),
fallback_local: true,
}
}
}
fn default_redis_prefix() -> String {
"zentinel:ratelimit:".to_string()
}
fn default_redis_pool_size() -> u32 {
10
}
fn default_redis_timeout_ms() -> u64 {
50
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct MemcachedBackendConfig {
pub url: String,
#[serde(default = "default_memcached_prefix", rename = "key-prefix")]
pub key_prefix: String,
#[serde(default = "default_memcached_pool_size", rename = "pool-size")]
pub pool_size: u32,
#[serde(default = "default_memcached_timeout_ms", rename = "timeout-ms")]
pub timeout_ms: u64,
#[serde(default = "default_true", rename = "fallback-local")]
pub fallback_local: bool,
#[serde(default = "default_memcached_ttl", rename = "ttl-secs")]
pub ttl_secs: u32,
}
impl Default for MemcachedBackendConfig {
fn default() -> Self {
Self {
url: "memcache://127.0.0.1:11211".to_string(),
key_prefix: default_memcached_prefix(),
pool_size: default_memcached_pool_size(),
timeout_ms: default_memcached_timeout_ms(),
fallback_local: true,
ttl_secs: default_memcached_ttl(),
}
}
}
fn default_memcached_prefix() -> String {
"zentinel:ratelimit:".to_string()
}
fn default_memcached_pool_size() -> u32 {
10
}
fn default_memcached_timeout_ms() -> u64 {
50
}
fn default_memcached_ttl() -> u32 {
2
}
fn default_burst() -> u32 {
10
}
fn default_limit_status() -> u16 {
429
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum RateLimitKey {
#[default]
ClientIp,
Header(String),
Path,
Route,
ClientIpAndPath,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum RateLimitAction {
#[default]
Reject,
Delay,
LogOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HeadersFilter {
#[serde(default)]
pub phase: FilterPhase,
#[serde(default)]
pub rename: HashMap<String, String>,
#[serde(default)]
pub set: HashMap<String, String>,
#[serde(default)]
pub add: HashMap<String, String>,
#[serde(default)]
pub remove: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompressFilter {
#[serde(default = "default_algorithms")]
pub algorithms: Vec<CompressionAlgorithm>,
#[serde(default = "default_min_size", rename = "min-size")]
pub min_size: usize,
#[serde(default = "default_content_types", rename = "content-types")]
pub content_types: Vec<String>,
#[serde(default = "default_compression_level")]
pub level: u8,
}
impl Default for CompressFilter {
fn default() -> Self {
Self {
algorithms: default_algorithms(),
min_size: default_min_size(),
content_types: default_content_types(),
level: default_compression_level(),
}
}
}
fn default_algorithms() -> Vec<CompressionAlgorithm> {
vec![CompressionAlgorithm::Gzip, CompressionAlgorithm::Brotli]
}
fn default_min_size() -> usize {
1024 }
fn default_content_types() -> Vec<String> {
vec![
"text/html".into(),
"text/css".into(),
"text/plain".into(),
"text/xml".into(),
"application/json".into(),
"application/javascript".into(),
"application/xml".into(),
"image/svg+xml".into(),
]
}
fn default_compression_level() -> u8 {
6
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CompressionAlgorithm {
Gzip,
Brotli,
Deflate,
Zstd,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CorsFilter {
#[serde(default, rename = "allowed-origins")]
pub allowed_origins: Vec<String>,
#[serde(default = "default_cors_methods", rename = "allowed-methods")]
pub allowed_methods: Vec<String>,
#[serde(default, rename = "allowed-headers")]
pub allowed_headers: Vec<String>,
#[serde(default, rename = "exposed-headers")]
pub exposed_headers: Vec<String>,
#[serde(default, rename = "allow-credentials")]
pub allow_credentials: bool,
#[serde(default = "default_cors_max_age", rename = "max-age-secs")]
pub max_age_secs: u64,
}
impl Default for CorsFilter {
fn default() -> Self {
Self {
allowed_origins: vec!["*".into()],
allowed_methods: default_cors_methods(),
allowed_headers: vec![],
exposed_headers: vec![],
allow_credentials: false,
max_age_secs: default_cors_max_age(),
}
}
}
fn default_cors_methods() -> Vec<String> {
vec![
"GET".into(),
"POST".into(),
"PUT".into(),
"DELETE".into(),
"OPTIONS".into(),
"HEAD".into(),
"PATCH".into(),
]
}
fn default_cors_max_age() -> u64 {
86400 }
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct TimeoutFilter {
#[serde(rename = "request-timeout-secs")]
pub request_timeout_secs: Option<u64>,
#[serde(rename = "upstream-timeout-secs")]
pub upstream_timeout_secs: Option<u64>,
#[serde(rename = "connect-timeout-secs")]
pub connect_timeout_secs: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LogFilter {
#[serde(default = "default_true", rename = "log-request")]
pub log_request: bool,
#[serde(default = "default_true", rename = "log-response")]
pub log_response: bool,
#[serde(default, rename = "log-body")]
pub log_body: bool,
#[serde(default = "default_max_body_log", rename = "max-body-log-size")]
pub max_body_log_size: usize,
#[serde(default)]
pub fields: Vec<String>,
#[serde(default = "default_log_level")]
pub level: String,
}
impl Default for LogFilter {
fn default() -> Self {
Self {
log_request: true,
log_response: true,
log_body: false,
max_body_log_size: default_max_body_log(),
fields: vec![],
level: default_log_level(),
}
}
}
fn default_true() -> bool {
true
}
fn default_max_body_log() -> usize {
4096 }
fn default_log_level() -> String {
"info".into()
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum GeoDatabaseType {
MaxMind,
Ip2Location,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum GeoFilterAction {
#[default]
Block,
Allow,
LogOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "kebab-case")]
pub enum GeoFailureMode {
#[default]
Open,
Closed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoFilter {
#[serde(rename = "database-path")]
pub database_path: String,
#[serde(default, rename = "database-type")]
pub database_type: Option<GeoDatabaseType>,
#[serde(default)]
pub action: GeoFilterAction,
#[serde(default)]
pub countries: Vec<String>,
#[serde(default, rename = "on-failure")]
pub on_failure: GeoFailureMode,
#[serde(default = "default_geo_status", rename = "status-code")]
pub status_code: u16,
#[serde(rename = "block-message")]
pub block_message: Option<String>,
#[serde(default = "default_geo_cache_ttl", rename = "cache-ttl-secs")]
pub cache_ttl_secs: u64,
#[serde(default = "default_true", rename = "add-country-header")]
pub add_country_header: bool,
}
impl Default for GeoFilter {
fn default() -> Self {
Self {
database_path: String::new(),
database_type: None,
action: GeoFilterAction::Block,
countries: Vec::new(),
on_failure: GeoFailureMode::Open,
status_code: default_geo_status(),
block_message: None,
cache_ttl_secs: default_geo_cache_ttl(),
add_country_header: true,
}
}
}
fn default_geo_status() -> u16 {
403
}
fn default_geo_cache_ttl() -> u64 {
3600 }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFilter {
pub agent: String,
#[serde(default)]
pub phase: Option<FilterPhase>,
#[serde(rename = "timeout-ms")]
pub timeout_ms: Option<u64>,
#[serde(rename = "failure-mode")]
pub failure_mode: Option<FailureMode>,
#[serde(default, rename = "inspect-body")]
pub inspect_body: bool,
#[serde(rename = "max-body-bytes")]
pub max_body_bytes: Option<usize>,
}
impl AgentFilter {
pub fn new(agent: impl Into<String>) -> Self {
Self {
agent: agent.into(),
phase: None,
timeout_ms: None,
failure_mode: None,
inspect_body: false,
max_body_bytes: None,
}
}
pub fn agent_id(&self) -> &str {
&self.agent
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_phases() {
assert_eq!(
Filter::RateLimit(RateLimitFilter {
max_rps: 100,
burst: 10,
key: RateLimitKey::ClientIp,
on_limit: RateLimitAction::Reject,
status_code: 429,
limit_message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 5000,
})
.phase(),
FilterPhase::Request
);
assert_eq!(
Filter::Compress(CompressFilter::default()).phase(),
FilterPhase::Response
);
assert_eq!(
Filter::Cors(CorsFilter::default()).phase(),
FilterPhase::Both
);
}
#[test]
fn test_agent_filter_validation() {
let filter = Filter::Agent(AgentFilter::new("auth-agent"));
assert!(filter.validate(&["auth-agent".into()]).is_ok());
assert!(filter.validate(&["other-agent".into()]).is_err());
}
#[test]
fn test_filter_config() {
let config = FilterConfig::new(
"my-rate-limit",
Filter::RateLimit(RateLimitFilter {
max_rps: 100,
burst: 10,
key: RateLimitKey::ClientIp,
on_limit: RateLimitAction::Reject,
status_code: 429,
limit_message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 5000,
}),
);
assert_eq!(config.id, "my-rate-limit");
assert_eq!(config.filter_type(), "rate-limit");
assert_eq!(config.phase(), FilterPhase::Request);
}
#[test]
fn test_redis_backend_config() {
let config = RedisBackendConfig::default();
assert_eq!(config.url, "redis://127.0.0.1:6379");
assert_eq!(config.key_prefix, "zentinel:ratelimit:");
assert_eq!(config.pool_size, 10);
assert_eq!(config.timeout_ms, 50);
assert!(config.fallback_local);
}
#[test]
fn test_global_rate_limit_config_default() {
let config = GlobalRateLimitConfig::default();
assert!(config.default_rps.is_none());
assert!(config.default_burst.is_none());
assert_eq!(config.key, RateLimitKey::ClientIp);
assert!(config.global.is_none());
}
#[test]
fn test_global_rate_limit_config_with_values() {
let config = GlobalRateLimitConfig {
default_rps: Some(100),
default_burst: Some(20),
key: RateLimitKey::Path,
global: Some(GlobalLimitConfig {
max_rps: 10000,
burst: 1000,
key: RateLimitKey::ClientIp,
}),
};
assert_eq!(config.default_rps, Some(100));
assert_eq!(config.default_burst, Some(20));
assert_eq!(config.key, RateLimitKey::Path);
assert!(config.global.is_some());
let global = config.global.unwrap();
assert_eq!(global.max_rps, 10000);
assert_eq!(global.burst, 1000);
assert_eq!(global.key, RateLimitKey::ClientIp);
}
#[test]
fn test_rate_limit_filter_max_delay_ms() {
let filter = RateLimitFilter {
max_rps: 100,
burst: 10,
key: RateLimitKey::ClientIp,
on_limit: RateLimitAction::Delay,
status_code: 429,
limit_message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 3000,
};
assert_eq!(filter.max_delay_ms, 3000);
assert_eq!(filter.on_limit, RateLimitAction::Delay);
}
#[test]
fn test_rate_limit_filter_default_max_delay() {
let filter = RateLimitFilter {
max_rps: 100,
burst: 10,
key: RateLimitKey::ClientIp,
on_limit: RateLimitAction::Reject,
status_code: 429,
limit_message: None,
backend: RateLimitBackend::Local,
max_delay_ms: 5000, };
assert_eq!(filter.max_delay_ms, 5000);
}
#[test]
fn test_geo_filter_default() {
let filter = GeoFilter::default();
assert!(filter.database_path.is_empty());
assert!(filter.database_type.is_none());
assert_eq!(filter.action, GeoFilterAction::Block);
assert!(filter.countries.is_empty());
assert_eq!(filter.on_failure, GeoFailureMode::Open);
assert_eq!(filter.status_code, 403);
assert!(filter.block_message.is_none());
assert_eq!(filter.cache_ttl_secs, 3600);
assert!(filter.add_country_header);
}
#[test]
fn test_geo_filter_action_enum() {
assert_eq!(GeoFilterAction::default(), GeoFilterAction::Block);
assert_ne!(GeoFilterAction::Allow, GeoFilterAction::Block);
assert_ne!(GeoFilterAction::LogOnly, GeoFilterAction::Block);
}
#[test]
fn test_geo_failure_mode_enum() {
assert_eq!(GeoFailureMode::default(), GeoFailureMode::Open);
assert_ne!(GeoFailureMode::Closed, GeoFailureMode::Open);
}
#[test]
fn test_geo_database_type_enum() {
let maxmind = GeoDatabaseType::MaxMind;
let ip2loc = GeoDatabaseType::Ip2Location;
assert_ne!(maxmind, ip2loc);
}
#[test]
fn test_geo_filter_validation_missing_path() {
let filter = Filter::Geo(GeoFilter::default());
let result = filter.validate(&[]);
assert!(result.is_err());
assert!(result.unwrap_err().contains("database-path"));
}
#[test]
fn test_geo_filter_validation_invalid_country_code() {
let filter = Filter::Geo(GeoFilter {
database_path: "/path/to/db.mmdb".to_string(),
countries: vec!["invalid".to_string()],
..Default::default()
});
let result = filter.validate(&[]);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid country code"));
}
#[test]
fn test_geo_filter_validation_valid() {
let filter = Filter::Geo(GeoFilter {
database_path: "/path/to/db.mmdb".to_string(),
countries: vec!["US".to_string(), "CA".to_string()],
..Default::default()
});
assert!(filter.validate(&[]).is_ok());
}
#[test]
fn test_geo_filter_phase() {
let filter = Filter::Geo(GeoFilter::default());
assert_eq!(filter.phase(), FilterPhase::Request);
}
#[test]
fn test_geo_filter_type_name() {
let filter = Filter::Geo(GeoFilter::default());
assert_eq!(filter.type_name(), "geo");
}
#[test]
fn test_geo_filter_config() {
let config = FilterConfig::new(
"block-countries",
Filter::Geo(GeoFilter {
database_path: "/etc/zentinel/GeoLite2-Country.mmdb".to_string(),
database_type: Some(GeoDatabaseType::MaxMind),
action: GeoFilterAction::Block,
countries: vec!["RU".to_string(), "CN".to_string()],
on_failure: GeoFailureMode::Closed,
status_code: 403,
block_message: Some("Access denied from your region".to_string()),
cache_ttl_secs: 7200,
add_country_header: true,
}),
);
assert_eq!(config.id, "block-countries");
assert_eq!(config.filter_type(), "geo");
assert_eq!(config.phase(), FilterPhase::Request);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RedirectFilter {
#[serde(default)]
pub hostname: Option<String>,
#[serde(default = "default_redirect_status", rename = "status-code")]
pub status_code: u16,
#[serde(default)]
pub scheme: Option<String>,
#[serde(default)]
pub port: Option<u16>,
#[serde(default)]
pub path: Option<PathModifier>,
}
fn default_redirect_status() -> u16 {
302
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UrlRewriteFilter {
#[serde(default)]
pub hostname: Option<String>,
#[serde(default)]
pub path: Option<PathModifier>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "kebab-case")]
pub enum PathModifier {
ReplaceFullPath {
value: String,
},
ReplacePrefixMatch {
value: String,
},
}