use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant};
use arc_swap::ArcSwap;
use crossbeam_queue::SegQueue;
use crate::{get_host_from_url, CAT_BAD};
type DynMap = HashMap<Box<str>, DynEntry, foldhash::fast::RandomState>;
#[inline]
fn new_map(cap: usize) -> DynMap {
HashMap::with_capacity_and_hasher(cap, foldhash::fast::RandomState::default())
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum BadReason {
WafChallenge,
FirewallBlock,
HardFailures,
Manual,
Other,
}
#[non_exhaustive]
pub struct BadReport<'a> {
pub host: &'a str,
pub cats: u64,
pub reason: BadReason,
pub at: Instant,
pub ttl: Duration,
}
#[derive(Clone, Copy)]
struct DynEntry {
cats: u64,
expires: Instant,
}
struct DynSnapshot {
map: DynMap,
built: Instant,
}
struct PendingReport {
host: Box<str>,
cats: u64,
expires: Instant,
}
type ReportSinkFn = dyn Fn(&BadReport) + Send + Sync + 'static;
static SNAPSHOT: OnceLock<ArcSwap<DynSnapshot>> = OnceLock::new();
static QUEUE: OnceLock<SegQueue<PendingReport>> = OnceLock::new();
static PENDING: AtomicUsize = AtomicUsize::new(0);
static MERGE_IN_PROGRESS: AtomicBool = AtomicBool::new(false);
static REPORT_SINK: OnceLock<Box<ReportSinkFn>> = OnceLock::new();
static PRUNER_STARTED: AtomicBool = AtomicBool::new(false);
static DEFAULT_TTL_MS: AtomicU64 = AtomicU64::new(24 * 60 * 60 * 1000); static MERGE_BATCH: AtomicUsize = AtomicUsize::new(64);
static MERGE_INTERVAL_MS: AtomicU64 = AtomicU64::new(5_000);
static MAX_ENTRIES: AtomicUsize = AtomicUsize::new(100_000);
#[inline]
fn snapshot_cell() -> &'static ArcSwap<DynSnapshot> {
SNAPSHOT.get_or_init(|| {
ArcSwap::from_pointee(DynSnapshot {
map: new_map(0),
built: Instant::now(),
})
})
}
#[inline]
fn queue() -> &'static SegQueue<PendingReport> {
QUEUE.get_or_init(SegQueue::new)
}
#[inline]
fn default_ttl_dur() -> Duration {
Duration::from_millis(DEFAULT_TTL_MS.load(Ordering::Relaxed))
}
#[inline]
pub fn dynamic_has_category(host: &str, cat: u64) -> bool {
let cell = match SNAPSHOT.get() {
Some(c) => c,
None => return false,
};
let snap = cell.load();
if snap.map.is_empty() {
return false;
}
let now = Instant::now();
let mut h = host;
loop {
if let Some(e) = snap.map.get(h) {
if e.cats & cat != 0 && e.expires > now {
return true;
}
}
match h.find('.') {
Some(dot) => {
h = &h[dot + 1..];
if !h.contains('.') {
break;
}
}
None => break,
}
}
false
}
#[inline]
pub fn dynamic_contains(host: &str) -> bool {
let cell = match SNAPSHOT.get() {
Some(c) => c,
None => return false,
};
let snap = cell.load();
if snap.map.is_empty() {
return false;
}
let now = Instant::now();
let mut h = host;
loop {
if let Some(e) = snap.map.get(h) {
if e.expires > now {
return true;
}
}
match h.find('.') {
Some(dot) => {
h = &h[dot + 1..];
if !h.contains('.') {
break;
}
}
None => break,
}
}
false
}
pub fn dynamic_len() -> usize {
SNAPSHOT.get().map(|c| c.load().map.len()).unwrap_or(0)
}
#[inline]
pub fn report_bad(host: &str) {
report_bad_with_ttl(host, CAT_BAD, BadReason::Other, default_ttl_dur());
}
#[inline]
pub fn report_bad_categorized(host: &str, cats: u64, reason: BadReason) {
report_bad_with_ttl(host, cats, reason, default_ttl_dur());
}
pub fn report_bad_with_ttl(host: &str, cats: u64, reason: BadReason, ttl: Duration) {
let cats = if cats == 0 { CAT_BAD } else { cats };
let raw = get_host_from_url(host).unwrap_or(host).trim();
if raw.is_empty() {
return;
}
let host_norm = raw.to_ascii_lowercase();
let at = Instant::now();
let expires = at.checked_add(ttl).unwrap_or(at);
invoke_sink(&BadReport {
host: &host_norm,
cats,
reason,
at,
ttl,
});
queue().push(PendingReport {
host: host_norm.into_boxed_str(),
cats,
expires,
});
let pending = PENDING.fetch_add(1, Ordering::Relaxed) + 1;
maybe_merge(pending, at);
}
pub fn seed_dynamic<I>(entries: I)
where
I: IntoIterator<Item = (String, u64, Duration)>,
{
let now = Instant::now();
let q = queue();
let mut any = false;
for (host, cats, ttl) in entries {
let cats = if cats == 0 { CAT_BAD } else { cats };
let raw = get_host_from_url(&host).unwrap_or(&host).trim();
if raw.is_empty() {
continue;
}
q.push(PendingReport {
host: raw.to_ascii_lowercase().into_boxed_str(),
cats,
expires: now.checked_add(ttl).unwrap_or(now),
});
PENDING.fetch_add(1, Ordering::Relaxed);
any = true;
}
if any {
try_merge(now);
}
}
pub fn seed_dynamic_hosts<I>(hosts: I)
where
I: IntoIterator<Item = String>,
{
let ttl = default_ttl_dur();
seed_dynamic(hosts.into_iter().map(move |h| (h, CAT_BAD, ttl)));
}
#[inline]
fn maybe_merge(pending: usize, now: Instant) {
let batch = MERGE_BATCH.load(Ordering::Relaxed).max(1);
if pending >= batch {
try_merge(now);
return;
}
let stale = match SNAPSHOT.get() {
Some(c) => {
let interval = Duration::from_millis(MERGE_INTERVAL_MS.load(Ordering::Relaxed));
now.duration_since(c.load().built) >= interval
}
None => true,
};
if stale {
try_merge(now);
}
}
struct MergeGuard;
impl Drop for MergeGuard {
fn drop(&mut self) {
MERGE_IN_PROGRESS.store(false, Ordering::Release);
}
}
fn try_merge(now: Instant) {
if MERGE_IN_PROGRESS
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let _slot = MergeGuard;
let cell = snapshot_cell();
let cur = cell.load_full();
let mut next: DynMap = new_map(cur.map.len());
for (k, v) in cur.map.iter() {
if v.expires > now {
next.insert(k.clone(), *v);
}
}
let q = queue();
while let Some(p) = q.pop() {
PENDING.fetch_sub(1, Ordering::Relaxed);
if p.expires <= now {
continue;
}
match next.get_mut(&p.host) {
Some(e) => {
e.cats |= p.cats;
if p.expires > e.expires {
e.expires = p.expires;
}
}
None => {
next.insert(
p.host,
DynEntry {
cats: p.cats,
expires: p.expires,
},
);
}
}
}
let cap = MAX_ENTRIES.load(Ordering::Relaxed);
if next.len() > cap {
evict_to_cap(&mut next, cap);
}
cell.store(Arc::new(DynSnapshot {
map: next,
built: now,
}));
}
fn evict_to_cap(map: &mut DynMap, cap: usize) {
let over = map.len().saturating_sub(cap);
if over == 0 {
return;
}
let mut by_exp: Vec<(Instant, Box<str>)> =
map.iter().map(|(k, e)| (e.expires, k.clone())).collect();
by_exp.sort_unstable_by_key(|(exp, _)| *exp);
for (_, k) in by_exp.into_iter().take(over) {
map.remove(&k);
}
}
pub fn flush() {
for _ in 0..256 {
try_merge(Instant::now());
if PENDING.load(Ordering::Relaxed) == 0 && queue().is_empty() {
return;
}
std::thread::yield_now();
}
}
pub fn set_report_sink<F>(sink: F)
where
F: Fn(&BadReport) + Send + Sync + 'static,
{
let _ = REPORT_SINK.set(Box::new(sink));
}
#[inline]
fn invoke_sink(report: &BadReport) {
if let Some(sink) = REPORT_SINK.get() {
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| sink(report)));
}
}
pub fn set_default_ttl(ttl: Duration) {
DEFAULT_TTL_MS.store(ttl.as_millis().min(u64::MAX as u128) as u64, Ordering::Relaxed);
}
pub fn default_ttl() -> Duration {
default_ttl_dur()
}
pub fn set_merge_batch(n: usize) {
MERGE_BATCH.store(n.max(1), Ordering::Relaxed);
}
pub fn set_merge_interval(d: Duration) {
MERGE_INTERVAL_MS.store(d.as_millis().min(u64::MAX as u128) as u64, Ordering::Relaxed);
}
pub fn set_max_entries(n: usize) {
MAX_ENTRIES.store(n.max(1), Ordering::Relaxed);
}
pub fn enable_background_pruner(interval: Duration) {
if PRUNER_STARTED
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
return;
}
let _ = std::thread::Builder::new()
.name("spider-firewall-pruner".into())
.spawn(move || loop {
std::thread::sleep(interval);
try_merge(Instant::now());
});
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dynamic_overlay_behavior() {
let hour = Duration::from_secs(3600);
assert!(!dynamic_contains("rt-evil.test"));
report_bad("rt-evil.test");
flush();
assert!(dynamic_contains("rt-evil.test"));
assert!(crate::is_url_bad("rt-evil.test"));
report_bad("sub-evil.test");
flush();
assert!(dynamic_contains("a.b.sub-evil.test"));
assert!(crate::is_url_bad("deep.a.b.sub-evil.test"));
report_bad_with_ttl("ttl-evil.test", CAT_BAD, BadReason::Manual, Duration::from_millis(40));
flush();
assert!(dynamic_contains("ttl-evil.test"));
std::thread::sleep(Duration::from_millis(60));
assert!(!dynamic_contains("ttl-evil.test"));
report_bad_categorized("ads-only.test", crate::CAT_ADS, BadReason::Other);
flush();
assert!(dynamic_has_category("ads-only.test", crate::CAT_ADS));
assert!(!dynamic_has_category("ads-only.test", CAT_BAD));
assert!(crate::is_url_bad("ads-only.test"));
report_bad("https://URL-Evil.test/path?x=1");
flush();
assert!(dynamic_contains("url-evil.test"));
seed_dynamic(vec![
("seed-a.test".to_string(), CAT_BAD, hour),
("seed-b.test".to_string(), CAT_BAD, hour),
]);
flush();
assert!(dynamic_contains("seed-a.test"));
assert!(dynamic_contains("seed-b.test"));
static HITS: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
set_report_sink(|r: &BadReport| {
assert!(!r.host.is_empty());
HITS.fetch_add(1, Ordering::Relaxed);
});
report_bad("sink-evil.test");
flush();
assert!(HITS.load(Ordering::Relaxed) >= 1);
report_bad("");
report_bad(" ");
flush();
assert!(!dynamic_contains(""));
}
}