aster/security/
scanner.rs1use crate::config::Config;
2use crate::conversation::message::Message;
3use crate::security::classification_client::ClassificationClient;
4use crate::security::patterns::{PatternMatch, PatternMatcher};
5use anyhow::Result;
6use futures::stream::{self, StreamExt};
7use rmcp::model::CallToolRequestParam;
8
9const USER_SCAN_LIMIT: usize = 10;
10const ML_SCAN_CONCURRENCY: usize = 3;
11
12#[derive(Debug, Clone)]
13pub struct ScanResult {
14 pub is_malicious: bool,
15 pub confidence: f32,
16 pub explanation: String,
17}
18
19struct DetailedScanResult {
20 confidence: f32,
21 pattern_matches: Vec<PatternMatch>,
22 ml_confidence: Option<f32>,
23}
24
25pub struct PromptInjectionScanner {
26 pattern_matcher: PatternMatcher,
27 classifier_client: Option<ClassificationClient>,
28}
29
30impl PromptInjectionScanner {
31 pub fn new() -> Self {
32 Self {
33 pattern_matcher: PatternMatcher::new(),
34 classifier_client: None,
35 }
36 }
37
38 pub fn with_ml_detection() -> Result<Self> {
39 let classifier_client = Self::create_classifier_from_config()?;
40 Ok(Self {
41 pattern_matcher: PatternMatcher::new(),
42 classifier_client: Some(classifier_client),
43 })
44 }
45
46 fn create_classifier_from_config() -> Result<ClassificationClient> {
47 let config = Config::global();
48
49 let model_name = config
50 .get_param::<String>("SECURITY_PROMPT_CLASSIFIER_MODEL")
51 .ok()
52 .filter(|s| !s.trim().is_empty());
53 let endpoint = config
54 .get_param::<String>("SECURITY_PROMPT_CLASSIFIER_ENDPOINT")
55 .ok()
56 .filter(|s| !s.trim().is_empty());
57 let token = config
58 .get_secret::<String>("SECURITY_PROMPT_CLASSIFIER_TOKEN")
59 .ok()
60 .filter(|s| !s.trim().is_empty());
61
62 tracing::debug!(
63 model_name = ?model_name,
64 has_endpoint = endpoint.is_some(),
65 has_token = token.is_some(),
66 "Initializing classifier from config"
67 );
68
69 if let Some(model) = model_name {
70 tracing::info!(model_name = %model, "Using model-based configuration (internal)");
71 return ClassificationClient::from_model_name(&model, None);
72 }
73
74 if let Some(endpoint_url) = endpoint {
75 tracing::info!(endpoint = %endpoint_url, "Using endpoint-based configuration (external)");
76 return ClassificationClient::from_endpoint(endpoint_url, None, token);
77 }
78
79 anyhow::bail!(
80 "ML detection requires either SECURITY_PROMPT_CLASSIFIER_MODEL (for model mapping) \
81 or SECURITY_PROMPT_CLASSIFIER_ENDPOINT (for direct endpoint configuration)"
82 )
83 }
84
85 pub fn get_threshold_from_config(&self) -> f32 {
86 Config::global()
87 .get_param::<f64>("SECURITY_PROMPT_THRESHOLD")
88 .unwrap_or(0.8) as f32
89 }
90
91 pub async fn analyze_tool_call_with_context(
92 &self,
93 tool_call: &CallToolRequestParam,
94 messages: &[Message],
95 ) -> Result<ScanResult> {
96 let tool_content = self.extract_tool_content(tool_call);
97
98 tracing::info!(
99 "🔍 Scanning tool call: {} ({} chars)",
100 tool_call.name,
101 tool_content.len()
102 );
103
104 let (tool_result, context_result) = tokio::join!(
105 self.analyze_text(&tool_content),
106 self.scan_conversation(messages)
107 );
108
109 let highest_confidence_result =
110 self.select_highest_confidence_result(tool_result?, context_result?);
111 let threshold = self.get_threshold_from_config();
112
113 tracing::info!(
114 "✅ Security analysis complete: confidence={:.3}, malicious={}",
115 highest_confidence_result.confidence,
116 highest_confidence_result.confidence >= threshold
117 );
118
119 Ok(ScanResult {
120 is_malicious: highest_confidence_result.confidence >= threshold,
121 confidence: highest_confidence_result.confidence,
122 explanation: self.build_explanation(&highest_confidence_result, threshold),
123 })
124 }
125
126 async fn analyze_text(&self, text: &str) -> Result<DetailedScanResult> {
127 let (pattern_confidence, pattern_matches) = self.pattern_based_scanning(text);
128 let ml_confidence = self.scan_with_classifier(text).await;
129 let confidence = ml_confidence.unwrap_or(0.0).max(pattern_confidence);
130
131 Ok(DetailedScanResult {
132 confidence,
133 pattern_matches,
134 ml_confidence,
135 })
136 }
137
138 async fn scan_conversation(&self, messages: &[Message]) -> Result<DetailedScanResult> {
139 let user_messages = self.extract_user_messages(messages, USER_SCAN_LIMIT);
140
141 if user_messages.is_empty() || self.classifier_client.is_none() {
142 tracing::debug!("Skipping conversation scan - no classifier or messages");
143 return Ok(DetailedScanResult {
144 confidence: 0.0,
145 pattern_matches: Vec::new(),
146 ml_confidence: None,
147 });
148 }
149
150 tracing::debug!(
151 "Scanning {} user messages ({} chars) with concurrency limit of {}",
152 user_messages.len(),
153 user_messages.iter().map(|m| m.len()).sum::<usize>(),
154 ML_SCAN_CONCURRENCY
155 );
156
157 let max_confidence = stream::iter(user_messages)
158 .map(|msg| async move { self.scan_with_classifier(&msg).await })
159 .buffer_unordered(ML_SCAN_CONCURRENCY)
160 .fold(0.0_f32, |acc, result| async move {
161 result.unwrap_or(0.0).max(acc)
162 })
163 .await;
164
165 Ok(DetailedScanResult {
166 confidence: max_confidence,
167 pattern_matches: Vec::new(),
168 ml_confidence: Some(max_confidence),
169 })
170 }
171
172 fn select_highest_confidence_result(
173 &self,
174 tool_result: DetailedScanResult,
175 context_result: DetailedScanResult,
176 ) -> DetailedScanResult {
177 if tool_result.confidence >= context_result.confidence {
178 tool_result
179 } else {
180 context_result
181 }
182 }
183
184 async fn scan_with_classifier(&self, text: &str) -> Option<f32> {
185 let classifier = self.classifier_client.as_ref()?;
186
187 tracing::debug!("🤖 Running classifier scan ({} chars)", text.len());
188 let start = std::time::Instant::now();
189
190 match classifier.classify(text).await {
191 Ok(conf) => {
192 tracing::debug!(
193 "✅ Classifier scan: confidence={:.3}, duration={:.0}ms",
194 conf,
195 start.elapsed().as_secs_f64() * 1000.0
196 );
197 Some(conf)
198 }
199 Err(e) => {
200 tracing::warn!("Classifier scan failed: {:#}", e);
201 None
202 }
203 }
204 }
205
206 fn pattern_based_scanning(&self, text: &str) -> (f32, Vec<PatternMatch>) {
207 let matches = self.pattern_matcher.scan_for_patterns(text);
208 let confidence = self
209 .pattern_matcher
210 .get_max_risk_level(&matches)
211 .map_or(0.0, |r| r.confidence_score());
212
213 (confidence, matches)
214 }
215
216 fn build_explanation(&self, result: &DetailedScanResult, threshold: f32) -> String {
217 if result.confidence < threshold {
218 return "No security threats detected".to_string();
219 }
220
221 if let Some(top_match) = result.pattern_matches.first() {
222 let preview = top_match.matched_text.chars().take(50).collect::<String>();
223 return format!(
224 "Security threat detected: {} (Risk: {:?}) - Found: '{}'",
225 top_match.threat.description, top_match.threat.risk_level, preview
226 );
227 }
228
229 if let Some(ml_conf) = result.ml_confidence {
230 format!("Security threat detected (ML confidence: {:.2})", ml_conf)
231 } else {
232 "Security threat detected".to_string()
233 }
234 }
235
236 fn extract_user_messages(&self, messages: &[Message], limit: usize) -> Vec<String> {
237 messages
238 .iter()
239 .rev()
240 .filter(|m| crate::conversation::effective_role(m) == "user")
241 .take(limit)
242 .map(|m| {
243 m.content
244 .iter()
245 .filter_map(|c| match c {
246 crate::conversation::message::MessageContent::Text(t) => {
247 Some(t.text.clone())
248 }
249 _ => None,
250 })
251 .collect::<Vec<_>>()
252 .join("\n")
253 })
254 .filter(|s| !s.is_empty())
255 .collect()
256 }
257
258 fn extract_tool_content(&self, tool_call: &CallToolRequestParam) -> String {
259 let mut s = format!("Tool: {}", tool_call.name);
260 if let Some(args) = &tool_call.arguments {
261 if let Ok(json) = serde_json::to_string_pretty(args) {
262 s.push('\n');
263 s.push_str(&json);
264 }
265 }
266 s
267 }
268}
269
270impl Default for PromptInjectionScanner {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use rmcp::object;
280
281 #[tokio::test]
282 async fn test_text_pattern_detection() {
283 let scanner = PromptInjectionScanner::new();
284 let result = scanner.analyze_text("rm -rf /").await.unwrap();
285
286 assert!(result.confidence >= 0.75); assert!(!result.pattern_matches.is_empty());
288 }
289
290 #[tokio::test]
291 async fn test_conversation_scan_without_ml() {
292 let scanner = PromptInjectionScanner::new();
293 let result = scanner.scan_conversation(&[]).await.unwrap();
294
295 assert_eq!(result.confidence, 0.0);
296 }
297
298 #[tokio::test]
299 async fn test_tool_call_analysis() {
300 let scanner = PromptInjectionScanner::new();
301
302 let tool_call = CallToolRequestParam {
303 name: "shell".into(),
304 arguments: Some(object!({
305 "command": "rm -rf /tmp/malicious"
306 })),
307 };
308
309 let result = scanner
310 .analyze_tool_call_with_context(&tool_call, &[])
311 .await
312 .unwrap();
313
314 assert!(result.is_malicious);
315 assert!(result.explanation.contains("Security threat"));
316 }
317}