use hashiverse_lib::transport::ddos::ddos::{DdosProtection, DdosScore};
use parking_lot::Mutex;
use log::{info, warn};
use moka::sync::Cache;
use std::collections::HashMap;
use std::process::Command;
use std::sync::Arc;
use std::time::Duration;
pub struct IpsetDdosProtection {
set_name: String,
score_threshold: f64,
decay_per_second: f64,
bad_request_penalty: f64,
max_connections_per_ip: usize,
scores: Cache<String, Arc<Mutex<DdosScore>>>,
ipset_throttle: Cache<String, ()>,
connections: Mutex<HashMap<String, usize>>,
}
impl IpsetDdosProtection {
pub fn new(set_name: impl Into<String>, score_threshold: f64, decay_per_second: f64, bad_request_penalty: f64, max_connections_per_ip: usize) -> Self {
let set_name = set_name.into();
let idle_secs = if decay_per_second > 0.0 {
(score_threshold / decay_per_second * 2.0).ceil() as u64
} else {
3600
};
let result = Command::new("ipset").args(["create", &set_name, "hash:ip", "timeout", &idle_secs.to_string(), "--exist"]).status();
match result {
Ok(status) if status.success() => info!("DDoS: ipset set '{}' ready", set_name),
Ok(status) => warn!("DDoS: ipset create '{}' failed with status {}", set_name, status),
Err(e) => warn!("DDoS: failed to run ipset create '{}': {}", set_name, e),
}
for chain in ["INPUT", "FORWARD"] {
let _ = Command::new("iptables").args(["-D", chain, "-m", "set", "--match-set", &set_name, "src", "-j", "DROP"]).status();
let result = Command::new("iptables").args(["-I", chain, "-m", "set", "--match-set", &set_name, "src", "-j", "DROP"]).status();
match result {
Ok(status) if status.success() => info!("DDoS: iptables {} rule for '{}' installed", chain, set_name),
Ok(status) => warn!("DDoS: iptables -I {} for '{}' failed with status {}", chain, set_name, status),
Err(e) => warn!("DDoS: failed to run iptables {} for '{}': {}", chain, set_name, e),
}
}
Self {
set_name,
score_threshold,
decay_per_second,
bad_request_penalty,
max_connections_per_ip,
scores: Cache::builder().time_to_idle(Duration::from_secs(idle_secs)).build(),
ipset_throttle: Cache::builder().time_to_live(Duration::from_secs(10)).build(),
connections: Mutex::new(HashMap::new()),
}
}
fn increment_score(&self, ip: &str, points: f64) -> f64 {
let entry = self.scores.get_with(ip.to_string(), || Arc::new(Mutex::new(DdosScore::new())));
entry.lock().increment(points, self.decay_per_second)
}
fn is_score_banned(&self, ip: &str) -> bool {
self.scores
.get(ip)
.map(|entry| entry.lock().current(self.decay_per_second) >= self.score_threshold)
.unwrap_or(false)
}
fn maybe_call_ipset(&self, ip: &str) {
{
if self.ipset_throttle.contains_key(ip) {
return;
}
self.ipset_throttle.insert(ip.to_string(), ());
}
{
let set_name = self.set_name.clone();
let ip = ip.to_string();
tokio::spawn(async move {
info!("Banning DDoS ip: {}", ip);
match tokio::process::Command::new("ipset").args(["add", &set_name, &ip, "--exist"]).status().await {
Ok(status) if status.success() => info!("DDoS: banned {} via ipset set '{}'", ip, set_name),
Ok(status) => warn!("DDoS: ipset add {} failed with status {}", ip, status),
Err(e) => warn!("DDoS: failed to run ipset for {}: {}", ip, e),
}
});
}
}
}
impl DdosProtection for IpsetDdosProtection {
fn allow_request(&self, ip: &str) -> bool {
let score = self.increment_score(ip, 1.0);
if score >= self.score_threshold {
self.maybe_call_ipset(ip);
false
}
else {
true
}
}
fn report_bad_request(&self, ip: &str) {
let score = self.increment_score(ip, self.bad_request_penalty);
if score >= self.score_threshold {
self.maybe_call_ipset(ip);
}
}
fn try_acquire_connection(&self, ip: &str) -> bool {
if self.is_score_banned(ip) {
return false;
}
let mut connections = self.connections.lock();
let count = connections.entry(ip.to_string()).or_insert(0);
if *count >= self.max_connections_per_ip {
return false;
}
*count += 1;
true
}
fn release_connection(&self, ip: &str) {
let mut connections = self.connections.lock();
if let Some(count) = connections.get_mut(ip) {
*count = count.saturating_sub(1);
if *count == 0 {
connections.remove(ip);
}
}
}
}