llm_shield_core/
result.rs

1//! Scan result types
2//!
3//! ## SPARC Specification
4//!
5//! Standardized result format for all scanners:
6//! - `ScanResult`: Main result structure
7//! - `Entity`: Detected entities (PII, secrets, etc.)
8//! - `RiskFactor`: Individual risk factors
9//! - `Severity`: Risk severity levels
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14/// Result of a security scan
15///
16/// ## Enterprise Design
17///
18/// - **Immutable**: Once created, cannot be modified
19/// - **Serializable**: Can be sent over network or stored
20/// - **Rich Metadata**: Includes detailed information for debugging
21/// - **Composable**: Multiple results can be combined
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct ScanResult {
24    /// The sanitized/modified text (if applicable)
25    pub sanitized_text: String,
26
27    /// Whether the input passed validation
28    pub is_valid: bool,
29
30    /// Risk score from 0.0 (no risk) to 1.0 (maximum risk)
31    pub risk_score: f32,
32
33    /// Detected entities (PII, secrets, banned content, etc.)
34    pub entities: Vec<Entity>,
35
36    /// Risk factors that contributed to the score
37    pub risk_factors: Vec<RiskFactor>,
38
39    /// Additional scanner-specific metadata
40    pub metadata: HashMap<String, serde_json::Value>,
41}
42
43impl ScanResult {
44    /// Create a new scan result
45    pub fn new(sanitized_text: String, is_valid: bool, risk_score: f32) -> Self {
46        Self {
47            sanitized_text,
48            is_valid,
49            risk_score,
50            entities: Vec::new(),
51            risk_factors: Vec::new(),
52            metadata: HashMap::new(),
53        }
54    }
55
56    /// Create a passing scan result (no risks detected)
57    pub fn pass(text: String) -> Self {
58        Self::new(text, true, 0.0)
59    }
60
61    /// Create a failing scan result with risk score
62    pub fn fail(text: String, risk_score: f32) -> Self {
63        Self::new(text, false, risk_score)
64    }
65
66    /// Add an entity to the result
67    pub fn with_entity(mut self, entity: Entity) -> Self {
68        self.entities.push(entity);
69        self
70    }
71
72    /// Add a risk factor to the result
73    pub fn with_risk_factor(mut self, factor: RiskFactor) -> Self {
74        self.risk_factors.push(factor);
75        self
76    }
77
78    /// Add metadata to the result
79    pub fn with_metadata<K: Into<String>, V: Serialize>(
80        mut self,
81        key: K,
82        value: V,
83    ) -> Self {
84        if let Ok(json_value) = serde_json::to_value(value) {
85            self.metadata.insert(key.into(), json_value);
86        }
87        self
88    }
89
90    /// Get the overall severity level
91    pub fn severity(&self) -> Severity {
92        if self.risk_score >= 0.9 {
93            Severity::Critical
94        } else if self.risk_score >= 0.7 {
95            Severity::High
96        } else if self.risk_score >= 0.4 {
97            Severity::Medium
98        } else if self.risk_score > 0.0 {
99            Severity::Low
100        } else {
101            Severity::None
102        }
103    }
104
105    /// Combine multiple scan results
106    ///
107    /// Takes the maximum risk score and merges entities
108    pub fn combine(results: Vec<ScanResult>) -> Self {
109        if results.is_empty() {
110            return Self::pass(String::new());
111        }
112
113        let max_risk = results
114            .iter()
115            .map(|r| r.risk_score)
116            .fold(0.0f32, f32::max);
117
118        let is_valid = results.iter().all(|r| r.is_valid);
119
120        let mut combined = Self::new(
121            results[0].sanitized_text.clone(),
122            is_valid,
123            max_risk,
124        );
125
126        for result in results {
127            combined.entities.extend(result.entities);
128            combined.risk_factors.extend(result.risk_factors);
129            for (k, v) in result.metadata {
130                combined.metadata.insert(k, v);
131            }
132        }
133
134        combined
135    }
136}
137
138/// A detected entity in the scanned text
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140pub struct Entity {
141    /// Type of entity (e.g., "email", "ssn", "api_key")
142    pub entity_type: String,
143
144    /// The detected text
145    pub text: String,
146
147    /// Start position in original text
148    pub start: usize,
149
150    /// End position in original text
151    pub end: usize,
152
153    /// Confidence score (0.0 to 1.0)
154    pub confidence: f32,
155
156    /// Additional entity-specific data
157    pub metadata: HashMap<String, String>,
158}
159
160impl Entity {
161    /// Create a new entity
162    pub fn new<S: Into<String>>(
163        entity_type: S,
164        text: S,
165        start: usize,
166        end: usize,
167        confidence: f32,
168    ) -> Self {
169        Self {
170            entity_type: entity_type.into(),
171            text: text.into(),
172            start,
173            end,
174            confidence,
175            metadata: HashMap::new(),
176        }
177    }
178}
179
180/// A risk factor contributing to the overall risk score
181#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct RiskFactor {
183    /// Factor type (e.g., "prompt_injection", "toxicity")
184    pub factor_type: String,
185
186    /// Human-readable description
187    pub description: String,
188
189    /// Severity level
190    pub severity: Severity,
191
192    /// Contribution to overall risk score
193    pub score_contribution: f32,
194}
195
196impl RiskFactor {
197    /// Create a new risk factor
198    pub fn new<S: Into<String>>(
199        factor_type: S,
200        description: S,
201        severity: Severity,
202        score_contribution: f32,
203    ) -> Self {
204        Self {
205            factor_type: factor_type.into(),
206            description: description.into(),
207            severity,
208            score_contribution,
209        }
210    }
211}
212
213/// Severity levels for risks
214#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
215#[serde(rename_all = "lowercase")]
216pub enum Severity {
217    /// No risk detected
218    None,
219    /// Low severity (0.0 < score < 0.4)
220    Low,
221    /// Medium severity (0.4 <= score < 0.7)
222    Medium,
223    /// High severity (0.7 <= score < 0.9)
224    High,
225    /// Critical severity (score >= 0.9)
226    Critical,
227}
228
229impl Severity {
230    /// Get numeric threshold for this severity
231    pub fn threshold(&self) -> f32 {
232        match self {
233            Severity::None => 0.0,
234            Severity::Low => 0.01,
235            Severity::Medium => 0.4,
236            Severity::High => 0.7,
237            Severity::Critical => 0.9,
238        }
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_scan_result_creation() {
248        let result = ScanResult::pass("test text".to_string());
249        assert!(result.is_valid);
250        assert_eq!(result.risk_score, 0.0);
251        assert_eq!(result.severity(), Severity::None);
252    }
253
254    #[test]
255    fn test_scan_result_fail() {
256        let result = ScanResult::fail("bad text".to_string(), 0.85);
257        assert!(!result.is_valid);
258        assert_eq!(result.risk_score, 0.85);
259        assert_eq!(result.severity(), Severity::High);
260    }
261
262    #[test]
263    fn test_scan_result_builder() {
264        let entity = Entity::new("email", "test@example.com", 0, 16, 0.95);
265        let factor = RiskFactor::new(
266            "banned_content",
267            "Email address detected",
268            Severity::Low,
269            0.2,
270        );
271
272        let result = ScanResult::pass("text".to_string())
273            .with_entity(entity)
274            .with_risk_factor(factor)
275            .with_metadata("scanner", "test");
276
277        assert_eq!(result.entities.len(), 1);
278        assert_eq!(result.risk_factors.len(), 1);
279        assert!(result.metadata.contains_key("scanner"));
280    }
281
282    #[test]
283    fn test_combine_results() {
284        let r1 = ScanResult::fail("text1".to_string(), 0.3);
285        let r2 = ScanResult::fail("text2".to_string(), 0.7);
286        let r3 = ScanResult::pass("text3".to_string());
287
288        let combined = ScanResult::combine(vec![r1, r2, r3]);
289        assert_eq!(combined.risk_score, 0.7);
290        assert!(!combined.is_valid);
291    }
292
293    #[test]
294    fn test_severity_ordering() {
295        assert!(Severity::Critical > Severity::High);
296        assert!(Severity::High > Severity::Medium);
297        assert!(Severity::Medium > Severity::Low);
298        assert!(Severity::Low > Severity::None);
299    }
300}