use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use notify::{Event, EventKind, RecursiveMode, Watcher};
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tracing::{debug, error, info, trace, warn};
use grapsus_config::{GeoDatabaseType, GeoFailureMode, GeoFilter, GeoFilterAction};
#[derive(Debug, Clone)]
pub enum GeoLookupError {
InvalidIp(String),
DatabaseError(String),
LoadError(String),
}
impl std::fmt::Display for GeoLookupError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
GeoLookupError::InvalidIp(ip) => write!(f, "invalid IP address: {}", ip),
GeoLookupError::DatabaseError(msg) => write!(f, "database error: {}", msg),
GeoLookupError::LoadError(msg) => write!(f, "failed to load database: {}", msg),
}
}
}
impl std::error::Error for GeoLookupError {}
pub trait GeoDatabase: Send + Sync {
fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError>;
fn database_type(&self) -> GeoDatabaseType;
}
pub struct MaxMindDatabase {
reader: maxminddb::Reader<Vec<u8>>,
}
impl MaxMindDatabase {
pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
let path = path.as_ref();
let reader = maxminddb::Reader::open_readfile(path).map_err(|e| {
GeoLookupError::LoadError(format!("failed to open MaxMind database {:?}: {}", path, e))
})?;
debug!(path = ?path, "Opened MaxMind GeoIP database");
Ok(Self { reader })
}
}
impl GeoDatabase for MaxMindDatabase {
fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
match self.reader.lookup(ip) {
Ok(result) => {
if !result.has_data() {
trace!(ip = %ip, "IP not found in MaxMind database");
return Ok(None);
}
match result.decode::<maxminddb::geoip2::Country>() {
Ok(Some(record)) => {
let country_code = record.country.iso_code.map(|s| s.to_string());
trace!(ip = %ip, country = ?country_code, "MaxMind lookup");
Ok(country_code)
}
Ok(None) => {
trace!(ip = %ip, "No country data for IP in MaxMind database");
Ok(None)
}
Err(e) => {
warn!(ip = %ip, error = %e, "MaxMind decode error");
Err(GeoLookupError::DatabaseError(e.to_string()))
}
}
}
Err(e) => {
warn!(ip = %ip, error = %e, "MaxMind lookup error");
Err(GeoLookupError::DatabaseError(e.to_string()))
}
}
}
fn database_type(&self) -> GeoDatabaseType {
GeoDatabaseType::MaxMind
}
}
pub struct Ip2LocationDatabase {
db: ip2location::DB,
}
impl Ip2LocationDatabase {
pub fn open(path: impl AsRef<Path>) -> Result<Self, GeoLookupError> {
let path = path.as_ref();
let db = ip2location::DB::from_file(path).map_err(|e| {
GeoLookupError::LoadError(format!(
"failed to open IP2Location database {:?}: {}",
path, e
))
})?;
debug!(path = ?path, "Opened IP2Location GeoIP database");
Ok(Self { db })
}
}
impl GeoDatabase for Ip2LocationDatabase {
fn lookup(&self, ip: IpAddr) -> Result<Option<String>, GeoLookupError> {
match self.db.ip_lookup(ip) {
Ok(record) => {
let country_code = match record {
ip2location::Record::LocationDb(loc) => {
loc.country.map(|c| c.short_name.to_string())
}
ip2location::Record::ProxyDb(proxy) => {
proxy.country.map(|c| c.short_name.to_string())
}
};
trace!(ip = %ip, country = ?country_code, "IP2Location lookup");
Ok(country_code)
}
Err(ip2location::error::Error::RecordNotFound) => {
trace!(ip = %ip, "IP not found in IP2Location database");
Ok(None)
}
Err(e) => {
warn!(ip = %ip, error = %e, "IP2Location lookup error");
Err(GeoLookupError::DatabaseError(e.to_string()))
}
}
}
fn database_type(&self) -> GeoDatabaseType {
GeoDatabaseType::Ip2Location
}
}
struct CachedCountry {
country_code: Option<String>,
cached_at: Instant,
}
#[derive(Debug, Clone)]
pub struct GeoFilterResult {
pub allowed: bool,
pub country_code: Option<String>,
pub cache_hit: bool,
pub add_header: bool,
pub status_code: u16,
pub block_message: Option<String>,
}
pub struct GeoFilterPool {
database: RwLock<Arc<dyn GeoDatabase>>,
cache: DashMap<IpAddr, CachedCountry>,
config: GeoFilter,
countries_set: HashSet<String>,
cache_ttl: Duration,
database_path: PathBuf,
database_type: GeoDatabaseType,
}
impl GeoFilterPool {
pub fn new(config: GeoFilter) -> Result<Self, GeoLookupError> {
let db_type = config.database_type.clone().unwrap_or_else(|| {
if config.database_path.ends_with(".mmdb") {
GeoDatabaseType::MaxMind
} else {
GeoDatabaseType::Ip2Location
}
});
let database_path = PathBuf::from(&config.database_path);
let database: Arc<dyn GeoDatabase> = match db_type {
GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&config.database_path)?),
GeoDatabaseType::Ip2Location => {
Arc::new(Ip2LocationDatabase::open(&config.database_path)?)
}
};
let countries_set: HashSet<String> = config.countries.iter().cloned().collect();
let cache_ttl = Duration::from_secs(config.cache_ttl_secs);
debug!(
database_path = %config.database_path,
database_type = ?db_type,
action = ?config.action,
countries_count = countries_set.len(),
cache_ttl_secs = config.cache_ttl_secs,
"Created GeoFilterPool"
);
Ok(Self {
database: RwLock::new(database),
cache: DashMap::new(),
config,
countries_set,
cache_ttl,
database_path,
database_type: db_type,
})
}
pub fn reload_database(&self) -> Result<(), GeoLookupError> {
info!(
database_path = %self.database_path.display(),
database_type = ?self.database_type,
"Reloading geo database"
);
let new_database: Arc<dyn GeoDatabase> = match self.database_type {
GeoDatabaseType::MaxMind => Arc::new(MaxMindDatabase::open(&self.database_path)?),
GeoDatabaseType::Ip2Location => {
Arc::new(Ip2LocationDatabase::open(&self.database_path)?)
}
};
{
let mut db = self.database.write();
*db = new_database;
}
self.cache.clear();
info!(
database_path = %self.database_path.display(),
"Geo database reloaded successfully"
);
Ok(())
}
pub fn database_path(&self) -> &Path {
&self.database_path
}
pub fn check(&self, client_ip: &str) -> GeoFilterResult {
let ip: IpAddr = match client_ip.parse() {
Ok(ip) => ip,
Err(_) => {
warn!(client_ip = %client_ip, "Failed to parse client IP for geo filter");
return self.handle_failure();
}
};
let now = Instant::now();
if let Some(entry) = self.cache.get(&ip) {
if now.duration_since(entry.cached_at) < self.cache_ttl {
trace!(ip = %ip, country = ?entry.country_code, "Geo cache hit");
return self.evaluate(entry.country_code.clone(), true);
}
}
let database = self.database.read();
match database.lookup(ip) {
Ok(country_code) => {
self.cache.insert(
ip,
CachedCountry {
country_code: country_code.clone(),
cached_at: now,
},
);
self.evaluate(country_code, false)
}
Err(e) => {
warn!(ip = %ip, error = %e, "Geo lookup failed");
self.handle_failure()
}
}
}
fn evaluate(&self, country_code: Option<String>, cache_hit: bool) -> GeoFilterResult {
let in_list = country_code
.as_ref()
.map(|c| self.countries_set.contains(c))
.unwrap_or(false);
let allowed = match self.config.action {
GeoFilterAction::Block => {
!in_list
}
GeoFilterAction::Allow => {
if self.countries_set.is_empty() {
true
} else {
in_list
}
}
GeoFilterAction::LogOnly => {
true
}
};
trace!(
country = ?country_code,
in_list = in_list,
action = ?self.config.action,
allowed = allowed,
"Geo filter evaluation"
);
GeoFilterResult {
allowed,
country_code,
cache_hit,
add_header: self.config.add_country_header,
status_code: self.config.status_code,
block_message: self.config.block_message.clone(),
}
}
fn handle_failure(&self) -> GeoFilterResult {
let allowed = match self.config.on_failure {
GeoFailureMode::Open => true,
GeoFailureMode::Closed => false,
};
GeoFilterResult {
allowed,
country_code: None,
cache_hit: false,
add_header: false,
status_code: self.config.status_code,
block_message: self.config.block_message.clone(),
}
}
pub fn cache_stats(&self) -> (usize, usize) {
let now = Instant::now();
let total = self.cache.len();
let valid = self
.cache
.iter()
.filter(|e| now.duration_since(e.cached_at) < self.cache_ttl)
.count();
(total, valid)
}
pub fn clear_expired(&self) {
let now = Instant::now();
self.cache
.retain(|_, v| now.duration_since(v.cached_at) < self.cache_ttl);
}
}
pub struct GeoFilterManager {
filter_pools: DashMap<String, Arc<GeoFilterPool>>,
}
impl GeoFilterManager {
pub fn new() -> Self {
Self {
filter_pools: DashMap::new(),
}
}
pub fn register_filter(
&self,
filter_id: &str,
config: GeoFilter,
) -> Result<(), GeoLookupError> {
let pool = GeoFilterPool::new(config)?;
self.filter_pools
.insert(filter_id.to_string(), Arc::new(pool));
debug!(filter_id = %filter_id, "Registered geo filter");
Ok(())
}
pub fn check(&self, filter_id: &str, client_ip: &str) -> Option<GeoFilterResult> {
self.filter_pools
.get(filter_id)
.map(|pool| pool.check(client_ip))
}
pub fn get_pool(&self, filter_id: &str) -> Option<Arc<GeoFilterPool>> {
self.filter_pools.get(filter_id).map(|r| r.clone())
}
pub fn has_filter(&self, filter_id: &str) -> bool {
self.filter_pools.contains_key(filter_id)
}
pub fn filter_ids(&self) -> Vec<String> {
self.filter_pools.iter().map(|r| r.key().clone()).collect()
}
pub fn clear_expired_caches(&self) {
for pool in self.filter_pools.iter() {
pool.clear_expired();
}
}
pub fn reload_filter(&self, filter_id: &str) -> Result<(), GeoLookupError> {
if let Some(pool) = self.filter_pools.get(filter_id) {
pool.reload_database()
} else {
Err(GeoLookupError::LoadError(format!(
"Filter '{}' not found",
filter_id
)))
}
}
pub fn reload_by_path(&self, path: &Path) -> Vec<(String, Result<(), GeoLookupError>)> {
let mut results = Vec::new();
for entry in self.filter_pools.iter() {
if entry.value().database_path() == path {
let filter_id = entry.key().clone();
let result = entry.value().reload_database();
results.push((filter_id, result));
}
}
results
}
pub fn database_paths(&self) -> Vec<(String, PathBuf)> {
self.filter_pools
.iter()
.map(|e| (e.key().clone(), e.value().database_path().to_path_buf()))
.collect()
}
}
impl Default for GeoFilterManager {
fn default() -> Self {
Self::new()
}
}
pub struct GeoDatabaseWatcher {
watcher: RwLock<Option<notify::RecommendedWatcher>>,
path_to_filters: RwLock<HashMap<PathBuf, Vec<String>>>,
manager: Arc<GeoFilterManager>,
}
impl GeoDatabaseWatcher {
pub fn new(manager: Arc<GeoFilterManager>) -> Self {
Self {
watcher: RwLock::new(None),
path_to_filters: RwLock::new(HashMap::new()),
manager,
}
}
pub fn start_watching(&self) -> Result<mpsc::Receiver<PathBuf>, GeoLookupError> {
let db_paths = self.manager.database_paths();
let mut path_map: HashMap<PathBuf, Vec<String>> = HashMap::new();
for (filter_id, path) in db_paths {
path_map.entry(path).or_default().push(filter_id);
}
if path_map.is_empty() {
debug!("No geo databases to watch");
let (_tx, rx) = mpsc::channel(1);
return Ok(rx);
}
*self.path_to_filters.write() = path_map.clone();
let (tx, rx) = mpsc::channel::<PathBuf>(10);
let paths: Vec<PathBuf> = path_map.keys().cloned().collect();
let watcher = notify::recommended_watcher(move |event: Result<Event, notify::Error>| {
if let Ok(event) = event {
if matches!(event.kind, EventKind::Modify(_) | EventKind::Create(_)) {
for path in &event.paths {
let _ = tx.blocking_send(path.clone());
}
}
}
})
.map_err(|e| GeoLookupError::LoadError(format!("Failed to create file watcher: {}", e)))?;
*self.watcher.write() = Some(watcher);
if let Some(ref mut watcher) = *self.watcher.write() {
for path in &paths {
if let Err(e) = watcher.watch(path, RecursiveMode::NonRecursive) {
warn!(
path = %path.display(),
error = %e,
"Failed to watch geo database file"
);
} else {
info!(
path = %path.display(),
"Watching geo database for changes"
);
}
}
}
Ok(rx)
}
pub fn handle_change(&self, path: &Path) {
let path_map = self.path_to_filters.read();
if let Some(filter_ids) = path_map.get(path) {
info!(
path = %path.display(),
filters = ?filter_ids,
"Geo database file changed, reloading"
);
for filter_id in filter_ids {
match self.manager.reload_filter(filter_id) {
Ok(()) => {
info!(
filter_id = %filter_id,
"Geo filter database reloaded successfully"
);
}
Err(e) => {
error!(
filter_id = %filter_id,
error = %e,
"Failed to reload geo filter database"
);
}
}
}
}
}
pub fn stop(&self) {
*self.watcher.write() = None;
info!("Stopped watching geo database files");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_geo_lookup_error_display() {
let err = GeoLookupError::InvalidIp("not-an-ip".to_string());
assert!(err.to_string().contains("invalid IP"));
let err = GeoLookupError::DatabaseError("db error".to_string());
assert!(err.to_string().contains("database error"));
let err = GeoLookupError::LoadError("load error".to_string());
assert!(err.to_string().contains("failed to load"));
}
#[test]
fn test_geo_filter_result_default() {
let result = GeoFilterResult {
allowed: true,
country_code: Some("US".to_string()),
cache_hit: false,
add_header: true,
status_code: 403,
block_message: None,
};
assert!(result.allowed);
assert_eq!(result.country_code, Some("US".to_string()));
assert!(!result.cache_hit);
assert!(result.add_header);
}
#[test]
fn test_geo_filter_manager_new() {
let manager = GeoFilterManager::new();
assert!(manager.filter_ids().is_empty());
assert!(!manager.has_filter("test"));
}
}