#![cfg(test)]
use std::net::IpAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::correlation::detectors::{AttackPayload, TokenFingerprint};
use crate::correlation::{
Campaign, CampaignManager, CorrelationReason, CorrelationType, ManagerConfig,
};
fn ip(last_octet: u8) -> IpAddr {
format!("192.168.1.{}", last_octet).parse().unwrap()
}
fn test_manager() -> CampaignManager {
let config = ManagerConfig {
shared_threshold: 2,
rotation_threshold: 2,
rotation_window: Duration::from_secs(60),
scan_interval: Duration::from_millis(100),
background_scanning: false,
track_combined: true,
shared_confidence: 0.85,
attack_sequence_min_ips: 2,
attack_sequence_window: Duration::from_secs(300),
auth_token_min_ips: 2,
auth_token_window: Duration::from_secs(600),
behavioral_min_ips: 2,
behavioral_min_sequence: 3,
behavioral_window: Duration::from_secs(300),
timing_min_ips: 2,
timing_bucket_ms: 100,
timing_min_bucket_hits: 3,
timing_window: Duration::from_secs(60),
network_min_ips: 2,
network_check_subnet: true,
graph_min_component_size: 3,
graph_max_depth: 3,
graph_edge_ttl: Duration::from_secs(3600),
auto_mitigation_enabled: false,
auto_mitigation_threshold: 0.90,
};
CampaignManager::with_config(config)
}
fn mock_jwt() -> String {
"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwiaXNzIjoiaHR0cHM6Ly9leGFtcGxlLmNvbSJ9.signature".to_string()
}
fn attack_payload(payload_hash: &str, attack_type: &str) -> AttackPayload {
AttackPayload {
payload_hash: payload_hash.to_string(),
attack_type: attack_type.to_string(),
target_path: "/api/vulnerable".to_string(),
timestamp: Instant::now(),
}
}
#[allow(dead_code)]
fn token_fingerprint(issuer: &str) -> TokenFingerprint {
TokenFingerprint {
issuer: Some(issuer.to_string()),
algorithm: "RS256".to_string(),
header_fields: vec!["alg".to_string(), "typ".to_string()],
payload_fields: vec!["sub".to_string(), "iss".to_string()],
}
}
#[tokio::test]
async fn test_multi_detector_campaign_detection() {
let manager = test_manager();
let ips: Vec<IpAddr> = (1..=5).map(ip).collect();
let tls_fingerprint = "t13d1516h2_shared_campaign_fp".to_string();
for &test_ip in &ips {
manager.register_ja4(test_ip, tls_fingerprint.clone());
}
let shared_payload = attack_payload("sqli_probe_hash_abc123", "sqli");
for &test_ip in &ips {
manager.record_attack(
test_ip,
shared_payload.payload_hash.clone(),
shared_payload.attack_type.clone(),
shared_payload.target_path.clone(),
);
}
let jwt = mock_jwt();
for &test_ip in &ips {
manager.record_token(test_ip, &jwt);
}
for &test_ip in &ips {
manager.record_request(test_ip, "GET", "/");
manager.record_request(test_ip, "GET", "/api/users");
manager.record_request(test_ip, "POST", "/api/login");
}
let updates = manager.run_detection_cycle().await.unwrap();
assert!(updates > 0, "Expected at least one campaign update");
let campaigns = manager.get_campaigns();
assert!(
!campaigns.is_empty(),
"Expected at least one campaign to be created"
);
let stats = manager.stats();
let detections_by_type = &stats.detections_by_type;
let detector_types_with_detections: Vec<&String> = detections_by_type
.iter()
.filter(|(_, &count)| count > 0)
.map(|(name, _)| name)
.collect();
assert!(
detector_types_with_detections.len() >= 1,
"Expected detections from multiple detector types, got: {:?}",
detector_types_with_detections
);
}
#[tokio::test]
async fn test_weighted_scoring_aggregation() {
let mut campaign = Campaign::new(
"test-weighted-scoring".to_string(),
vec!["192.168.1.1".to_string(), "192.168.1.2".to_string()],
0.5,
);
campaign.correlation_reasons.push(CorrelationReason::new(
CorrelationType::AttackSequence,
0.95,
"Same SQLi payload",
vec!["192.168.1.1".to_string(), "192.168.1.2".to_string()],
));
campaign.correlation_reasons.push(CorrelationReason::new(
CorrelationType::AuthToken,
0.90,
"Same JWT issuer",
vec!["192.168.1.1".to_string(), "192.168.1.2".to_string()],
));
campaign.correlation_reasons.push(CorrelationReason::new(
CorrelationType::TlsFingerprint,
0.85,
"Same JA4 fingerprint",
vec!["192.168.1.1".to_string(), "192.168.1.2".to_string()],
));
assert_eq!(CorrelationType::AttackSequence.weight(), 50);
assert_eq!(CorrelationType::AuthToken.weight(), 45);
assert_eq!(CorrelationType::HttpFingerprint.weight(), 40);
assert_eq!(CorrelationType::TlsFingerprint.weight(), 35);
assert_eq!(CorrelationType::BehavioralSimilarity.weight(), 30);
assert_eq!(CorrelationType::TimingCorrelation.weight(), 25);
assert_eq!(CorrelationType::NetworkProximity.weight(), 15);
let manager = test_manager();
let calculated_score = manager.calculate_campaign_score(&campaign);
let expected_score = (50.0 * 0.95 + 45.0 * 0.90 + 35.0 * 0.85) / 3.0;
let tolerance = 0.01;
assert!(
(calculated_score - expected_score).abs() < tolerance,
"Score mismatch: calculated={:.4}, expected={:.4}",
calculated_score,
expected_score
);
let attack_only_score = 50.0 * 0.95; let network_only_score = 15.0 * 0.95; assert!(
attack_only_score > network_only_score,
"AttackSequence (50) should contribute more than NetworkProximity (15)"
);
}
#[test]
fn test_correlation_type_weight_ordering() {
let all_types = CorrelationType::all_by_weight();
assert_eq!(all_types.len(), 7, "Should have 7 correlation types");
assert_eq!(all_types[0], CorrelationType::AttackSequence);
assert_eq!(all_types[1], CorrelationType::AuthToken);
assert_eq!(all_types[2], CorrelationType::HttpFingerprint);
assert_eq!(all_types[3], CorrelationType::TlsFingerprint);
assert_eq!(all_types[4], CorrelationType::BehavioralSimilarity);
assert_eq!(all_types[5], CorrelationType::TimingCorrelation);
assert_eq!(all_types[6], CorrelationType::NetworkProximity);
for i in 0..all_types.len() - 1 {
assert!(
all_types[i].weight() > all_types[i + 1].weight(),
"Weight ordering violated at index {}: {} ({}) should be > {} ({})",
i,
all_types[i],
all_types[i].weight(),
all_types[i + 1],
all_types[i + 1].weight()
);
}
}
#[tokio::test]
async fn test_detection_cycle_runs_all_detectors() {
let manager = test_manager();
let test_ips: Vec<IpAddr> = (1..=5).map(ip).collect();
for &test_ip in &test_ips {
manager.register_ja4(test_ip, "shared_tls_fp".to_string());
}
for &test_ip in &test_ips {
manager.record_attack(
test_ip,
"attack_hash_xyz".to_string(),
"sqli".to_string(),
"/api/data".to_string(),
);
}
for &test_ip in &test_ips {
manager.record_token(test_ip, &mock_jwt());
}
for &test_ip in &test_ips {
manager.record_request(test_ip, "GET", "/home");
manager.record_request(test_ip, "GET", "/products");
manager.record_request(test_ip, "POST", "/cart");
}
let stats_before = manager.stats();
assert_eq!(stats_before.detections_run, 0);
let updates = manager.run_detection_cycle().await.unwrap();
let stats_after = manager.stats();
assert_eq!(stats_after.detections_run, 1);
assert!(stats_after.last_scan.is_some());
assert!(updates > 0, "Expected detection updates from cycle");
let detection_stats = &stats_after.detections_by_type;
let active_detectors: Vec<&String> = detection_stats
.iter()
.filter(|(_, &count)| count > 0)
.map(|(name, _)| name)
.collect();
println!("Active detectors with detections: {:?}", active_detectors);
println!("Detection stats: {:?}", detection_stats);
assert!(
!active_detectors.is_empty(),
"Expected at least one detector to produce detections"
);
}
#[tokio::test]
async fn test_detector_independence() {
let config_no_network = ManagerConfig {
shared_threshold: 2,
rotation_threshold: 2,
rotation_window: Duration::from_secs(60),
scan_interval: Duration::from_millis(100),
background_scanning: false,
network_check_subnet: false, ..Default::default()
};
let manager_no_network = CampaignManager::with_config(config_no_network);
let config_with_network = ManagerConfig {
shared_threshold: 2,
rotation_threshold: 2,
rotation_window: Duration::from_secs(60),
scan_interval: Duration::from_millis(100),
background_scanning: false,
network_check_subnet: true, ..Default::default()
};
let manager_with_network = CampaignManager::with_config(config_with_network);
let test_ips: Vec<IpAddr> = (1..=5).map(ip).collect();
for &test_ip in &test_ips {
manager_no_network.register_ja4(test_ip, "detector_independence_test".to_string());
manager_with_network.register_ja4(test_ip, "detector_independence_test".to_string());
manager_no_network.record_request(test_ip, "GET", "/");
manager_no_network.record_request(test_ip, "GET", "/api");
manager_no_network.record_request(test_ip, "POST", "/submit");
manager_with_network.record_request(test_ip, "GET", "/");
manager_with_network.record_request(test_ip, "GET", "/api");
manager_with_network.record_request(test_ip, "POST", "/submit");
}
let updates_no_network = manager_no_network.run_detection_cycle().await.unwrap();
let updates_with_network = manager_with_network.run_detection_cycle().await.unwrap();
assert!(
updates_no_network > 0,
"Manager without network detector should still have detections"
);
assert!(
updates_with_network > 0,
"Manager with network detector should have detections"
);
let stats_no_network = manager_no_network.stats();
let stats_with_network = manager_with_network.stats();
let network_detections_disabled = stats_no_network
.detections_by_type
.get("network")
.copied()
.unwrap_or(0);
assert_eq!(
network_detections_disabled, 0,
"Disabled network detector should not contribute detections"
);
let _network_detections_enabled = stats_with_network
.detections_by_type
.get("network")
.copied()
.unwrap_or(0);
}
#[tokio::test]
async fn test_detector_threshold_independence() {
let config = ManagerConfig {
shared_threshold: 3, attack_sequence_min_ips: 2, auth_token_min_ips: 4, behavioral_min_ips: 2, timing_min_ips: 5, network_min_ips: 3, rotation_threshold: 3,
rotation_window: Duration::from_secs(60),
scan_interval: Duration::from_millis(100),
background_scanning: false,
..Default::default()
};
let manager = CampaignManager::with_config(config);
let two_ips: Vec<IpAddr> = (1..=2).map(ip).collect();
for &test_ip in &two_ips {
manager.record_attack(
test_ip,
"threshold_test_hash".to_string(),
"xss".to_string(),
"/test".to_string(),
);
manager.register_ja4(test_ip, "threshold_test_fp".to_string());
manager.record_token(test_ip, &mock_jwt());
}
let _updates = manager.run_detection_cycle().await.unwrap();
let stats = manager.stats();
let attack_detections = stats
.detections_by_type
.get("attack_sequence")
.copied()
.unwrap_or(0);
assert!(
attack_detections > 0,
"Attack sequence should detect with 2 IPs (threshold: 2)"
);
let auth_detections = stats
.detections_by_type
.get("auth_token")
.copied()
.unwrap_or(0);
assert_eq!(
auth_detections, 0,
"Auth token should not detect with 2 IPs (threshold: 4)"
);
}
#[tokio::test]
async fn test_concurrent_multi_detector_access() {
let manager = Arc::new(test_manager());
let mut handles = vec![];
for task_id in 0..5 {
let manager = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
for i in 0..20 {
let test_ip: IpAddr = format!("10.{}.{}.{}", task_id, i / 256, i % 256)
.parse()
.unwrap();
manager.register_ja4(test_ip, format!("concurrent_fp_{}", task_id));
manager.record_attack(
test_ip,
format!("attack_{}", task_id),
"sqli".to_string(),
"/api".to_string(),
);
manager.record_token(test_ip, &mock_jwt());
manager.record_request(test_ip, "GET", "/");
manager.record_request(test_ip, "POST", "/api");
manager.record_request(test_ip, "GET", "/done");
tokio::task::yield_now().await;
}
}));
}
for _ in 0..3 {
let manager = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
for _ in 0..5 {
let _ = manager.run_detection_cycle().await;
tokio::time::sleep(Duration::from_millis(10)).await;
}
}));
}
let timeout_result = tokio::time::timeout(Duration::from_secs(10), async {
for handle in handles {
handle.await.expect("Task panicked");
}
})
.await;
assert!(
timeout_result.is_ok(),
"Concurrent access test timed out - possible deadlock"
);
let stats = manager.stats();
assert!(
stats.fingerprints_registered > 0,
"Should have registered fingerprints"
);
assert!(stats.detections_run > 0, "Should have run detection cycles");
let campaigns = manager.get_campaigns();
let all_campaigns = manager.get_all_campaigns();
println!(
"After concurrent test: {} active campaigns, {} total campaigns",
campaigns.len(),
all_campaigns.len()
);
}
#[tokio::test]
async fn test_concurrent_read_write_operations() {
let manager = Arc::new(test_manager());
let num_writers = 4;
let num_readers = 4;
let operations_per_task = 50;
let mut handles = vec![];
for writer_id in 0..num_writers {
let manager = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
for i in 0..operations_per_task {
let test_ip: IpAddr = format!("10.{}.0.{}", writer_id, i).parse().unwrap();
manager.register_ja4(test_ip, format!("writer_{}_fp", writer_id));
manager.record_request(test_ip, "GET", "/");
manager.record_request(test_ip, "GET", "/api");
manager.record_request(test_ip, "POST", "/submit");
tokio::task::yield_now().await;
}
}));
}
for _reader_id in 0..num_readers {
let manager = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
for _ in 0..operations_per_task {
let _ = manager.stats();
let _ = manager.get_campaigns();
let _ = manager.get_all_campaigns();
tokio::task::yield_now().await;
}
}));
}
for _ in 0..2 {
let manager = Arc::clone(&manager);
handles.push(tokio::spawn(async move {
for _ in 0..10 {
let _ = manager.run_detection_cycle().await;
tokio::time::sleep(Duration::from_millis(5)).await;
}
}));
}
for handle in handles {
handle.await.expect("Task failed");
}
let final_stats = manager.stats();
assert!(
final_stats.fingerprints_registered > 0,
"Should have data after concurrent operations"
);
}
#[tokio::test]
async fn test_campaign_merging_same_actors() {
let manager = test_manager();
let shared_ips: Vec<IpAddr> = (1..=3).map(ip).collect();
for &test_ip in &shared_ips {
manager.register_ja4(test_ip, "merge_test_tls".to_string());
}
manager.run_detection_cycle().await.unwrap();
let campaigns_after_first = manager.get_campaigns();
for &test_ip in &shared_ips {
manager.record_attack(
test_ip,
"merge_test_attack".to_string(),
"sqli".to_string(),
"/api".to_string(),
);
}
manager.run_detection_cycle().await.unwrap();
let campaigns_after_second = manager.get_campaigns();
for &test_ip in &shared_ips {
manager.record_request(test_ip, "GET", "/merge");
manager.record_request(test_ip, "GET", "/test");
manager.record_request(test_ip, "POST", "/data");
}
manager.run_detection_cycle().await.unwrap();
let campaigns_final = manager.get_campaigns();
println!(
"Campaigns after first: {}, second: {}, final: {}",
campaigns_after_first.len(),
campaigns_after_second.len(),
campaigns_final.len()
);
let _campaigns_with_all_ips = campaigns_final
.iter()
.filter(|c| {
shared_ips
.iter()
.all(|ip| c.actors.contains(&ip.to_string()))
})
.count();
assert!(
!campaigns_final.is_empty(),
"Should have at least one campaign"
);
for (i, campaign) in campaigns_final.iter().enumerate() {
println!(
"Campaign {}: {} actors, {} correlation reasons",
i,
campaign.actor_count,
campaign.correlation_reasons.len()
);
}
}
#[tokio::test]
async fn test_detection_stats_accuracy() {
let manager = test_manager();
let initial_stats = manager.stats();
assert_eq!(initial_stats.detections_run, 0);
assert_eq!(initial_stats.campaigns_created, 0);
assert!(
initial_stats.detections_by_type.is_empty()
|| initial_stats.detections_by_type.values().all(|&v| v == 0)
);
let test_ips: Vec<IpAddr> = (1..=4).map(ip).collect();
for &test_ip in &test_ips {
manager.register_ja4(test_ip, "stats_test_fp".to_string());
manager.record_attack(
test_ip,
"stats_test_payload".to_string(),
"xss".to_string(),
"/".to_string(),
);
}
manager.run_detection_cycle().await.unwrap();
let stats_after = manager.stats();
assert_eq!(
stats_after.detections_run, 1,
"Should count 1 detection cycle"
);
assert!(
stats_after.campaigns_created > 0,
"Should have created campaigns"
);
manager.run_detection_cycle().await.unwrap();
let stats_final = manager.stats();
assert_eq!(
stats_final.detections_run, 2,
"Should count 2 detection cycles"
);
}
#[tokio::test]
async fn test_empty_data_handling() {
let manager = test_manager();
let updates = manager.run_detection_cycle().await.unwrap();
assert_eq!(updates, 0, "Should have no updates with no data");
let campaigns = manager.get_campaigns();
assert!(
campaigns.is_empty(),
"Should have no campaigns with no data"
);
let stats = manager.stats();
assert_eq!(stats.detections_run, 1, "Detection cycle should be counted");
assert_eq!(stats.campaigns_created, 0, "No campaigns should be created");
}
#[tokio::test]
async fn test_single_ip_no_detection() {
let manager = test_manager();
let single_ip = ip(1);
manager.register_ja4(single_ip, "single_ip_fp".to_string());
manager.record_attack(
single_ip,
"single_ip_attack".to_string(),
"sqli".to_string(),
"/".to_string(),
);
manager.record_token(single_ip, &mock_jwt());
manager.record_request(single_ip, "GET", "/");
manager.record_request(single_ip, "GET", "/api");
manager.record_request(single_ip, "POST", "/submit");
let updates = manager.run_detection_cycle().await.unwrap();
assert_eq!(updates, 0, "Single IP should not trigger any detections");
}
#[tokio::test]
async fn test_mixed_ip_versions() {
let manager = test_manager();
let ipv4_1: IpAddr = "192.168.1.1".parse().unwrap();
let ipv4_2: IpAddr = "192.168.1.2".parse().unwrap();
let ipv6_1: IpAddr = "2001:db8::1".parse().unwrap();
let ipv6_2: IpAddr = "2001:db8::2".parse().unwrap();
let mixed_ips = vec![ipv4_1, ipv4_2, ipv6_1, ipv6_2];
for test_ip in &mixed_ips {
manager.register_ja4(*test_ip, "mixed_ip_fp".to_string());
}
let updates = manager.run_detection_cycle().await.unwrap();
assert!(updates > 0, "Should detect mixed IP campaign");
let campaigns = manager.get_campaigns();
assert!(
!campaigns.is_empty(),
"Should create campaign with mixed IPs"
);
if let Some(campaign) = campaigns.first() {
let has_ipv4 = campaign.actors.iter().any(|a| !a.contains(':'));
let has_ipv6 = campaign.actors.iter().any(|a| a.contains(':'));
assert!(
has_ipv4 && has_ipv6,
"Campaign should contain both IPv4 and IPv6 addresses"
);
}
}