llm_shield_core/
scanner.rs1use crate::{Error, Result, ScanResult, Vault};
12use async_trait::async_trait;
13use std::sync::Arc;
14
15#[async_trait]
51pub trait Scanner: Send + Sync {
52 fn name(&self) -> &str;
54
55 async fn scan(&self, input: &str, vault: &Vault) -> Result<ScanResult>;
67
68 fn scanner_type(&self) -> ScannerType {
70 ScannerType::Input
71 }
72
73 fn version(&self) -> &str {
75 "1.0.0"
76 }
77
78 fn description(&self) -> &str {
80 "No description provided"
81 }
82
83 fn requires_async(&self) -> bool {
88 false
89 }
90
91 fn validate_config(&self) -> Result<()> {
93 Ok(())
94 }
95}
96
97#[async_trait]
101pub trait InputScanner: Scanner {
102 async fn scan_prompt(&self, prompt: &str, vault: &Vault) -> Result<ScanResult> {
104 self.scan(prompt, vault).await
105 }
106}
107
108#[async_trait]
112pub trait OutputScanner: Scanner {
113 async fn scan_output(
115 &self,
116 prompt: &str,
117 output: &str,
118 vault: &Vault,
119 ) -> Result<ScanResult>;
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum ScannerType {
125 Input,
127 Output,
129 Bidirectional,
131}
132
133pub struct ScannerPipeline {
143 scanners: Vec<Arc<dyn Scanner>>,
144 short_circuit: bool,
145 short_circuit_threshold: f32,
146}
147
148impl ScannerPipeline {
149 pub fn new() -> Self {
151 Self {
152 scanners: Vec::new(),
153 short_circuit: false,
154 short_circuit_threshold: 0.9,
155 }
156 }
157
158 pub fn add(mut self, scanner: Arc<dyn Scanner>) -> Self {
160 self.scanners.push(scanner);
161 self
162 }
163
164 pub fn with_short_circuit(mut self, threshold: f32) -> Self {
168 self.short_circuit = true;
169 self.short_circuit_threshold = threshold;
170 self
171 }
172
173 pub async fn execute(&self, input: &str, vault: &Vault) -> Result<Vec<ScanResult>> {
175 let mut results = Vec::new();
176
177 for scanner in &self.scanners {
178 let result = scanner.scan(input, vault).await?;
179
180 if self.short_circuit && result.risk_score >= self.short_circuit_threshold {
181 results.push(result);
182 break;
183 }
184
185 results.push(result);
186 }
187
188 Ok(results)
189 }
190
191 pub async fn execute_parallel(&self, input: &str, vault: &Vault) -> Result<Vec<ScanResult>> {
195 use futures::future::join_all;
196
197 let futures: Vec<_> = self
198 .scanners
199 .iter()
200 .map(|scanner| {
201 let input = input.to_string();
202 let vault = vault.clone();
203 let scanner = Arc::clone(scanner);
204 async move { scanner.scan(&input, &vault).await }
205 })
206 .collect();
207
208 let results: Vec<Result<ScanResult>> = join_all(futures).await;
209
210 results.into_iter().collect()
212 }
213
214 pub async fn execute_aggregated(&self, input: &str, vault: &Vault) -> Result<ScanResult> {
216 let results = self.execute(input, vault).await?;
217 Ok(ScanResult::combine(results))
218 }
219}
220
221impl Default for ScannerPipeline {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 struct MockScanner {
233 name: String,
234 risk_score: f32,
235 }
236
237 #[async_trait]
238 impl Scanner for MockScanner {
239 fn name(&self) -> &str {
240 &self.name
241 }
242
243 async fn scan(&self, input: &str, _vault: &Vault) -> Result<ScanResult> {
244 Ok(ScanResult::new(
245 input.to_string(),
246 self.risk_score < 0.5,
247 self.risk_score,
248 ))
249 }
250
251 fn scanner_type(&self) -> ScannerType {
252 ScannerType::Input
253 }
254 }
255
256 #[tokio::test]
257 async fn test_scanner_pipeline_sequential() {
258 let vault = Vault::new();
259
260 let scanner1 = Arc::new(MockScanner {
261 name: "test1".to_string(),
262 risk_score: 0.3,
263 });
264
265 let scanner2 = Arc::new(MockScanner {
266 name: "test2".to_string(),
267 risk_score: 0.5,
268 });
269
270 let pipeline = ScannerPipeline::new().add(scanner1).add(scanner2);
271
272 let results = pipeline.execute("test input", &vault).await.unwrap();
273
274 assert_eq!(results.len(), 2);
275 assert_eq!(results[0].risk_score, 0.3);
276 assert_eq!(results[1].risk_score, 0.5);
277 }
278
279 #[tokio::test]
280 async fn test_scanner_pipeline_short_circuit() {
281 let vault = Vault::new();
282
283 let scanner1 = Arc::new(MockScanner {
284 name: "test1".to_string(),
285 risk_score: 0.95,
286 });
287
288 let scanner2 = Arc::new(MockScanner {
289 name: "test2".to_string(),
290 risk_score: 0.2,
291 });
292
293 let pipeline = ScannerPipeline::new()
294 .add(scanner1)
295 .add(scanner2)
296 .with_short_circuit(0.9);
297
298 let results = pipeline.execute("test input", &vault).await.unwrap();
299
300 assert_eq!(results.len(), 1);
302 assert_eq!(results[0].risk_score, 0.95);
303 }
304
305 #[tokio::test]
306 async fn test_scanner_pipeline_aggregated() {
307 let vault = Vault::new();
308
309 let scanner1 = Arc::new(MockScanner {
310 name: "test1".to_string(),
311 risk_score: 0.3,
312 });
313
314 let scanner2 = Arc::new(MockScanner {
315 name: "test2".to_string(),
316 risk_score: 0.7,
317 });
318
319 let pipeline = ScannerPipeline::new().add(scanner1).add(scanner2);
320
321 let result = pipeline
322 .execute_aggregated("test input", &vault)
323 .await
324 .unwrap();
325
326 assert_eq!(result.risk_score, 0.7);
328 }
329}