Skip to main content

llmtrace_security/
result_parser.rs

1//! Standardized tool result parsing (R-AS-01).
2//!
3//! Diverse security detectors (classifiers, heuristics, LLM judges, etc.)
4//! produce wildly different output formats. This module normalises those
5//! outputs into a unified [`DetectorResult`] schema, then aggregates multiple
6//! detector results into a single [`ScanResult`] using pluggable
7//! [`AggregationStrategy`] policies.
8//!
9//! # Quick start
10//!
11//! ```
12//! use llmtrace_security::result_parser::*;
13//! use llmtrace_core::SecuritySeverity;
14//!
15//! let r1 = DetectorResult::new("injecguard", DetectorType::Classifier)
16//!     .with_threat(ThreatCategory::InjectionDirect)
17//!     .with_confidence(0.92)
18//!     .with_severity(SecuritySeverity::High);
19//!
20//! let r2 = DetectorResult::new("heuristic-v1", DetectorType::Heuristic)
21//!     .with_threat(ThreatCategory::Benign)
22//!     .with_confidence(0.85)
23//!     .with_severity(SecuritySeverity::Info);
24//!
25//! let aggregator = ResultAggregator::new(AggregationStrategy::MajorityVote);
26//! let agg = aggregator.aggregate(&[r1, r2]);
27//! ```
28
29use chrono::{DateTime, Utc};
30use llmtrace_core::{SecurityFinding, SecuritySeverity};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33
34// ---------------------------------------------------------------------------
35// Enums
36// ---------------------------------------------------------------------------
37
38/// Threat classification for a scanned input.
39#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub enum ThreatCategory {
41    InjectionDirect,
42    InjectionIndirect,
43    Jailbreak,
44    PiiLeak,
45    ToxicContent,
46    DataExfiltration,
47    PromptExtraction,
48    CodeExecution,
49    Benign,
50}
51
52impl ThreatCategory {
53    /// Map the category to a human-readable finding_type string for
54    /// `SecurityFinding`.
55    #[must_use]
56    pub fn as_finding_type(&self) -> &'static str {
57        match self {
58            Self::InjectionDirect => "prompt_injection_direct",
59            Self::InjectionIndirect => "prompt_injection_indirect",
60            Self::Jailbreak => "jailbreak",
61            Self::PiiLeak => "pii_leak",
62            Self::ToxicContent => "toxic_content",
63            Self::DataExfiltration => "data_exfiltration",
64            Self::PromptExtraction => "prompt_extraction",
65            Self::CodeExecution => "code_execution",
66            Self::Benign => "benign",
67        }
68    }
69
70    /// True when the category represents an actual threat (not benign).
71    #[must_use]
72    pub fn is_threat(&self) -> bool {
73        *self != Self::Benign
74    }
75}
76
77/// Kind of security detector that produced a result.
78#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
79pub enum DetectorType {
80    Classifier,
81    Heuristic,
82    Semantic,
83    LlmJudge,
84    Ensemble,
85    Canary,
86}
87
88/// Strategy for combining multiple detector results.
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
90pub enum AggregationStrategy {
91    /// Simple majority of detectors determines outcome.
92    MajorityVote,
93    /// Detectors contribute according to per-detector weights.
94    WeightedVote,
95    /// Any single detector flagging a threat causes a threat verdict.
96    Conservative,
97    /// All detectors must agree on the same threat to flag it.
98    Permissive,
99    /// First result above the confidence threshold wins.
100    Cascade,
101}
102
103// ---------------------------------------------------------------------------
104// DetectorResult
105// ---------------------------------------------------------------------------
106
107/// Normalised output from a single security detector.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct DetectorResult {
110    pub detector_name: String,
111    pub detector_type: DetectorType,
112    pub threat_category: ThreatCategory,
113    pub severity: SecuritySeverity,
114    pub confidence: f64,
115    pub raw_output: Option<String>,
116    pub metadata: HashMap<String, String>,
117    pub latency_ms: Option<u64>,
118}
119
120impl DetectorResult {
121    /// Start building a result for the given detector.
122    #[must_use]
123    pub fn new(name: &str, detector_type: DetectorType) -> Self {
124        Self {
125            detector_name: name.to_string(),
126            detector_type,
127            threat_category: ThreatCategory::Benign,
128            severity: SecuritySeverity::Info,
129            confidence: 0.0,
130            raw_output: None,
131            metadata: HashMap::new(),
132            latency_ms: None,
133        }
134    }
135
136    #[must_use]
137    pub fn with_threat(mut self, cat: ThreatCategory) -> Self {
138        self.threat_category = cat;
139        self
140    }
141
142    #[must_use]
143    pub fn with_confidence(mut self, c: f64) -> Self {
144        self.confidence = c;
145        self
146    }
147
148    #[must_use]
149    pub fn with_severity(mut self, s: SecuritySeverity) -> Self {
150        self.severity = s;
151        self
152    }
153
154    #[must_use]
155    pub fn with_raw_output(mut self, raw: String) -> Self {
156        self.raw_output = Some(raw);
157        self
158    }
159
160    #[must_use]
161    pub fn with_metadata(mut self, key: &str, value: &str) -> Self {
162        self.metadata.insert(key.to_string(), value.to_string());
163        self
164    }
165
166    #[must_use]
167    pub fn with_latency_ms(mut self, ms: u64) -> Self {
168        self.latency_ms = Some(ms);
169        self
170    }
171}
172
173impl From<&DetectorResult> for SecurityFinding {
174    fn from(dr: &DetectorResult) -> Self {
175        let desc = format!(
176            "[{}] detected {} (confidence {:.2})",
177            dr.detector_name,
178            dr.threat_category.as_finding_type(),
179            dr.confidence,
180        );
181        let mut finding = SecurityFinding::new(
182            dr.severity.clone(),
183            dr.threat_category.as_finding_type().to_string(),
184            desc,
185            dr.confidence,
186        );
187        for (k, v) in &dr.metadata {
188            finding = finding.with_metadata(k.clone(), v.clone());
189        }
190        finding = finding.with_metadata("detector_name".to_string(), dr.detector_name.clone());
191        finding
192    }
193}
194
195// ---------------------------------------------------------------------------
196// AggregatedResult
197// ---------------------------------------------------------------------------
198
199/// The outcome of combining multiple detector results.
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct AggregatedResult {
202    pub threat_category: ThreatCategory,
203    pub confidence: f64,
204    pub severity: SecuritySeverity,
205    pub contributing_detectors: Vec<String>,
206    pub strategy_used: AggregationStrategy,
207}
208
209// ---------------------------------------------------------------------------
210// ResultAggregator
211// ---------------------------------------------------------------------------
212
213/// Combines multiple [`DetectorResult`]s into a single verdict.
214#[derive(Debug, Clone)]
215pub struct ResultAggregator {
216    pub strategy: AggregationStrategy,
217    pub confidence_threshold: f64,
218    pub detector_weights: HashMap<String, f64>,
219}
220
221impl ResultAggregator {
222    #[must_use]
223    pub fn new(strategy: AggregationStrategy) -> Self {
224        Self {
225            strategy,
226            confidence_threshold: 0.5,
227            detector_weights: HashMap::new(),
228        }
229    }
230
231    #[must_use]
232    pub fn with_weights(strategy: AggregationStrategy, weights: HashMap<String, f64>) -> Self {
233        Self {
234            strategy,
235            confidence_threshold: 0.5,
236            detector_weights: weights,
237        }
238    }
239
240    #[must_use]
241    pub fn with_threshold(mut self, threshold: f64) -> Self {
242        self.confidence_threshold = threshold;
243        self
244    }
245
246    /// Dispatch to the appropriate strategy implementation.
247    #[must_use]
248    pub fn aggregate(&self, results: &[DetectorResult]) -> AggregatedResult {
249        match self.strategy {
250            AggregationStrategy::MajorityVote => self.aggregate_majority_vote(results),
251            AggregationStrategy::WeightedVote => self.aggregate_weighted_vote(results),
252            AggregationStrategy::Conservative => self.aggregate_conservative(results),
253            AggregationStrategy::Permissive => self.aggregate_permissive(results),
254            AggregationStrategy::Cascade => self.aggregate_cascade(results),
255        }
256    }
257
258    /// Majority vote: the most commonly reported threat category wins.
259    /// Only categories above `confidence_threshold` participate.
260    #[must_use]
261    pub fn aggregate_majority_vote(&self, results: &[DetectorResult]) -> AggregatedResult {
262        if results.is_empty() {
263            return benign_aggregated(AggregationStrategy::MajorityVote);
264        }
265
266        let eligible: Vec<&DetectorResult> = results
267            .iter()
268            .filter(|r| r.confidence >= self.confidence_threshold)
269            .collect();
270
271        if eligible.is_empty() {
272            return benign_aggregated(AggregationStrategy::MajorityVote);
273        }
274
275        let mut votes: HashMap<&ThreatCategory, (usize, f64, &SecuritySeverity, Vec<&str>)> =
276            HashMap::new();
277        for r in &eligible {
278            let entry = votes.entry(&r.threat_category).or_insert((
279                0,
280                0.0,
281                &SecuritySeverity::Info,
282                Vec::new(),
283            ));
284            entry.0 += 1;
285            entry.1 += r.confidence;
286            if r.severity > *entry.2 {
287                entry.2 = &r.severity;
288            }
289            entry.3.push(&r.detector_name);
290        }
291
292        let winner = votes
293            .iter()
294            .max_by_key(|(_, (count, _, _, _))| *count)
295            .unwrap();
296
297        let (cat, (count, conf_sum, sev, names)) = winner;
298        AggregatedResult {
299            threat_category: (*cat).clone(),
300            confidence: conf_sum / *count as f64,
301            severity: (*sev).clone(),
302            contributing_detectors: names.iter().map(|s| (*s).to_string()).collect(),
303            strategy_used: AggregationStrategy::MajorityVote,
304        }
305    }
306
307    /// Weighted vote: each detector's confidence is scaled by its weight.
308    /// The threat category with the highest total weighted score wins.
309    #[must_use]
310    pub fn aggregate_weighted_vote(&self, results: &[DetectorResult]) -> AggregatedResult {
311        if results.is_empty() {
312            return benign_aggregated(AggregationStrategy::WeightedVote);
313        }
314
315        let mut scores: HashMap<&ThreatCategory, (f64, &SecuritySeverity, Vec<&str>)> =
316            HashMap::new();
317        for r in results {
318            let w = self
319                .detector_weights
320                .get(&r.detector_name)
321                .copied()
322                .unwrap_or(1.0);
323            let entry = scores.entry(&r.threat_category).or_insert((
324                0.0,
325                &SecuritySeverity::Info,
326                Vec::new(),
327            ));
328            entry.0 += r.confidence * w;
329            if r.severity > *entry.1 {
330                entry.1 = &r.severity;
331            }
332            entry.2.push(&r.detector_name);
333        }
334
335        let winner = scores
336            .iter()
337            .max_by(|a, b| a.1 .0.partial_cmp(&b.1 .0).unwrap())
338            .unwrap();
339
340        let (cat, (score, sev, names)) = winner;
341        let total_weight: f64 = names
342            .iter()
343            .map(|n| self.detector_weights.get(*n).copied().unwrap_or(1.0))
344            .sum();
345        let avg_conf = if total_weight > 0.0 {
346            score / total_weight
347        } else {
348            0.0
349        };
350
351        AggregatedResult {
352            threat_category: (*cat).clone(),
353            confidence: avg_conf,
354            severity: (*sev).clone(),
355            contributing_detectors: names.iter().map(|s| (*s).to_string()).collect(),
356            strategy_used: AggregationStrategy::WeightedVote,
357        }
358    }
359
360    /// Conservative: any detector flagging a threat above threshold -> threat.
361    /// Picks the highest-severity threat found.
362    #[must_use]
363    pub fn aggregate_conservative(&self, results: &[DetectorResult]) -> AggregatedResult {
364        if results.is_empty() {
365            return benign_aggregated(AggregationStrategy::Conservative);
366        }
367
368        let threats: Vec<&DetectorResult> = results
369            .iter()
370            .filter(|r| r.threat_category.is_threat() && r.confidence >= self.confidence_threshold)
371            .collect();
372
373        if threats.is_empty() {
374            return AggregatedResult {
375                threat_category: ThreatCategory::Benign,
376                confidence: results.iter().map(|r| 1.0 - r.confidence).product(),
377                severity: SecuritySeverity::Info,
378                contributing_detectors: Vec::new(),
379                strategy_used: AggregationStrategy::Conservative,
380            };
381        }
382
383        let worst = threats
384            .iter()
385            .max_by(|a, b| a.severity.cmp(&b.severity))
386            .unwrap();
387
388        AggregatedResult {
389            threat_category: worst.threat_category.clone(),
390            confidence: threats.iter().map(|r| r.confidence).fold(0.0_f64, f64::max),
391            severity: worst.severity.clone(),
392            contributing_detectors: threats.iter().map(|r| r.detector_name.clone()).collect(),
393            strategy_used: AggregationStrategy::Conservative,
394        }
395    }
396
397    /// Permissive: all detectors above threshold must agree on the same
398    /// threat category for it to be flagged.
399    #[must_use]
400    pub fn aggregate_permissive(&self, results: &[DetectorResult]) -> AggregatedResult {
401        if results.is_empty() {
402            return benign_aggregated(AggregationStrategy::Permissive);
403        }
404
405        let eligible: Vec<&DetectorResult> = results
406            .iter()
407            .filter(|r| r.confidence >= self.confidence_threshold)
408            .collect();
409
410        if eligible.is_empty() {
411            return benign_aggregated(AggregationStrategy::Permissive);
412        }
413
414        let first_cat = &eligible[0].threat_category;
415        let all_agree = eligible.iter().all(|r| r.threat_category == *first_cat);
416
417        if !all_agree || !first_cat.is_threat() {
418            return AggregatedResult {
419                threat_category: ThreatCategory::Benign,
420                confidence: eligible.iter().map(|r| r.confidence).sum::<f64>()
421                    / eligible.len() as f64,
422                severity: SecuritySeverity::Info,
423                contributing_detectors: Vec::new(),
424                strategy_used: AggregationStrategy::Permissive,
425            };
426        }
427
428        let max_sev = eligible.iter().map(|r| &r.severity).max().unwrap();
429
430        AggregatedResult {
431            threat_category: first_cat.clone(),
432            confidence: eligible.iter().map(|r| r.confidence).sum::<f64>() / eligible.len() as f64,
433            severity: max_sev.clone(),
434            contributing_detectors: eligible.iter().map(|r| r.detector_name.clone()).collect(),
435            strategy_used: AggregationStrategy::Permissive,
436        }
437    }
438
439    /// Cascade: the first result whose confidence exceeds the threshold wins.
440    /// Order is determined by input slice order.
441    #[must_use]
442    pub fn aggregate_cascade(&self, results: &[DetectorResult]) -> AggregatedResult {
443        for r in results {
444            if r.confidence >= self.confidence_threshold {
445                return AggregatedResult {
446                    threat_category: r.threat_category.clone(),
447                    confidence: r.confidence,
448                    severity: r.severity.clone(),
449                    contributing_detectors: vec![r.detector_name.clone()],
450                    strategy_used: AggregationStrategy::Cascade,
451                };
452            }
453        }
454        benign_aggregated(AggregationStrategy::Cascade)
455    }
456}
457
458/// Helper: returns a Benign aggregated result for the given strategy.
459#[must_use]
460fn benign_aggregated(strategy: AggregationStrategy) -> AggregatedResult {
461    AggregatedResult {
462        threat_category: ThreatCategory::Benign,
463        confidence: 0.0,
464        severity: SecuritySeverity::Info,
465        contributing_detectors: Vec::new(),
466        strategy_used: strategy,
467    }
468}
469
470// ---------------------------------------------------------------------------
471// ScanResult + builder
472// ---------------------------------------------------------------------------
473
474/// Final scan output combining all detector results and their aggregation.
475#[derive(Debug, Clone, Serialize, Deserialize)]
476pub struct ScanResult {
477    pub input_hash: String,
478    pub detector_results: Vec<DetectorResult>,
479    pub aggregate_threat: ThreatCategory,
480    pub aggregate_confidence: f64,
481    pub aggregate_severity: SecuritySeverity,
482    pub scan_duration_ms: u64,
483    pub timestamp: DateTime<Utc>,
484}
485
486impl ScanResult {
487    /// Convert every detector result into a `SecurityFinding`.
488    /// Only non-benign results are included.
489    #[must_use]
490    pub fn to_security_findings(&self) -> Vec<SecurityFinding> {
491        self.detector_results
492            .iter()
493            .filter(|r| r.threat_category.is_threat())
494            .map(SecurityFinding::from)
495            .collect()
496    }
497}
498
499/// Incrementally builds a [`ScanResult`] from individual detector results.
500pub struct ScanResultBuilder {
501    input_hash: String,
502    results: Vec<DetectorResult>,
503    start: std::time::Instant,
504}
505
506impl ScanResultBuilder {
507    /// Create a new builder, computing the input hash immediately.
508    #[must_use]
509    pub fn new(input_text: &str) -> Self {
510        Self {
511            input_hash: compute_input_hash(input_text),
512            results: Vec::new(),
513            start: std::time::Instant::now(),
514        }
515    }
516
517    /// Append a detector result.
518    pub fn add_result(&mut self, result: DetectorResult) -> &mut Self {
519        self.results.push(result);
520        self
521    }
522
523    /// Finalise the scan using the provided aggregator.
524    #[must_use]
525    pub fn build(self, aggregator: &ResultAggregator) -> ScanResult {
526        let agg = aggregator.aggregate(&self.results);
527        let duration = self.start.elapsed().as_millis() as u64;
528        ScanResult {
529            input_hash: self.input_hash,
530            detector_results: self.results,
531            aggregate_threat: agg.threat_category,
532            aggregate_confidence: agg.confidence,
533            aggregate_severity: agg.severity,
534            scan_duration_ms: duration,
535            timestamp: Utc::now(),
536        }
537    }
538}
539
540// ---------------------------------------------------------------------------
541// Helpers
542// ---------------------------------------------------------------------------
543
544/// Deterministic hash of input text.
545/// Uses a simple multiplicative hash folded to 16 hex characters.
546#[must_use]
547pub fn compute_input_hash(input: &str) -> String {
548    let h = input.bytes().fold(0u64, |acc, b| {
549        acc.wrapping_mul(31).wrapping_add(u64::from(b))
550    });
551    format!("{h:016x}")
552}
553
554// ---------------------------------------------------------------------------
555// Tests
556// ---------------------------------------------------------------------------
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    // -- ThreatCategory --
563
564    #[test]
565    fn threat_category_serde_roundtrip() {
566        let cats = vec![
567            ThreatCategory::InjectionDirect,
568            ThreatCategory::InjectionIndirect,
569            ThreatCategory::Jailbreak,
570            ThreatCategory::PiiLeak,
571            ThreatCategory::ToxicContent,
572            ThreatCategory::DataExfiltration,
573            ThreatCategory::PromptExtraction,
574            ThreatCategory::CodeExecution,
575            ThreatCategory::Benign,
576        ];
577        for cat in cats {
578            let json = serde_json::to_string(&cat).unwrap();
579            let back: ThreatCategory = serde_json::from_str(&json).unwrap();
580            assert_eq!(cat, back);
581        }
582    }
583
584    #[test]
585    fn threat_category_is_threat() {
586        assert!(ThreatCategory::InjectionDirect.is_threat());
587        assert!(ThreatCategory::Jailbreak.is_threat());
588        assert!(!ThreatCategory::Benign.is_threat());
589    }
590
591    #[test]
592    fn threat_category_finding_type_mapping() {
593        assert_eq!(
594            ThreatCategory::InjectionDirect.as_finding_type(),
595            "prompt_injection_direct"
596        );
597        assert_eq!(ThreatCategory::Benign.as_finding_type(), "benign");
598    }
599
600    // -- DetectorResult --
601
602    #[test]
603    fn detector_result_construction_and_metadata() {
604        let r = DetectorResult::new("test-det", DetectorType::Classifier)
605            .with_threat(ThreatCategory::Jailbreak)
606            .with_confidence(0.95)
607            .with_severity(SecuritySeverity::High)
608            .with_raw_output("raw output".to_string())
609            .with_metadata("model", "v2")
610            .with_latency_ms(42);
611
612        assert_eq!(r.detector_name, "test-det");
613        assert_eq!(r.detector_type, DetectorType::Classifier);
614        assert_eq!(r.threat_category, ThreatCategory::Jailbreak);
615        assert_eq!(r.severity, SecuritySeverity::High);
616        assert!((r.confidence - 0.95).abs() < f64::EPSILON);
617        assert_eq!(r.raw_output.as_deref(), Some("raw output"));
618        assert_eq!(r.metadata.get("model").unwrap(), "v2");
619        assert_eq!(r.latency_ms, Some(42));
620    }
621
622    #[test]
623    fn detector_result_serde_roundtrip() {
624        let r = DetectorResult::new("d1", DetectorType::Semantic)
625            .with_threat(ThreatCategory::PiiLeak)
626            .with_confidence(0.77)
627            .with_severity(SecuritySeverity::Medium);
628
629        let json = serde_json::to_string(&r).unwrap();
630        let back: DetectorResult = serde_json::from_str(&json).unwrap();
631        assert_eq!(back.detector_name, "d1");
632        assert_eq!(back.threat_category, ThreatCategory::PiiLeak);
633        assert!((back.confidence - 0.77).abs() < f64::EPSILON);
634    }
635
636    // -- SecurityFinding conversion --
637
638    #[test]
639    fn detector_result_to_security_finding() {
640        let r = DetectorResult::new("ig", DetectorType::Classifier)
641            .with_threat(ThreatCategory::InjectionDirect)
642            .with_confidence(0.88)
643            .with_severity(SecuritySeverity::Critical)
644            .with_metadata("source", "unit-test");
645
646        let finding: SecurityFinding = SecurityFinding::from(&r);
647        assert_eq!(finding.severity, SecuritySeverity::Critical);
648        assert_eq!(finding.finding_type, "prompt_injection_direct");
649        assert!((finding.confidence_score - 0.88).abs() < f64::EPSILON);
650        assert_eq!(finding.metadata.get("detector_name").unwrap(), "ig");
651        assert_eq!(finding.metadata.get("source").unwrap(), "unit-test");
652        assert!(finding.requires_alert);
653    }
654
655    // -- Input hashing --
656
657    #[test]
658    fn input_hash_deterministic() {
659        let h1 = compute_input_hash("hello world");
660        let h2 = compute_input_hash("hello world");
661        assert_eq!(h1, h2);
662        assert_eq!(h1.len(), 16);
663    }
664
665    #[test]
666    fn input_hash_differs_for_different_inputs() {
667        let h1 = compute_input_hash("input A");
668        let h2 = compute_input_hash("input B");
669        assert_ne!(h1, h2);
670    }
671
672    #[test]
673    fn input_hash_empty_string() {
674        let h = compute_input_hash("");
675        assert_eq!(h.len(), 16);
676        assert_eq!(h, "0000000000000000");
677    }
678
679    // -- MajorityVote aggregation --
680
681    #[test]
682    fn majority_vote_single_threat_wins() {
683        let results = vec![
684            DetectorResult::new("a", DetectorType::Classifier)
685                .with_threat(ThreatCategory::Jailbreak)
686                .with_confidence(0.9)
687                .with_severity(SecuritySeverity::High),
688            DetectorResult::new("b", DetectorType::Heuristic)
689                .with_threat(ThreatCategory::Jailbreak)
690                .with_confidence(0.8)
691                .with_severity(SecuritySeverity::Medium),
692            DetectorResult::new("c", DetectorType::Semantic)
693                .with_threat(ThreatCategory::Benign)
694                .with_confidence(0.7)
695                .with_severity(SecuritySeverity::Info),
696        ];
697
698        let agg = ResultAggregator::new(AggregationStrategy::MajorityVote);
699        let out = agg.aggregate(&results);
700
701        assert_eq!(out.threat_category, ThreatCategory::Jailbreak);
702        assert_eq!(out.strategy_used, AggregationStrategy::MajorityVote);
703        assert_eq!(out.contributing_detectors.len(), 2);
704        assert_eq!(out.severity, SecuritySeverity::High);
705    }
706
707    #[test]
708    fn majority_vote_all_benign() {
709        let results = vec![
710            DetectorResult::new("a", DetectorType::Classifier)
711                .with_threat(ThreatCategory::Benign)
712                .with_confidence(0.9)
713                .with_severity(SecuritySeverity::Info),
714            DetectorResult::new("b", DetectorType::Heuristic)
715                .with_threat(ThreatCategory::Benign)
716                .with_confidence(0.85)
717                .with_severity(SecuritySeverity::Info),
718        ];
719
720        let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&results);
721        assert_eq!(out.threat_category, ThreatCategory::Benign);
722    }
723
724    #[test]
725    fn majority_vote_below_threshold_ignored() {
726        let results = vec![
727            DetectorResult::new("a", DetectorType::Classifier)
728                .with_threat(ThreatCategory::Jailbreak)
729                .with_confidence(0.3)
730                .with_severity(SecuritySeverity::High),
731            DetectorResult::new("b", DetectorType::Heuristic)
732                .with_threat(ThreatCategory::Benign)
733                .with_confidence(0.8)
734                .with_severity(SecuritySeverity::Info),
735        ];
736
737        let agg = ResultAggregator::new(AggregationStrategy::MajorityVote).with_threshold(0.5);
738        let out = agg.aggregate(&results);
739        // Only the benign result is above threshold -> benign wins
740        assert_eq!(out.threat_category, ThreatCategory::Benign);
741    }
742
743    #[test]
744    fn majority_vote_mixed_threats() {
745        let results = vec![
746            DetectorResult::new("a", DetectorType::Classifier)
747                .with_threat(ThreatCategory::InjectionDirect)
748                .with_confidence(0.9)
749                .with_severity(SecuritySeverity::High),
750            DetectorResult::new("b", DetectorType::Heuristic)
751                .with_threat(ThreatCategory::Jailbreak)
752                .with_confidence(0.8)
753                .with_severity(SecuritySeverity::Medium),
754            DetectorResult::new("c", DetectorType::LlmJudge)
755                .with_threat(ThreatCategory::InjectionDirect)
756                .with_confidence(0.85)
757                .with_severity(SecuritySeverity::Critical),
758        ];
759
760        let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&results);
761        assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
762        assert_eq!(out.severity, SecuritySeverity::Critical);
763        assert_eq!(out.contributing_detectors.len(), 2);
764    }
765
766    // -- WeightedVote aggregation --
767
768    #[test]
769    fn weighted_vote_respects_weights() {
770        let mut weights = HashMap::new();
771        weights.insert("trusted".to_string(), 5.0);
772        weights.insert("weak".to_string(), 1.0);
773
774        let results = vec![
775            DetectorResult::new("trusted", DetectorType::Ensemble)
776                .with_threat(ThreatCategory::InjectionDirect)
777                .with_confidence(0.7)
778                .with_severity(SecuritySeverity::High),
779            DetectorResult::new("weak", DetectorType::Heuristic)
780                .with_threat(ThreatCategory::Benign)
781                .with_confidence(0.9)
782                .with_severity(SecuritySeverity::Info),
783        ];
784
785        let agg = ResultAggregator::with_weights(AggregationStrategy::WeightedVote, weights);
786        let out = agg.aggregate(&results);
787
788        // trusted: 0.7 * 5.0 = 3.5  vs  weak: 0.9 * 1.0 = 0.9
789        assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
790    }
791
792    #[test]
793    fn weighted_vote_default_weight_is_one() {
794        // No explicit weights -> all detectors get weight 1.0
795        let results = vec![
796            DetectorResult::new("a", DetectorType::Classifier)
797                .with_threat(ThreatCategory::Jailbreak)
798                .with_confidence(0.8)
799                .with_severity(SecuritySeverity::High),
800            DetectorResult::new("b", DetectorType::Heuristic)
801                .with_threat(ThreatCategory::Jailbreak)
802                .with_confidence(0.7)
803                .with_severity(SecuritySeverity::Medium),
804        ];
805
806        let agg = ResultAggregator::new(AggregationStrategy::WeightedVote);
807        let out = agg.aggregate(&results);
808        assert_eq!(out.threat_category, ThreatCategory::Jailbreak);
809        // avg confidence = (0.8 + 0.7) / 2.0 = 0.75
810        assert!((out.confidence - 0.75).abs() < f64::EPSILON);
811    }
812
813    // -- Conservative aggregation --
814
815    #[test]
816    fn conservative_any_threat_flags() {
817        let results = vec![
818            DetectorResult::new("a", DetectorType::Classifier)
819                .with_threat(ThreatCategory::Benign)
820                .with_confidence(0.95)
821                .with_severity(SecuritySeverity::Info),
822            DetectorResult::new("b", DetectorType::Canary)
823                .with_threat(ThreatCategory::PromptExtraction)
824                .with_confidence(0.6)
825                .with_severity(SecuritySeverity::Critical),
826        ];
827
828        let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&results);
829        assert_eq!(out.threat_category, ThreatCategory::PromptExtraction);
830        assert_eq!(out.severity, SecuritySeverity::Critical);
831        assert_eq!(out.contributing_detectors, vec!["b"]);
832    }
833
834    #[test]
835    fn conservative_all_benign_returns_benign() {
836        let results = vec![DetectorResult::new("a", DetectorType::Classifier)
837            .with_threat(ThreatCategory::Benign)
838            .with_confidence(0.99)
839            .with_severity(SecuritySeverity::Info)];
840
841        let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&results);
842        assert_eq!(out.threat_category, ThreatCategory::Benign);
843    }
844
845    #[test]
846    fn conservative_below_threshold_ignored() {
847        let results = vec![DetectorResult::new("a", DetectorType::Classifier)
848            .with_threat(ThreatCategory::Jailbreak)
849            .with_confidence(0.3)
850            .with_severity(SecuritySeverity::High)];
851
852        let agg = ResultAggregator::new(AggregationStrategy::Conservative).with_threshold(0.5);
853        let out = agg.aggregate(&results);
854        assert_eq!(out.threat_category, ThreatCategory::Benign);
855    }
856
857    // -- Permissive aggregation --
858
859    #[test]
860    fn permissive_all_agree_on_threat() {
861        let results = vec![
862            DetectorResult::new("a", DetectorType::Classifier)
863                .with_threat(ThreatCategory::DataExfiltration)
864                .with_confidence(0.8)
865                .with_severity(SecuritySeverity::High),
866            DetectorResult::new("b", DetectorType::Semantic)
867                .with_threat(ThreatCategory::DataExfiltration)
868                .with_confidence(0.75)
869                .with_severity(SecuritySeverity::Medium),
870        ];
871
872        let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
873        assert_eq!(out.threat_category, ThreatCategory::DataExfiltration);
874        assert_eq!(out.contributing_detectors.len(), 2);
875    }
876
877    #[test]
878    fn permissive_disagreement_returns_benign() {
879        let results = vec![
880            DetectorResult::new("a", DetectorType::Classifier)
881                .with_threat(ThreatCategory::Jailbreak)
882                .with_confidence(0.9)
883                .with_severity(SecuritySeverity::High),
884            DetectorResult::new("b", DetectorType::Heuristic)
885                .with_threat(ThreatCategory::InjectionDirect)
886                .with_confidence(0.85)
887                .with_severity(SecuritySeverity::High),
888        ];
889
890        let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
891        assert_eq!(out.threat_category, ThreatCategory::Benign);
892        assert!(out.contributing_detectors.is_empty());
893    }
894
895    #[test]
896    fn permissive_all_agree_benign() {
897        let results = vec![
898            DetectorResult::new("a", DetectorType::Classifier)
899                .with_threat(ThreatCategory::Benign)
900                .with_confidence(0.9)
901                .with_severity(SecuritySeverity::Info),
902            DetectorResult::new("b", DetectorType::Heuristic)
903                .with_threat(ThreatCategory::Benign)
904                .with_confidence(0.85)
905                .with_severity(SecuritySeverity::Info),
906        ];
907
908        let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&results);
909        assert_eq!(out.threat_category, ThreatCategory::Benign);
910    }
911
912    // -- Cascade aggregation --
913
914    #[test]
915    fn cascade_first_high_confidence_wins() {
916        let results = vec![
917            DetectorResult::new("fast", DetectorType::Heuristic)
918                .with_threat(ThreatCategory::ToxicContent)
919                .with_confidence(0.3)
920                .with_severity(SecuritySeverity::Low),
921            DetectorResult::new("accurate", DetectorType::Classifier)
922                .with_threat(ThreatCategory::InjectionDirect)
923                .with_confidence(0.85)
924                .with_severity(SecuritySeverity::High),
925            DetectorResult::new("slow", DetectorType::LlmJudge)
926                .with_threat(ThreatCategory::Jailbreak)
927                .with_confidence(0.99)
928                .with_severity(SecuritySeverity::Critical),
929        ];
930
931        let agg = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
932        let out = agg.aggregate(&results);
933        // First result above 0.5 is "accurate"
934        assert_eq!(out.threat_category, ThreatCategory::InjectionDirect);
935        assert_eq!(out.contributing_detectors, vec!["accurate"]);
936    }
937
938    #[test]
939    fn cascade_none_above_threshold_returns_benign() {
940        let results = vec![DetectorResult::new("a", DetectorType::Heuristic)
941            .with_threat(ThreatCategory::Jailbreak)
942            .with_confidence(0.2)
943            .with_severity(SecuritySeverity::High)];
944
945        let agg = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
946        let out = agg.aggregate(&results);
947        assert_eq!(out.threat_category, ThreatCategory::Benign);
948    }
949
950    // -- Empty results --
951
952    #[test]
953    fn empty_results_majority_vote() {
954        let out = ResultAggregator::new(AggregationStrategy::MajorityVote).aggregate(&[]);
955        assert_eq!(out.threat_category, ThreatCategory::Benign);
956        assert!((out.confidence - 0.0).abs() < f64::EPSILON);
957    }
958
959    #[test]
960    fn empty_results_conservative() {
961        let out = ResultAggregator::new(AggregationStrategy::Conservative).aggregate(&[]);
962        assert_eq!(out.threat_category, ThreatCategory::Benign);
963    }
964
965    #[test]
966    fn empty_results_permissive() {
967        let out = ResultAggregator::new(AggregationStrategy::Permissive).aggregate(&[]);
968        assert_eq!(out.threat_category, ThreatCategory::Benign);
969    }
970
971    #[test]
972    fn empty_results_cascade() {
973        let out = ResultAggregator::new(AggregationStrategy::Cascade).aggregate(&[]);
974        assert_eq!(out.threat_category, ThreatCategory::Benign);
975    }
976
977    #[test]
978    fn empty_results_weighted_vote() {
979        let out = ResultAggregator::new(AggregationStrategy::WeightedVote).aggregate(&[]);
980        assert_eq!(out.threat_category, ThreatCategory::Benign);
981    }
982
983    // -- ScanResultBuilder --
984
985    #[test]
986    fn scan_result_builder_basic() {
987        let mut builder = ScanResultBuilder::new("test input");
988        builder.add_result(
989            DetectorResult::new("d1", DetectorType::Classifier)
990                .with_threat(ThreatCategory::Jailbreak)
991                .with_confidence(0.9)
992                .with_severity(SecuritySeverity::High),
993        );
994        builder.add_result(
995            DetectorResult::new("d2", DetectorType::Heuristic)
996                .with_threat(ThreatCategory::Benign)
997                .with_confidence(0.7)
998                .with_severity(SecuritySeverity::Info),
999        );
1000
1001        let aggregator = ResultAggregator::new(AggregationStrategy::MajorityVote);
1002        let scan = builder.build(&aggregator);
1003
1004        assert_eq!(scan.input_hash, compute_input_hash("test input"));
1005        assert_eq!(scan.detector_results.len(), 2);
1006        assert!(scan.timestamp <= Utc::now());
1007    }
1008
1009    #[test]
1010    fn scan_result_to_security_findings_excludes_benign() {
1011        let mut builder = ScanResultBuilder::new("probe");
1012        builder.add_result(
1013            DetectorResult::new("d1", DetectorType::Classifier)
1014                .with_threat(ThreatCategory::InjectionDirect)
1015                .with_confidence(0.88)
1016                .with_severity(SecuritySeverity::High),
1017        );
1018        builder.add_result(
1019            DetectorResult::new("d2", DetectorType::Heuristic)
1020                .with_threat(ThreatCategory::Benign)
1021                .with_confidence(0.7)
1022                .with_severity(SecuritySeverity::Info),
1023        );
1024
1025        let aggregator = ResultAggregator::new(AggregationStrategy::Conservative);
1026        let scan = builder.build(&aggregator);
1027        let findings = scan.to_security_findings();
1028
1029        assert_eq!(findings.len(), 1);
1030        assert_eq!(findings[0].finding_type, "prompt_injection_direct");
1031    }
1032
1033    // -- Single detector --
1034
1035    #[test]
1036    fn single_detector_all_strategies() {
1037        let result = DetectorResult::new("sole", DetectorType::Classifier)
1038            .with_threat(ThreatCategory::CodeExecution)
1039            .with_confidence(0.91)
1040            .with_severity(SecuritySeverity::Critical);
1041
1042        let strategies = vec![
1043            AggregationStrategy::MajorityVote,
1044            AggregationStrategy::WeightedVote,
1045            AggregationStrategy::Conservative,
1046            AggregationStrategy::Permissive,
1047            AggregationStrategy::Cascade,
1048        ];
1049
1050        for s in strategies {
1051            let agg = ResultAggregator::new(s.clone());
1052            let out = agg.aggregate(std::slice::from_ref(&result));
1053            assert_eq!(out.threat_category, ThreatCategory::CodeExecution);
1054            assert_eq!(out.severity, SecuritySeverity::Critical);
1055        }
1056    }
1057
1058    // -- Confidence threshold enforcement --
1059
1060    #[test]
1061    fn custom_threshold_enforcement() {
1062        let result = DetectorResult::new("d", DetectorType::Classifier)
1063            .with_threat(ThreatCategory::Jailbreak)
1064            .with_confidence(0.6)
1065            .with_severity(SecuritySeverity::High);
1066
1067        // threshold 0.7 -> should be ignored as benign
1068        let agg_high = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.7);
1069        let out_high = agg_high.aggregate(std::slice::from_ref(&result));
1070        assert_eq!(out_high.threat_category, ThreatCategory::Benign);
1071
1072        // threshold 0.5 -> should be detected
1073        let agg_low = ResultAggregator::new(AggregationStrategy::Cascade).with_threshold(0.5);
1074        let out_low = agg_low.aggregate(&[result]);
1075        assert_eq!(out_low.threat_category, ThreatCategory::Jailbreak);
1076    }
1077
1078    // -- ScanResult serialization --
1079
1080    #[test]
1081    fn scan_result_serde_roundtrip() {
1082        let mut builder = ScanResultBuilder::new("serialize me");
1083        builder.add_result(
1084            DetectorResult::new("d1", DetectorType::Canary)
1085                .with_threat(ThreatCategory::PromptExtraction)
1086                .with_confidence(0.99)
1087                .with_severity(SecuritySeverity::Critical),
1088        );
1089        let aggregator = ResultAggregator::new(AggregationStrategy::Conservative);
1090        let scan = builder.build(&aggregator);
1091
1092        let json = serde_json::to_string(&scan).unwrap();
1093        let back: ScanResult = serde_json::from_str(&json).unwrap();
1094
1095        assert_eq!(back.input_hash, scan.input_hash);
1096        assert_eq!(back.detector_results.len(), 1);
1097        assert_eq!(back.aggregate_threat, ThreatCategory::PromptExtraction);
1098    }
1099
1100    // -- AggregatedResult fields --
1101
1102    #[test]
1103    fn aggregated_result_has_correct_strategy() {
1104        let result = DetectorResult::new("d", DetectorType::Classifier)
1105            .with_threat(ThreatCategory::Benign)
1106            .with_confidence(0.9)
1107            .with_severity(SecuritySeverity::Info);
1108
1109        let strategies_and_expected = vec![
1110            AggregationStrategy::MajorityVote,
1111            AggregationStrategy::WeightedVote,
1112            AggregationStrategy::Conservative,
1113            AggregationStrategy::Permissive,
1114            AggregationStrategy::Cascade,
1115        ];
1116
1117        for strategy in strategies_and_expected {
1118            let agg = ResultAggregator::new(strategy.clone());
1119            let out = agg.aggregate(std::slice::from_ref(&result));
1120            assert_eq!(out.strategy_used, strategy);
1121        }
1122    }
1123}