1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
23pub struct ScanResult {
24 pub sanitized_text: String,
26
27 pub is_valid: bool,
29
30 pub risk_score: f32,
32
33 pub entities: Vec<Entity>,
35
36 pub risk_factors: Vec<RiskFactor>,
38
39 pub metadata: HashMap<String, serde_json::Value>,
41}
42
43impl ScanResult {
44 pub fn new(sanitized_text: String, is_valid: bool, risk_score: f32) -> Self {
46 Self {
47 sanitized_text,
48 is_valid,
49 risk_score,
50 entities: Vec::new(),
51 risk_factors: Vec::new(),
52 metadata: HashMap::new(),
53 }
54 }
55
56 pub fn pass(text: String) -> Self {
58 Self::new(text, true, 0.0)
59 }
60
61 pub fn fail(text: String, risk_score: f32) -> Self {
63 Self::new(text, false, risk_score)
64 }
65
66 pub fn with_entity(mut self, entity: Entity) -> Self {
68 self.entities.push(entity);
69 self
70 }
71
72 pub fn with_risk_factor(mut self, factor: RiskFactor) -> Self {
74 self.risk_factors.push(factor);
75 self
76 }
77
78 pub fn with_metadata<K: Into<String>, V: Serialize>(
80 mut self,
81 key: K,
82 value: V,
83 ) -> Self {
84 if let Ok(json_value) = serde_json::to_value(value) {
85 self.metadata.insert(key.into(), json_value);
86 }
87 self
88 }
89
90 pub fn severity(&self) -> Severity {
92 if self.risk_score >= 0.9 {
93 Severity::Critical
94 } else if self.risk_score >= 0.7 {
95 Severity::High
96 } else if self.risk_score >= 0.4 {
97 Severity::Medium
98 } else if self.risk_score > 0.0 {
99 Severity::Low
100 } else {
101 Severity::None
102 }
103 }
104
105 pub fn combine(results: Vec<ScanResult>) -> Self {
109 if results.is_empty() {
110 return Self::pass(String::new());
111 }
112
113 let max_risk = results
114 .iter()
115 .map(|r| r.risk_score)
116 .fold(0.0f32, f32::max);
117
118 let is_valid = results.iter().all(|r| r.is_valid);
119
120 let mut combined = Self::new(
121 results[0].sanitized_text.clone(),
122 is_valid,
123 max_risk,
124 );
125
126 for result in results {
127 combined.entities.extend(result.entities);
128 combined.risk_factors.extend(result.risk_factors);
129 for (k, v) in result.metadata {
130 combined.metadata.insert(k, v);
131 }
132 }
133
134 combined
135 }
136}
137
138#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
140pub struct Entity {
141 pub entity_type: String,
143
144 pub text: String,
146
147 pub start: usize,
149
150 pub end: usize,
152
153 pub confidence: f32,
155
156 pub metadata: HashMap<String, String>,
158}
159
160impl Entity {
161 pub fn new<S: Into<String>>(
163 entity_type: S,
164 text: S,
165 start: usize,
166 end: usize,
167 confidence: f32,
168 ) -> Self {
169 Self {
170 entity_type: entity_type.into(),
171 text: text.into(),
172 start,
173 end,
174 confidence,
175 metadata: HashMap::new(),
176 }
177 }
178}
179
180#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
182pub struct RiskFactor {
183 pub factor_type: String,
185
186 pub description: String,
188
189 pub severity: Severity,
191
192 pub score_contribution: f32,
194}
195
196impl RiskFactor {
197 pub fn new<S: Into<String>>(
199 factor_type: S,
200 description: S,
201 severity: Severity,
202 score_contribution: f32,
203 ) -> Self {
204 Self {
205 factor_type: factor_type.into(),
206 description: description.into(),
207 severity,
208 score_contribution,
209 }
210 }
211}
212
213#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
215#[serde(rename_all = "lowercase")]
216pub enum Severity {
217 None,
219 Low,
221 Medium,
223 High,
225 Critical,
227}
228
229impl Severity {
230 pub fn threshold(&self) -> f32 {
232 match self {
233 Severity::None => 0.0,
234 Severity::Low => 0.01,
235 Severity::Medium => 0.4,
236 Severity::High => 0.7,
237 Severity::Critical => 0.9,
238 }
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_scan_result_creation() {
248 let result = ScanResult::pass("test text".to_string());
249 assert!(result.is_valid);
250 assert_eq!(result.risk_score, 0.0);
251 assert_eq!(result.severity(), Severity::None);
252 }
253
254 #[test]
255 fn test_scan_result_fail() {
256 let result = ScanResult::fail("bad text".to_string(), 0.85);
257 assert!(!result.is_valid);
258 assert_eq!(result.risk_score, 0.85);
259 assert_eq!(result.severity(), Severity::High);
260 }
261
262 #[test]
263 fn test_scan_result_builder() {
264 let entity = Entity::new("email", "test@example.com", 0, 16, 0.95);
265 let factor = RiskFactor::new(
266 "banned_content",
267 "Email address detected",
268 Severity::Low,
269 0.2,
270 );
271
272 let result = ScanResult::pass("text".to_string())
273 .with_entity(entity)
274 .with_risk_factor(factor)
275 .with_metadata("scanner", "test");
276
277 assert_eq!(result.entities.len(), 1);
278 assert_eq!(result.risk_factors.len(), 1);
279 assert!(result.metadata.contains_key("scanner"));
280 }
281
282 #[test]
283 fn test_combine_results() {
284 let r1 = ScanResult::fail("text1".to_string(), 0.3);
285 let r2 = ScanResult::fail("text2".to_string(), 0.7);
286 let r3 = ScanResult::pass("text3".to_string());
287
288 let combined = ScanResult::combine(vec![r1, r2, r3]);
289 assert_eq!(combined.risk_score, 0.7);
290 assert!(!combined.is_valid);
291 }
292
293 #[test]
294 fn test_severity_ordering() {
295 assert!(Severity::Critical > Severity::High);
296 assert!(Severity::High > Severity::Medium);
297 assert!(Severity::Medium > Severity::Low);
298 assert!(Severity::Low > Severity::None);
299 }
300}