use std::net::IpAddr;
use serde::{Deserialize, Serialize};
use crate::{WafDecision, WafRequest};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum GeoMode {
Allow,
#[default]
Deny,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeoConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub mode: GeoMode,
#[serde(default)]
pub countries: Vec<String>,
#[serde(default)]
pub bypass_paths: Vec<String>,
#[serde(default)]
pub real_ip_header: Option<String>,
}
impl Default for GeoConfig {
fn default() -> Self {
Self {
enabled: false,
mode: GeoMode::Deny,
countries: Vec::new(),
bypass_paths: Vec::new(),
real_ip_header: None,
}
}
}
pub struct GeoBlocker {
config: GeoConfig,
#[cfg(feature = "geoip")]
reader: Option<maxminddb::Reader<Vec<u8>>>,
}
#[cfg(feature = "geoip")]
#[derive(serde::Deserialize)]
struct GeoLookup {
country: Option<CountryRecord>,
}
#[cfg(feature = "geoip")]
#[derive(serde::Deserialize)]
struct CountryRecord {
iso_code: Option<String>,
}
impl GeoBlocker {
pub fn new(config: GeoConfig, _db_path: &str) -> anyhow::Result<Self> {
#[cfg(feature = "geoip")]
{
let reader = if config.enabled {
Some(maxminddb::Reader::open_readfile(_db_path)?)
} else {
None
};
Ok(Self { config, reader })
}
#[cfg(not(feature = "geoip"))]
{
if config.enabled {
tracing::warn!(
"geo-blocking enabled but 'geoip' feature not compiled in — skipping"
);
}
Ok(Self { config })
}
}
pub fn disabled() -> Self {
Self {
config: GeoConfig::default(),
#[cfg(feature = "geoip")]
reader: None,
}
}
fn is_bypass_path(&self, path: &str) -> bool {
self.config
.bypass_paths
.iter()
.any(|bp| path.starts_with(bp.as_str()))
}
#[allow(unused_variables)]
fn lookup_country(&self, ip: IpAddr) -> Option<String> {
#[cfg(feature = "geoip")]
{
let reader = self.reader.as_ref()?;
let result: Result<GeoLookup, _> = reader.lookup(ip);
match result {
Ok(geo) => geo.country.and_then(|c| c.iso_code),
Err(e) => {
tracing::debug!(ip = %ip, error = %e, "geoip lookup failed");
None
}
}
}
#[cfg(not(feature = "geoip"))]
{
None
}
}
pub fn check(&self, ip: IpAddr, path: &str) -> Option<WafDecision> {
if !self.config.enabled {
return None;
}
if self.is_bypass_path(path) {
return None;
}
let country = match self.lookup_country(ip) {
Some(c) => c.to_uppercase(),
None => {
return match self.config.mode {
GeoMode::Allow => Some(WafDecision::Block {
status: 403,
reason: "GeoIP lookup failed — blocked by allow-list policy".into(),
rule: "geo".into(),
}),
GeoMode::Deny => None, };
}
};
let in_list = self
.config
.countries
.iter()
.any(|c| c.eq_ignore_ascii_case(&country));
match self.config.mode {
GeoMode::Deny => {
if in_list {
Some(WafDecision::Block {
status: 403,
reason: format!("country {country} is geo-blocked"),
rule: "geo_block_deny".into(),
})
} else {
None
}
}
GeoMode::Allow => {
if in_list {
None
} else {
Some(WafDecision::Block {
status: 403,
reason: format!("country {country} is not in allowed list"),
rule: "geo_block_allow".into(),
})
}
}
}
}
pub fn check_request(&self, req: &WafRequest) -> Option<WafDecision> {
let ip = self.resolve_ip(req);
self.check(ip, &req.path)
}
fn resolve_ip(&self, req: &WafRequest) -> IpAddr {
if let Some(ref header_name) = self.config.real_ip_header {
let lower = header_name.to_lowercase();
if let Some(header_val) = req
.headers
.iter()
.find(|(k, _)| k.to_lowercase() == lower)
.map(|(_, v)| v)
{
let ip_str = header_val.split(',').next().unwrap_or(header_val).trim();
if let Ok(ip) = ip_str.parse::<IpAddr>() {
return ip;
}
tracing::debug!(
header = header_name.as_str(),
value = header_val.as_str(),
"failed to parse IP from real_ip_header, falling back to client_ip"
);
}
}
req.client_ip
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn disabled_allows_all() {
let blocker = GeoBlocker::disabled();
assert!(blocker
.check("1.2.3.4".parse().unwrap(), "/api/data")
.is_none());
}
#[test]
fn bypass_path_skips_check() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec!["/health".into(), "/api/status".into()],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
assert!(blocker
.check("1.2.3.4".parse().unwrap(), "/health")
.is_none());
assert!(blocker
.check("1.2.3.4".parse().unwrap(), "/api/status/deep")
.is_none());
}
#[test]
fn default_config_is_disabled() {
let config = GeoConfig::default();
assert!(!config.enabled);
}
#[test]
fn country_list_case_insensitive() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["cn".into(), "RU".into()],
bypass_paths: vec![],
real_ip_header: None,
};
assert!(config
.countries
.iter()
.any(|c| c.eq_ignore_ascii_case("CN")));
assert!(config
.countries
.iter()
.any(|c| c.eq_ignore_ascii_case("ru")));
}
#[test]
fn unknown_country_is_allowed_in_deny_mode() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec![],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
assert!(blocker.check("8.8.8.8".parse().unwrap(), "/page").is_none());
}
#[test]
fn unknown_country_is_blocked_in_allow_mode() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Allow,
countries: vec!["US".into()],
bypass_paths: vec![],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let decision = blocker.check("8.8.8.8".parse().unwrap(), "/page");
assert!(matches!(
decision,
Some(WafDecision::Block { status: 403, .. })
));
}
#[test]
fn unknown_country_blocked_in_allow_mode_has_correct_rule() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Allow,
countries: vec!["US".into()],
bypass_paths: vec![],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
match blocker.check("8.8.8.8".parse().unwrap(), "/page") {
Some(WafDecision::Block { rule, reason, .. }) => {
assert_eq!(rule, "geo");
assert!(reason.contains("allow-list policy"));
}
other => panic!("expected Block, got {other:?}"),
}
}
#[test]
fn unknown_country_bypass_path_still_allowed_in_allow_mode() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Allow,
countries: vec!["US".into()],
bypass_paths: vec!["/health".into()],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
assert!(blocker
.check("8.8.8.8".parse().unwrap(), "/health")
.is_none());
}
#[test]
fn check_request_uses_client_ip_when_no_real_ip_header() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec![],
real_ip_header: None,
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let req = WafRequest {
client_ip: "8.8.8.8".parse().unwrap(),
method: "GET".into(),
path: "/page".into(),
query: None,
headers: std::collections::HashMap::new(),
body: None,
user_agent: None,
};
assert!(blocker.check_request(&req).is_none());
}
#[test]
fn check_request_uses_real_ip_header_when_configured() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec![],
real_ip_header: Some("X-Real-IP".into()),
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let mut headers = std::collections::HashMap::new();
headers.insert("X-Real-IP".into(), "1.2.3.4".into());
let req = WafRequest {
client_ip: "8.8.8.8".parse().unwrap(),
method: "GET".into(),
path: "/page".into(),
query: None,
headers,
body: None,
user_agent: None,
};
assert_eq!(
blocker.resolve_ip(&req),
"1.2.3.4".parse::<IpAddr>().unwrap()
);
}
#[test]
fn check_request_falls_back_to_client_ip_on_invalid_header() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec![],
real_ip_header: Some("X-Real-IP".into()),
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let mut headers = std::collections::HashMap::new();
headers.insert("X-Real-IP".into(), "not-an-ip".into());
let req = WafRequest {
client_ip: "8.8.8.8".parse().unwrap(),
method: "GET".into(),
path: "/page".into(),
query: None,
headers,
body: None,
user_agent: None,
};
assert_eq!(
blocker.resolve_ip(&req),
"8.8.8.8".parse::<IpAddr>().unwrap()
);
}
#[test]
fn check_request_falls_back_when_header_missing() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec!["CN".into()],
bypass_paths: vec![],
real_ip_header: Some("X-Real-IP".into()),
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let req = WafRequest {
client_ip: "8.8.8.8".parse().unwrap(),
method: "GET".into(),
path: "/page".into(),
query: None,
headers: std::collections::HashMap::new(),
body: None,
user_agent: None,
};
assert_eq!(
blocker.resolve_ip(&req),
"8.8.8.8".parse::<IpAddr>().unwrap()
);
}
#[test]
fn check_request_parses_x_forwarded_for_first_ip() {
let config = GeoConfig {
enabled: true,
mode: GeoMode::Deny,
countries: vec![],
bypass_paths: vec![],
real_ip_header: Some("X-Forwarded-For".into()),
};
let blocker =
GeoBlocker::new(config, "/nonexistent.mmdb").unwrap_or_else(|_| GeoBlocker::disabled());
let mut headers = std::collections::HashMap::new();
headers.insert(
"X-Forwarded-For".into(),
"1.2.3.4, 5.6.7.8, 9.10.11.12".into(),
);
let req = WafRequest {
client_ip: "127.0.0.1".parse().unwrap(),
method: "GET".into(),
path: "/page".into(),
query: None,
headers,
body: None,
user_agent: None,
};
assert_eq!(
blocker.resolve_ip(&req),
"1.2.3.4".parse::<IpAddr>().unwrap()
);
}
}