Skip to main content

aster/security/
scanner.rs

1use crate::config::Config;
2use crate::conversation::message::Message;
3use crate::security::classification_client::ClassificationClient;
4use crate::security::patterns::{PatternMatch, PatternMatcher};
5use anyhow::Result;
6use futures::stream::{self, StreamExt};
7use rmcp::model::CallToolRequestParam;
8
9const USER_SCAN_LIMIT: usize = 10;
10const ML_SCAN_CONCURRENCY: usize = 3;
11
12#[derive(Debug, Clone)]
13pub struct ScanResult {
14    pub is_malicious: bool,
15    pub confidence: f32,
16    pub explanation: String,
17}
18
19struct DetailedScanResult {
20    confidence: f32,
21    pattern_matches: Vec<PatternMatch>,
22    ml_confidence: Option<f32>,
23}
24
25pub struct PromptInjectionScanner {
26    pattern_matcher: PatternMatcher,
27    classifier_client: Option<ClassificationClient>,
28}
29
30impl PromptInjectionScanner {
31    pub fn new() -> Self {
32        Self {
33            pattern_matcher: PatternMatcher::new(),
34            classifier_client: None,
35        }
36    }
37
38    pub fn with_ml_detection() -> Result<Self> {
39        let classifier_client = Self::create_classifier_from_config()?;
40        Ok(Self {
41            pattern_matcher: PatternMatcher::new(),
42            classifier_client: Some(classifier_client),
43        })
44    }
45
46    fn create_classifier_from_config() -> Result<ClassificationClient> {
47        let config = Config::global();
48
49        let model_name = config
50            .get_param::<String>("SECURITY_PROMPT_CLASSIFIER_MODEL")
51            .ok()
52            .filter(|s| !s.trim().is_empty());
53        let endpoint = config
54            .get_param::<String>("SECURITY_PROMPT_CLASSIFIER_ENDPOINT")
55            .ok()
56            .filter(|s| !s.trim().is_empty());
57        let token = config
58            .get_secret::<String>("SECURITY_PROMPT_CLASSIFIER_TOKEN")
59            .ok()
60            .filter(|s| !s.trim().is_empty());
61
62        tracing::debug!(
63            model_name = ?model_name,
64            has_endpoint = endpoint.is_some(),
65            has_token = token.is_some(),
66            "Initializing classifier from config"
67        );
68
69        if let Some(model) = model_name {
70            tracing::info!(model_name = %model, "Using model-based configuration (internal)");
71            return ClassificationClient::from_model_name(&model, None);
72        }
73
74        if let Some(endpoint_url) = endpoint {
75            tracing::info!(endpoint = %endpoint_url, "Using endpoint-based configuration (external)");
76            return ClassificationClient::from_endpoint(endpoint_url, None, token);
77        }
78
79        anyhow::bail!(
80            "ML detection requires either SECURITY_PROMPT_CLASSIFIER_MODEL (for model mapping) \
81             or SECURITY_PROMPT_CLASSIFIER_ENDPOINT (for direct endpoint configuration)"
82        )
83    }
84
85    pub fn get_threshold_from_config(&self) -> f32 {
86        Config::global()
87            .get_param::<f64>("SECURITY_PROMPT_THRESHOLD")
88            .unwrap_or(0.8) as f32
89    }
90
91    pub async fn analyze_tool_call_with_context(
92        &self,
93        tool_call: &CallToolRequestParam,
94        messages: &[Message],
95    ) -> Result<ScanResult> {
96        let tool_content = self.extract_tool_content(tool_call);
97
98        tracing::info!(
99            "🔍 Scanning tool call: {} ({} chars)",
100            tool_call.name,
101            tool_content.len()
102        );
103
104        let (tool_result, context_result) = tokio::join!(
105            self.analyze_text(&tool_content),
106            self.scan_conversation(messages)
107        );
108
109        let highest_confidence_result =
110            self.select_highest_confidence_result(tool_result?, context_result?);
111        let threshold = self.get_threshold_from_config();
112
113        tracing::info!(
114            "✅ Security analysis complete: confidence={:.3}, malicious={}",
115            highest_confidence_result.confidence,
116            highest_confidence_result.confidence >= threshold
117        );
118
119        Ok(ScanResult {
120            is_malicious: highest_confidence_result.confidence >= threshold,
121            confidence: highest_confidence_result.confidence,
122            explanation: self.build_explanation(&highest_confidence_result, threshold),
123        })
124    }
125
126    async fn analyze_text(&self, text: &str) -> Result<DetailedScanResult> {
127        let (pattern_confidence, pattern_matches) = self.pattern_based_scanning(text);
128        let ml_confidence = self.scan_with_classifier(text).await;
129        let confidence = ml_confidence.unwrap_or(0.0).max(pattern_confidence);
130
131        Ok(DetailedScanResult {
132            confidence,
133            pattern_matches,
134            ml_confidence,
135        })
136    }
137
138    async fn scan_conversation(&self, messages: &[Message]) -> Result<DetailedScanResult> {
139        let user_messages = self.extract_user_messages(messages, USER_SCAN_LIMIT);
140
141        if user_messages.is_empty() || self.classifier_client.is_none() {
142            tracing::debug!("Skipping conversation scan - no classifier or messages");
143            return Ok(DetailedScanResult {
144                confidence: 0.0,
145                pattern_matches: Vec::new(),
146                ml_confidence: None,
147            });
148        }
149
150        tracing::debug!(
151            "Scanning {} user messages ({} chars) with concurrency limit of {}",
152            user_messages.len(),
153            user_messages.iter().map(|m| m.len()).sum::<usize>(),
154            ML_SCAN_CONCURRENCY
155        );
156
157        let max_confidence = stream::iter(user_messages)
158            .map(|msg| async move { self.scan_with_classifier(&msg).await })
159            .buffer_unordered(ML_SCAN_CONCURRENCY)
160            .fold(0.0_f32, |acc, result| async move {
161                result.unwrap_or(0.0).max(acc)
162            })
163            .await;
164
165        Ok(DetailedScanResult {
166            confidence: max_confidence,
167            pattern_matches: Vec::new(),
168            ml_confidence: Some(max_confidence),
169        })
170    }
171
172    fn select_highest_confidence_result(
173        &self,
174        tool_result: DetailedScanResult,
175        context_result: DetailedScanResult,
176    ) -> DetailedScanResult {
177        if tool_result.confidence >= context_result.confidence {
178            tool_result
179        } else {
180            context_result
181        }
182    }
183
184    async fn scan_with_classifier(&self, text: &str) -> Option<f32> {
185        let classifier = self.classifier_client.as_ref()?;
186
187        tracing::debug!("🤖 Running classifier scan ({} chars)", text.len());
188        let start = std::time::Instant::now();
189
190        match classifier.classify(text).await {
191            Ok(conf) => {
192                tracing::debug!(
193                    "✅ Classifier scan: confidence={:.3}, duration={:.0}ms",
194                    conf,
195                    start.elapsed().as_secs_f64() * 1000.0
196                );
197                Some(conf)
198            }
199            Err(e) => {
200                tracing::warn!("Classifier scan failed: {:#}", e);
201                None
202            }
203        }
204    }
205
206    fn pattern_based_scanning(&self, text: &str) -> (f32, Vec<PatternMatch>) {
207        let matches = self.pattern_matcher.scan_for_patterns(text);
208        let confidence = self
209            .pattern_matcher
210            .get_max_risk_level(&matches)
211            .map_or(0.0, |r| r.confidence_score());
212
213        (confidence, matches)
214    }
215
216    fn build_explanation(&self, result: &DetailedScanResult, threshold: f32) -> String {
217        if result.confidence < threshold {
218            return "No security threats detected".to_string();
219        }
220
221        if let Some(top_match) = result.pattern_matches.first() {
222            let preview = top_match.matched_text.chars().take(50).collect::<String>();
223            return format!(
224                "Security threat detected: {} (Risk: {:?}) - Found: '{}'",
225                top_match.threat.description, top_match.threat.risk_level, preview
226            );
227        }
228
229        if let Some(ml_conf) = result.ml_confidence {
230            format!("Security threat detected (ML confidence: {:.2})", ml_conf)
231        } else {
232            "Security threat detected".to_string()
233        }
234    }
235
236    fn extract_user_messages(&self, messages: &[Message], limit: usize) -> Vec<String> {
237        messages
238            .iter()
239            .rev()
240            .filter(|m| crate::conversation::effective_role(m) == "user")
241            .take(limit)
242            .map(|m| {
243                m.content
244                    .iter()
245                    .filter_map(|c| match c {
246                        crate::conversation::message::MessageContent::Text(t) => {
247                            Some(t.text.clone())
248                        }
249                        _ => None,
250                    })
251                    .collect::<Vec<_>>()
252                    .join("\n")
253            })
254            .filter(|s| !s.is_empty())
255            .collect()
256    }
257
258    fn extract_tool_content(&self, tool_call: &CallToolRequestParam) -> String {
259        let mut s = format!("Tool: {}", tool_call.name);
260        if let Some(args) = &tool_call.arguments {
261            if let Ok(json) = serde_json::to_string_pretty(args) {
262                s.push('\n');
263                s.push_str(&json);
264            }
265        }
266        s
267    }
268}
269
270impl Default for PromptInjectionScanner {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use rmcp::object;
280
281    #[tokio::test]
282    async fn test_text_pattern_detection() {
283        let scanner = PromptInjectionScanner::new();
284        let result = scanner.analyze_text("rm -rf /").await.unwrap();
285
286        assert!(result.confidence >= 0.75); // High risk level = 0.75 confidence
287        assert!(!result.pattern_matches.is_empty());
288    }
289
290    #[tokio::test]
291    async fn test_conversation_scan_without_ml() {
292        let scanner = PromptInjectionScanner::new();
293        let result = scanner.scan_conversation(&[]).await.unwrap();
294
295        assert_eq!(result.confidence, 0.0);
296    }
297
298    #[tokio::test]
299    async fn test_tool_call_analysis() {
300        let scanner = PromptInjectionScanner::new();
301
302        let tool_call = CallToolRequestParam {
303            name: "shell".into(),
304            arguments: Some(object!({
305                "command": "rm -rf /tmp/malicious"
306            })),
307        };
308
309        let result = scanner
310            .analyze_tool_call_with_context(&tool_call, &[])
311            .await
312            .unwrap();
313
314        assert!(result.is_malicious);
315        assert!(result.explanation.contains("Security threat"));
316    }
317}