1use chrono::{DateTime, Utc};
30use llmtrace_core::{SecurityFinding, SecuritySeverity};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33
34#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum ThreatCategory {
41 InjectionDirect,
42 InjectionIndirect,
43 Jailbreak,
44 PiiLeak,
45 ToxicContent,
46 DataExfiltration,
47 PromptExtraction,
48 CodeExecution,
49 Benign,
50}
51
52impl ThreatCategory {
53 #[must_use]
56 pub fn as_finding_type(&self) -> &'static str {
57 match self {
58 Self::InjectionDirect => "prompt_injection_direct",
59 Self::InjectionIndirect => "prompt_injection_indirect",
60 Self::Jailbreak => "jailbreak",
61 Self::PiiLeak => "pii_leak",
62 Self::ToxicContent => "toxic_content",
63 Self::DataExfiltration => "data_exfiltration",
64 Self::PromptExtraction => "prompt_extraction",
65 Self::CodeExecution => "code_execution",
66 Self::Benign => "benign",
67 }
68 }
69
70 #[must_use]
72 pub fn is_threat(&self) -> bool {
73 *self != Self::Benign
74 }
75}
76
77#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
79pub enum DetectorType {
80 Classifier,
81 Heuristic,
82 Semantic,
83 LlmJudge,
84 Ensemble,
85 Canary,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub enum AggregationStrategy {
91 MajorityVote,
93 WeightedVote,
95 Conservative,
97 Permissive,
99 Cascade,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct DetectorResult {
110 pub detector_name: String,
111 pub detector_type: DetectorType,
112 pub threat_category: ThreatCategory,
113 pub severity: SecuritySeverity,
114 pub confidence: f64,
115 pub raw_output: Option<String>,
116 pub metadata: HashMap<String, String>,
117 pub latency_ms: Option<u64>,
118}
119
120impl DetectorResult {
121 #[must_use]
123 pub fn new(name: &str, detector_type: DetectorType) -> Self {
124 Self {
125 detector_name: name.to_string(),
126 detector_type,
127 threat_category: ThreatCategory::Benign,
128 severity: SecuritySeverity::Info,
129 confidence: 0.0,
130 raw_output: None,
131 metadata: HashMap::new(),
132 latency_ms: None,
133 }
134 }
135
136 #[must_use]
137 pub fn with_threat(mut self, cat: ThreatCategory) -> Self {
138 self.threat_category = cat;
139 self
140 }
141
142 #[must_use]
143 pub fn with_confidence(mut self, c: f64) -> Self {
144 self.confidence = c;
145 self
146 }
147
148 #[must_use]
149 pub fn with_severity(mut self, s: SecuritySeverity) -> Self {
150 self.severity = s;
151 self
152 }
153
154 #[must_use]
155 pub fn with_raw_output(mut self, raw: String) -> Self {
156 self.raw_output = Some(raw);
157 self
158 }
159
160 #[must_use]
161 pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
162 self.metadata.insert(key.to_string(), value.to_string());
163 self
164 }
165
166 #[must_use]
167 pub fn with_latency_ms(mut self, ms: u64) -> Self {
168 self.latency_ms = Some(ms);
169 self
170 }
171}
172
173impl From<&DetectorResult> for SecurityFinding {
174 fn from(dr: &DetectorResult) -> Self {
175 let desc = format!(
176 "[{}] detected {} (confidence {:.2})",
177 dr.detector_name,
178 dr.threat_category.as_finding_type(),
179 dr.confidence,
180 );
181 let mut finding = SecurityFinding::new(
182 dr.severity.clone(),
183 dr.threat_category.as_finding_type().to_string(),
184 desc,
185 dr.confidence,
186 );
187 for (k, v) in &dr.metadata {
188 finding = finding.with_metadata(k.clone(), v.clone());
189 }
190 finding = finding.with_metadata("detector_name".to_string(), dr.detector_name.clone());
191 finding
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct AggregatedResult {
202 pub threat_category: ThreatCategory,
203 pub confidence: f64,
204 pub severity: SecuritySeverity,
205 pub contributing_detectors: Vec<String>,
206 pub strategy_used: AggregationStrategy,
207}
208
209#[derive(Debug, Clone)]
215pub struct ResultAggregator {
216 pub strategy: AggregationStrategy,
217 pub confidence_threshold: f64,
218 pub detector_weights: HashMap<String, f64>,
219}
220
221impl ResultAggregator {
222 #[must_use]
223 pub fn new(strategy: AggregationStrategy) -> Self {
224 Self {
225 strategy,
226 confidence_threshold: 0.5,
227 detector_weights: HashMap::new(),
228 }
229 }
230
231 #[must_use]
232 pub fn with_weights(strategy: AggregationStrategy, weights: HashMap<String, f64>) -> Self {
233 Self {
234 strategy,
235 confidence_threshold: 0.5,
236 detector_weights: weights,
237 }
238 }
239
240 #[must_use]
241 pub fn with_threshold(mut self, threshold: f64) -> Self {
242 self.confidence_threshold = threshold;
243 self
244 }
245
246 #[must_use]
248 pub fn aggregate(&self, results: &[DetectorResult]) -> AggregatedResult {
249 match self.strategy {
250 AggregationStrategy::MajorityVote => self.aggregate_majority_vote(results),
251 AggregationStrategy::WeightedVote => self.aggregate_weighted_vote(results),
252 AggregationStrategy::Conservative => self.aggregate_conservative(results),
253 AggregationStrategy::Permissive => self.aggregate_permissive(results),
254 AggregationStrategy::Cascade => self.aggregate_cascade(results),
255 }
256 }
257
258 #[must_use]
261 pub fn aggregate_majority_vote(&self, results: &[DetectorResult]) -> AggregatedResult {
262 if results.is_empty() {
263 return benign_aggregated(AggregationStrategy::MajorityVote);
264 }
265
266 let eligible: Vec<&DetectorResult> = results
267 .iter()
268 .filter(|r| r.confidence >= self.confidence_threshold)
269 .collect();
270
271 if eligible.is_empty() {
272 return benign_aggregated(AggregationStrategy::MajorityVote);
273 }
274
275 let mut votes: HashMap<&ThreatCategory, (usize, f64, &SecuritySeverity, Vec<&str>)> =
276 HashMap::new();
277 for r in &eligible {
278 let entry = votes.entry(&r.threat_category).or_insert((
279 0,
280 0.0,
281 &SecuritySeverity::Info,
282 Vec::new(),
283 ));
284 entry.0 += 1;
285 entry.1 += r.confidence;
286 if r.severity > *entry.2 {
287 entry.2 = &r.severity;
288 }
289 entry.3.push(&r.detector_name);
290 }
291
292 let winner = votes
293 .iter()
294 .max_by_key(|(_, (count, _, _, _))| *count)
295 .unwrap();
296
297 let (cat, (count, conf_sum, sev, names)) = winner;
298 AggregatedResult {
299 threat_category: (*cat).clone(),
300 confidence: conf_sum / *count as f64,
301 severity: (*sev).clone(),
302 contributing_detectors: names.iter().map(|s| (*s).to_string()).collect(),
303 strategy_used: AggregationStrategy::MajorityVote,
304 }
305 }
306
307 #[must_use]
310 pub fn aggregate_weighted_vote(&self, results: &[DetectorResult]) -> AggregatedResult {
311 if results.is_empty() {
312 return benign_aggregated(AggregationStrategy::WeightedVote);
313 }
314
315 let mut scores: HashMap<&ThreatCategory, (f64, &SecuritySeverity, Vec<&str>)> =
316 HashMap::new();
317 for r in results {
318 let w = self
319 .detector_weights
320 .get(&r.detector_name)
321 .copied()
322 .unwrap_or(1.0);
323 let entry = scores.entry(&r.threat_category).or_insert((
324 0.0,
325 &SecuritySeverity::Info,
326 Vec::new(),
327 ));
328 entry.0 += r.confidence * w;
329 if r.severity > *entry.1 {
330 entry.1 = &r.severity;
331 }
332 entry.2.push(&r.detector_name);
333 }
334
335 let winner = scores
336 .iter()
337 .max_by(|a, b| a.1 .0.partial_cmp(&b.1 .0).unwrap())
338 .unwrap();
339
340 let (cat, (score, sev, names)) = winner;
341 let total_weight: f64 = names
342 .iter()
343 .map(|n| self.detector_weights.get(*n).copied().unwrap_or(1.0))
344 .sum();
345 let avg_conf = if total_weight > 0.0 {
346 score / total_weight
347 } else {
348 0.0
349 };
350
351 AggregatedResult {
352 threat_category: (*cat).clone(),
353 confidence: avg_conf,
354 severity: (*sev).clone(),
355 contributing_detectors: names.iter().map(|s| (*s).to_string()).collect(),
356 strategy_used: AggregationStrategy::WeightedVote,
357 }
358 }
359
360 #[must_use]
363 pub fn aggregate_conservative(&self, results: &[DetectorResult]) -> AggregatedResult {
364 if results.is_empty() {
365 return benign_aggregated(AggregationStrategy::Conservative);
366 }
367
368 let threats: Vec<&DetectorResult> = results
369 .iter()
370 .filter(|r| r.threat_category.is_threat() && r.confidence >= self.confidence_threshold)
371 .collect();
372
373 if threats.is_empty() {
374 return AggregatedResult {
375 threat_category: ThreatCategory::Benign,
376 confidence: results.iter().map(|r| 1.0 - r.confidence).product(),
377 severity: SecuritySeverity::Info,
378 contributing_detectors: Vec::new(),
379 strategy_used: AggregationStrategy::Conservative,
380 };
381 }
382
383 let worst = threats
384 .iter()
385 .max_by(|a, b| a.severity.cmp(&b.severity))
386 .unwrap();
387
388 AggregatedResult {
389 threat_category: worst.threat_category.clone(),
390 confidence: threats.iter().map(|r| r.confidence).fold(0.0_f64, f64::max),
391 severity: worst.severity.clone(),
392 contributing_detectors: threats.iter().map(|r| r.detector_name.clone()).collect(),
393 strategy_used: AggregationStrategy::Conservative,
394 }
395 }
396
397 #[must_use]
400 pub fn aggregate_permissive(&self, results: &[DetectorResult]) -> AggregatedResult {
401 if results.is_empty() {
402 return benign_aggregated(AggregationStrategy::Permissive);
403 }
404
405 let eligible: Vec<&DetectorResult> = results
406 .iter()
407 .filter(|r| r.confidence >= self.confidence_threshold)
408 .collect();
409
410 if eligible.is_empty() {
411 return benign_aggregated(AggregationStrategy::Permissive);
412 }
413
414 let first_cat = &eligible[0].threat_category;
415 let all_agree = eligible.iter().all(|r| r.threat_category == *first_cat);
416
417 if !all_agree || !first_cat.is_threat() {
418 return AggregatedResult {
419 threat_category: ThreatCategory::Benign,
420 confidence: eligible.iter().map(|r| r.confidence).sum::<f64>()
421 / eligible.len() as f64,
422 severity: SecuritySeverity::Info,
423 contributing_detectors: Vec::new(),
424 strategy_used: AggregationStrategy::Permissive,
425 };
426 }
427
428 let max_sev = eligible.iter().map(|r| &r.severity).max().unwrap();
429
430 AggregatedResult {
431 threat_category: first_cat.clone(),
432 confidence: eligible.iter().map(|r| r.confidence).sum::<f64>() / eligible.len() as f64,
433 severity: max_sev.clone(),
434 contributing_detectors: eligible.iter().map(|r| r.detector_name.clone()).collect(),
435 strategy_used: AggregationStrategy::Permissive,
436 }
437 }
438
439 #[must_use]
442 pub fn aggregate_cascade(&self, results: &[DetectorResult]) -> AggregatedResult {
443 for r in results {
444 if r.confidence >= self.confidence_threshold {
445 return AggregatedResult {
446 threat_category: r.threat_category.clone(),
447 confidence: r.confidence,
448 severity: r.severity.clone(),
449 contributing_detectors: vec![r.detector_name.clone()],
450 strategy_used: AggregationStrategy::Cascade,
451 };
452 }
453 }
454 benign_aggregated(AggregationStrategy::Cascade)
455 }
456}
457
458#[must_use]
460fn benign_aggregated(strategy: AggregationStrategy) -> AggregatedResult {
461 AggregatedResult {
462 threat_category: ThreatCategory::Benign,
463 confidence: 0.0,
464 severity: SecuritySeverity::Info,
465 contributing_detectors: Vec::new(),
466 strategy_used: strategy,
467 }
468}
469
470#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ScanResult {
477 pub input_hash: String,
478 pub detector_results: Vec<DetectorResult>,
479 pub aggregate_threat: ThreatCategory,
480 pub aggregate_confidence: f64,
481 pub aggregate_severity: SecuritySeverity,
482 pub scan_duration_ms: u64,
483 pub timestamp: DateTime<Utc>,
484}
485
486impl ScanResult {
487 #[must_use]
490 pub fn to_security_findings(&self) -> Vec<SecurityFinding> {
491 self.detector_results
492 .iter()
493 .filter(|r| r.threat_category.is_threat())
494 .map(SecurityFinding::from)
495 .collect()
496 }
497}
498
499pub struct ScanResultBuilder {
501 input_hash: String,
502 results: Vec<DetectorResult>,
503 start: std::time::Instant,
504}
505
506impl ScanResultBuilder {
507 #[must_use]
509 pub fn new(input_text: &str) -> Self {
510 Self {
511 input_hash: compute_input_hash(input_text),
512 results: Vec::new(),
513 start: std::time::Instant::now(),
514 }
515 }
516
517 pub fn add_result(&mut self, result: DetectorResult) -> &mut Self {
519 self.results.push(result);
520 self
521 }
522
523 #[must_use]
525 pub fn build(self, aggregator: &ResultAggregator) -> ScanResult {
526 let agg = aggregator.aggregate(&self.results);
527 let duration = self.start.elapsed().as_millis() as u64;
528 ScanResult {
529 input_hash: self.input_hash,
530 detector_results: self.results,
531 aggregate_threat: agg.threat_category,
532 aggregate_confidence: agg.confidence,
533 aggregate_severity: agg.severity,
534 scan_duration_ms: duration,
535 timestamp: Utc::now(),
536 }
537 }
538}
539
540#[must_use]
547pub fn compute_input_hash(input: &str) -> String {
548 let h = input.bytes().fold(0u64, |acc, b| {
549 acc.wrapping_mul(31).wrapping_add(u64::from(b))
550 });
551 format!("{h:016x}")
552}
553
554#[cfg(test)]
559mod tests {
560 use super::*;
561
562 #[test]
565 fn threat_category_serde_roundtrip() {
566 let cats = vec![
567 ThreatCategory::InjectionDirect,
568 ThreatCategory::InjectionIndirect,
569 ThreatCategory::Jailbreak,
570 ThreatCategory::PiiLeak,
571 ThreatCategory::ToxicContent,
572 ThreatCategory::DataExfiltration,
573 ThreatCategory::PromptExtraction,
574 ThreatCategory::CodeExecution,
575 ThreatCategory::Benign,
576 ];
577 for cat in cats {
578 let json = serde_json::to_string(&cat).unwrap();
579 let back: ThreatCategory = serde_json::from_str(&json).unwrap();
580 assert_eq!(cat, back);
581 }
582 }
583
584 #[test]
585 fn threat_category_is_threat() {
586 assert!(ThreatCategory::InjectionDirect.is_threat());
587 assert!(ThreatCategory::Jailbreak.is_threat());
588 assert!(!ThreatCategory::Benign.is_threat());
589 }
590
591 #[test]
592 fn threat_category_finding_type_mapping() {
593 assert_eq!(
594 ThreatCategory::InjectionDirect.as_finding_type(),
595 "prompt_injection_direct"
596 );
597 assert_eq!(ThreatCategory::Benign.as_finding_type(), "benign");
598 }
599
600 #[test]
603 fn detector_result_construction_and_metadata() {
604 let r = DetectorResult::new("test-det", DetectorType::Classifier)
605 .with_threat(ThreatCategory::Jailbreak)
606 .with_confidence(0.95)
607 .with_severity(SecuritySeverity::High)
608 .with_raw_output("raw output".to_string())
609 .with_metadata("model", "v2")
610 .with_latency_ms(42);
611
612 assert_eq!(r.detector_name, "test-det");
613 assert_eq!(r.detector_type, DetectorType::Classifier);
614 assert_eq!(r.threat_category, ThreatCategory::Jailbreak);
615 assert_eq!(r.severity, SecuritySeverity::High);
616 assert!((r.confidence - 0.95).abs() < f64::EPSILON);
617 assert_eq!(r.raw_output.as_deref(), Some("raw output"));
618 assert_eq!(r.metadata.get("model").unwrap(), "v2");
619 assert_eq!(r.latency_ms, Some(42));
620 }
621
622 #[test]
623 fn detector_result_serde_roundtrip() {
624 let r = DetectorResult::new("d1", DetectorType::Semantic)
625 .with_threat(ThreatCategory::PiiLeak)
626 .with_confidence(0.77)
627 .with_severity(SecuritySeverity::Medium);
628
629 let json = serde_json::to_string(&r).unwrap();
630 let back: DetectorResult = serde_json::from_str(&json).unwrap();
631 assert_eq!(back.detector_name, "d1");
632 assert_eq!(back.threat_category, ThreatCategory::PiiLeak);
633 assert!((back.confidence - 0.77).abs() < f64::EPSILON);
634 }
635
636 #[test]
639 fn detector_result_to_security_finding() {
640 let r = DetectorResult::new("ig", DetectorType::Classifier)
641 .with_threat(ThreatCategory::InjectionDirect)
642 .with_confidence(0.88)
643 .with_severity(SecuritySeverity::Critical)
644 .with_metadata("source", "unit-test");
645
646 let finding: SecurityFinding = SecurityFinding::from(&r);
647 assert_eq!(finding.severity, SecuritySeverity::Critical);
648 assert_eq!(finding.finding_type, "prompt_injection_direct");
649 assert!((finding.confidence_score - 0.88).abs() < f64::EPSILON);
650 assert_eq!(finding.metadata.get("detector_name").unwrap(), "ig");
651 assert_eq!(finding.metadata.get("source").unwrap(), "unit-test");
652 assert!(finding.requires_alert);
653 }
654
655 #[test]
658 fn input_hash_deterministic() {
659 let h1 = compute_input_hash("hello world");
660 let h2 = compute_input_hash("hello world");
661 assert_eq!(h1, h2);
662 assert_eq!(h1.len(), 16);
663 }
664
665 #[test]
666 fn input_hash_differs_for_different_inputs() {
667 let h1 = compute_input_hash("input A");
668 let h2 = compute_input_hash("input B");
669 assert_ne!(h1, h2);
670 }
671
672 #[test]
673 fn input_hash_empty_string() {
674 let h = compute_input_hash("");
675 assert_eq!(h.len(), 16);
676 assert_eq!(h, "0000000000000000");
677 }
678
679 #[test]
682 fn majority_vote_single_threat_wins() {
683 let results = vec![
684 DetectorResult::new("a", DetectorType::Classifier)
685 .with_threat(ThreatCategory::Jailbreak)
686 .with_confidence(0.9)
687 .with_severity(SecuritySeverity::High),
688 DetectorResult::new("b", DetectorType::Heuristic)
689 .with_threat(ThreatCategory::Jailbreak)
690 .with_confidence(0.8)
691 .with_severity(SecuritySeverity::Medium),
692 DetectorResult::new("c", DetectorType::Semantic)
693 .with_threat(ThreatCategory::Benign)
694 .with_confidence(0.7)
695 .with_severity(SecuritySeverity::Info),
696 ];
697
698 let agg = ResultAggregator::new(AggregationStrategy::MajorityVote);
699 let out = agg.aggregate(&results);
700
701 assert_eq!(out.threat_category, ThreatCategory::Jailbreak);
702 assert_eq!(out.strategy_used, AggregationStrategy::MajorityVote);
703 assert_eq!(out.contributing_detectors.len(), 2);
704 assert_eq!(out.severity, SecuritySeverity::High);
705 }
706
707 #[test]
708 fn majority_vote_all_benign() {
709 let results = vec![
710 DetectorResult::new("a", DetectorType::Classifier)
711 .with_threat(ThreatCategory::Benign)
712 .with_confidence(0.9)
713 .with_severity(SecuritySeverity::Info),
714 DetectorResult::new("b", DetectorType::Heuristic)
715 .with_threat(ThreatCategory::Benign)
716 .with_confidence(0.85)
717 .with_severity(SecuritySeverity::Info),
718 ];
719
720 let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&results);
721 assert_eq!(out.threat_category, ThreatCategory::Benign);
722 }
723
724 #[test]
725 fn majority_vote_below_threshold_ignored() {
726 let results = vec![
727 DetectorResult::new("a", DetectorType::Classifier)
728 .with_threat(ThreatCategory::Jailbreak)
729 .with_confidence(0.3)
730 .with_severity(SecuritySeverity::High),
731 DetectorResult::new("b", DetectorType::Heuristic)
732 .with_threat(ThreatCategory::Benign)
733 .with_confidence(0.8)
734 .with_severity(SecuritySeverity::Info),
735 ];
736
737 let agg = ResultAggregator::new(AggregationStrategy::MajorityVote).with_threshold(0.5);
738 let out = agg.aggregate(&results);
739 assert_eq!(out.threat_category, ThreatCategory::Benign);
741 }
742
743 #[test]
744 fn majority_vote_mixed_threats() {
745 let results = vec![
746 DetectorResult::new("a", DetectorType::Classifier)
747 .with_threat(ThreatCategory::InjectionDirect)
748 .with_confidence(0.9)
749 .with_severity(SecuritySeverity::High),
750 DetectorResult::new("b", DetectorType::Heuristic)
751 .with_threat(ThreatCategory::Jailbreak)
752 .with_confidence(0.8)
753 .with_severity(SecuritySeverity::Medium),
754 DetectorResult::new("c", DetectorType::LlmJudge)
755 .with_threat(ThreatCategory::InjectionDirect)
756 .with_confidence(0.85)
757 .with_severity(SecuritySeverity::Critical),
758 ];
759
760 let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&results);
761 assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
762 assert_eq!(out.severity, SecuritySeverity::Critical);
763 assert_eq!(out.contributing_detectors.len(), 2);
764 }
765
766 #[test]
769 fn weighted_vote_respects_weights() {
770 let mut weights = HashMap::new();
771 weights.insert("trusted".to_string(), 5.0);
772 weights.insert("weak".to_string(), 1.0);
773
774 let results = vec![
775 DetectorResult::new("trusted", DetectorType::Ensemble)
776 .with_threat(ThreatCategory::InjectionDirect)
777 .with_confidence(0.7)
778 .with_severity(SecuritySeverity::High),
779 DetectorResult::new("weak", DetectorType::Heuristic)
780 .with_threat(ThreatCategory::Benign)
781 .with_confidence(0.9)
782 .with_severity(SecuritySeverity::Info),
783 ];
784
785 let agg = ResultAggregator::with_weights(AggregationStrategy::WeightedVote, weights);
786 let out = agg.aggregate(&results);
787
788 assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
790 }
791
792 #[test]
793 fn weighted_vote_default_weight_is_one() {
794 let results = vec![
796 DetectorResult::new("a", DetectorType::Classifier)
797 .with_threat(ThreatCategory::Jailbreak)
798 .with_confidence(0.8)
799 .with_severity(SecuritySeverity::High),
800 DetectorResult::new("b", DetectorType::Heuristic)
801 .with_threat(ThreatCategory::Jailbreak)
802 .with_confidence(0.7)
803 .with_severity(SecuritySeverity::Medium),
804 ];
805
806 let agg = ResultAggregator::new(AggregationStrategy::WeightedVote);
807 let out = agg.aggregate(&results);
808 assert_eq!(out.threat_category, ThreatCategory::Jailbreak);
809 assert!((out.confidence - 0.75).abs() < f64::EPSILON);
811 }
812
813 #[test]
816 fn conservative_any_threat_flags() {
817 let results = vec![
818 DetectorResult::new("a", DetectorType::Classifier)
819 .with_threat(ThreatCategory::Benign)
820 .with_confidence(0.95)
821 .with_severity(SecuritySeverity::Info),
822 DetectorResult::new("b", DetectorType::Canary)
823 .with_threat(ThreatCategory::PromptExtraction)
824 .with_confidence(0.6)
825 .with_severity(SecuritySeverity::Critical),
826 ];
827
828 let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&results);
829 assert_eq!(out.threat_category, ThreatCategory::PromptExtraction);
830 assert_eq!(out.severity, SecuritySeverity::Critical);
831 assert_eq!(out.contributing_detectors, vec!["b"]);
832 }
833
834 #[test]
835 fn conservative_all_benign_returns_benign() {
836 let results = vec![DetectorResult::new("a", DetectorType::Classifier)
837 .with_threat(ThreatCategory::Benign)
838 .with_confidence(0.99)
839 .with_severity(SecuritySeverity::Info)];
840
841 let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&results);
842 assert_eq!(out.threat_category, ThreatCategory::Benign);
843 }
844
845 #[test]
846 fn conservative_below_threshold_ignored() {
847 let results = vec![DetectorResult::new("a", DetectorType::Classifier)
848 .with_threat(ThreatCategory::Jailbreak)
849 .with_confidence(0.3)
850 .with_severity(SecuritySeverity::High)];
851
852 let agg = ResultAggregator::new(AggregationStrategy::Conservative).with_threshold(0.5);
853 let out = agg.aggregate(&results);
854 assert_eq!(out.threat_category, ThreatCategory::Benign);
855 }
856
857 #[test]
860 fn permissive_all_agree_on_threat() {
861 let results = vec![
862 DetectorResult::new("a", DetectorType::Classifier)
863 .with_threat(ThreatCategory::DataExfiltration)
864 .with_confidence(0.8)
865 .with_severity(SecuritySeverity::High),
866 DetectorResult::new("b", DetectorType::Semantic)
867 .with_threat(ThreatCategory::DataExfiltration)
868 .with_confidence(0.75)
869 .with_severity(SecuritySeverity::Medium),
870 ];
871
872 let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
873 assert_eq!(out.threat_category, ThreatCategory::DataExfiltration);
874 assert_eq!(out.contributing_detectors.len(), 2);
875 }
876
877 #[test]
878 fn permissive_disagreement_returns_benign() {
879 let results = vec![
880 DetectorResult::new("a", DetectorType::Classifier)
881 .with_threat(ThreatCategory::Jailbreak)
882 .with_confidence(0.9)
883 .with_severity(SecuritySeverity::High),
884 DetectorResult::new("b", DetectorType::Heuristic)
885 .with_threat(ThreatCategory::InjectionDirect)
886 .with_confidence(0.85)
887 .with_severity(SecuritySeverity::High),
888 ];
889
890 let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
891 assert_eq!(out.threat_category, ThreatCategory::Benign);
892 assert!(out.contributing_detectors.is_empty());
893 }
894
895 #[test]
896 fn permissive_all_agree_benign() {
897 let results = vec![
898 DetectorResult::new("a", DetectorType::Classifier)
899 .with_threat(ThreatCategory::Benign)
900 .with_confidence(0.9)
901 .with_severity(SecuritySeverity::Info),
902 DetectorResult::new("b", DetectorType::Heuristic)
903 .with_threat(ThreatCategory::Benign)
904 .with_confidence(0.85)
905 .with_severity(SecuritySeverity::Info),
906 ];
907
908 let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
909 assert_eq!(out.threat_category, ThreatCategory::Benign);
910 }
911
912 #[test]
915 fn cascade_first_high_confidence_wins() {
916 let results = vec![
917 DetectorResult::new("fast", DetectorType::Heuristic)
918 .with_threat(ThreatCategory::ToxicContent)
919 .with_confidence(0.3)
920 .with_severity(SecuritySeverity::Low),
921 DetectorResult::new("accurate", DetectorType::Classifier)
922 .with_threat(ThreatCategory::InjectionDirect)
923 .with_confidence(0.85)
924 .with_severity(SecuritySeverity::High),
925 DetectorResult::new("slow", DetectorType::LlmJudge)
926 .with_threat(ThreatCategory::Jailbreak)
927 .with_confidence(0.99)
928 .with_severity(SecuritySeverity::Critical),
929 ];
930
931 let agg = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
932 let out = agg.aggregate(&results);
933 assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
935 assert_eq!(out.contributing_detectors, vec!["accurate"]);
936 }
937
938 #[test]
939 fn cascade_none_above_threshold_returns_benign() {
940 let results = vec![DetectorResult::new("a", DetectorType::Heuristic)
941 .with_threat(ThreatCategory::Jailbreak)
942 .with_confidence(0.2)
943 .with_severity(SecuritySeverity::High)];
944
945 let agg = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
946 let out = agg.aggregate(&results);
947 assert_eq!(out.threat_category, ThreatCategory::Benign);
948 }
949
950 #[test]
953 fn empty_results_majority_vote() {
954 let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&[]);
955 assert_eq!(out.threat_category, ThreatCategory::Benign);
956 assert!((out.confidence - 0.0).abs() < f64::EPSILON);
957 }
958
959 #[test]
960 fn empty_results_conservative() {
961 let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&[]);
962 assert_eq!(out.threat_category, ThreatCategory::Benign);
963 }
964
965 #[test]
966 fn empty_results_permissive() {
967 let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&[]);
968 assert_eq!(out.threat_category, ThreatCategory::Benign);
969 }
970
971 #[test]
972 fn empty_results_cascade() {
973 let out = ResultAggregator::new(AggregationStrategy::Cascade).aggregate(&[]);
974 assert_eq!(out.threat_category, ThreatCategory::Benign);
975 }
976
977 #[test]
978 fn empty_results_weighted_vote() {
979 let out = ResultAggregator::new(AggregationStrategy::WeightedVote).aggregate(&[]);
980 assert_eq!(out.threat_category, ThreatCategory::Benign);
981 }
982
983 #[test]
986 fn scan_result_builder_basic() {
987 let mut builder = ScanResultBuilder::new("test input");
988 builder.add_result(
989 DetectorResult::new("d1", DetectorType::Classifier)
990 .with_threat(ThreatCategory::Jailbreak)
991 .with_confidence(0.9)
992 .with_severity(SecuritySeverity::High),
993 );
994 builder.add_result(
995 DetectorResult::new("d2", DetectorType::Heuristic)
996 .with_threat(ThreatCategory::Benign)
997 .with_confidence(0.7)
998 .with_severity(SecuritySeverity::Info),
999 );
1000
1001 let aggregator = ResultAggregator::new(AggregationStrategy::MajorityVote);
1002 let scan = builder.build(&aggregator);
1003
1004 assert_eq!(scan.input_hash, compute_input_hash("test input"));
1005 assert_eq!(scan.detector_results.len(), 2);
1006 assert!(scan.timestamp <= Utc::now());
1007 }
1008
1009 #[test]
1010 fn scan_result_to_security_findings_excludes_benign() {
1011 let mut builder = ScanResultBuilder::new("probe");
1012 builder.add_result(
1013 DetectorResult::new("d1", DetectorType::Classifier)
1014 .with_threat(ThreatCategory::InjectionDirect)
1015 .with_confidence(0.88)
1016 .with_severity(SecuritySeverity::High),
1017 );
1018 builder.add_result(
1019 DetectorResult::new("d2", DetectorType::Heuristic)
1020 .with_threat(ThreatCategory::Benign)
1021 .with_confidence(0.7)
1022 .with_severity(SecuritySeverity::Info),
1023 );
1024
1025 let aggregator = ResultAggregator::new(AggregationStrategy::Conservative);
1026 let scan = builder.build(&aggregator);
1027 let findings = scan.to_security_findings();
1028
1029 assert_eq!(findings.len(), 1);
1030 assert_eq!(findings[0].finding_type, "prompt_injection_direct");
1031 }
1032
1033 #[test]
1036 fn single_detector_all_strategies() {
1037 let result = DetectorResult::new("sole", DetectorType::Classifier)
1038 .with_threat(ThreatCategory::CodeExecution)
1039 .with_confidence(0.91)
1040 .with_severity(SecuritySeverity::Critical);
1041
1042 let strategies = vec![
1043 AggregationStrategy::MajorityVote,
1044 AggregationStrategy::WeightedVote,
1045 AggregationStrategy::Conservative,
1046 AggregationStrategy::Permissive,
1047 AggregationStrategy::Cascade,
1048 ];
1049
1050 for s in strategies {
1051 let agg = ResultAggregator::new(s.clone());
1052 let out = agg.aggregate(std::slice::from_ref(&result));
1053 assert_eq!(out.threat_category, ThreatCategory::CodeExecution);
1054 assert_eq!(out.severity, SecuritySeverity::Critical);
1055 }
1056 }
1057
1058 #[test]
1061 fn custom_threshold_enforcement() {
1062 let result = DetectorResult::new("d", DetectorType::Classifier)
1063 .with_threat(ThreatCategory::Jailbreak)
1064 .with_confidence(0.6)
1065 .with_severity(SecuritySeverity::High);
1066
1067 let agg_high = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.7);
1069 let out_high = agg_high.aggregate(std::slice::from_ref(&result));
1070 assert_eq!(out_high.threat_category, ThreatCategory::Benign);
1071
1072 let agg_low = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
1074 let out_low = agg_low.aggregate(&[result]);
1075 assert_eq!(out_low.threat_category, ThreatCategory::Jailbreak);
1076 }
1077
1078 #[test]
1081 fn scan_result_serde_roundtrip() {
1082 let mut builder = ScanResultBuilder::new("serialize me");
1083 builder.add_result(
1084 DetectorResult::new("d1", DetectorType::Canary)
1085 .with_threat(ThreatCategory::PromptExtraction)
1086 .with_confidence(0.99)
1087 .with_severity(SecuritySeverity::Critical),
1088 );
1089 let aggregator = ResultAggregator::new(AggregationStrategy::Conservative);
1090 let scan = builder.build(&aggregator);
1091
1092 let json = serde_json::to_string(&scan).unwrap();
1093 let back: ScanResult = serde_json::from_str(&json).unwrap();
1094
1095 assert_eq!(back.input_hash, scan.input_hash);
1096 assert_eq!(back.detector_results.len(), 1);
1097 assert_eq!(back.aggregate_threat, ThreatCategory::PromptExtraction);
1098 }
1099
1100 #[test]
1103 fn aggregated_result_has_correct_strategy() {
1104 let result = DetectorResult::new("d", DetectorType::Classifier)
1105 .with_threat(ThreatCategory::Benign)
1106 .with_confidence(0.9)
1107 .with_severity(SecuritySeverity::Info);
1108
1109 let strategies_and_expected = vec![
1110 AggregationStrategy::MajorityVote,
1111 AggregationStrategy::WeightedVote,
1112 AggregationStrategy::Conservative,
1113 AggregationStrategy::Permissive,
1114 AggregationStrategy::Cascade,
1115 ];
1116
1117 for strategy in strategies_and_expected {
1118 let agg = ResultAggregator::new(strategy.clone());
1119 let out = agg.aggregate(std::slice::from_ref(&result));
1120 assert_eq!(out.strategy_used, strategy);
1121 }
1122 }
1123}