1use llmtrace_core::{AgentActionType, SecurityFinding, SecuritySeverity};
34use regex::Regex;
35use std::collections::{HashMap, VecDeque};
36use std::time::{Duration, Instant};
37
38#[derive(Debug, Clone)]
44pub struct TrackedAction {
45 pub action_type: AgentActionType,
47 pub target: String,
49 pub timestamp: Instant,
51 pub session_id: String,
53 pub risk_score: f64,
55}
56
57#[derive(Debug, Clone)]
63pub struct PatternStep {
64 pub action_type: Option<AgentActionType>,
66 pub target_pattern: Option<String>,
68 pub min_risk: f64,
70}
71
72#[derive(Debug, Clone)]
78pub struct AttackPattern {
79 pub name: String,
81 pub description: String,
83 pub steps: Vec<PatternStep>,
85 pub max_time_window: Duration,
87 pub severity: SecuritySeverity,
89 pub confidence: f64,
91}
92
93#[derive(Debug, Clone)]
99pub struct CorrelationConfig {
100 pub max_history_per_session: usize,
102 pub session_timeout: Duration,
104 pub patterns: Vec<AttackPattern>,
106 pub enable_temporal_analysis: bool,
108 pub rapid_action_threshold: Duration,
110 pub rapid_action_count: usize,
112}
113
114impl Default for CorrelationConfig {
115 fn default() -> Self {
116 Self {
117 max_history_per_session: 500,
118 session_timeout: Duration::from_secs(3600),
119 patterns: Vec::new(),
120 enable_temporal_analysis: true,
121 rapid_action_threshold: Duration::from_secs(1),
122 rapid_action_count: 10,
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
133pub struct CorrelationResult {
134 pub session_id: String,
136 pub pattern_matches: Vec<PatternMatch>,
138 pub rapid_actions: Option<RapidActionAlert>,
140 pub escalation: Option<EscalationSequence>,
142 pub total_risk: f64,
144}
145
146#[derive(Debug, Clone)]
148pub struct PatternMatch {
149 pub pattern_name: String,
151 pub matched_actions: Vec<usize>,
153 pub confidence: f64,
155 pub severity: SecuritySeverity,
157 pub time_span: Duration,
159}
160
161#[derive(Debug, Clone)]
163pub struct RapidActionAlert {
164 pub action_count: usize,
166 pub time_window: Duration,
168 pub avg_interval: Duration,
170}
171
172#[derive(Debug, Clone)]
174pub struct EscalationSequence {
175 pub steps: Vec<(AgentActionType, String, f64)>,
177 pub risk_trajectory: Vec<f64>,
179}
180
181#[derive(Debug)]
187struct CompiledStep {
188 action_type: Option<AgentActionType>,
189 target_regex: Option<Regex>,
190 min_risk: f64,
191}
192
193#[derive(Debug)]
195struct CompiledPattern {
196 pattern: AttackPattern,
197 compiled_steps: Vec<CompiledStep>,
198}
199
200pub struct ActionCorrelator {
207 config: CorrelationConfig,
208 session_histories: HashMap<String, VecDeque<TrackedAction>>,
209 compiled_patterns: Vec<CompiledPattern>,
210}
211
212impl std::fmt::Debug for ActionCorrelator {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("ActionCorrelator")
215 .field("config", &self.config)
216 .field("session_count", &self.session_histories.len())
217 .field("pattern_count", &self.compiled_patterns.len())
218 .finish()
219 }
220}
221
222impl ActionCorrelator {
223 pub fn new(config: CorrelationConfig) -> Self {
225 let compiled_patterns = config
226 .patterns
227 .iter()
228 .filter_map(|p| compile_pattern(p.clone()))
229 .collect();
230
231 Self {
232 config,
233 session_histories: HashMap::new(),
234 compiled_patterns,
235 }
236 }
237
238 pub fn with_defaults() -> Self {
240 let config = CorrelationConfig {
241 patterns: default_attack_patterns(),
242 ..Default::default()
243 };
244 Self::new(config)
245 }
246
247 pub fn record_action(&mut self, action: TrackedAction) -> CorrelationResult {
250 let session_id = action.session_id.clone();
251
252 let history = self
253 .session_histories
254 .entry(session_id.clone())
255 .or_default();
256 history.push_back(action);
257
258 while history.len() > self.config.max_history_per_session {
260 history.pop_front();
261 }
262
263 let pattern_matches = self.check_patterns(&session_id);
264
265 let rapid_actions = if self.config.enable_temporal_analysis {
266 self.detect_rapid_actions(&session_id)
267 } else {
268 None
269 };
270
271 let escalation = self.detect_privilege_escalation_sequence(&session_id);
272
273 let total_risk = compute_total_risk(&pattern_matches, &rapid_actions, &escalation);
274
275 CorrelationResult {
276 session_id,
277 pattern_matches,
278 rapid_actions,
279 escalation,
280 total_risk,
281 }
282 }
283
284 #[must_use]
286 pub fn check_patterns(&self, session_id: &str) -> Vec<PatternMatch> {
287 let history = match self.session_histories.get(session_id) {
288 Some(h) if !h.is_empty() => h,
289 _ => return Vec::new(),
290 };
291
292 let mut matches = Vec::new();
293
294 for cp in &self.compiled_patterns {
295 if let Some(m) = match_pattern(cp, history) {
296 matches.push(m);
297 }
298 }
299
300 matches
301 }
302
303 #[must_use]
305 pub fn detect_rapid_actions(&self, session_id: &str) -> Option<RapidActionAlert> {
306 let history = self.session_histories.get(session_id)?;
307
308 if history.len() < self.config.rapid_action_count {
309 return None;
310 }
311
312 let start = history.len() - self.config.rapid_action_count;
314 let window: Vec<&TrackedAction> = history.iter().skip(start).collect();
315
316 let first_ts = window.first()?.timestamp;
317 let last_ts = window.last()?.timestamp;
318 let time_window = last_ts.duration_since(first_ts);
319
320 let mut all_rapid = true;
322 for pair in window.windows(2) {
323 let interval = pair[1].timestamp.duration_since(pair[0].timestamp);
324 if interval > self.config.rapid_action_threshold {
325 all_rapid = false;
326 break;
327 }
328 }
329
330 if !all_rapid {
331 return None;
332 }
333
334 let count = window.len();
335 let avg_interval = if count > 1 {
336 time_window / (count as u32 - 1)
337 } else {
338 Duration::ZERO
339 };
340
341 Some(RapidActionAlert {
342 action_count: count,
343 time_window,
344 avg_interval,
345 })
346 }
347
348 #[must_use]
352 pub fn detect_privilege_escalation_sequence(
353 &self,
354 session_id: &str,
355 ) -> Option<EscalationSequence> {
356 let history = self.session_histories.get(session_id)?;
357
358 if history.len() < 3 {
359 return None;
360 }
361
362 let mut escalation_steps: Vec<(AgentActionType, String, f64)> = Vec::new();
364 let mut trajectory: Vec<f64> = Vec::new();
365
366 for action in history.iter() {
367 let extends = trajectory
368 .last()
369 .is_none_or(|&prev| action.risk_score > prev);
370
371 if extends {
372 escalation_steps.push((
373 action.action_type.clone(),
374 action.target.clone(),
375 action.risk_score,
376 ));
377 trajectory.push(action.risk_score);
378 } else {
379 escalation_steps.clear();
381 trajectory.clear();
382 escalation_steps.push((
383 action.action_type.clone(),
384 action.target.clone(),
385 action.risk_score,
386 ));
387 trajectory.push(action.risk_score);
388 }
389 }
390
391 if escalation_steps.len() < 3 {
392 return None;
393 }
394
395 let final_risk = trajectory.last().copied().unwrap_or(0.0);
396 if final_risk < 0.7 {
397 return None;
398 }
399
400 Some(EscalationSequence {
401 steps: escalation_steps,
402 risk_trajectory: trajectory,
403 })
404 }
405
406 pub fn cleanup_expired_sessions(&mut self) {
408 let timeout = self.config.session_timeout;
409 let now = Instant::now();
410
411 self.session_histories.retain(|_session_id, history| {
412 match history.back() {
413 Some(last_action) => now.duration_since(last_action.timestamp) < timeout,
414 None => false, }
416 });
417 }
418
419 #[must_use]
421 pub fn session_count(&self) -> usize {
422 self.session_histories.len()
423 }
424
425 #[must_use]
427 pub fn to_security_findings(result: &CorrelationResult) -> Vec<SecurityFinding> {
428 let mut findings = Vec::new();
429
430 for pm in &result.pattern_matches {
431 let finding = SecurityFinding::new(
432 pm.severity.clone(),
433 format!("attack_pattern_{}", pm.pattern_name),
434 format!(
435 "Multi-step attack pattern '{}' matched in session '{}' \
436 ({} actions over {:.1}s, confidence {:.2})",
437 pm.pattern_name,
438 result.session_id,
439 pm.matched_actions.len(),
440 pm.time_span.as_secs_f64(),
441 pm.confidence,
442 ),
443 pm.confidence,
444 )
445 .with_location(format!("session:{}", result.session_id))
446 .with_metadata("pattern_name".to_string(), pm.pattern_name.clone())
447 .with_metadata(
448 "matched_action_count".to_string(),
449 pm.matched_actions.len().to_string(),
450 )
451 .with_metadata(
452 "time_span_ms".to_string(),
453 pm.time_span.as_millis().to_string(),
454 );
455
456 findings.push(finding);
457 }
458
459 if let Some(ref rapid) = result.rapid_actions {
460 let finding = SecurityFinding::new(
461 SecuritySeverity::Medium,
462 "rapid_action_alert".to_string(),
463 format!(
464 "Rapid-fire actions detected in session '{}': \
465 {} actions in {:.1}s (avg interval {:.0}ms)",
466 result.session_id,
467 rapid.action_count,
468 rapid.time_window.as_secs_f64(),
469 rapid.avg_interval.as_secs_f64() * 1000.0,
470 ),
471 0.8,
472 )
473 .with_location(format!("session:{}", result.session_id))
474 .with_metadata("action_count".to_string(), rapid.action_count.to_string())
475 .with_metadata(
476 "time_window_ms".to_string(),
477 rapid.time_window.as_millis().to_string(),
478 );
479
480 findings.push(finding);
481 }
482
483 if let Some(ref esc) = result.escalation {
484 let finding = SecurityFinding::new(
485 SecuritySeverity::High,
486 "privilege_escalation_sequence".to_string(),
487 format!(
488 "Privilege escalation detected in session '{}': \
489 {} steps with risk trajectory {:?}",
490 result.session_id,
491 esc.steps.len(),
492 esc.risk_trajectory,
493 ),
494 0.85,
495 )
496 .with_location(format!("session:{}", result.session_id))
497 .with_metadata("step_count".to_string(), esc.steps.len().to_string())
498 .with_metadata(
499 "final_risk".to_string(),
500 esc.risk_trajectory
501 .last()
502 .map_or("0.0".to_string(), |r| format!("{r:.2}")),
503 );
504
505 findings.push(finding);
506 }
507
508 findings
509 }
510}
511
512fn compile_pattern(pattern: AttackPattern) -> Option<CompiledPattern> {
518 let mut compiled_steps = Vec::with_capacity(pattern.steps.len());
519
520 for step in &pattern.steps {
521 let target_regex = match &step.target_pattern {
522 Some(pat) => Some(Regex::new(pat).ok()?),
523 None => None,
524 };
525 compiled_steps.push(CompiledStep {
526 action_type: step.action_type.clone(),
527 target_regex,
528 min_risk: step.min_risk,
529 });
530 }
531
532 Some(CompiledPattern {
533 pattern,
534 compiled_steps,
535 })
536}
537
538fn step_matches(step: &CompiledStep, action: &TrackedAction) -> bool {
540 if let Some(ref required_type) = step.action_type {
541 if &action.action_type != required_type {
542 return false;
543 }
544 }
545
546 if let Some(ref re) = step.target_regex {
547 if !re.is_match(&action.target) {
548 return false;
549 }
550 }
551
552 if action.risk_score < step.min_risk {
553 return false;
554 }
555
556 true
557}
558
559fn match_pattern(
563 compiled: &CompiledPattern,
564 history: &VecDeque<TrackedAction>,
565) -> Option<PatternMatch> {
566 if compiled.compiled_steps.is_empty() {
567 return None;
568 }
569
570 let mut step_idx = 0;
571 let mut matched_indices: Vec<usize> = Vec::new();
572
573 for (i, action) in history.iter().enumerate() {
574 if step_idx >= compiled.compiled_steps.len() {
575 break;
576 }
577 if step_matches(&compiled.compiled_steps[step_idx], action) {
578 matched_indices.push(i);
579 step_idx += 1;
580 }
581 }
582
583 if step_idx < compiled.compiled_steps.len() {
585 return None;
586 }
587
588 let first = &history[matched_indices[0]];
590 let last = &history[*matched_indices.last().unwrap()];
591 let time_span = last.timestamp.duration_since(first.timestamp);
592
593 if time_span > compiled.pattern.max_time_window {
594 return None;
595 }
596
597 Some(PatternMatch {
598 pattern_name: compiled.pattern.name.clone(),
599 matched_actions: matched_indices,
600 confidence: compiled.pattern.confidence,
601 severity: compiled.pattern.severity.clone(),
602 time_span,
603 })
604}
605
606fn compute_total_risk(
608 matches: &[PatternMatch],
609 rapid: &Option<RapidActionAlert>,
610 escalation: &Option<EscalationSequence>,
611) -> f64 {
612 let mut risk = 0.0_f64;
613
614 for m in matches {
615 risk += m.confidence
616 * match m.severity {
617 SecuritySeverity::Critical => 1.0,
618 SecuritySeverity::High => 0.8,
619 SecuritySeverity::Medium => 0.5,
620 SecuritySeverity::Low => 0.3,
621 SecuritySeverity::Info => 0.1,
622 };
623 }
624
625 if rapid.is_some() {
626 risk += 0.3;
627 }
628
629 if let Some(ref esc) = escalation {
630 risk += esc.risk_trajectory.last().copied().unwrap_or(0.0);
631 }
632
633 risk.min(1.0)
634}
635
636fn default_attack_patterns() -> Vec<AttackPattern> {
642 vec![
643 AttackPattern {
645 name: "data_exfiltration_chain".to_string(),
646 description: "File read followed by web access to an external URL".to_string(),
647 steps: vec![
648 PatternStep {
649 action_type: Some(AgentActionType::FileAccess),
650 target_pattern: None,
651 min_risk: 0.0,
652 },
653 PatternStep {
654 action_type: Some(AgentActionType::WebAccess),
655 target_pattern: Some(r"(?i)https?://".to_string()),
656 min_risk: 0.0,
657 },
658 ],
659 max_time_window: Duration::from_secs(300),
660 severity: SecuritySeverity::High,
661 confidence: 0.8,
662 },
663 AttackPattern {
665 name: "credential_theft".to_string(),
666 description: "Access to credential/secret files followed by web or skill call"
667 .to_string(),
668 steps: vec![
669 PatternStep {
670 action_type: Some(AgentActionType::FileAccess),
671 target_pattern: Some(r"(?i)\.(env|key|pem|credentials|secret)".to_string()),
672 min_risk: 0.0,
673 },
674 PatternStep {
675 action_type: Some(AgentActionType::WebAccess),
676 target_pattern: None,
677 min_risk: 0.0,
678 },
679 ],
680 max_time_window: Duration::from_secs(300),
681 severity: SecuritySeverity::Critical,
682 confidence: 0.9,
683 },
684 AttackPattern {
686 name: "reconnaissance_then_exploit".to_string(),
687 description: "Multiple tool calls with increasing risk followed by command execution"
688 .to_string(),
689 steps: vec![
690 PatternStep {
691 action_type: Some(AgentActionType::ToolCall),
692 target_pattern: None,
693 min_risk: 0.2,
694 },
695 PatternStep {
696 action_type: Some(AgentActionType::ToolCall),
697 target_pattern: None,
698 min_risk: 0.5,
699 },
700 PatternStep {
701 action_type: Some(AgentActionType::CommandExecution),
702 target_pattern: None,
703 min_risk: 0.6,
704 },
705 ],
706 max_time_window: Duration::from_secs(600),
707 severity: SecuritySeverity::High,
708 confidence: 0.75,
709 },
710 AttackPattern {
712 name: "lateral_movement".to_string(),
713 description: "Skill invocation to another agent followed by privileged tool call"
714 .to_string(),
715 steps: vec![
716 PatternStep {
717 action_type: Some(AgentActionType::SkillInvocation),
718 target_pattern: None,
719 min_risk: 0.0,
720 },
721 PatternStep {
722 action_type: Some(AgentActionType::ToolCall),
723 target_pattern: Some(r"(?i)(admin|sudo|escalat|privil)".to_string()),
724 min_risk: 0.5,
725 },
726 ],
727 max_time_window: Duration::from_secs(300),
728 severity: SecuritySeverity::High,
729 confidence: 0.85,
730 },
731 ]
732}
733
734#[cfg(test)]
739mod tests {
740 use super::*;
741 use std::thread;
742 use std::time::{Duration, Instant};
743
744 fn make_action(
746 action_type: AgentActionType,
747 target: &str,
748 session_id: &str,
749 risk: f64,
750 ) -> TrackedAction {
751 TrackedAction {
752 action_type,
753 target: target.to_string(),
754 timestamp: Instant::now(),
755 session_id: session_id.to_string(),
756 risk_score: risk,
757 }
758 }
759
760 fn make_action_at(
761 action_type: AgentActionType,
762 target: &str,
763 session_id: &str,
764 risk: f64,
765 timestamp: Instant,
766 ) -> TrackedAction {
767 TrackedAction {
768 action_type,
769 target: target.to_string(),
770 timestamp,
771 session_id: session_id.to_string(),
772 risk_score: risk,
773 }
774 }
775
776 #[test]
781 fn test_action_recording_and_history() {
782 let mut correlator = ActionCorrelator::new(CorrelationConfig::default());
783
784 let a1 = make_action(AgentActionType::ToolCall, "search", "s1", 0.1);
785 correlator.record_action(a1);
786
787 let a2 = make_action(AgentActionType::FileAccess, "/tmp/file", "s1", 0.3);
788 correlator.record_action(a2);
789
790 assert_eq!(correlator.session_count(), 1);
791 let history = correlator.session_histories.get("s1").unwrap();
792 assert_eq!(history.len(), 2);
793 assert_eq!(history[0].target, "search");
794 assert_eq!(history[1].target, "/tmp/file");
795 }
796
797 #[test]
802 fn test_data_exfiltration_chain_detected() {
803 let mut correlator = ActionCorrelator::with_defaults();
804 let now = Instant::now();
805
806 let a1 = make_action_at(AgentActionType::FileAccess, "/etc/shadow", "s1", 0.7, now);
807 correlator.record_action(a1);
808
809 let a2 = make_action_at(
810 AgentActionType::WebAccess,
811 "https://evil.example.com/upload",
812 "s1",
813 0.6,
814 now + Duration::from_secs(30),
815 );
816 let result = correlator.record_action(a2);
817
818 assert!(
819 result
820 .pattern_matches
821 .iter()
822 .any(|m| m.pattern_name == "data_exfiltration_chain"),
823 "Expected data_exfiltration_chain to match"
824 );
825 }
826
827 #[test]
832 fn test_benign_sequence_not_matched() {
833 let mut correlator = ActionCorrelator::with_defaults();
834 let now = Instant::now();
835
836 let a1 = make_action_at(AgentActionType::ToolCall, "search", "s1", 0.1, now);
838 correlator.record_action(a1);
839
840 let a2 = make_action_at(
841 AgentActionType::ToolCall,
842 "calculator",
843 "s1",
844 0.1,
845 now + Duration::from_secs(5),
846 );
847 let result = correlator.record_action(a2);
848
849 assert!(
850 result.pattern_matches.is_empty(),
851 "Benign tool calls should not match any attack pattern"
852 );
853 }
854
855 #[test]
860 fn test_rapid_action_detection() {
861 let config = CorrelationConfig {
862 rapid_action_threshold: Duration::from_millis(500),
863 rapid_action_count: 5,
864 ..Default::default()
865 };
866 let mut correlator = ActionCorrelator::new(config);
867
868 let base = Instant::now();
869 for i in 0..5 {
870 let action = make_action_at(
871 AgentActionType::ToolCall,
872 &format!("tool_{i}"),
873 "s1",
874 0.1,
875 base + Duration::from_millis(i * 100), );
877 correlator.record_action(action);
878 }
879
880 let rapid = correlator.detect_rapid_actions("s1");
881 assert!(rapid.is_some(), "Should detect rapid actions");
882 let alert = rapid.unwrap();
883 assert_eq!(alert.action_count, 5);
884 assert!(alert.avg_interval < Duration::from_millis(200));
885 }
886
887 #[test]
892 fn test_rapid_action_not_triggered_for_slow_actions() {
893 let config = CorrelationConfig {
894 rapid_action_threshold: Duration::from_millis(100),
895 rapid_action_count: 3,
896 ..Default::default()
897 };
898 let mut correlator = ActionCorrelator::new(config);
899
900 let base = Instant::now();
901 for i in 0..3u64 {
902 let action = make_action_at(
903 AgentActionType::ToolCall,
904 &format!("tool_{i}"),
905 "s1",
906 0.1,
907 base + Duration::from_secs(i * 5), );
909 correlator.record_action(action);
910 }
911
912 let rapid = correlator.detect_rapid_actions("s1");
913 assert!(
914 rapid.is_none(),
915 "Slow actions should not trigger rapid alert"
916 );
917 }
918
919 #[test]
924 fn test_privilege_escalation_sequence_detection() {
925 let mut correlator = ActionCorrelator::new(CorrelationConfig::default());
926 let now = Instant::now();
927
928 let actions = [
929 (AgentActionType::ToolCall, "list_users", 0.2),
930 (AgentActionType::ToolCall, "read_config", 0.4),
931 (AgentActionType::ToolCall, "modify_permissions", 0.6),
932 (AgentActionType::CommandExecution, "sudo rm -rf", 0.9),
933 ];
934
935 for (i, (atype, target, risk)) in actions.iter().enumerate() {
936 let action = make_action_at(
937 atype.clone(),
938 target,
939 "s1",
940 *risk,
941 now + Duration::from_secs(i as u64),
942 );
943 correlator.record_action(action);
944 }
945
946 let esc = correlator.detect_privilege_escalation_sequence("s1");
947 assert!(esc.is_some(), "Should detect escalation");
948 let seq = esc.unwrap();
949 assert_eq!(seq.steps.len(), 4);
950 assert_eq!(seq.risk_trajectory, vec![0.2, 0.4, 0.6, 0.9]);
951 }
952
953 #[test]
958 fn test_time_window_enforcement() {
959 let mut correlator = ActionCorrelator::with_defaults();
960 let now = Instant::now();
961
962 let a1 = make_action_at(AgentActionType::FileAccess, "/data/secret", "s1", 0.5, now);
964 correlator.record_action(a1);
965
966 let a2 = make_action_at(
968 AgentActionType::WebAccess,
969 "https://evil.com/exfil",
970 "s1",
971 0.5,
972 now + Duration::from_secs(601),
973 );
974 let result = correlator.record_action(a2);
975
976 let exfil = result
977 .pattern_matches
978 .iter()
979 .find(|m| m.pattern_name == "data_exfiltration_chain");
980 assert!(
981 exfil.is_none(),
982 "Pattern should not match when outside time window"
983 );
984 }
985
986 #[test]
991 fn test_multiple_sessions_independent() {
992 let mut correlator = ActionCorrelator::with_defaults();
993 let now = Instant::now();
994
995 let a1 = make_action_at(AgentActionType::FileAccess, "/etc/passwd", "s1", 0.5, now);
997 correlator.record_action(a1);
998
999 let a2 = make_action_at(
1001 AgentActionType::WebAccess,
1002 "https://evil.com",
1003 "s2",
1004 0.5,
1005 now + Duration::from_secs(10),
1006 );
1007 let result = correlator.record_action(a2);
1008
1009 assert_eq!(correlator.session_count(), 2);
1010 assert!(
1011 result.pattern_matches.is_empty(),
1012 "Cross-session actions should not match patterns"
1013 );
1014 }
1015
1016 #[test]
1021 fn test_session_cleanup_expired() {
1022 let config = CorrelationConfig {
1023 session_timeout: Duration::from_millis(50),
1024 ..Default::default()
1025 };
1026 let mut correlator = ActionCorrelator::new(config);
1027
1028 let a1 = make_action(AgentActionType::ToolCall, "tool_a", "expired_session", 0.1);
1029 correlator.record_action(a1);
1030
1031 assert_eq!(correlator.session_count(), 1);
1032
1033 thread::sleep(Duration::from_millis(60));
1035
1036 let a2 = make_action(AgentActionType::ToolCall, "tool_b", "fresh_session", 0.1);
1038 correlator.record_action(a2);
1039
1040 correlator.cleanup_expired_sessions();
1041
1042 assert_eq!(correlator.session_count(), 1);
1044 assert!(correlator.session_histories.contains_key("fresh_session"));
1045 assert!(!correlator.session_histories.contains_key("expired_session"));
1046 }
1047
1048 #[test]
1053 fn test_max_history_enforcement() {
1054 let config = CorrelationConfig {
1055 max_history_per_session: 5,
1056 ..Default::default()
1057 };
1058 let mut correlator = ActionCorrelator::new(config);
1059
1060 for i in 0..10 {
1061 let action = make_action(AgentActionType::ToolCall, &format!("tool_{i}"), "s1", 0.1);
1062 correlator.record_action(action);
1063 }
1064
1065 let history = correlator.session_histories.get("s1").unwrap();
1066 assert_eq!(history.len(), 5);
1067 assert_eq!(history[0].target, "tool_5");
1069 }
1070
1071 #[test]
1076 fn test_custom_pattern_matching() {
1077 let pattern = AttackPattern {
1078 name: "custom_test".to_string(),
1079 description: "Test pattern".to_string(),
1080 steps: vec![
1081 PatternStep {
1082 action_type: Some(AgentActionType::CommandExecution),
1083 target_pattern: Some(r"^whoami$".to_string()),
1084 min_risk: 0.0,
1085 },
1086 PatternStep {
1087 action_type: Some(AgentActionType::CommandExecution),
1088 target_pattern: Some(r"(?i)cat /etc/".to_string()),
1089 min_risk: 0.3,
1090 },
1091 ],
1092 max_time_window: Duration::from_secs(60),
1093 severity: SecuritySeverity::Medium,
1094 confidence: 0.7,
1095 };
1096
1097 let config = CorrelationConfig {
1098 patterns: vec![pattern],
1099 ..Default::default()
1100 };
1101 let mut correlator = ActionCorrelator::new(config);
1102 let now = Instant::now();
1103
1104 let a1 = make_action_at(AgentActionType::CommandExecution, "whoami", "s1", 0.2, now);
1105 correlator.record_action(a1);
1106
1107 let a2 = make_action_at(
1108 AgentActionType::CommandExecution,
1109 "cat /etc/shadow",
1110 "s1",
1111 0.5,
1112 now + Duration::from_secs(10),
1113 );
1114 let result = correlator.record_action(a2);
1115
1116 assert_eq!(result.pattern_matches.len(), 1);
1117 assert_eq!(result.pattern_matches[0].pattern_name, "custom_test");
1118 assert_eq!(result.pattern_matches[0].matched_actions, vec![0, 1]);
1119 }
1120
1121 #[test]
1126 fn test_security_finding_generation() {
1127 let result = CorrelationResult {
1128 session_id: "s1".to_string(),
1129 pattern_matches: vec![PatternMatch {
1130 pattern_name: "data_exfiltration_chain".to_string(),
1131 matched_actions: vec![0, 1],
1132 confidence: 0.8,
1133 severity: SecuritySeverity::High,
1134 time_span: Duration::from_secs(30),
1135 }],
1136 rapid_actions: Some(RapidActionAlert {
1137 action_count: 10,
1138 time_window: Duration::from_secs(5),
1139 avg_interval: Duration::from_millis(500),
1140 }),
1141 escalation: Some(EscalationSequence {
1142 steps: vec![
1143 (AgentActionType::ToolCall, "ls".to_string(), 0.3),
1144 (
1145 AgentActionType::CommandExecution,
1146 "sudo rm".to_string(),
1147 0.9,
1148 ),
1149 ],
1150 risk_trajectory: vec![0.3, 0.9],
1151 }),
1152 total_risk: 0.9,
1153 };
1154
1155 let findings = ActionCorrelator::to_security_findings(&result);
1156 assert_eq!(findings.len(), 3);
1157
1158 assert_eq!(
1160 findings[0].finding_type,
1161 "attack_pattern_data_exfiltration_chain"
1162 );
1163 assert_eq!(findings[0].severity, SecuritySeverity::High);
1164 assert!((findings[0].confidence_score - 0.8).abs() < f64::EPSILON);
1165
1166 assert_eq!(findings[1].finding_type, "rapid_action_alert");
1168 assert_eq!(findings[1].severity, SecuritySeverity::Medium);
1169
1170 assert_eq!(findings[2].finding_type, "privilege_escalation_sequence");
1172 assert_eq!(findings[2].severity, SecuritySeverity::High);
1173 }
1174
1175 #[test]
1180 fn test_compiled_pattern_regex_matching() {
1181 let pattern = AttackPattern {
1182 name: "regex_test".to_string(),
1183 description: "Regex test".to_string(),
1184 steps: vec![PatternStep {
1185 action_type: Some(AgentActionType::FileAccess),
1186 target_pattern: Some(r"\.env$".to_string()),
1187 min_risk: 0.0,
1188 }],
1189 max_time_window: Duration::from_secs(60),
1190 severity: SecuritySeverity::Medium,
1191 confidence: 0.7,
1192 };
1193
1194 let compiled = compile_pattern(pattern).unwrap();
1195 assert!(compiled.compiled_steps[0].target_regex.is_some());
1196
1197 let action_match = TrackedAction {
1199 action_type: AgentActionType::FileAccess,
1200 target: "/app/.env".to_string(),
1201 timestamp: Instant::now(),
1202 session_id: "s1".to_string(),
1203 risk_score: 0.5,
1204 };
1205 assert!(step_matches(&compiled.compiled_steps[0], &action_match));
1206
1207 let action_no_match = TrackedAction {
1209 action_type: AgentActionType::FileAccess,
1210 target: "/app/config.json".to_string(),
1211 timestamp: Instant::now(),
1212 session_id: "s1".to_string(),
1213 risk_score: 0.5,
1214 };
1215 assert!(!step_matches(&compiled.compiled_steps[0], &action_no_match));
1216
1217 let action_wrong_type = TrackedAction {
1219 action_type: AgentActionType::WebAccess,
1220 target: "/app/.env".to_string(),
1221 timestamp: Instant::now(),
1222 session_id: "s1".to_string(),
1223 risk_score: 0.5,
1224 };
1225 assert!(!step_matches(
1226 &compiled.compiled_steps[0],
1227 &action_wrong_type
1228 ));
1229 }
1230
1231 #[test]
1236 fn test_multiple_patterns_matching_same_sequence() {
1237 let mut correlator = ActionCorrelator::with_defaults();
1238 let now = Instant::now();
1239
1240 let a1 = make_action_at(AgentActionType::FileAccess, "/app/.env", "s1", 0.5, now);
1242 correlator.record_action(a1);
1243
1244 let a2 = make_action_at(
1245 AgentActionType::WebAccess,
1246 "https://attacker.com/collect",
1247 "s1",
1248 0.6,
1249 now + Duration::from_secs(10),
1250 );
1251 let result = correlator.record_action(a2);
1252
1253 let pattern_names: Vec<&str> = result
1254 .pattern_matches
1255 .iter()
1256 .map(|m| m.pattern_name.as_str())
1257 .collect();
1258
1259 assert!(
1260 pattern_names.contains(&"data_exfiltration_chain"),
1261 "Should match data_exfiltration_chain"
1262 );
1263 assert!(
1264 pattern_names.contains(&"credential_theft"),
1265 "Should match credential_theft (.env file pattern)"
1266 );
1267 assert!(result.pattern_matches.len() >= 2);
1268 }
1269
1270 #[test]
1275 fn test_empty_history_check_patterns() {
1276 let correlator = ActionCorrelator::with_defaults();
1277 let matches = correlator.check_patterns("nonexistent");
1278 assert!(matches.is_empty());
1279 }
1280
1281 #[test]
1282 fn test_empty_history_detect_rapid() {
1283 let correlator = ActionCorrelator::with_defaults();
1284 let rapid = correlator.detect_rapid_actions("nonexistent");
1285 assert!(rapid.is_none());
1286 }
1287
1288 #[test]
1289 fn test_empty_history_detect_escalation() {
1290 let correlator = ActionCorrelator::with_defaults();
1291 let esc = correlator.detect_privilege_escalation_sequence("nonexistent");
1292 assert!(esc.is_none());
1293 }
1294
1295 #[test]
1300 fn test_single_action_no_match() {
1301 let mut correlator = ActionCorrelator::with_defaults();
1302
1303 let action = make_action(AgentActionType::FileAccess, "/etc/passwd", "s1", 0.8);
1304 let result = correlator.record_action(action);
1305
1306 assert!(result.pattern_matches.is_empty());
1307 assert!(result.escalation.is_none());
1308 }
1309
1310 #[test]
1315 fn test_temporal_ordering_required() {
1316 let mut correlator = ActionCorrelator::with_defaults();
1317 let now = Instant::now();
1318
1319 let a1 = make_action_at(
1322 AgentActionType::WebAccess,
1323 "https://example.com",
1324 "s1",
1325 0.5,
1326 now,
1327 );
1328 correlator.record_action(a1);
1329
1330 let a2 = make_action_at(
1331 AgentActionType::FileAccess,
1332 "/etc/passwd",
1333 "s1",
1334 0.5,
1335 now + Duration::from_secs(10),
1336 );
1337 let result = correlator.record_action(a2);
1338
1339 let exfil = result
1340 .pattern_matches
1341 .iter()
1342 .find(|m| m.pattern_name == "data_exfiltration_chain");
1343 assert!(
1344 exfil.is_none(),
1345 "Reversed order should not match the exfiltration pattern"
1346 );
1347 }
1348
1349 #[test]
1354 fn test_correlation_result_aggregation() {
1355 let mut correlator = ActionCorrelator::with_defaults();
1356 let now = Instant::now();
1357
1358 let a1 = make_action_at(AgentActionType::FileAccess, "/etc/shadow", "s1", 0.7, now);
1360 correlator.record_action(a1);
1361
1362 let a2 = make_action_at(
1363 AgentActionType::WebAccess,
1364 "https://evil.com",
1365 "s1",
1366 0.6,
1367 now + Duration::from_secs(10),
1368 );
1369 let result = correlator.record_action(a2);
1370
1371 assert_eq!(result.session_id, "s1");
1372 assert!(!result.pattern_matches.is_empty());
1373 assert!(result.total_risk > 0.0);
1374
1375 let m = &result.pattern_matches[0];
1377 assert_eq!(m.time_span, Duration::from_secs(10));
1378 }
1379
1380 #[test]
1385 fn test_credential_theft_pattern() {
1386 let mut correlator = ActionCorrelator::with_defaults();
1387 let now = Instant::now();
1388
1389 let a1 = make_action_at(
1390 AgentActionType::FileAccess,
1391 "/home/user/.credentials",
1392 "s1",
1393 0.6,
1394 now,
1395 );
1396 correlator.record_action(a1);
1397
1398 let a2 = make_action_at(
1399 AgentActionType::WebAccess,
1400 "http://internal-api/store",
1401 "s1",
1402 0.5,
1403 now + Duration::from_secs(60),
1404 );
1405 let result = correlator.record_action(a2);
1406
1407 assert!(result
1408 .pattern_matches
1409 .iter()
1410 .any(|m| m.pattern_name == "credential_theft"));
1411 }
1412
1413 #[test]
1418 fn test_reconnaissance_then_exploit_pattern() {
1419 let mut correlator = ActionCorrelator::with_defaults();
1420 let now = Instant::now();
1421
1422 let a1 = make_action_at(AgentActionType::ToolCall, "scan_network", "s1", 0.3, now);
1423 correlator.record_action(a1);
1424
1425 let a2 = make_action_at(
1426 AgentActionType::ToolCall,
1427 "enumerate_services",
1428 "s1",
1429 0.5,
1430 now + Duration::from_secs(30),
1431 );
1432 correlator.record_action(a2);
1433
1434 let a3 = make_action_at(
1435 AgentActionType::CommandExecution,
1436 "exploit_payload",
1437 "s1",
1438 0.8,
1439 now + Duration::from_secs(60),
1440 );
1441 let result = correlator.record_action(a3);
1442
1443 assert!(result
1444 .pattern_matches
1445 .iter()
1446 .any(|m| m.pattern_name == "reconnaissance_then_exploit"));
1447 }
1448
1449 #[test]
1454 fn test_lateral_movement_pattern() {
1455 let mut correlator = ActionCorrelator::with_defaults();
1456 let now = Instant::now();
1457
1458 let a1 = make_action_at(
1459 AgentActionType::SkillInvocation,
1460 "agent_b_proxy",
1461 "s1",
1462 0.4,
1463 now,
1464 );
1465 correlator.record_action(a1);
1466
1467 let a2 = make_action_at(
1468 AgentActionType::ToolCall,
1469 "admin_panel_access",
1470 "s1",
1471 0.7,
1472 now + Duration::from_secs(20),
1473 );
1474 let result = correlator.record_action(a2);
1475
1476 assert!(result
1477 .pattern_matches
1478 .iter()
1479 .any(|m| m.pattern_name == "lateral_movement"));
1480 }
1481
1482 #[test]
1487 fn test_escalation_not_triggered_for_two_steps() {
1488 let mut correlator = ActionCorrelator::new(CorrelationConfig::default());
1489 let now = Instant::now();
1490
1491 let a1 = make_action_at(AgentActionType::ToolCall, "a", "s1", 0.3, now);
1492 correlator.record_action(a1);
1493
1494 let a2 = make_action_at(
1495 AgentActionType::ToolCall,
1496 "b",
1497 "s1",
1498 0.9,
1499 now + Duration::from_secs(1),
1500 );
1501 correlator.record_action(a2);
1502
1503 let esc = correlator.detect_privilege_escalation_sequence("s1");
1504 assert!(
1505 esc.is_none(),
1506 "Two steps should be insufficient for escalation"
1507 );
1508 }
1509
1510 #[test]
1515 fn test_escalation_not_triggered_low_risk() {
1516 let mut correlator = ActionCorrelator::new(CorrelationConfig::default());
1517 let now = Instant::now();
1518
1519 let actions = [0.1, 0.2, 0.3, 0.4];
1520 for (i, risk) in actions.iter().enumerate() {
1521 let a = make_action_at(
1522 AgentActionType::ToolCall,
1523 &format!("t{i}"),
1524 "s1",
1525 *risk,
1526 now + Duration::from_secs(i as u64),
1527 );
1528 correlator.record_action(a);
1529 }
1530
1531 let esc = correlator.detect_privilege_escalation_sequence("s1");
1532 assert!(
1533 esc.is_none(),
1534 "Escalation should not trigger when final risk < 0.7"
1535 );
1536 }
1537
1538 #[test]
1543 fn test_with_defaults_has_builtin_patterns() {
1544 let correlator = ActionCorrelator::with_defaults();
1545 assert_eq!(correlator.compiled_patterns.len(), 4);
1546
1547 let names: Vec<&str> = correlator
1548 .compiled_patterns
1549 .iter()
1550 .map(|cp| cp.pattern.name.as_str())
1551 .collect();
1552 assert!(names.contains(&"data_exfiltration_chain"));
1553 assert!(names.contains(&"credential_theft"));
1554 assert!(names.contains(&"reconnaissance_then_exploit"));
1555 assert!(names.contains(&"lateral_movement"));
1556 }
1557
1558 #[test]
1563 fn test_security_finding_empty_result() {
1564 let result = CorrelationResult {
1565 session_id: "s1".to_string(),
1566 pattern_matches: Vec::new(),
1567 rapid_actions: None,
1568 escalation: None,
1569 total_risk: 0.0,
1570 };
1571
1572 let findings = ActionCorrelator::to_security_findings(&result);
1573 assert!(findings.is_empty());
1574 }
1575
1576 #[test]
1581 fn test_pattern_step_min_risk_filter() {
1582 let step = CompiledStep {
1583 action_type: Some(AgentActionType::ToolCall),
1584 target_regex: None,
1585 min_risk: 0.5,
1586 };
1587
1588 let low_risk = TrackedAction {
1589 action_type: AgentActionType::ToolCall,
1590 target: "tool".to_string(),
1591 timestamp: Instant::now(),
1592 session_id: "s1".to_string(),
1593 risk_score: 0.3,
1594 };
1595 assert!(!step_matches(&step, &low_risk));
1596
1597 let high_risk = TrackedAction {
1598 action_type: AgentActionType::ToolCall,
1599 target: "tool".to_string(),
1600 timestamp: Instant::now(),
1601 session_id: "s1".to_string(),
1602 risk_score: 0.6,
1603 };
1604 assert!(step_matches(&step, &high_risk));
1605 }
1606
1607 #[test]
1612 fn test_wildcard_pattern_step_matches_anything() {
1613 let step = CompiledStep {
1614 action_type: None,
1615 target_regex: None,
1616 min_risk: 0.0,
1617 };
1618
1619 let action = TrackedAction {
1620 action_type: AgentActionType::WebAccess,
1621 target: "https://any.url".to_string(),
1622 timestamp: Instant::now(),
1623 session_id: "s1".to_string(),
1624 risk_score: 0.0,
1625 };
1626 assert!(step_matches(&step, &action));
1627 }
1628
1629 #[test]
1634 fn test_invalid_regex_pattern_rejected() {
1635 let pattern = AttackPattern {
1636 name: "bad_regex".to_string(),
1637 description: "Pattern with invalid regex".to_string(),
1638 steps: vec![PatternStep {
1639 action_type: None,
1640 target_pattern: Some("[invalid(regex".to_string()),
1641 min_risk: 0.0,
1642 }],
1643 max_time_window: Duration::from_secs(60),
1644 severity: SecuritySeverity::Low,
1645 confidence: 0.5,
1646 };
1647
1648 let compiled = compile_pattern(pattern);
1649 assert!(
1650 compiled.is_none(),
1651 "Invalid regex should cause pattern compilation to fail"
1652 );
1653 }
1654
1655 #[test]
1660 fn test_total_risk_capped_at_one() {
1661 let matches = vec![
1662 PatternMatch {
1663 pattern_name: "a".to_string(),
1664 matched_actions: vec![0],
1665 confidence: 1.0,
1666 severity: SecuritySeverity::Critical,
1667 time_span: Duration::ZERO,
1668 },
1669 PatternMatch {
1670 pattern_name: "b".to_string(),
1671 matched_actions: vec![0],
1672 confidence: 1.0,
1673 severity: SecuritySeverity::Critical,
1674 time_span: Duration::ZERO,
1675 },
1676 ];
1677
1678 let risk = compute_total_risk(&matches, &None, &None);
1679 assert!(
1680 (risk - 1.0).abs() < f64::EPSILON,
1681 "Risk should be capped at 1.0"
1682 );
1683 }
1684}