llm_shield_core/
scanner.rs

1//! Scanner trait and types
2//!
3//! ## SPARC Specification - Scanner Abstraction
4//!
5//! Core trait that all security scanners must implement:
6//! - Async-first design for scalability
7//! - Composable architecture
8//! - Type-safe configuration
9//! - Observability built-in
10
11use crate::{Error, Result, ScanResult, Vault};
12use async_trait::async_trait;
13use std::sync::Arc;
14
15/// Core scanner trait
16///
17/// All security scanners implement this trait to provide consistent interface.
18///
19/// ## London School TDD Design
20///
21/// This trait is designed to be mockable for testing:
22/// - Pure async interface
23/// - No internal state mutations
24/// - Testable with mock vault
25///
26/// ## Example
27///
28/// ```rust,ignore
29/// use llm_shield_core::{Scanner, ScanResult, Vault};
30/// use async_trait::async_trait;
31///
32/// struct MyScanner;
33///
34/// #[async_trait]
35/// impl Scanner for MyScanner {
36///     fn name(&self) -> &str {
37///         "my_scanner"
38///     }
39///
40///     async fn scan(&self, input: &str, vault: &Vault) -> Result<ScanResult> {
41///         // Implement scanning logic
42///         Ok(ScanResult::pass(input.to_string()))
43///     }
44///
45///     fn scanner_type(&self) -> ScannerType {
46///         ScannerType::Input
47///     }
48/// }
49/// ```
50#[async_trait]
51pub trait Scanner: Send + Sync {
52    /// Scanner name for identification
53    fn name(&self) -> &str;
54
55    /// Scan input text and return result
56    ///
57    /// ## Parameters
58    ///
59    /// - `input`: Text to scan
60    /// - `vault`: State storage for cross-scanner communication
61    ///
62    /// ## Returns
63    ///
64    /// - `Ok(ScanResult)`: Scan completed successfully
65    /// - `Err(Error)`: Scan failed
66    async fn scan(&self, input: &str, vault: &Vault) -> Result<ScanResult>;
67
68    /// Type of scanner (input/output/bidirectional)
69    fn scanner_type(&self) -> ScannerType {
70        ScannerType::Input
71    }
72
73    /// Scanner version
74    fn version(&self) -> &str {
75        "1.0.0"
76    }
77
78    /// Scanner description
79    fn description(&self) -> &str {
80        "No description provided"
81    }
82
83    /// Whether this scanner requires async execution
84    ///
85    /// Some scanners (e.g., URL checking) must be async.
86    /// Simple scanners can be sync for better performance.
87    fn requires_async(&self) -> bool {
88        false
89    }
90
91    /// Validate scanner configuration
92    fn validate_config(&self) -> Result<()> {
93        Ok(())
94    }
95}
96
97/// Input scanner specialization
98///
99/// Scans LLM prompts/inputs before they're sent to the model
100#[async_trait]
101pub trait InputScanner: Scanner {
102    /// Scan a prompt before sending to LLM
103    async fn scan_prompt(&self, prompt: &str, vault: &Vault) -> Result<ScanResult> {
104        self.scan(prompt, vault).await
105    }
106}
107
108/// Output scanner specialization
109///
110/// Scans LLM responses/outputs before returning to user
111#[async_trait]
112pub trait OutputScanner: Scanner {
113    /// Scan LLM output with context of original prompt
114    async fn scan_output(
115        &self,
116        prompt: &str,
117        output: &str,
118        vault: &Vault,
119    ) -> Result<ScanResult>;
120}
121
122/// Scanner type classification
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum ScannerType {
125    /// Scans inputs (prompts)
126    Input,
127    /// Scans outputs (responses)
128    Output,
129    /// Scans both inputs and outputs
130    Bidirectional,
131}
132
133/// Scanner pipeline for composing multiple scanners
134///
135/// ## Enterprise Pattern
136///
137/// Provides:
138/// - Sequential execution
139/// - Parallel execution option
140/// - Short-circuit on high risk
141/// - Result aggregation
142pub struct ScannerPipeline {
143    scanners: Vec<Arc<dyn Scanner>>,
144    short_circuit: bool,
145    short_circuit_threshold: f32,
146}
147
148impl ScannerPipeline {
149    /// Create a new scanner pipeline
150    pub fn new() -> Self {
151        Self {
152            scanners: Vec::new(),
153            short_circuit: false,
154            short_circuit_threshold: 0.9,
155        }
156    }
157
158    /// Add a scanner to the pipeline
159    pub fn add(mut self, scanner: Arc<dyn Scanner>) -> Self {
160        self.scanners.push(scanner);
161        self
162    }
163
164    /// Enable short-circuit evaluation
165    ///
166    /// If any scanner returns risk >= threshold, stop execution
167    pub fn with_short_circuit(mut self, threshold: f32) -> Self {
168        self.short_circuit = true;
169        self.short_circuit_threshold = threshold;
170        self
171    }
172
173    /// Execute pipeline sequentially
174    pub async fn execute(&self, input: &str, vault: &Vault) -> Result<Vec<ScanResult>> {
175        let mut results = Vec::new();
176
177        for scanner in &self.scanners {
178            let result = scanner.scan(input, vault).await?;
179
180            if self.short_circuit && result.risk_score >= self.short_circuit_threshold {
181                results.push(result);
182                break;
183            }
184
185            results.push(result);
186        }
187
188        Ok(results)
189    }
190
191    /// Execute pipeline in parallel
192    ///
193    /// All scanners run concurrently. Useful for I/O-bound scanners.
194    pub async fn execute_parallel(&self, input: &str, vault: &Vault) -> Result<Vec<ScanResult>> {
195        use futures::future::join_all;
196
197        let futures: Vec<_> = self
198            .scanners
199            .iter()
200            .map(|scanner| {
201                let input = input.to_string();
202                let vault = vault.clone();
203                let scanner = Arc::clone(scanner);
204                async move { scanner.scan(&input, &vault).await }
205            })
206            .collect();
207
208        let results: Vec<Result<ScanResult>> = join_all(futures).await;
209
210        // Collect results, propagating first error
211        results.into_iter().collect()
212    }
213
214    /// Get aggregated result from pipeline
215    pub async fn execute_aggregated(&self, input: &str, vault: &Vault) -> Result<ScanResult> {
216        let results = self.execute(input, vault).await?;
217        Ok(ScanResult::combine(results))
218    }
219}
220
221impl Default for ScannerPipeline {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    // Mock scanner for testing
232    struct MockScanner {
233        name: String,
234        risk_score: f32,
235    }
236
237    #[async_trait]
238    impl Scanner for MockScanner {
239        fn name(&self) -> &str {
240            &self.name
241        }
242
243        async fn scan(&self, input: &str, _vault: &Vault) -> Result<ScanResult> {
244            Ok(ScanResult::new(
245                input.to_string(),
246                self.risk_score < 0.5,
247                self.risk_score,
248            ))
249        }
250
251        fn scanner_type(&self) -> ScannerType {
252            ScannerType::Input
253        }
254    }
255
256    #[tokio::test]
257    async fn test_scanner_pipeline_sequential() {
258        let vault = Vault::new();
259
260        let scanner1 = Arc::new(MockScanner {
261            name: "test1".to_string(),
262            risk_score: 0.3,
263        });
264
265        let scanner2 = Arc::new(MockScanner {
266            name: "test2".to_string(),
267            risk_score: 0.5,
268        });
269
270        let pipeline = ScannerPipeline::new().add(scanner1).add(scanner2);
271
272        let results = pipeline.execute("test input", &vault).await.unwrap();
273
274        assert_eq!(results.len(), 2);
275        assert_eq!(results[0].risk_score, 0.3);
276        assert_eq!(results[1].risk_score, 0.5);
277    }
278
279    #[tokio::test]
280    async fn test_scanner_pipeline_short_circuit() {
281        let vault = Vault::new();
282
283        let scanner1 = Arc::new(MockScanner {
284            name: "test1".to_string(),
285            risk_score: 0.95,
286        });
287
288        let scanner2 = Arc::new(MockScanner {
289            name: "test2".to_string(),
290            risk_score: 0.2,
291        });
292
293        let pipeline = ScannerPipeline::new()
294            .add(scanner1)
295            .add(scanner2)
296            .with_short_circuit(0.9);
297
298        let results = pipeline.execute("test input", &vault).await.unwrap();
299
300        // Should stop after first scanner due to high risk
301        assert_eq!(results.len(), 1);
302        assert_eq!(results[0].risk_score, 0.95);
303    }
304
305    #[tokio::test]
306    async fn test_scanner_pipeline_aggregated() {
307        let vault = Vault::new();
308
309        let scanner1 = Arc::new(MockScanner {
310            name: "test1".to_string(),
311            risk_score: 0.3,
312        });
313
314        let scanner2 = Arc::new(MockScanner {
315            name: "test2".to_string(),
316            risk_score: 0.7,
317        });
318
319        let pipeline = ScannerPipeline::new().add(scanner1).add(scanner2);
320
321        let result = pipeline
322            .execute_aggregated("test input", &vault)
323            .await
324            .unwrap();
325
326        // Should take maximum risk score
327        assert_eq!(result.risk_score, 0.7);
328    }
329}