1use std::collections::{HashMap, VecDeque};
15use std::time::Duration;
16
17use serde::{Deserialize, Serialize};
18use serde_json::Value;
19use uuid::Uuid;
20
21use super::incident::{IncludeMode, RiskEntityView, RiskIncidentResult, RiskRef};
22use super::snapshot::{EntitySnapshot, RiskStateSnapshot, SNAPSHOT_VERSION};
23
24pub const DEFAULT_MAX_OPEN_ENTITIES: usize = 100_000;
28pub const DEFAULT_MAX_SOURCES_PER_ENTITY: usize = 1_000;
30pub const DEFAULT_MAX_RESULTS_PER_INCIDENT: usize = 1_000;
33
34#[derive(Debug, Clone, Copy)]
36pub struct RiskCaps {
37 pub max_open_entities: usize,
39 pub max_sources_per_entity: usize,
41 pub max_results_per_incident: usize,
43}
44
45impl Default for RiskCaps {
46 fn default() -> Self {
47 RiskCaps {
48 max_open_entities: DEFAULT_MAX_OPEN_ENTITIES,
49 max_sources_per_entity: DEFAULT_MAX_SOURCES_PER_ENTITY,
50 max_results_per_incident: DEFAULT_MAX_RESULTS_PER_INCIDENT,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct IncidentConfig {
58 pub window: Duration,
60 pub score_threshold: Option<i64>,
62 pub tactic_count_threshold: Option<u64>,
64 pub cooldown: Duration,
66 pub include: IncludeMode,
68 pub nats_subject: Option<String>,
70 pub caps: RiskCaps,
72}
73
74impl IncidentConfig {
75 fn window_secs(&self) -> i64 {
76 self.window.as_secs() as i64
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct Contribution {
83 pub ts: i64,
85 pub score: i64,
87 pub tactics: Vec<String>,
89 pub rule: String,
91 #[serde(skip_serializing_if = "Option::is_none", default)]
93 pub level: Option<String>,
94 #[serde(skip_serializing_if = "Option::is_none", default)]
96 pub result: Option<Value>,
97}
98
99#[derive(Debug, Default)]
101struct EntityWindow {
102 contributions: VecDeque<Contribution>,
103 last_fired: Option<i64>,
104 last_seen: i64,
105}
106
107impl EntityWindow {
108 fn prune(&mut self, cutoff: i64) {
110 while let Some(front) = self.contributions.front() {
111 if front.ts <= cutoff {
112 self.contributions.pop_front();
113 } else {
114 break;
115 }
116 }
117 }
118
119 fn is_empty(&self) -> bool {
120 self.contributions.is_empty()
121 }
122}
123
124struct WindowStats {
126 score: i64,
127 tactics: Vec<String>,
128 sources: Vec<String>,
129 source_count: u64,
130 result_count: u64,
131 window_start: i64,
132 window_end: i64,
133}
134
135fn window_stats<'a>(
138 contributions: impl Iterator<Item = &'a Contribution>,
139 max_sources: usize,
140) -> WindowStats {
141 let mut score: i64 = 0;
142 let mut tactics: Vec<String> = Vec::new();
143 let mut sources: Vec<String> = Vec::new();
144 let mut result_count: u64 = 0;
145 let mut window_start = i64::MAX;
146 let mut window_end = i64::MIN;
147 for c in contributions {
148 score += c.score;
149 result_count += 1;
150 window_start = window_start.min(c.ts);
151 window_end = window_end.max(c.ts);
152 for t in &c.tactics {
153 if !tactics.contains(t) {
154 tactics.push(t.clone());
155 }
156 }
157 if !sources.contains(&c.rule) {
158 sources.push(c.rule.clone());
159 }
160 }
161 let source_count = sources.len() as u64;
162 if sources.len() > max_sources {
163 sources.truncate(max_sources);
164 }
165 if result_count == 0 {
166 window_start = 0;
167 window_end = 0;
168 }
169 WindowStats {
170 score,
171 tactics,
172 sources,
173 source_count,
174 result_count,
175 window_start,
176 window_end,
177 }
178}
179
180#[derive(Debug, Default)]
183pub struct RiskState {
184 entities: HashMap<(String, String), EntityWindow>,
185}
186
187pub struct RecordOutcome {
189 pub incident: Option<RiskIncidentResult>,
191 pub evicted: bool,
193}
194
195impl RiskState {
196 pub fn len(&self) -> usize {
198 self.entities.len()
199 }
200
201 pub fn is_empty(&self) -> bool {
203 self.entities.is_empty()
204 }
205
206 pub fn total_entries(&self) -> usize {
208 self.entities.values().map(|e| e.contributions.len()).sum()
209 }
210
211 pub fn record(
214 &mut self,
215 cfg: &IncidentConfig,
216 entity_type: &str,
217 entity_value: &str,
218 contribution: Contribution,
219 now: i64,
220 ) -> RecordOutcome {
221 let key = (entity_type.to_string(), entity_value.to_string());
222 let cutoff = now - cfg.window_secs();
223
224 if !self.entities.contains_key(&key) && self.entities.len() >= cfg.caps.max_open_entities {
225 return RecordOutcome {
229 incident: None,
230 evicted: true,
231 };
232 }
233
234 let entity = self.entities.entry(key.clone()).or_default();
235 entity.prune(cutoff);
236 entity.last_seen = now;
237 entity.contributions.push_back(contribution);
238 while entity.contributions.len() > cfg.caps.max_results_per_incident {
239 entity.contributions.pop_front();
240 }
241
242 let stats = window_stats(entity.contributions.iter(), cfg.caps.max_sources_per_entity);
243 let tactic_count = stats.tactics.len() as u64;
244
245 let trigger = if cfg.score_threshold.is_some_and(|t| stats.score >= t) {
246 Some("score")
247 } else if cfg
248 .tactic_count_threshold
249 .is_some_and(|t| tactic_count >= t)
250 {
251 Some("tactic_count")
252 } else {
253 None
254 };
255
256 let incident = trigger.and_then(|trigger| {
257 let cooling = entity
258 .last_fired
259 .is_some_and(|lf| now - lf < cfg.cooldown.as_secs() as i64);
260 if cooling {
261 return None;
262 }
263 entity.last_fired = Some(now);
264 Some(build_incident(
265 cfg,
266 entity_type,
267 entity_value,
268 trigger,
269 tactic_count,
270 &stats,
271 entity.contributions.iter(),
272 ))
273 });
274
275 RecordOutcome {
276 incident,
277 evicted: false,
278 }
279 }
280
281 pub fn tick(&mut self, cfg: &IncidentConfig, now: i64) -> usize {
284 let cutoff = now - cfg.window_secs();
285 let before = self.entities.len();
286 self.entities.retain(|_, entity| {
287 entity.prune(cutoff);
288 !entity.is_empty()
289 });
290 before - self.entities.len()
291 }
292
293 pub fn snapshot(&self) -> RiskStateSnapshot {
295 let entities = self
296 .entities
297 .iter()
298 .map(|((entity_type, entity_value), window)| EntitySnapshot {
299 entity_type: entity_type.clone(),
300 entity_value: entity_value.clone(),
301 last_fired: window.last_fired,
302 last_seen: window.last_seen,
303 contributions: window.contributions.iter().cloned().collect(),
304 })
305 .collect();
306 RiskStateSnapshot {
307 version: SNAPSHOT_VERSION,
308 entities,
309 }
310 }
311
312 pub fn restore(
317 &mut self,
318 snap: RiskStateSnapshot,
319 window_secs: i64,
320 max_open_entities: usize,
321 now: i64,
322 ) -> bool {
323 if snap.version != SNAPSHOT_VERSION {
324 return false;
325 }
326 let cutoff = now - window_secs;
327 for entity in snap.entities {
328 if self.entities.len() >= max_open_entities {
329 break;
330 }
331 let contributions: VecDeque<Contribution> = entity
332 .contributions
333 .into_iter()
334 .filter(|c| c.ts > cutoff)
335 .collect();
336 if contributions.is_empty() {
337 continue;
338 }
339 self.entities.insert(
340 (entity.entity_type, entity.entity_value),
341 EntityWindow {
342 contributions,
343 last_fired: entity.last_fired,
344 last_seen: entity.last_seen,
345 },
346 );
347 }
348 true
349 }
350
351 pub fn views(&self, cfg: &IncidentConfig, now: i64) -> Vec<RiskEntityView> {
353 let cutoff = now - cfg.window_secs();
354 let mut out = Vec::new();
355 for ((entity_type, entity_value), entity) in &self.entities {
356 let live = entity.contributions.iter().filter(|c| c.ts > cutoff);
357 let stats = window_stats(live, cfg.caps.max_sources_per_entity);
358 if stats.result_count == 0 {
359 continue;
360 }
361 out.push(RiskEntityView {
362 entity_type: entity_type.clone(),
363 entity_value: entity_value.clone(),
364 score: stats.score,
365 tactic_count: stats.tactics.len() as u64,
366 source_count: stats.source_count,
367 result_count: stats.result_count,
368 window_start: stats.window_start,
369 window_end: stats.window_end,
370 last_fired: entity.last_fired,
371 });
372 }
373 out
374 }
375}
376
377fn build_incident<'a>(
379 cfg: &IncidentConfig,
380 entity_type: &str,
381 entity_value: &str,
382 trigger: &'static str,
383 tactic_count: u64,
384 stats: &WindowStats,
385 contributions: impl Iterator<Item = &'a Contribution>,
386) -> RiskIncidentResult {
387 let recent: Vec<&Contribution> = {
388 let all: Vec<&Contribution> = contributions.collect();
389 let take = cfg.caps.max_results_per_incident.min(all.len());
390 all[all.len() - take..].to_vec()
391 };
392
393 let (refs, results) = match cfg.include {
394 IncludeMode::Refs => {
395 let refs = recent
396 .iter()
397 .map(|c| RiskRef {
398 rule: c.rule.clone(),
399 level: c.level.clone(),
400 score: c.score,
401 timestamp: c.ts,
402 })
403 .collect();
404 (Some(refs), None)
405 }
406 IncludeMode::Results => {
407 let results = recent.iter().filter_map(|c| c.result.clone()).collect();
408 (None, Some(results))
409 }
410 };
411
412 RiskIncidentResult {
413 risk_incident_id: Uuid::new_v4().to_string(),
414 entity_type: entity_type.to_string(),
415 entity_value: entity_value.to_string(),
416 trigger,
417 score: stats.score,
418 score_threshold: cfg.score_threshold,
419 tactic_count,
420 tactic_count_threshold: cfg.tactic_count_threshold,
421 tactics: stats.tactics.clone(),
422 sources: stats.sources.clone(),
423 source_count: stats.source_count,
424 window_start: stats.window_start,
425 window_end: stats.window_end,
426 result_count: stats.result_count,
427 refs,
428 results,
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435
436 fn cfg(score: Option<i64>, tactics: Option<u64>) -> IncidentConfig {
437 IncidentConfig {
438 window: Duration::from_secs(3600),
439 score_threshold: score,
440 tactic_count_threshold: tactics,
441 cooldown: Duration::from_secs(600),
442 include: IncludeMode::Refs,
443 nats_subject: None,
444 caps: RiskCaps::default(),
445 }
446 }
447
448 fn contrib(ts: i64, score: i64, tactics: &[&str], rule: &str) -> Contribution {
449 Contribution {
450 ts,
451 score,
452 tactics: tactics.iter().map(|s| s.to_string()).collect(),
453 rule: rule.to_string(),
454 level: Some("high".to_string()),
455 result: None,
456 }
457 }
458
459 #[test]
460 fn fires_on_score_threshold() {
461 let mut st = RiskState::default();
462 let c = cfg(Some(100), None);
463 let a = st.record(&c, "user", "alice", contrib(0, 60, &["execution"], "r1"), 0);
464 assert!(a.incident.is_none());
465 let b = st.record(
466 &c,
467 "user",
468 "alice",
469 contrib(10, 60, &["persistence"], "r2"),
470 10,
471 );
472 let inc = b.incident.expect("threshold crossed");
473 assert_eq!(inc.trigger, "score");
474 assert_eq!(inc.score, 120);
475 assert_eq!(inc.entity_value, "alice");
476 assert_eq!(inc.source_count, 2);
477 }
478
479 #[test]
480 fn fires_on_tactic_count_threshold() {
481 let mut st = RiskState::default();
482 let c = cfg(None, Some(3));
483 st.record(&c, "host", "dc01", contrib(0, 1, &["execution"], "r1"), 0);
484 st.record(&c, "host", "dc01", contrib(1, 1, &["persistence"], "r2"), 1);
485 let third = st.record(&c, "host", "dc01", contrib(2, 1, &["impact"], "r3"), 2);
486 let inc = third.incident.expect("three distinct tactics");
487 assert_eq!(inc.trigger, "tactic_count");
488 assert_eq!(inc.tactic_count, 3);
489 }
490
491 #[test]
492 fn cooldown_suppresses_refire() {
493 let mut st = RiskState::default();
494 let c = cfg(Some(50), None);
495 let first = st.record(&c, "user", "bob", contrib(0, 50, &["execution"], "r1"), 0);
496 assert!(first.incident.is_some());
497 let second = st.record(
499 &c,
500 "user",
501 "bob",
502 contrib(100, 50, &["execution"], "r1"),
503 100,
504 );
505 assert!(second.incident.is_none());
506 let third = st.record(
508 &c,
509 "user",
510 "bob",
511 contrib(700, 50, &["execution"], "r1"),
512 700,
513 );
514 assert!(third.incident.is_some());
515 }
516
517 #[test]
518 fn window_prunes_old_contributions() {
519 let mut st = RiskState::default();
520 let c = cfg(Some(100), None);
521 st.record(&c, "user", "carol", contrib(0, 60, &["execution"], "r1"), 0);
522 let later = st.record(
524 &c,
525 "user",
526 "carol",
527 contrib(4000, 60, &["execution"], "r1"),
528 4000,
529 );
530 assert!(
531 later.incident.is_none(),
532 "old contribution pruned, sum is 60"
533 );
534 }
535
536 #[test]
537 fn at_capacity_new_entity_is_not_tracked() {
538 let mut st = RiskState::default();
539 let mut c = cfg(Some(1), None);
540 c.caps.max_open_entities = 1;
541 let a = st.record(&c, "user", "a", contrib(0, 1, &[], "r1"), 0);
542 assert!(a.incident.is_some());
543 let b = st.record(&c, "user", "b", contrib(0, 1, &[], "r1"), 0);
544 assert!(b.evicted, "second distinct entity rejected at capacity");
545 assert_eq!(st.len(), 1);
546 }
547
548 #[test]
549 fn tick_evicts_fully_aged_entities() {
550 let mut st = RiskState::default();
551 let c = cfg(Some(1000), None);
552 st.record(&c, "user", "dan", contrib(0, 10, &["execution"], "r1"), 0);
553 assert_eq!(st.len(), 1);
554 let removed = st.tick(&c, 4000);
555 assert_eq!(removed, 1);
556 assert!(st.is_empty());
557 }
558
559 #[test]
560 fn snapshot_round_trips_and_prunes() {
561 let mut st = RiskState::default();
562 let c = cfg(Some(100), None);
563 st.record(
564 &c,
565 "user",
566 "erin",
567 contrib(100, 50, &["execution"], "r1"),
568 100,
569 );
570
571 let json = serde_json::to_string(&st.snapshot()).unwrap();
573 let snap: RiskStateSnapshot = serde_json::from_str(&json).unwrap();
574
575 let mut fresh = RiskState::default();
577 assert!(fresh.restore(
578 snap,
579 c.window.as_secs() as i64,
580 c.caps.max_open_entities,
581 200
582 ));
583 assert_eq!(fresh.len(), 1);
584 let again = fresh.record(
585 &c,
586 "user",
587 "erin",
588 contrib(200, 50, &["execution"], "r1"),
589 200,
590 );
591 assert_eq!(again.incident.unwrap().score, 100, "restored 50 + new 50");
592
593 let snap2: RiskStateSnapshot =
595 serde_json::from_str(&serde_json::to_string(&st.snapshot()).unwrap()).unwrap();
596 let mut aged = RiskState::default();
597 assert!(aged.restore(
598 snap2,
599 c.window.as_secs() as i64,
600 c.caps.max_open_entities,
601 100 + 3600 + 5
602 ));
603 assert!(aged.is_empty(), "stale entity pruned on restore");
604 }
605
606 #[test]
607 fn restore_honors_max_open_entities() {
608 let mut src = RiskState::default();
609 let c = cfg(Some(1_000_000), None);
610 for i in 0..5 {
611 src.record(
612 &c,
613 "user",
614 &format!("u{i}"),
615 contrib(0, 10, &["execution"], "r1"),
616 0,
617 );
618 }
619 let snap = src.snapshot();
620
621 let mut restored = RiskState::default();
622 assert!(restored.restore(snap, c.window.as_secs() as i64, 3, 0));
624 assert_eq!(restored.len(), 3);
625 }
626
627 #[test]
628 fn restore_rejects_version_mismatch() {
629 let mut st = RiskState::default();
630 let snap = RiskStateSnapshot {
631 version: SNAPSHOT_VERSION + 1,
632 entities: vec![],
633 };
634 assert!(!st.restore(snap, 3600, usize::MAX, 0));
635 }
636}