1pub mod classification_client;
2pub mod patterns;
3pub mod scanner;
4pub mod security_inspector;
5
6use crate::config::Config;
7use crate::conversation::message::{Message, ToolRequest};
8use crate::permission::permission_judge::PermissionCheckResult;
9use anyhow::Result;
10use scanner::PromptInjectionScanner;
11use std::sync::OnceLock;
12use uuid::Uuid;
13
14pub struct SecurityManager {
15 scanner: OnceLock<PromptInjectionScanner>,
16}
17
18#[derive(Debug, Clone)]
19pub struct SecurityResult {
20 pub is_malicious: bool,
21 pub confidence: f32,
22 pub explanation: String,
23 pub should_ask_user: bool,
24 pub finding_id: String,
25 pub tool_request_id: String,
26}
27
28impl SecurityManager {
29 pub fn new() -> Self {
30 Self {
31 scanner: OnceLock::new(),
32 }
33 }
34
35 pub fn is_prompt_injection_detection_enabled(&self) -> bool {
36 let config = Config::global();
37
38 config
39 .get_param::<bool>("SECURITY_PROMPT_ENABLED")
40 .unwrap_or(false)
41 }
42
43 fn is_ml_scanning_enabled(&self) -> bool {
44 let config = Config::global();
45
46 config
47 .get_param::<bool>("SECURITY_PROMPT_CLASSIFIER_ENABLED")
48 .unwrap_or(false)
49 }
50
51 pub async fn analyze_tool_requests(
52 &self,
53 tool_requests: &[ToolRequest],
54 messages: &[Message],
55 ) -> Result<Vec<SecurityResult>> {
56 if !self.is_prompt_injection_detection_enabled() {
57 tracing::debug!(
58 counter.aster.prompt_injection_scanner_disabled = 1,
59 "Security scanning disabled"
60 );
61 return Ok(vec![]);
62 }
63
64 let scanner = self.scanner.get_or_init(|| {
65 let ml_enabled = self.is_ml_scanning_enabled();
66
67 let scanner = if ml_enabled {
68 match PromptInjectionScanner::with_ml_detection() {
69 Ok(s) => {
70 tracing::info!(
71 counter.aster.prompt_injection_scanner_enabled = 1,
72 "🔓 Security scanner initialized with ML-based detection"
73 );
74 s
75 }
76 Err(e) => {
77 let error_chain = format!("{:#}", e);
78 tracing::warn!(
79 "⚠️ ML scanning requested but failed to initialize. Falling back to pattern-only scanning.\n\nError details:\n{}",
80 error_chain
81 );
82 PromptInjectionScanner::new()
83 }
84 }
85 } else {
86 tracing::info!(
87 counter.aster.prompt_injection_scanner_enabled = 1,
88 "🔓 Security scanner initialized with pattern-based detection only"
89 );
90 PromptInjectionScanner::new()
91 };
92
93 scanner
94 });
95
96 let mut results = Vec::new();
97
98 tracing::info!(
99 "🔍 Starting security analysis - {} tool requests, {} messages",
100 tool_requests.len(),
101 messages.len()
102 );
103
104 for tool_request in tool_requests.iter() {
105 if let Ok(tool_call) = &tool_request.tool_call {
106 let analysis_result = scanner
107 .analyze_tool_call_with_context(tool_call, messages)
108 .await?;
109
110 let config_threshold = scanner.get_threshold_from_config();
111 let sanitized_explanation = analysis_result.explanation.replace('\n', " | ");
112
113 if analysis_result.is_malicious {
114 let above_threshold = analysis_result.confidence > config_threshold;
115 let finding_id = format!("SEC-{}", Uuid::new_v4().simple());
116
117 tracing::warn!(
118 counter.aster.prompt_injection_finding = 1,
119 above_threshold = above_threshold,
120 tool_name = %tool_call.name,
121 tool_request_id = %tool_request.id,
122 confidence = analysis_result.confidence,
123 explanation = %sanitized_explanation,
124 finding_id = %finding_id,
125 threshold = config_threshold,
126 "{}",
127 if above_threshold {
128 "Current tool call flagged as malicious after security analysis (above threshold)"
129 } else {
130 "Security finding below threshold - logged but not blocking execution"
131 }
132 );
133 if above_threshold {
134 results.push(SecurityResult {
135 is_malicious: analysis_result.is_malicious,
136 confidence: analysis_result.confidence,
137 explanation: analysis_result.explanation,
138 should_ask_user: true, finding_id,
140 tool_request_id: tool_request.id.clone(),
141 });
142 }
143 } else {
144 tracing::info!(
145 tool_name = %tool_call.name,
146 tool_request_id = %tool_request.id,
147 confidence = analysis_result.confidence,
148 explanation = %sanitized_explanation,
149 "✅ Current tool call passed security analysis"
150 );
151 }
152 }
153 }
154
155 tracing::info!(
156 counter.aster.prompt_injection_analysis_performed = 1,
157 security_issues_found = results.len(),
158 "Security analysis complete"
159 );
160 Ok(results)
161 }
162
163 pub async fn filter_malicious_tool_calls(
164 &self,
165 messages: &[Message],
166 permission_check_result: &PermissionCheckResult,
167 _system_prompt: Option<&str>,
168 ) -> Result<Vec<SecurityResult>> {
169 let tool_requests: Vec<_> = permission_check_result
170 .approved
171 .iter()
172 .chain(permission_check_result.needs_approval.iter())
173 .cloned()
174 .collect();
175
176 self.analyze_tool_requests(&tool_requests, messages).await
177 }
178}
179
180impl Default for SecurityManager {
181 fn default() -> Self {
182 Self::new()
183 }
184}