use std::cmp::Ordering as CmpOrdering;
use std::collections::HashSet;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use dashmap::DashMap;
use parking_lot::RwLock as PLRwLock;
use serde::{Deserialize, Serialize};
use tokio::sync::Notify;
#[derive(Debug, Clone)]
pub struct ActorConfig {
pub max_actors: usize,
pub decay_interval_secs: u64,
pub persist_interval_secs: u64,
pub correlation_threshold: f64,
pub risk_decay_factor: f64,
pub max_rule_matches: usize,
pub max_session_ids: usize,
pub max_fingerprints_per_actor: usize,
pub max_fingerprint_mappings: usize,
pub enabled: bool,
pub max_risk: f64,
}
impl Default for ActorConfig {
fn default() -> Self {
Self {
max_actors: 100_000,
decay_interval_secs: 900,
persist_interval_secs: 300,
correlation_threshold: 0.7,
risk_decay_factor: 0.9,
max_rule_matches: 100,
max_session_ids: 50, max_fingerprints_per_actor: 20, max_fingerprint_mappings: 500_000, enabled: true,
max_risk: 100.0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleMatch {
pub rule_id: String,
pub timestamp: u64,
pub risk_contribution: f64,
pub category: String,
}
impl RuleMatch {
pub fn new(rule_id: String, risk_contribution: f64, category: String) -> Self {
Self {
rule_id,
timestamp: now_ms(),
risk_contribution,
category,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ActorState {
pub actor_id: String,
pub risk_score: f64,
pub rule_matches: Vec<RuleMatch>,
pub anomaly_count: u64,
pub session_ids: Vec<String>,
pub first_seen: u64,
pub last_seen: u64,
#[serde(with = "ip_set_serde")]
pub ips: HashSet<IpAddr>,
pub fingerprints: HashSet<String>,
pub is_blocked: bool,
pub block_reason: Option<String>,
pub blocked_since: Option<u64>,
}
impl ActorState {
pub fn new(actor_id: String) -> Self {
let now = now_ms();
Self {
actor_id,
risk_score: 0.0,
rule_matches: Vec::new(),
anomaly_count: 0,
session_ids: Vec::new(),
first_seen: now,
last_seen: now,
ips: HashSet::new(),
fingerprints: HashSet::new(),
is_blocked: false,
block_reason: None,
blocked_since: None,
}
}
pub fn touch(&mut self) {
self.last_seen = now_ms();
}
pub fn add_ip(&mut self, ip: IpAddr) {
self.ips.insert(ip);
self.touch();
}
pub fn add_fingerprint(&mut self, fingerprint: String, max_fingerprints: usize) -> bool {
if fingerprint.is_empty() {
return false;
}
if self.fingerprints.contains(&fingerprint) {
self.touch();
return true;
}
if self.fingerprints.len() >= max_fingerprints {
self.touch();
return false;
}
self.fingerprints.insert(fingerprint);
self.touch();
true
}
pub fn add_rule_match(&mut self, rule_match: RuleMatch, max_matches: usize) {
self.rule_matches.push(rule_match);
self.touch();
if self.rule_matches.len() > max_matches {
let excess = self.rule_matches.len() - max_matches;
self.rule_matches.drain(0..excess);
}
}
pub fn get_rule_match_count(&self, rule_id: &str) -> usize {
self.rule_matches
.iter()
.filter(|m| m.rule_id == rule_id)
.count()
}
}
mod ip_set_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::HashSet;
use std::net::IpAddr;
pub fn serialize<S>(set: &HashSet<IpAddr>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let strings: Vec<String> = set.iter().map(|ip| ip.to_string()).collect();
strings.serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<HashSet<IpAddr>, D::Error>
where
D: Deserializer<'de>,
{
let strings: Vec<String> = Vec::deserialize(deserializer)?;
let mut set = HashSet::new();
for s in strings {
if let Ok(ip) = s.parse() {
set.insert(ip);
}
}
Ok(set)
}
}
#[derive(Debug, Default)]
pub struct ActorStats {
pub total_actors: AtomicU64,
pub blocked_actors: AtomicU64,
pub correlations_made: AtomicU64,
pub evictions: AtomicU64,
pub total_created: AtomicU64,
pub total_rule_matches: AtomicU64,
pub fingerprint_evictions: AtomicU64,
}
impl ActorStats {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&self) -> ActorStatsSnapshot {
ActorStatsSnapshot {
total_actors: self.total_actors.load(Ordering::Relaxed),
blocked_actors: self.blocked_actors.load(Ordering::Relaxed),
correlations_made: self.correlations_made.load(Ordering::Relaxed),
evictions: self.evictions.load(Ordering::Relaxed),
total_created: self.total_created.load(Ordering::Relaxed),
total_rule_matches: self.total_rule_matches.load(Ordering::Relaxed),
fingerprint_evictions: self.fingerprint_evictions.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ActorStatsSnapshot {
pub total_actors: u64,
pub blocked_actors: u64,
pub correlations_made: u64,
pub evictions: u64,
pub total_created: u64,
pub total_rule_matches: u64,
pub fingerprint_evictions: u64,
}
type FingerprintClusterCache = Option<(Instant, Vec<(String, Vec<String>, f64)>)>;
#[derive(Debug)]
pub struct ActorManager {
actors: DashMap<String, ActorState>,
ip_to_actor: DashMap<IpAddr, String>,
fingerprint_to_actor: DashMap<String, String>,
config: ActorConfig,
stats: Arc<ActorStats>,
shutdown: Arc<Notify>,
touch_counter: AtomicU32,
fingerprint_groups_cache: PLRwLock<FingerprintClusterCache>,
}
impl ActorManager {
pub fn new(config: ActorConfig) -> Self {
Self {
actors: DashMap::with_capacity(config.max_actors),
ip_to_actor: DashMap::with_capacity(config.max_actors),
fingerprint_to_actor: DashMap::with_capacity(config.max_actors * 2),
config,
stats: Arc::new(ActorStats::new()),
shutdown: Arc::new(Notify::new()),
touch_counter: AtomicU32::new(0),
fingerprint_groups_cache: PLRwLock::new(None),
}
}
pub fn config(&self) -> &ActorConfig {
&self.config
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn len(&self) -> usize {
self.actors.len()
}
pub fn is_empty(&self) -> bool {
self.actors.is_empty()
}
pub fn get_or_create_actor(&self, ip: IpAddr, fingerprint: Option<&str>) -> String {
if !self.config.enabled {
return generate_actor_id();
}
self.maybe_evict();
if let Some(actor_id) = self.correlate_actor(ip, fingerprint) {
if let Some(mut entry) = self.actors.get_mut(&actor_id) {
entry.add_ip(ip);
if let Some(fp) = fingerprint {
if !fp.is_empty() {
if entry
.add_fingerprint(fp.to_string(), self.config.max_fingerprints_per_actor)
{
self.maybe_evict_fingerprint_mappings();
self.fingerprint_to_actor
.insert(fp.to_string(), actor_id.clone());
}
}
}
self.ip_to_actor.insert(ip, actor_id.clone());
}
return actor_id;
}
let actor_id = generate_actor_id();
let mut actor = ActorState::new(actor_id.clone());
actor.add_ip(ip);
if let Some(fp) = fingerprint {
if !fp.is_empty() {
if actor.add_fingerprint(fp.to_string(), self.config.max_fingerprints_per_actor) {
self.maybe_evict_fingerprint_mappings();
self.fingerprint_to_actor
.insert(fp.to_string(), actor_id.clone());
}
}
}
self.ip_to_actor.insert(ip, actor_id.clone());
self.actors.insert(actor_id.clone(), actor);
self.stats.total_actors.fetch_add(1, Ordering::Relaxed);
self.stats.total_created.fetch_add(1, Ordering::Relaxed);
actor_id
}
pub fn record_rule_match(
&self,
actor_id: &str,
rule_id: &str,
risk_contribution: f64,
category: &str,
) {
if !self.config.enabled {
return;
}
if let Some(mut entry) = self.actors.get_mut(actor_id) {
let rule_match =
RuleMatch::new(rule_id.to_string(), risk_contribution, category.to_string());
entry.risk_score = (entry.risk_score + risk_contribution).min(self.config.max_risk);
entry.add_rule_match(rule_match, self.config.max_rule_matches);
self.stats
.total_rule_matches
.fetch_add(1, Ordering::Relaxed);
}
}
pub fn touch_actor(&self, actor_id: &str) {
if !self.config.enabled {
return;
}
if let Some(mut entry) = self.actors.get_mut(actor_id) {
entry.touch();
}
}
pub fn get_actor(&self, actor_id: &str) -> Option<ActorState> {
self.actors.get(actor_id).map(|entry| entry.value().clone())
}
pub fn get_actor_by_ip(&self, ip: IpAddr) -> Option<ActorState> {
self.ip_to_actor
.get(&ip)
.and_then(|actor_id| self.actors.get(actor_id.value()).map(|e| e.value().clone()))
}
pub fn get_actor_by_fingerprint(&self, fingerprint: &str) -> Option<ActorState> {
self.fingerprint_to_actor
.get(fingerprint)
.and_then(|actor_id| self.actors.get(actor_id.value()).map(|e| e.value().clone()))
}
pub fn block_actor(&self, actor_id: &str, reason: &str) -> bool {
if let Some(mut entry) = self.actors.get_mut(actor_id) {
if !entry.is_blocked {
entry.is_blocked = true;
entry.block_reason = Some(reason.to_string());
entry.blocked_since = Some(now_ms());
self.stats.blocked_actors.fetch_add(1, Ordering::Relaxed);
}
true
} else {
false
}
}
pub fn unblock_actor(&self, actor_id: &str) -> bool {
if let Some(mut entry) = self.actors.get_mut(actor_id) {
if entry.is_blocked {
entry.is_blocked = false;
entry.block_reason = None;
entry.blocked_since = None;
self.stats.blocked_actors.fetch_sub(1, Ordering::Relaxed);
}
true
} else {
false
}
}
pub fn is_blocked(&self, actor_id: &str) -> bool {
self.actors
.get(actor_id)
.map(|entry| entry.is_blocked)
.unwrap_or(false)
}
pub fn bind_session(&self, actor_id: &str, session_id: &str) {
if let Some(mut entry) = self.actors.get_mut(actor_id) {
if !entry.session_ids.contains(&session_id.to_string()) {
if entry.session_ids.len() >= self.config.max_session_ids {
entry.session_ids.remove(0);
}
entry.session_ids.push(session_id.to_string());
entry.touch();
}
}
}
pub fn list_actors(&self, limit: usize, offset: usize) -> Vec<ActorState> {
let mut actors: Vec<ActorState> = self
.actors
.iter()
.map(|entry| entry.value().clone())
.collect();
actors.sort_by_key(|a| std::cmp::Reverse(a.last_seen));
actors.into_iter().skip(offset).take(limit).collect()
}
pub fn list_by_min_risk(&self, min_risk: f64, limit: usize, offset: usize) -> Vec<ActorState> {
let mut actors: Vec<ActorState> = self
.actors
.iter()
.filter(|entry| entry.value().risk_score >= min_risk)
.map(|entry| entry.value().clone())
.collect();
actors.sort_by(|a, b| {
b.risk_score
.partial_cmp(&a.risk_score)
.unwrap_or(CmpOrdering::Equal)
.then_with(|| b.last_seen.cmp(&a.last_seen))
});
actors.into_iter().skip(offset).take(limit).collect()
}
pub fn list_blocked_actors(&self) -> Vec<ActorState> {
self.actors
.iter()
.filter(|entry| entry.is_blocked)
.map(|entry| entry.value().clone())
.collect()
}
pub fn get_fingerprint_groups(&self, limit: usize) -> Vec<(String, Vec<String>, f64)> {
{
let cache = self.fingerprint_groups_cache.read();
if let Some((timestamp, data)) = &*cache {
if timestamp.elapsed() < Duration::from_secs(1) {
let mut result = data.clone();
result.truncate(limit);
return result;
}
}
}
use std::collections::HashMap;
let mut groups: HashMap<String, (Vec<String>, f64)> = HashMap::new();
for entry in self.actors.iter() {
let actor = entry.value();
for fp in &actor.fingerprints {
let group = groups
.entry(fp.clone())
.or_insert_with(|| (Vec::new(), 0.0));
group.0.push(actor.actor_id.clone());
group.1 = group.1.max(actor.risk_score);
}
}
let mut sorted_groups: Vec<_> = groups
.into_iter()
.map(|(fp, (actors, risk))| (fp, actors, risk))
.collect();
sorted_groups.sort_by_key(|a| std::cmp::Reverse(a.1.len()));
{
let mut cache = self.fingerprint_groups_cache.write();
*cache = Some((Instant::now(), sorted_groups.clone()));
}
sorted_groups.truncate(limit);
sorted_groups
}
pub fn start_background_tasks(self: Arc<Self>) {
let manager = self;
let decay_interval = Duration::from_secs(manager.config.decay_interval_secs);
tokio::spawn(async move {
let mut interval = tokio::time::interval(decay_interval);
loop {
tokio::select! {
_ = interval.tick() => {
if Arc::strong_count(&manager.shutdown) == 1 {
break;
}
manager.decay_scores();
manager.evict_if_needed();
}
_ = manager.shutdown.notified() => {
log::info!("Actor manager background tasks shutting down");
break;
}
}
}
});
}
pub fn shutdown(&self) {
self.shutdown.notify_one();
}
pub fn stats(&self) -> &ActorStats {
&self.stats
}
pub fn clear(&self) {
self.actors.clear();
self.ip_to_actor.clear();
self.fingerprint_to_actor.clear();
self.stats.total_actors.store(0, Ordering::Relaxed);
self.stats.blocked_actors.store(0, Ordering::Relaxed);
}
pub fn snapshot(&self) -> Vec<ActorState> {
self.actors.iter().map(|e| e.value().clone()).collect()
}
pub fn restore(&self, actors: Vec<ActorState>) {
self.clear();
let mut blocked_count: u64 = 0;
for actor in actors {
let actor_id = actor.actor_id.clone();
for ip in &actor.ips {
self.ip_to_actor.insert(*ip, actor_id.clone());
}
for fp in &actor.fingerprints {
self.fingerprint_to_actor
.insert(fp.clone(), actor_id.clone());
}
if actor.is_blocked {
blocked_count += 1;
}
self.actors.insert(actor_id, actor);
}
let actor_count = self.actors.len() as u64;
self.stats
.total_actors
.store(actor_count, Ordering::Relaxed);
self.stats
.blocked_actors
.store(blocked_count, Ordering::Relaxed);
self.stats
.total_created
.store(actor_count, Ordering::Relaxed);
}
fn decay_scores(&self) {
let decay_factor = self.config.risk_decay_factor;
for mut entry in self.actors.iter_mut() {
let actor = entry.value_mut();
if actor.risk_score > 0.0 {
actor.risk_score *= decay_factor;
if actor.risk_score < 0.01 {
actor.risk_score = 0.0;
}
}
}
}
fn evict_if_needed(&self) {
let current_len = self.actors.len();
if current_len <= self.config.max_actors {
return;
}
let evict_count = (self.config.max_actors / 100).max(1);
self.evict_oldest(evict_count);
}
fn maybe_evict(&self) {
let count = self.touch_counter.fetch_add(1, Ordering::Relaxed);
if !count.is_multiple_of(100) {
return;
}
if self.actors.len() < self.config.max_actors {
return;
}
let evict_count = (self.config.max_actors / 100).max(1);
self.evict_oldest(evict_count);
}
fn maybe_evict_fingerprint_mappings(&self) {
let current_len = self.fingerprint_to_actor.len();
if current_len < self.config.max_fingerprint_mappings {
return;
}
let target_len = (self.config.max_fingerprint_mappings * 9) / 10;
let to_evict = current_len.saturating_sub(target_len);
if to_evict == 0 {
return;
}
let keys_to_evict: Vec<String> = self
.fingerprint_to_actor
.iter()
.take(to_evict)
.map(|entry| entry.key().clone())
.collect();
for key in keys_to_evict {
self.fingerprint_to_actor.remove(&key);
}
self.stats
.fingerprint_evictions
.fetch_add(to_evict as u64, Ordering::Relaxed);
}
fn evict_oldest(&self, count: usize) {
let sample_size = (count * 10).min(1000).min(self.actors.len());
if sample_size == 0 {
return;
}
let mut candidates: Vec<(String, u64)> = Vec::with_capacity(sample_size);
for entry in self.actors.iter().take(sample_size) {
candidates.push((entry.key().clone(), entry.value().last_seen));
}
candidates.sort_unstable_by_key(|(_, ts)| *ts);
for (actor_id, _) in candidates.into_iter().take(count) {
self.remove_actor(&actor_id);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
fn remove_actor(&self, actor_id: &str) {
if let Some((_, actor)) = self.actors.remove(actor_id) {
for ip in &actor.ips {
self.ip_to_actor.remove(ip);
}
for fp in &actor.fingerprints {
self.fingerprint_to_actor.remove(fp);
}
self.stats.total_actors.fetch_sub(1, Ordering::Relaxed);
if actor.is_blocked {
self.stats.blocked_actors.fetch_sub(1, Ordering::Relaxed);
}
}
}
fn correlate_actor(&self, ip: IpAddr, fingerprint: Option<&str>) -> Option<String> {
let ip_actor = self.ip_to_actor.get(&ip).map(|r| r.value().clone());
let fp_actor = fingerprint.and_then(|fp| {
if fp.is_empty() {
None
} else {
self.fingerprint_to_actor.get(fp).map(|r| r.value().clone())
}
});
match (ip_actor, fp_actor) {
(Some(ip_id), Some(fp_id)) => {
if ip_id == fp_id {
Some(ip_id)
} else {
self.stats.correlations_made.fetch_add(1, Ordering::Relaxed);
Some(fp_id)
}
}
(Some(id), None) => {
self.stats.correlations_made.fetch_add(1, Ordering::Relaxed);
Some(id)
}
(None, Some(id)) => {
self.stats.correlations_made.fetch_add(1, Ordering::Relaxed);
Some(id)
}
(None, None) => None,
}
}
}
impl Default for ActorManager {
fn default() -> Self {
Self::new(ActorConfig::default())
}
}
fn generate_actor_id() -> String {
let mut bytes = [0u8; 16];
getrandom::getrandom(&mut bytes).expect("Failed to get random bytes");
bytes[6] = (bytes[6] & 0x0F) | 0x40; bytes[8] = (bytes[8] & 0x3F) | 0x80;
format!(
"{:08x}-{:04x}-{:04x}-{:04x}-{:012x}",
u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]),
u16::from_be_bytes([bytes[4], bytes[5]]),
u16::from_be_bytes([bytes[6], bytes[7]]),
u16::from_be_bytes([bytes[8], bytes[9]]),
u64::from_be_bytes([
0, 0, bytes[10], bytes[11], bytes[12], bytes[13], bytes[14], bytes[15]
])
)
}
#[inline]
fn now_ms() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn create_test_manager() -> ActorManager {
ActorManager::new(ActorConfig {
max_actors: 1000,
..Default::default()
})
}
fn create_test_ip(last_octet: u8) -> IpAddr {
format!("192.168.1.{}", last_octet).parse().unwrap()
}
#[test]
fn test_actor_creation() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
assert!(!actor_id.is_empty());
assert_eq!(manager.len(), 1);
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.actor_id, actor_id);
assert!(actor.ips.contains(&ip));
assert!(!actor.is_blocked);
}
#[test]
fn test_actor_retrieval_by_ip() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
let retrieved = manager.get_actor_by_ip(ip).unwrap();
assert_eq!(retrieved.actor_id, actor_id);
}
#[test]
fn test_actor_retrieval_by_fingerprint() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let fingerprint = "t13d1516h2_abc123";
let actor_id = manager.get_or_create_actor(ip, Some(fingerprint));
let retrieved = manager.get_actor_by_fingerprint(fingerprint).unwrap();
assert_eq!(retrieved.actor_id, actor_id);
}
#[test]
fn test_actor_nonexistent() {
let manager = create_test_manager();
assert!(manager.get_actor("nonexistent").is_none());
assert!(manager.get_actor_by_ip(create_test_ip(99)).is_none());
assert!(manager.get_actor_by_fingerprint("nonexistent").is_none());
}
#[test]
fn test_ip_correlation() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id1 = manager.get_or_create_actor(ip, None);
let actor_id2 = manager.get_or_create_actor(ip, None);
assert_eq!(actor_id1, actor_id2);
assert_eq!(manager.len(), 1);
}
#[test]
fn test_fingerprint_correlation() {
let manager = create_test_manager();
let ip1 = create_test_ip(1);
let ip2 = create_test_ip(2);
let fingerprint = "t13d1516h2_shared";
let actor_id1 = manager.get_or_create_actor(ip1, Some(fingerprint));
let actor_id2 = manager.get_or_create_actor(ip2, Some(fingerprint));
assert_eq!(actor_id1, actor_id2);
assert_eq!(manager.len(), 1);
let actor = manager.get_actor(&actor_id1).unwrap();
assert!(actor.ips.contains(&ip1));
assert!(actor.ips.contains(&ip2));
}
#[test]
fn test_fingerprint_preferred_over_ip() {
let manager = create_test_manager();
let ip1 = create_test_ip(1);
let ip2 = create_test_ip(2);
let fp1 = "fingerprint_1";
let fp2 = "fingerprint_2";
let actor_id1 = manager.get_or_create_actor(ip1, Some(fp1));
let actor_id2 = manager.get_or_create_actor(ip2, Some(fp2));
assert_ne!(actor_id1, actor_id2);
let actor_id3 = manager.get_or_create_actor(ip1, Some(fp2));
assert_eq!(actor_id3, actor_id2);
}
#[test]
fn test_record_rule_match() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.record_rule_match(&actor_id, "sqli-001", 25.0, "sqli");
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.rule_matches.len(), 1);
assert_eq!(actor.rule_matches[0].rule_id, "sqli-001");
assert_eq!(actor.rule_matches[0].risk_contribution, 25.0);
assert_eq!(actor.rule_matches[0].category, "sqli");
assert_eq!(actor.risk_score, 25.0);
}
#[test]
fn test_rule_match_risk_accumulation() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.record_rule_match(&actor_id, "sqli-001", 25.0, "sqli");
manager.record_rule_match(&actor_id, "xss-001", 20.0, "xss");
manager.record_rule_match(&actor_id, "sqli-002", 30.0, "sqli");
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.rule_matches.len(), 3);
assert_eq!(actor.risk_score, 75.0);
}
#[test]
fn test_rule_match_risk_capped() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
for _ in 0..15 {
manager.record_rule_match(&actor_id, "sqli-001", 10.0, "sqli");
}
let actor = manager.get_actor(&actor_id).unwrap();
assert!(actor.risk_score <= 100.0);
}
#[test]
fn test_rule_match_history_limit() {
let config = ActorConfig {
max_rule_matches: 5,
..Default::default()
};
let manager = ActorManager::new(config);
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
for i in 0..10 {
manager.record_rule_match(&actor_id, &format!("rule-{}", i), 5.0, "test");
}
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.rule_matches.len(), 5);
assert_eq!(actor.rule_matches[0].rule_id, "rule-5");
assert_eq!(actor.rule_matches[4].rule_id, "rule-9");
}
#[test]
fn test_block_actor() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
assert!(!manager.is_blocked(&actor_id));
let result = manager.block_actor(&actor_id, "High risk score");
assert!(result);
assert!(manager.is_blocked(&actor_id));
let actor = manager.get_actor(&actor_id).unwrap();
assert!(actor.is_blocked);
assert_eq!(actor.block_reason, Some("High risk score".to_string()));
assert!(actor.blocked_since.is_some());
}
#[test]
fn test_unblock_actor() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.block_actor(&actor_id, "Test");
assert!(manager.is_blocked(&actor_id));
let result = manager.unblock_actor(&actor_id);
assert!(result);
assert!(!manager.is_blocked(&actor_id));
let actor = manager.get_actor(&actor_id).unwrap();
assert!(!actor.is_blocked);
assert!(actor.block_reason.is_none());
assert!(actor.blocked_since.is_none());
}
#[test]
fn test_block_nonexistent() {
let manager = create_test_manager();
assert!(!manager.block_actor("nonexistent", "Test"));
assert!(!manager.unblock_actor("nonexistent"));
assert!(!manager.is_blocked("nonexistent"));
}
#[test]
fn test_lru_eviction() {
let config = ActorConfig {
max_actors: 100,
..Default::default()
};
let manager = ActorManager::new(config);
for i in 0..150 {
let ip = format!("10.0.{}.{}", i / 256, i % 256).parse().unwrap();
manager.get_or_create_actor(ip, None);
}
assert!(manager.len() <= 150);
for i in 0..200 {
let ip = format!("10.1.{}.{}", i / 256, i % 256).parse().unwrap();
manager.get_or_create_actor(ip, None);
}
let final_len = manager.len();
let evictions = manager.stats().evictions.load(Ordering::Relaxed);
assert!(evictions > 0, "Expected evictions to occur, got 0");
let created = manager.stats().total_created.load(Ordering::Relaxed);
assert!(
created > final_len as u64,
"Expected some actors to be evicted"
);
println!(
"LRU eviction test: created={}, evicted={}, final_len={}",
created, evictions, final_len
);
}
#[test]
fn test_eviction_removes_mappings() {
let config = ActorConfig {
max_actors: 10,
..Default::default()
};
let manager = ActorManager::new(config);
let first_ip = create_test_ip(1);
let first_fingerprint = "first_fp";
let first_actor_id = manager.get_or_create_actor(first_ip, Some(first_fingerprint));
std::thread::sleep(std::time::Duration::from_millis(10));
for i in 10..200 {
let ip = format!("10.0.{}.{}", i / 256, i % 256).parse().unwrap();
manager.get_or_create_actor(ip, Some(&format!("fp_{}", i)));
}
if manager.get_actor(&first_actor_id).is_none() {
assert!(manager.ip_to_actor.get(&first_ip).is_none());
assert!(manager
.fingerprint_to_actor
.get(first_fingerprint)
.is_none());
}
}
#[test]
fn test_decay_scores() {
let config = ActorConfig {
risk_decay_factor: 0.5,
..Default::default()
};
let manager = ActorManager::new(config);
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.record_rule_match(&actor_id, "test", 100.0, "test");
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.risk_score, 100.0);
manager.decay_scores();
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.risk_score, 50.0);
manager.decay_scores();
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.risk_score, 25.0);
}
#[test]
fn test_decay_floors_to_zero() {
let config = ActorConfig {
risk_decay_factor: 0.001,
..Default::default()
};
let manager = ActorManager::new(config);
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.record_rule_match(&actor_id, "test", 1.0, "test");
for _ in 0..5 {
manager.decay_scores();
}
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.risk_score, 0.0);
}
#[test]
fn test_bind_session() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
manager.bind_session(&actor_id, "session-123");
manager.bind_session(&actor_id, "session-456");
manager.bind_session(&actor_id, "session-123");
let actor = manager.get_actor(&actor_id).unwrap();
assert_eq!(actor.session_ids.len(), 2);
assert!(actor.session_ids.contains(&"session-123".to_string()));
assert!(actor.session_ids.contains(&"session-456".to_string()));
}
#[test]
fn test_list_actors() {
let manager = create_test_manager();
for i in 0..10 {
let ip = create_test_ip(i);
manager.get_or_create_actor(ip, None);
std::thread::sleep(std::time::Duration::from_millis(1));
}
let first_page = manager.list_actors(5, 0);
assert_eq!(first_page.len(), 5);
let second_page = manager.list_actors(5, 5);
assert_eq!(second_page.len(), 5);
for window in first_page.windows(2) {
assert!(window[0].last_seen >= window[1].last_seen);
}
}
#[test]
fn test_list_blocked_actors() {
let manager = create_test_manager();
for i in 0..10 {
let ip = create_test_ip(i);
let actor_id = manager.get_or_create_actor(ip, None);
if i % 2 == 0 {
manager.block_actor(&actor_id, "Test");
}
}
let blocked = manager.list_blocked_actors();
assert_eq!(blocked.len(), 5);
for actor in blocked {
assert!(actor.is_blocked);
}
}
#[test]
fn test_concurrent_access() {
let manager = Arc::new(create_test_manager());
let mut handles = vec![];
for thread_id in 0..10 {
let manager = Arc::clone(&manager);
handles.push(thread::spawn(move || {
for i in 0..100 {
let ip: IpAddr = format!("10.{}.0.{}", thread_id, i % 256).parse().unwrap();
let fingerprint = format!("fp_t{}_{}", thread_id, i % 5);
let actor_id = manager.get_or_create_actor(ip, Some(&fingerprint));
manager.record_rule_match(&actor_id, "test", 1.0, "test");
}
}));
}
for handle in handles {
handle.join().unwrap();
}
assert!(manager.len() > 0);
assert!(manager.stats().total_created.load(Ordering::Relaxed) > 0);
}
#[test]
fn test_stress_concurrent_updates() {
let manager = Arc::new(ActorManager::new(ActorConfig {
max_actors: 10_000,
max_fingerprint_mappings: 50_000,
..Default::default()
}));
let mut handles = vec![];
for thread_id in 0..16 {
let manager = Arc::clone(&manager);
handles.push(thread::spawn(move || {
for i in 0..500 {
let ip: IpAddr = format!("10.{}.{}.{}", thread_id, i / 256, i % 256)
.parse()
.unwrap();
let fingerprint = format!("fp_t{}_{}", thread_id, i % 20);
let actor_id = manager.get_or_create_actor(ip, Some(&fingerprint));
manager.record_rule_match(&actor_id, "stress", 0.5, "stress");
if i % 5 == 0 {
manager.touch_actor(&actor_id);
}
if i % 200 == 0 {
manager.block_actor(&actor_id, "stress");
}
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let stats = manager.stats();
assert!(manager.len() > 0);
assert!(stats.total_created.load(Ordering::Relaxed) > 0);
assert!(stats.total_rule_matches.load(Ordering::Relaxed) > 0);
}
#[test]
fn test_stats() {
let manager = create_test_manager();
let stats = manager.stats().snapshot();
assert_eq!(stats.total_actors, 0);
assert_eq!(stats.blocked_actors, 0);
assert_eq!(stats.total_created, 0);
for i in 0..5 {
let ip = create_test_ip(i);
let actor_id = manager.get_or_create_actor(ip, None);
manager.record_rule_match(&actor_id, "test", 10.0, "test");
}
let actor = manager.list_actors(1, 0)[0].clone();
manager.block_actor(&actor.actor_id, "Test");
let stats = manager.stats().snapshot();
assert_eq!(stats.total_actors, 5);
assert_eq!(stats.blocked_actors, 1);
assert_eq!(stats.total_created, 5);
assert_eq!(stats.total_rule_matches, 5);
}
#[test]
fn test_clear() {
let manager = create_test_manager();
for i in 0..10 {
let ip = create_test_ip(i);
let actor_id = manager.get_or_create_actor(ip, Some(&format!("fp_{}", i)));
manager.block_actor(&actor_id, "Test");
}
assert_eq!(manager.len(), 10);
manager.clear();
assert_eq!(manager.len(), 0);
assert!(manager.ip_to_actor.is_empty());
assert!(manager.fingerprint_to_actor.is_empty());
assert_eq!(manager.stats().total_actors.load(Ordering::Relaxed), 0);
assert_eq!(manager.stats().blocked_actors.load(Ordering::Relaxed), 0);
}
#[test]
fn test_default() {
let manager = ActorManager::default();
assert!(manager.is_enabled());
assert!(manager.is_empty());
assert_eq!(manager.config().max_actors, 100_000);
}
#[test]
fn test_actor_id_uniqueness() {
let mut ids = HashSet::new();
for _ in 0..1000 {
let id = generate_actor_id();
assert!(!ids.contains(&id), "Duplicate ID generated: {}", id);
ids.insert(id);
}
}
#[test]
fn test_actor_id_format() {
let id = generate_actor_id();
assert_eq!(id.len(), 36);
assert_eq!(id.chars().nth(8), Some('-'));
assert_eq!(id.chars().nth(13), Some('-'));
assert_eq!(id.chars().nth(14), Some('4')); assert_eq!(id.chars().nth(18), Some('-'));
assert_eq!(id.chars().nth(23), Some('-'));
}
#[test]
fn test_empty_fingerprint() {
let manager = create_test_manager();
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, Some(""));
let actor = manager.get_actor(&actor_id).unwrap();
assert!(actor.fingerprints.is_empty());
assert!(manager.fingerprint_to_actor.is_empty());
}
#[test]
fn test_ipv6_addresses() {
let manager = create_test_manager();
let ipv6_1: IpAddr = "2001:db8::1".parse().unwrap();
let ipv6_2: IpAddr = "2001:db8::2".parse().unwrap();
let actor_id1 = manager.get_or_create_actor(ipv6_1, Some("ipv6_fp"));
let actor_id2 = manager.get_or_create_actor(ipv6_2, Some("ipv6_fp"));
assert_eq!(actor_id1, actor_id2);
let actor = manager.get_actor(&actor_id1).unwrap();
assert!(actor.ips.contains(&ipv6_1));
assert!(actor.ips.contains(&ipv6_2));
}
#[test]
fn test_disabled_manager() {
let config = ActorConfig {
enabled: false,
..Default::default()
};
let manager = ActorManager::new(config);
assert!(!manager.is_enabled());
let ip = create_test_ip(1);
let actor_id = manager.get_or_create_actor(ip, None);
assert!(!actor_id.is_empty());
assert!(manager.is_empty());
manager.record_rule_match(&actor_id, "test", 10.0, "test");
assert!(manager.get_actor(&actor_id).is_none());
}
}