use std::net::IpAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use futures::future::join_all;
use parking_lot::RwLock as ParkingLotRwLock;
use tokio::sync::{Mutex, RwLock};
use tokio::time::interval;
use crate::access::AccessListManager;
use crate::correlation::{
Campaign, CampaignStatus, CampaignStore, CampaignStoreStats, CampaignUpdate, FingerprintGroup,
FingerprintIndex, IndexStats,
};
use crate::telemetry::{TelemetryClient, TelemetryEvent};
use crate::correlation::detectors::{
AttackPayload,
AttackSequenceConfig,
AttackSequenceDetector,
AuthTokenConfig,
AuthTokenDetector,
BehavioralConfig,
BehavioralSimilarityDetector,
Detector,
DetectorError,
DetectorResult,
GraphConfig,
GraphDetector,
Ja4RotationDetector,
NetworkProximityConfig,
NetworkProximityDetector,
RotationConfig,
SharedFingerprintDetector,
TimingConfig,
TimingCorrelationDetector,
};
#[derive(Debug)]
pub struct MitigationRateLimiter {
bans_in_window: AtomicU64,
window_start: Mutex<Instant>,
max_bans_per_window: u64,
window_duration: Duration,
max_ips_per_campaign: usize,
}
impl MitigationRateLimiter {
pub fn new(
max_bans_per_window: u64,
window_duration: Duration,
max_ips_per_campaign: usize,
) -> Self {
Self {
bans_in_window: AtomicU64::new(0),
window_start: Mutex::new(Instant::now()),
max_bans_per_window,
window_duration,
max_ips_per_campaign,
}
}
pub async fn try_ban(&self) -> Result<(), String> {
self.maybe_reset_window().await;
let current = self.bans_in_window.fetch_add(1, Ordering::SeqCst);
if current >= self.max_bans_per_window {
self.bans_in_window.fetch_sub(1, Ordering::SeqCst);
return Err(format!(
"Rate limit exceeded: {} bans in {:?} window",
self.max_bans_per_window, self.window_duration
));
}
Ok(())
}
async fn maybe_reset_window(&self) {
let mut start = self.window_start.lock().await;
if start.elapsed() >= self.window_duration {
*start = Instant::now();
self.bans_in_window.store(0, Ordering::SeqCst);
}
}
pub fn max_ips_per_campaign(&self) -> usize {
self.max_ips_per_campaign
}
pub fn current_count(&self) -> u64 {
self.bans_in_window.load(Ordering::SeqCst)
}
}
impl Default for MitigationRateLimiter {
fn default() -> Self {
Self::new(
50, Duration::from_secs(60), 10, )
}
}
#[derive(Debug, Clone)]
pub struct ManagerConfig {
pub shared_threshold: usize,
pub rotation_window: Duration,
pub rotation_threshold: usize,
pub scan_interval: Duration,
pub background_scanning: bool,
pub track_combined: bool,
pub shared_confidence: f64,
pub attack_sequence_min_ips: usize,
pub attack_sequence_window: Duration,
pub auth_token_min_ips: usize,
pub auth_token_window: Duration,
pub behavioral_min_ips: usize,
pub behavioral_min_sequence: usize,
pub behavioral_window: Duration,
pub timing_min_ips: usize,
pub timing_bucket_ms: u64,
pub timing_min_bucket_hits: usize,
pub timing_window: Duration,
pub network_min_ips: usize,
pub network_check_subnet: bool,
pub graph_min_component_size: usize,
pub graph_max_depth: usize,
pub graph_edge_ttl: Duration,
pub auto_mitigation_enabled: bool,
pub auto_mitigation_threshold: f64,
}
impl Default for ManagerConfig {
fn default() -> Self {
Self {
shared_threshold: 3,
rotation_window: Duration::from_secs(60),
rotation_threshold: 3,
scan_interval: Duration::from_secs(5),
background_scanning: true,
track_combined: true,
shared_confidence: 0.85,
attack_sequence_min_ips: 2,
attack_sequence_window: Duration::from_secs(300),
auth_token_min_ips: 2,
auth_token_window: Duration::from_secs(600),
behavioral_min_ips: 2,
behavioral_min_sequence: 3,
behavioral_window: Duration::from_secs(300),
timing_min_ips: 3,
timing_bucket_ms: 100,
timing_min_bucket_hits: 5,
timing_window: Duration::from_secs(60),
network_min_ips: 3,
network_check_subnet: true,
graph_min_component_size: 3,
graph_max_depth: 3,
graph_edge_ttl: Duration::from_secs(3600),
auto_mitigation_enabled: false,
auto_mitigation_threshold: 0.90,
}
}
}
impl ManagerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_shared_threshold(mut self, threshold: usize) -> Self {
self.shared_threshold = threshold;
self
}
pub fn with_rotation_window(mut self, window: Duration) -> Self {
self.rotation_window = window;
self
}
pub fn with_rotation_threshold(mut self, threshold: usize) -> Self {
self.rotation_threshold = threshold;
self
}
pub fn with_scan_interval(mut self, interval: Duration) -> Self {
self.scan_interval = interval;
self
}
pub fn with_background_scanning(mut self, enabled: bool) -> Self {
self.background_scanning = enabled;
self
}
pub fn with_track_combined(mut self, enabled: bool) -> Self {
self.track_combined = enabled;
self
}
pub fn with_shared_confidence(mut self, confidence: f64) -> Self {
self.shared_confidence = confidence.clamp(0.0, 1.0);
self
}
pub fn with_auto_mitigation(mut self, enabled: bool) -> Self {
self.auto_mitigation_enabled = enabled;
self
}
pub fn with_auto_mitigation_threshold(mut self, threshold: f64) -> Self {
self.auto_mitigation_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn validate(&self) -> Result<(), String> {
if self.shared_threshold < 2 {
return Err("shared_threshold must be at least 2".to_string());
}
if self.rotation_threshold < 2 {
return Err("rotation_threshold must be at least 2".to_string());
}
if self.rotation_window.is_zero() {
return Err("rotation_window must be positive".to_string());
}
if self.scan_interval.is_zero() {
return Err("scan_interval must be positive".to_string());
}
if self.auto_mitigation_enabled && self.auto_mitigation_threshold < 0.7 {
return Err(
"auto_mitigation_threshold must be >= 0.7 when auto_mitigation is enabled to prevent false positives"
.to_string(),
);
}
if self.graph_min_component_size < 2 {
return Err("graph_min_component_size must be at least 2".to_string());
}
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct ManagerStats {
pub fingerprints_registered: u64,
pub detections_run: u64,
pub campaigns_created: u64,
pub last_scan: Option<Instant>,
pub index_stats: IndexStats,
pub campaign_stats: CampaignStoreStats,
pub detections_by_type: std::collections::HashMap<String, u64>,
}
const GROUP_CACHE_TTL: Duration = Duration::from_millis(100);
struct GroupCache {
groups: Vec<FingerprintGroup>,
cached_at: Instant,
threshold: usize,
}
impl GroupCache {
fn new(groups: Vec<FingerprintGroup>, threshold: usize) -> Self {
Self {
groups,
cached_at: Instant::now(),
threshold,
}
}
fn is_valid(&self, threshold: usize) -> bool {
self.threshold == threshold && self.cached_at.elapsed() < GROUP_CACHE_TTL
}
}
pub struct CampaignManager {
config: ManagerConfig,
index: Arc<FingerprintIndex>,
store: Arc<CampaignStore>,
access_list_manager: Option<Arc<ParkingLotRwLock<AccessListManager>>>,
telemetry_client: Option<Arc<TelemetryClient>>,
attack_sequence_detector: AttackSequenceDetector,
auth_token_detector: AuthTokenDetector,
http_fingerprint_detector: SharedFingerprintDetector,
tls_fingerprint_detector: Ja4RotationDetector,
behavioral_detector: BehavioralSimilarityDetector,
timing_detector: TimingCorrelationDetector,
network_detector: NetworkProximityDetector,
graph_detector: GraphDetector,
stats_fingerprints_registered: AtomicU64,
stats_detections_run: AtomicU64,
stats_campaigns_created: AtomicU64,
stats_detections_by_type: RwLock<std::collections::HashMap<String, u64>>,
last_scan: RwLock<Option<Instant>>,
shutdown: AtomicBool,
group_cache: RwLock<Option<GroupCache>>,
mitigation_rate_limiter: MitigationRateLimiter,
mitigated_campaigns: dashmap::DashSet<String>,
}
impl CampaignManager {
pub fn new() -> Self {
Self::with_config(ManagerConfig::default())
}
pub fn with_config(config: ManagerConfig) -> Self {
let attack_sequence_config = AttackSequenceConfig {
min_ips: config.attack_sequence_min_ips,
window: config.attack_sequence_window,
similarity_threshold: 0.95, ..Default::default()
};
let attack_sequence_detector = AttackSequenceDetector::new(attack_sequence_config);
let auth_token_config = AuthTokenConfig {
min_ips: config.auth_token_min_ips,
window: config.auth_token_window,
..Default::default()
};
let auth_token_detector = AuthTokenDetector::new(auth_token_config);
let http_fingerprint_detector = SharedFingerprintDetector::with_config(
config.shared_threshold,
config.shared_confidence,
config.scan_interval.as_millis() as u64,
);
let rotation_config = RotationConfig {
min_fingerprints: config.rotation_threshold,
window: config.rotation_window,
track_combined: config.track_combined,
..Default::default()
};
let tls_fingerprint_detector = Ja4RotationDetector::new(rotation_config);
let behavioral_config = BehavioralConfig {
min_ips: config.behavioral_min_ips,
min_sequence_length: config.behavioral_min_sequence,
window: config.behavioral_window,
..Default::default()
};
let behavioral_detector = BehavioralSimilarityDetector::new(behavioral_config);
let timing_config = TimingConfig {
min_ips: config.timing_min_ips,
bucket_size: Duration::from_millis(config.timing_bucket_ms),
min_bucket_hits: config.timing_min_bucket_hits,
window: config.timing_window,
..Default::default()
};
let timing_detector = TimingCorrelationDetector::new(timing_config);
let network_config = NetworkProximityConfig {
min_ips: config.network_min_ips,
check_subnet: config.network_check_subnet,
check_asn: false, ..Default::default()
};
let network_detector = NetworkProximityDetector::new(network_config);
let graph_config = GraphConfig {
min_component_size: config.graph_min_component_size,
max_traversal_depth: config.graph_max_depth,
edge_ttl: config.graph_edge_ttl,
..Default::default()
};
let graph_detector = GraphDetector::new(graph_config);
Self {
config,
index: Arc::new(FingerprintIndex::new()),
store: Arc::new(CampaignStore::new()),
access_list_manager: None,
telemetry_client: None,
attack_sequence_detector,
auth_token_detector,
http_fingerprint_detector,
tls_fingerprint_detector,
behavioral_detector,
timing_detector,
network_detector,
graph_detector,
stats_fingerprints_registered: AtomicU64::new(0),
stats_detections_run: AtomicU64::new(0),
stats_campaigns_created: AtomicU64::new(0),
stats_detections_by_type: RwLock::new(std::collections::HashMap::new()),
last_scan: RwLock::new(None),
shutdown: AtomicBool::new(false),
group_cache: RwLock::new(None),
mitigation_rate_limiter: MitigationRateLimiter::default(),
mitigated_campaigns: dashmap::DashSet::new(),
}
}
pub fn set_access_list_manager(&mut self, manager: Arc<ParkingLotRwLock<AccessListManager>>) {
self.access_list_manager = Some(manager);
}
pub fn set_telemetry_client(&mut self, client: Arc<TelemetryClient>) {
self.telemetry_client = Some(client);
}
pub fn register_ja4(&self, ip: IpAddr, fingerprint: String) {
if fingerprint.is_empty() {
return;
}
let ip_str = ip.to_string();
self.index.update_entity(&ip_str, Some(&fingerprint), None);
self.tls_fingerprint_detector
.record_fingerprint(ip, fingerprint);
self.stats_fingerprints_registered
.fetch_add(1, Ordering::Relaxed);
}
pub fn register_ja4_arc(&self, ip: IpAddr, fingerprint: Arc<str>) {
if fingerprint.is_empty() {
return;
}
let ip_str = ip.to_string();
self.index.update_entity(&ip_str, Some(&fingerprint), None);
self.tls_fingerprint_detector
.record_fingerprint(ip, fingerprint.to_string());
self.stats_fingerprints_registered
.fetch_add(1, Ordering::Relaxed);
}
pub fn register_combined(&self, ip: IpAddr, fingerprint: String) {
if fingerprint.is_empty() {
return;
}
let ip_str = ip.to_string();
self.index.update_entity(&ip_str, None, Some(&fingerprint));
if self.config.track_combined {
self.tls_fingerprint_detector
.record_fingerprint(ip, fingerprint);
}
self.stats_fingerprints_registered
.fetch_add(1, Ordering::Relaxed);
}
pub fn register_combined_arc(&self, ip: IpAddr, fingerprint: Arc<str>) {
if fingerprint.is_empty() {
return;
}
let ip_str = ip.to_string();
self.index.update_entity(&ip_str, None, Some(&fingerprint));
if self.config.track_combined {
self.tls_fingerprint_detector
.record_fingerprint(ip, fingerprint.to_string());
}
self.stats_fingerprints_registered
.fetch_add(1, Ordering::Relaxed);
}
pub fn register_fingerprints(&self, ip: IpAddr, ja4: Option<String>, ja4h: Option<String>) {
let ip_str = ip.to_string();
let mut registered = false;
let ja4_ref = ja4.as_deref();
let combined = ja4h.as_ref().map(|h| {
format!("{}_{}", ja4.as_deref().unwrap_or(""), h)
});
let combined_ref = combined.as_deref();
self.index.update_entity(&ip_str, ja4_ref, combined_ref);
if let Some(ref fp) = ja4 {
if !fp.is_empty() {
self.tls_fingerprint_detector
.record_fingerprint(ip, fp.clone());
registered = true;
}
}
if self.config.track_combined {
if let Some(ref fp) = combined {
if !fp.is_empty() {
self.tls_fingerprint_detector
.record_fingerprint(ip, fp.clone());
registered = true;
}
}
}
if registered {
self.stats_fingerprints_registered
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn record_attack(
&self,
ip: IpAddr,
payload_hash: String,
attack_type: String,
path: String,
) {
self.attack_sequence_detector.record_attack(
ip,
AttackPayload {
payload_hash,
attack_type,
target_path: path,
timestamp: std::time::Instant::now(),
},
);
}
pub fn record_token(&self, ip: IpAddr, jwt: &str) {
self.auth_token_detector.record_jwt(ip, jwt);
}
pub fn record_request(&self, ip: IpAddr, method: &str, path: &str) {
self.behavioral_detector.record_request(ip, method, path);
self.timing_detector.record_request(ip);
self.network_detector.register_ip(ip);
}
pub fn record_request_full(
&self,
ip: IpAddr,
method: &str,
path: &str,
ja4: Option<&str>,
jwt: Option<&str>,
) {
self.record_request(ip, method, path);
let ip_id = GraphDetector::ip_id(&ip.to_string());
if let Some(fp) = ja4 {
if !fp.is_empty() {
self.register_ja4(ip, fp.to_string());
self.record_relation(&ip_id, &GraphDetector::fp_id(fp));
}
}
if let Some(token) = jwt {
if !token.is_empty() {
self.record_token(ip, token);
let token_id = if token.len() > 16 {
&token[..16]
} else {
token
};
self.record_relation(&ip_id, &GraphDetector::token_id(token_id));
}
}
}
pub fn record_relation(&self, entity_a: &str, entity_b: &str) {
self.graph_detector.record_relation(entity_a, entity_b);
}
pub fn calculate_campaign_score(&self, campaign: &Campaign) -> f64 {
if campaign.correlation_reasons.is_empty() {
return 0.0;
}
let total_weighted: f64 = campaign
.correlation_reasons
.iter()
.map(|r| r.correlation_type.weight() as f64 * r.confidence)
.sum();
total_weighted / campaign.correlation_reasons.len() as f64
}
pub async fn run_detection_cycle(&self) -> DetectorResult<usize> {
let detectors: Vec<(&dyn Detector, &'static str)> = vec![
(
&self.attack_sequence_detector as &dyn Detector,
"attack_sequence",
),
(&self.auth_token_detector as &dyn Detector, "auth_token"),
(
&self.http_fingerprint_detector as &dyn Detector,
"http_fingerprint",
),
(
&self.tls_fingerprint_detector as &dyn Detector,
"tls_fingerprint",
),
(&self.behavioral_detector as &dyn Detector, "behavioral"),
(&self.timing_detector as &dyn Detector, "timing"),
(&self.network_detector as &dyn Detector, "network"),
(&self.graph_detector as &dyn Detector, "graph"),
];
let detector_futures: Vec<_> = detectors
.into_iter()
.map(|(detector, name)| {
let index = &self.index;
async move {
let result = detector.analyze(index);
(name, result)
}
})
.collect();
let results = join_all(detector_futures).await;
let mut total_updates = 0;
let mut stats_updates: std::collections::HashMap<String, u64> =
std::collections::HashMap::new();
for (name, result) in results {
match result {
Ok(updates) => {
let update_count = updates.len();
for update in updates {
self.process_campaign_update(update).await;
total_updates += 1;
}
if update_count > 0 {
*stats_updates.entry(name.to_string()).or_insert(0) += update_count as u64;
}
}
Err(e) => {
tracing::warn!("Detector {} failed: {}", name, e);
}
}
}
if !stats_updates.is_empty() {
let mut stats = self.stats_detections_by_type.write().await;
for (name, count) in stats_updates {
*stats.entry(name).or_insert(0) += count;
}
}
self.stats_detections_run.fetch_add(1, Ordering::Relaxed);
{
let mut last_scan = self.last_scan.write().await;
*last_scan = Some(Instant::now());
}
Ok(total_updates)
}
pub async fn get_cached_groups(&self, threshold: usize) -> Vec<FingerprintGroup> {
{
let cache_guard = self.group_cache.read().await;
if let Some(ref cache) = *cache_guard {
if cache.is_valid(threshold) {
return cache.groups.clone();
}
}
}
let groups = self.index.get_groups_above_threshold(threshold);
{
let mut cache_guard = self.group_cache.write().await;
*cache_guard = Some(GroupCache::new(groups.clone(), threshold));
}
groups
}
pub async fn invalidate_group_cache(&self) {
let mut cache_guard = self.group_cache.write().await;
*cache_guard = None;
}
async fn process_campaign_update(&self, update: CampaignUpdate) {
let ips: Vec<String> = update
.add_correlation_reason
.as_ref()
.map(|reason| reason.evidence.clone())
.unwrap_or_default();
if ips.is_empty() {
return;
}
let existing_campaign_id = ips.iter().find_map(|ip| self.store.get_campaign_for_ip(ip));
let mut check_mitigation = false;
let mut target_campaign_id = String::new();
match existing_campaign_id {
Some(campaign_id) => {
let _ = self.store.update_campaign(&campaign_id, update);
for ip in &ips {
let _ = self.store.add_actor_to_campaign(&campaign_id, ip);
}
check_mitigation = true;
target_campaign_id = campaign_id;
}
None => {
let confidence = update.confidence.unwrap_or(0.5);
let mut campaign_id = Campaign::generate_id();
let mut retry_count = 0;
while self.store.get_campaign(&campaign_id).is_some() && retry_count < 10 {
campaign_id = format!("{}-{:x}", Campaign::generate_id(), fastrand::u32(..));
retry_count += 1;
}
let mut campaign = Campaign::new(campaign_id.clone(), ips, confidence);
if let Some(status) = update.status {
campaign.status = status;
}
if let Some(ref attack_types) = update.attack_types {
campaign.attack_types = attack_types.clone();
}
if let Some(reason) = update.add_correlation_reason {
campaign.correlation_reasons.push(reason);
}
if let Some(risk_score) = update.risk_score {
campaign.risk_score = risk_score;
}
if self.store.create_campaign(campaign).is_ok() {
self.stats_campaigns_created.fetch_add(1, Ordering::Relaxed);
check_mitigation = true;
target_campaign_id = campaign_id;
}
}
}
if check_mitigation {
if let Some(campaign) = self.store.get_campaign(&target_campaign_id) {
if self.config.auto_mitigation_enabled
&& campaign.confidence >= self.config.auto_mitigation_threshold
&& campaign.status != CampaignStatus::Resolved
{
self.mitigate_campaign(&campaign).await;
}
if campaign.confidence >= 0.8 {
self.report_campaign(&campaign);
}
}
}
}
fn report_campaign(&self, campaign: &Campaign) {
if let Some(ref client) = self.telemetry_client {
if !client.is_enabled() {
return;
}
let event = TelemetryEvent::CampaignReport {
campaign_id: campaign.id.clone(),
confidence: campaign.confidence,
attack_types: campaign
.attack_types
.iter()
.map(|at| format!("{:?}", at))
.collect(),
actor_count: campaign.actor_count,
correlation_reasons: campaign
.correlation_reasons
.iter()
.map(|r| r.description.clone())
.collect(),
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
};
let client = Arc::clone(client);
tokio::spawn(async move {
if let Err(e) = client.report(event).await {
tracing::debug!("Failed to report campaign telemetry: {}", e);
}
});
}
}
async fn mitigate_campaign(&self, campaign: &Campaign) {
if self.mitigated_campaigns.contains(&campaign.id) {
tracing::debug!(campaign_id = %campaign.id, "Campaign already mitigated, skipping");
return;
}
let access_list = match &self.access_list_manager {
Some(al) => al,
None => {
tracing::debug!("No AccessListManager configured, skipping mitigation");
return;
}
};
let max_ips = self.mitigation_rate_limiter.max_ips_per_campaign();
let ips_to_block: Vec<IpAddr> = campaign
.actors
.iter()
.filter_map(|ip_str| ip_str.parse::<IpAddr>().ok())
.take(max_ips)
.collect();
if ips_to_block.is_empty() {
tracing::debug!(campaign_id = %campaign.id, "No valid IPs to block");
return;
}
let mut blocked_count = 0;
let mut rate_limited = false;
for ip in &ips_to_block {
if let Err(reason) = self.mitigation_rate_limiter.try_ban().await {
tracing::warn!(
campaign_id = %campaign.id,
reason = %reason,
blocked = blocked_count,
remaining = ips_to_block.len() - blocked_count,
"Mitigation rate limited"
);
rate_limited = true;
break;
}
let comment = format!(
"Campaign {} (confidence: {:.2})",
campaign.id, campaign.confidence
);
{
let mut al = access_list.write();
if let Err(e) = al.add_deny_ip(ip, Some(&comment)) {
tracing::error!(ip = %ip, error = %e, "Failed to add deny rule");
continue;
}
}
blocked_count += 1;
}
let attack_types: Vec<String> = campaign
.attack_types
.iter()
.map(|at| format!("{:?}", at))
.collect();
tracing::info!(
campaign_id = %campaign.id,
confidence = campaign.confidence,
total_actors = campaign.actors.len(),
blocked = blocked_count,
rate_limited = rate_limited,
attack_types = ?attack_types,
"Auto-mitigation applied"
);
self.mitigated_campaigns.insert(campaign.id.clone());
if let Some(ref client) = self.telemetry_client {
if client.is_enabled() {
let event = TelemetryEvent::CampaignReport {
campaign_id: format!("mitigation:{}", campaign.id),
confidence: campaign.confidence,
attack_types,
actor_count: blocked_count,
correlation_reasons: vec![format!(
"Auto-mitigation applied: {} IPs blocked",
blocked_count
)],
timestamp_ms: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64,
};
let client = Arc::clone(client);
tokio::spawn(async move {
if let Err(e) = client.report(event).await {
tracing::debug!("Failed to report mitigation telemetry: {}", e);
}
});
}
}
}
pub fn should_trigger_detection(&self, ip: &IpAddr) -> bool {
self.attack_sequence_detector
.should_trigger(ip, &self.index)
|| self.auth_token_detector.should_trigger(ip, &self.index)
|| self
.http_fingerprint_detector
.should_trigger(ip, &self.index)
|| self
.tls_fingerprint_detector
.should_trigger(ip, &self.index)
|| self.behavioral_detector.should_trigger(ip, &self.index)
|| self.timing_detector.should_trigger(ip, &self.index)
|| self.network_detector.should_trigger(ip, &self.index)
|| self.graph_detector.should_trigger(ip, &self.index)
}
pub fn get_campaigns(&self) -> Vec<Campaign> {
self.store.list_active_campaigns()
}
pub fn get_all_campaigns(&self) -> Vec<Campaign> {
self.store.list_campaigns(None)
}
pub fn snapshot(&self) -> Vec<Campaign> {
self.store.list_campaigns(None)
}
pub fn restore(&self, campaigns: Vec<Campaign>) {
self.store.clear();
self.index.clear();
for campaign in campaigns {
for ip_str in &campaign.actors {
self.index.update_entity(ip_str, None, None);
}
let _ = self.store.create_campaign(campaign);
}
}
pub fn get_campaign(&self, id: &str) -> Option<Campaign> {
self.store.get_campaign(id)
}
pub fn get_campaign_actors(&self, campaign_id: &str) -> Vec<IpAddr> {
self.store
.get_campaign(campaign_id)
.map(|campaign| {
campaign
.actors
.iter()
.filter_map(|ip_str| ip_str.parse().ok())
.collect()
})
.unwrap_or_default()
}
pub fn get_campaign_graph(&self, campaign_id: &str) -> serde_json::Value {
let ips = self.get_campaign_actors(campaign_id);
let ips_str: Vec<String> = ips.into_iter().map(|ip| ip.to_string()).collect();
self.graph_detector.get_cytoscape_data(&ips_str)
}
pub fn get_campaign_graph_paginated(
&self,
campaign_id: &str,
limit: Option<usize>,
offset: Option<usize>,
hash_identifiers: bool,
) -> crate::correlation::detectors::graph::PaginatedGraph {
use crate::correlation::detectors::graph::GraphExportOptions;
let ips = self.get_campaign_actors(campaign_id);
let ips_str: Vec<String> = ips.into_iter().map(|ip| ip.to_string()).collect();
let options = GraphExportOptions {
limit,
offset,
hash_identifiers,
};
self.graph_detector
.get_cytoscape_data_paginated(&ips_str, options)
}
pub fn stats(&self) -> ManagerStats {
let last_scan = {
self.last_scan
.try_read()
.map(|guard| *guard)
.unwrap_or(None)
};
let detections_by_type = self
.stats_detections_by_type
.try_read()
.map(|guard| guard.clone())
.unwrap_or_default();
ManagerStats {
fingerprints_registered: self.stats_fingerprints_registered.load(Ordering::Relaxed),
detections_run: self.stats_detections_run.load(Ordering::Relaxed),
campaigns_created: self.stats_campaigns_created.load(Ordering::Relaxed),
last_scan,
index_stats: self.index.stats(),
campaign_stats: self.store.stats(),
detections_by_type,
}
}
pub fn start_background_worker(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
let manager = self;
let scan_interval = manager.config.scan_interval;
tokio::spawn(async move {
let mut ticker = interval(scan_interval);
loop {
ticker.tick().await;
if manager.shutdown.load(Ordering::Relaxed) {
log::info!("Campaign manager background worker shutting down");
break;
}
match manager.run_detection_cycle().await {
Ok(updates) => {
if updates > 0 {
log::debug!("Detection cycle processed {} updates", updates);
}
}
Err(e) => {
log::warn!("Detection cycle error: {}", e);
}
}
}
})
}
pub fn shutdown(&self) {
self.shutdown.store(true, Ordering::Relaxed);
}
pub fn is_shutdown(&self) -> bool {
self.shutdown.load(Ordering::Relaxed)
}
pub fn remove_ip(&self, ip: &IpAddr) {
let ip_str = ip.to_string();
self.index.remove_entity(&ip_str);
if let Some(campaign_id) = self.store.get_campaign_for_ip(&ip_str) {
let _ = self.store.remove_actor_from_campaign(&campaign_id, &ip_str);
}
}
pub fn index(&self) -> &Arc<FingerprintIndex> {
&self.index
}
pub fn store(&self) -> &Arc<CampaignStore> {
&self.store
}
pub fn config(&self) -> &ManagerConfig {
&self.config
}
pub fn resolve_campaign(&self, campaign_id: &str, reason: &str) -> Result<(), DetectorError> {
self.store
.resolve_campaign(campaign_id, reason)
.map_err(|e| DetectorError::DetectionFailed(e.to_string()))
}
pub fn clear(&self) {
self.index.clear();
self.store.clear();
self.http_fingerprint_detector.clear_processed();
self.tls_fingerprint_detector.cleanup_old_observations();
}
}
impl Default for CampaignManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn create_test_manager() -> CampaignManager {
let config = ManagerConfig {
shared_threshold: 3,
rotation_threshold: 3,
rotation_window: Duration::from_secs(60),
scan_interval: Duration::from_millis(100),
background_scanning: false,
..Default::default()
};
CampaignManager::with_config(config)
}
fn create_test_ip(last_octet: u8) -> IpAddr {
format!("192.168.1.{}", last_octet).parse().unwrap()
}
#[test]
fn test_config_default() {
let config = ManagerConfig::default();
assert_eq!(config.shared_threshold, 3);
assert_eq!(config.rotation_threshold, 3);
assert_eq!(config.rotation_window, Duration::from_secs(60));
assert_eq!(config.scan_interval, Duration::from_secs(5));
assert!(config.background_scanning);
assert!(config.track_combined);
assert!((config.shared_confidence - 0.85).abs() < 0.001);
}
#[test]
fn test_config_builder() {
let config = ManagerConfig::new()
.with_shared_threshold(5)
.with_rotation_threshold(4)
.with_rotation_window(Duration::from_secs(120))
.with_scan_interval(Duration::from_secs(10))
.with_background_scanning(false)
.with_track_combined(false)
.with_shared_confidence(0.9);
assert_eq!(config.shared_threshold, 5);
assert_eq!(config.rotation_threshold, 4);
assert_eq!(config.rotation_window, Duration::from_secs(120));
assert_eq!(config.scan_interval, Duration::from_secs(10));
assert!(!config.background_scanning);
assert!(!config.track_combined);
assert!((config.shared_confidence - 0.9).abs() < 0.001);
}
#[tokio::test]
async fn test_mitigation_rate_limiter_limits() {
let limiter = MitigationRateLimiter::new(2, Duration::from_secs(60), 10);
assert!(limiter.try_ban().await.is_ok());
assert!(limiter.try_ban().await.is_ok());
assert!(limiter.try_ban().await.is_err());
}
#[test]
fn test_config_validation() {
let config = ManagerConfig::default();
assert!(config.validate().is_ok());
let config = ManagerConfig::new().with_shared_threshold(1);
assert!(config.validate().is_err());
let config = ManagerConfig::new().with_rotation_threshold(1);
assert!(config.validate().is_err());
let config = ManagerConfig {
rotation_window: Duration::ZERO,
..Default::default()
};
assert!(config.validate().is_err());
let config = ManagerConfig {
scan_interval: Duration::ZERO,
..Default::default()
};
assert!(config.validate().is_err());
let config = ManagerConfig {
auto_mitigation_enabled: true,
auto_mitigation_threshold: 0.5, ..Default::default()
};
assert!(config.validate().is_err());
let config = ManagerConfig {
auto_mitigation_enabled: true,
auto_mitigation_threshold: 0.9,
..Default::default()
};
assert!(config.validate().is_ok());
let config = ManagerConfig {
auto_mitigation_enabled: false,
auto_mitigation_threshold: 0.5, ..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_config_confidence_clamping() {
let config = ManagerConfig::new().with_shared_confidence(1.5);
assert!((config.shared_confidence - 1.0).abs() < 0.001);
let config = ManagerConfig::new().with_shared_confidence(-0.5);
assert!(config.shared_confidence >= 0.0);
}
#[test]
fn test_register_ja4() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.register_ja4(ip, "t13d1516h2_abc123".to_string());
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 1);
assert_eq!(stats.index_stats.total_ips, 1);
assert_eq!(stats.index_stats.ja4_fingerprints, 1);
}
#[test]
fn test_register_ja4_empty_skipped() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.register_ja4(ip, "".to_string());
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 0);
assert_eq!(stats.index_stats.total_ips, 0);
}
#[test]
fn test_register_combined() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.register_combined(ip, "combined_hash_xyz".to_string());
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 1);
assert_eq!(stats.index_stats.total_ips, 1);
assert_eq!(stats.index_stats.combined_fingerprints, 1);
}
#[test]
fn test_register_fingerprints_both() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.register_fingerprints(
ip,
Some("ja4_test".to_string()),
Some("ja4h_test".to_string()),
);
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 1);
assert_eq!(stats.index_stats.ja4_fingerprints, 1);
assert_eq!(stats.index_stats.combined_fingerprints, 1);
}
#[test]
fn test_register_fingerprints_ja4_only() {
let manager = create_test_manager();
let ip = create_test_ip(1);
manager.register_fingerprints(ip, Some("ja4_only".to_string()), None);
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 1);
assert_eq!(stats.index_stats.ja4_fingerprints, 1);
assert_eq!(stats.index_stats.combined_fingerprints, 0);
}
#[tokio::test]
async fn test_detection_cycle_empty() {
let manager = create_test_manager();
let updates = manager.run_detection_cycle().await.unwrap();
assert_eq!(updates, 0);
assert_eq!(manager.stats().detections_run, 1);
}
#[tokio::test]
async fn test_detection_cycle_creates_campaign() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "shared_fingerprint".to_string());
}
let updates = manager.run_detection_cycle().await.unwrap();
assert!(updates >= 1);
assert_eq!(manager.stats().campaigns_created, 1);
let campaigns = manager.get_campaigns();
assert_eq!(campaigns.len(), 1);
}
#[tokio::test]
async fn test_detection_cycle_no_duplicate_campaigns() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "shared_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let first_count = manager.stats().campaigns_created;
manager.run_detection_cycle().await.unwrap();
let second_count = manager.stats().campaigns_created;
assert_eq!(first_count, second_count);
}
#[tokio::test]
async fn test_get_campaigns() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "test_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
assert!(!campaigns.is_empty());
let campaign = &campaigns[0];
assert_eq!(campaign.actor_count, 3);
}
#[tokio::test]
async fn test_get_campaign_by_id() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "get_by_id_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
let campaign_id = &campaigns[0].id;
let retrieved = manager.get_campaign(campaign_id);
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().id, *campaign_id);
let not_found = manager.get_campaign("nonexistent");
assert!(not_found.is_none());
}
#[tokio::test]
async fn test_get_campaign_actors() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "actors_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
let campaign_id = &campaigns[0].id;
let actors = manager.get_campaign_actors(campaign_id);
assert_eq!(actors.len(), 3);
let no_actors = manager.get_campaign_actors("nonexistent");
assert!(no_actors.is_empty());
}
#[tokio::test]
async fn test_stats_tracking() {
let manager = create_test_manager();
let initial = manager.stats();
assert_eq!(initial.fingerprints_registered, 0);
assert_eq!(initial.detections_run, 0);
assert_eq!(initial.campaigns_created, 0);
assert!(initial.last_scan.is_none());
for i in 1..=5 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "stats_test_fp".to_string());
}
let after_register = manager.stats();
assert_eq!(after_register.fingerprints_registered, 5);
assert_eq!(after_register.index_stats.total_ips, 5);
manager.run_detection_cycle().await.unwrap();
let after_detect = manager.stats();
assert_eq!(after_detect.detections_run, 1);
assert!(after_detect.last_scan.is_some());
assert!(after_detect.campaigns_created >= 1);
}
#[tokio::test]
async fn test_remove_ip_cleanup() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "remove_test_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
assert_eq!(campaigns[0].actor_count, 3);
let ip_to_remove = create_test_ip(1);
manager.remove_ip(&ip_to_remove);
assert_eq!(manager.index.len(), 2);
let updated_campaigns = manager.get_campaigns();
assert_eq!(updated_campaigns[0].actor_count, 2);
}
#[test]
fn test_concurrent_registration() {
let manager = Arc::new(create_test_manager());
let mut handles = vec![];
for thread_id in 0..10 {
let manager = Arc::clone(&manager);
handles.push(thread::spawn(move || {
for i in 0..100 {
let ip: IpAddr = format!("10.{}.0.{}", thread_id, i % 256).parse().unwrap();
manager.register_ja4(ip, format!("fp_t{}_{}", thread_id, i % 5));
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 1000);
assert!(stats.index_stats.total_ips > 0);
}
#[test]
fn test_should_trigger_detection_below_threshold() {
let manager = create_test_manager();
for i in 1..=2 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "trigger_test_fp".to_string());
}
let ip = create_test_ip(1);
assert!(!manager.should_trigger_detection(&ip));
}
#[test]
fn test_should_trigger_detection_at_threshold() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "trigger_threshold_fp".to_string());
}
let ip = create_test_ip(1);
assert!(manager.should_trigger_detection(&ip));
}
#[tokio::test]
async fn test_background_worker_lifecycle() {
let config = ManagerConfig {
scan_interval: Duration::from_millis(50),
background_scanning: true,
shared_threshold: 3,
..Default::default()
};
let manager = Arc::new(CampaignManager::with_config(config));
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "worker_test_fp".to_string());
}
let worker = Arc::clone(&manager).start_background_worker();
tokio::time::sleep(Duration::from_millis(200)).await;
let stats = manager.stats();
assert!(stats.detections_run >= 1);
manager.shutdown();
let timeout = tokio::time::timeout(Duration::from_millis(500), worker).await;
assert!(timeout.is_ok(), "Worker should shut down gracefully");
}
#[tokio::test]
async fn test_shutdown_flag() {
let manager = CampaignManager::new();
assert!(!manager.is_shutdown());
manager.shutdown();
assert!(manager.is_shutdown());
}
#[tokio::test]
async fn test_full_flow() {
let manager = create_test_manager();
let fingerprint = "t13d1516h2_full_flow_test";
for i in 1..=5 {
let ip = create_test_ip(i);
manager.register_ja4(ip, fingerprint.to_string());
}
let updates = manager.run_detection_cycle().await.unwrap();
assert!(updates >= 1);
let campaigns = manager.get_campaigns();
assert_eq!(campaigns.len(), 1);
let campaign = &campaigns[0];
assert_eq!(campaign.actor_count, 5);
assert!(campaign.confidence >= 0.8);
assert!(!campaign.correlation_reasons.is_empty());
let retrieved = manager.get_campaign(&campaign.id).unwrap();
assert_eq!(retrieved.actors.len(), 5);
let actors = manager.get_campaign_actors(&campaign.id);
assert_eq!(actors.len(), 5);
manager.remove_ip(&create_test_ip(1));
let updated = manager.get_campaign(&campaign.id).unwrap();
assert_eq!(updated.actors.len(), 4);
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 5);
assert_eq!(stats.campaigns_created, 1);
assert_eq!(stats.campaign_stats.total_campaigns, 1);
}
#[test]
fn test_clear() {
let manager = create_test_manager();
for i in 1..=5 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "clear_test_fp".to_string());
}
assert_eq!(manager.index.len(), 5);
manager.clear();
assert_eq!(manager.index.len(), 0);
assert!(manager.store.is_empty());
}
#[tokio::test]
async fn test_resolve_campaign() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "resolve_test_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
let campaign_id = campaigns[0].id.clone();
let result = manager.resolve_campaign(&campaign_id, "Threat mitigated");
assert!(result.is_ok());
let resolved = manager.get_campaign(&campaign_id).unwrap();
assert_eq!(resolved.status, CampaignStatus::Resolved);
let active = manager.get_campaigns();
assert!(active.is_empty());
}
#[test]
fn test_index_and_store_access() {
let manager = create_test_manager();
let _index = manager.index();
let _store = manager.store();
let _config = manager.config();
assert!(manager.index().is_empty());
assert!(manager.store().is_empty());
}
#[test]
fn test_ipv6_addresses() {
let manager = create_test_manager();
let ipv6_1: IpAddr = "2001:db8::1".parse().unwrap();
let ipv6_2: IpAddr = "2001:db8::2".parse().unwrap();
let ipv6_3: IpAddr = "2001:db8::3".parse().unwrap();
manager.register_ja4(ipv6_1, "ipv6_fp".to_string());
manager.register_ja4(ipv6_2, "ipv6_fp".to_string());
manager.register_ja4(ipv6_3, "ipv6_fp".to_string());
let stats = manager.stats();
assert_eq!(stats.fingerprints_registered, 3);
assert_eq!(stats.index_stats.total_ips, 3);
}
#[test]
fn test_default_trait() {
let manager = CampaignManager::default();
assert!(manager.index.is_empty());
assert!(manager.store.is_empty());
assert!(!manager.is_shutdown());
}
#[tokio::test]
async fn test_multiple_fingerprint_groups() {
let manager = create_test_manager();
for i in 1..=3 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "group_a_fp".to_string());
}
for i in 10..=13 {
let ip = create_test_ip(i);
manager.register_ja4(ip, "group_b_fp".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns = manager.get_campaigns();
assert_eq!(campaigns.len(), 2);
let actor_counts: Vec<usize> = campaigns.iter().map(|c| c.actor_count).collect();
assert!(actor_counts.contains(&3));
assert!(actor_counts.contains(&4));
}
}