use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum BlockType {
Ip,
Fingerprint,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BlocklistEntry {
pub block_type: BlockType,
pub indicator: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<String>,
pub source: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub created_at: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct BlocklistUpdate {
pub action: BlocklistAction,
pub block_type: BlockType,
pub indicator: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum BlocklistAction {
Add,
Remove,
}
pub struct BlocklistCache {
ips: DashMap<String, BlocklistEntry>,
fingerprints: DashMap<String, BlocklistEntry>,
sequence_id: AtomicU64,
}
impl Default for BlocklistCache {
fn default() -> Self {
Self::new()
}
}
impl BlocklistCache {
pub fn new() -> Self {
Self {
ips: DashMap::new(),
fingerprints: DashMap::new(),
sequence_id: AtomicU64::new(0),
}
}
#[inline]
pub fn is_ip_blocked(&self, ip: &str) -> bool {
self.ips.contains_key(ip)
}
#[inline]
pub fn is_fingerprint_blocked(&self, fingerprint: &str) -> bool {
self.fingerprints.contains_key(fingerprint)
}
#[inline]
pub fn is_blocked(&self, ip: Option<&str>, fingerprint: Option<&str>) -> bool {
if let Some(ip) = ip {
if self.is_ip_blocked(ip) {
return true;
}
}
if let Some(fp) = fingerprint {
if self.is_fingerprint_blocked(fp) {
return true;
}
}
false
}
pub fn get_ip(&self, ip: &str) -> Option<BlocklistEntry> {
self.ips.get(ip).map(|r| r.value().clone())
}
pub fn get_fingerprint(&self, fingerprint: &str) -> Option<BlocklistEntry> {
self.fingerprints
.get(fingerprint)
.map(|r| r.value().clone())
}
pub fn add(&self, entry: BlocklistEntry) {
match entry.block_type {
BlockType::Ip => {
self.ips.insert(entry.indicator.clone(), entry);
}
BlockType::Fingerprint => {
self.fingerprints.insert(entry.indicator.clone(), entry);
}
}
}
pub fn remove(&self, block_type: BlockType, indicator: &str) {
match block_type {
BlockType::Ip => {
self.ips.remove(indicator);
}
BlockType::Fingerprint => {
self.fingerprints.remove(indicator);
}
}
}
pub fn load_snapshot(&self, entries: Vec<BlocklistEntry>, sequence_id: u64) {
self.ips.clear();
self.fingerprints.clear();
for entry in entries {
self.add(entry);
}
self.sequence_id.store(sequence_id, Ordering::SeqCst);
}
pub fn apply_updates(&self, updates: Vec<BlocklistUpdate>, sequence_id: u64) {
for update in updates {
match update.action {
BlocklistAction::Add => {
self.add(BlocklistEntry {
block_type: update.block_type,
indicator: update.indicator,
expires_at: None,
source: update.source.unwrap_or_else(|| "hub".to_string()),
reason: update.reason,
created_at: None,
});
}
BlocklistAction::Remove => {
self.remove(update.block_type, &update.indicator);
}
}
}
self.sequence_id.store(sequence_id, Ordering::SeqCst);
}
pub fn size(&self) -> usize {
self.ips.len() + self.fingerprints.len()
}
pub fn ip_count(&self) -> usize {
self.ips.len()
}
pub fn fingerprint_count(&self) -> usize {
self.fingerprints.len()
}
pub fn sequence_id(&self) -> u64 {
self.sequence_id.load(Ordering::SeqCst)
}
pub fn clear(&self) {
self.ips.clear();
self.fingerprints.clear();
self.sequence_id.store(0, Ordering::SeqCst);
}
pub fn all_ips(&self) -> Vec<BlocklistEntry> {
self.ips.iter().map(|r| r.value().clone()).collect()
}
pub fn all_fingerprints(&self) -> Vec<BlocklistEntry> {
self.fingerprints
.iter()
.map(|r| r.value().clone())
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ip_blocking() {
let cache = BlocklistCache::new();
cache.add(BlocklistEntry {
block_type: BlockType::Ip,
indicator: "192.168.1.100".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
assert!(cache.is_ip_blocked("192.168.1.100"));
assert!(!cache.is_ip_blocked("192.168.1.101"));
}
#[test]
fn test_fingerprint_blocking() {
let cache = BlocklistCache::new();
cache.add(BlocklistEntry {
block_type: BlockType::Fingerprint,
indicator: "t13d1516h2_abc123".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
assert!(cache.is_fingerprint_blocked("t13d1516h2_abc123"));
assert!(!cache.is_fingerprint_blocked("t13d1516h2_def456"));
}
#[test]
fn test_is_blocked_combined() {
let cache = BlocklistCache::new();
cache.add(BlocklistEntry {
block_type: BlockType::Ip,
indicator: "10.0.0.1".to_string(),
expires_at: None,
source: "test".to_string(),
reason: None,
created_at: None,
});
assert!(cache.is_blocked(Some("10.0.0.1"), None));
assert!(cache.is_blocked(Some("10.0.0.1"), Some("fp123")));
assert!(!cache.is_blocked(Some("10.0.0.2"), Some("fp123")));
assert!(!cache.is_blocked(None, None));
}
#[test]
fn test_load_snapshot() {
let cache = BlocklistCache::new();
cache.add(BlocklistEntry {
block_type: BlockType::Ip,
indicator: "old-ip".to_string(),
expires_at: None,
source: "old".to_string(),
reason: None,
created_at: None,
});
cache.load_snapshot(
vec![BlocklistEntry {
block_type: BlockType::Ip,
indicator: "new-ip".to_string(),
expires_at: None,
source: "snapshot".to_string(),
reason: None,
created_at: None,
}],
42,
);
assert!(!cache.is_ip_blocked("old-ip"));
assert!(cache.is_ip_blocked("new-ip"));
assert_eq!(cache.sequence_id(), 42);
}
#[test]
fn test_apply_updates() {
let cache = BlocklistCache::new();
cache.apply_updates(
vec![
BlocklistUpdate {
action: BlocklistAction::Add,
block_type: BlockType::Ip,
indicator: "10.0.0.1".to_string(),
source: Some("hub".to_string()),
reason: None,
},
BlocklistUpdate {
action: BlocklistAction::Add,
block_type: BlockType::Fingerprint,
indicator: "fp1".to_string(),
source: None,
reason: Some("malicious".to_string()),
},
],
100,
);
assert!(cache.is_ip_blocked("10.0.0.1"));
assert!(cache.is_fingerprint_blocked("fp1"));
assert_eq!(cache.size(), 2);
assert_eq!(cache.sequence_id(), 100);
cache.apply_updates(
vec![BlocklistUpdate {
action: BlocklistAction::Remove,
block_type: BlockType::Ip,
indicator: "10.0.0.1".to_string(),
source: None,
reason: None,
}],
101,
);
assert!(!cache.is_ip_blocked("10.0.0.1"));
assert_eq!(cache.size(), 1);
}
}