Skip to main content

aster/security/
mod.rs

1pub mod classification_client;
2pub mod patterns;
3pub mod scanner;
4pub mod security_inspector;
5
6use crate::config::Config;
7use crate::conversation::message::{Message, ToolRequest};
8use crate::permission::permission_judge::PermissionCheckResult;
9use anyhow::Result;
10use scanner::PromptInjectionScanner;
11use std::sync::OnceLock;
12use uuid::Uuid;
13
14pub struct SecurityManager {
15    scanner: OnceLock<PromptInjectionScanner>,
16}
17
18#[derive(Debug, Clone)]
19pub struct SecurityResult {
20    pub is_malicious: bool,
21    pub confidence: f32,
22    pub explanation: String,
23    pub should_ask_user: bool,
24    pub finding_id: String,
25    pub tool_request_id: String,
26}
27
28impl SecurityManager {
29    pub fn new() -> Self {
30        Self {
31            scanner: OnceLock::new(),
32        }
33    }
34
35    pub fn is_prompt_injection_detection_enabled(&self) -> bool {
36        let config = Config::global();
37
38        config
39            .get_param::<bool>("SECURITY_PROMPT_ENABLED")
40            .unwrap_or(false)
41    }
42
43    fn is_ml_scanning_enabled(&self) -> bool {
44        let config = Config::global();
45
46        config
47            .get_param::<bool>("SECURITY_PROMPT_CLASSIFIER_ENABLED")
48            .unwrap_or(false)
49    }
50
51    pub async fn analyze_tool_requests(
52        &self,
53        tool_requests: &[ToolRequest],
54        messages: &[Message],
55    ) -> Result<Vec<SecurityResult>> {
56        if !self.is_prompt_injection_detection_enabled() {
57            tracing::debug!(
58                counter.aster.prompt_injection_scanner_disabled = 1,
59                "Security scanning disabled"
60            );
61            return Ok(vec![]);
62        }
63
64        let scanner = self.scanner.get_or_init(|| {
65            let ml_enabled = self.is_ml_scanning_enabled();
66
67            let scanner = if ml_enabled {
68                match PromptInjectionScanner::with_ml_detection() {
69                    Ok(s) => {
70                        tracing::info!(
71                            counter.aster.prompt_injection_scanner_enabled = 1,
72                            "🔓 Security scanner initialized with ML-based detection"
73                        );
74                        s
75                    }
76                    Err(e) => {
77                        let error_chain = format!("{:#}", e);
78                        tracing::warn!(
79                            "⚠️ ML scanning requested but failed to initialize. Falling back to pattern-only scanning.\n\nError details:\n{}",
80                            error_chain
81                        );
82                        PromptInjectionScanner::new()
83                    }
84                }
85            } else {
86                tracing::info!(
87                    counter.aster.prompt_injection_scanner_enabled = 1,
88                    "🔓 Security scanner initialized with pattern-based detection only"
89                );
90                PromptInjectionScanner::new()
91            };
92
93            scanner
94        });
95
96        let mut results = Vec::new();
97
98        tracing::info!(
99            "🔍 Starting security analysis - {} tool requests, {} messages",
100            tool_requests.len(),
101            messages.len()
102        );
103
104        for tool_request in tool_requests.iter() {
105            if let Ok(tool_call) = &tool_request.tool_call {
106                let analysis_result = scanner
107                    .analyze_tool_call_with_context(tool_call, messages)
108                    .await?;
109
110                let config_threshold = scanner.get_threshold_from_config();
111                let sanitized_explanation = analysis_result.explanation.replace('\n', " | ");
112
113                if analysis_result.is_malicious {
114                    let above_threshold = analysis_result.confidence > config_threshold;
115                    let finding_id = format!("SEC-{}", Uuid::new_v4().simple());
116
117                    tracing::warn!(
118                        counter.aster.prompt_injection_finding = 1,
119                        above_threshold = above_threshold,
120                        tool_name = %tool_call.name,
121                        tool_request_id = %tool_request.id,
122                        confidence = analysis_result.confidence,
123                        explanation = %sanitized_explanation,
124                        finding_id = %finding_id,
125                        threshold = config_threshold,
126                        "{}",
127                        if above_threshold {
128                            "Current tool call flagged as malicious after security analysis (above threshold)"
129                        } else {
130                            "Security finding below threshold - logged but not blocking execution"
131                        }
132                    );
133                    if above_threshold {
134                        results.push(SecurityResult {
135                            is_malicious: analysis_result.is_malicious,
136                            confidence: analysis_result.confidence,
137                            explanation: analysis_result.explanation,
138                            should_ask_user: true, // Always ask user for threats above threshold
139                            finding_id,
140                            tool_request_id: tool_request.id.clone(),
141                        });
142                    }
143                } else {
144                    tracing::info!(
145                        tool_name = %tool_call.name,
146                        tool_request_id = %tool_request.id,
147                        confidence = analysis_result.confidence,
148                        explanation = %sanitized_explanation,
149                        "✅ Current tool call passed security analysis"
150                    );
151                }
152            }
153        }
154
155        tracing::info!(
156            counter.aster.prompt_injection_analysis_performed = 1,
157            security_issues_found = results.len(),
158            "Security analysis complete"
159        );
160        Ok(results)
161    }
162
163    pub async fn filter_malicious_tool_calls(
164        &self,
165        messages: &[Message],
166        permission_check_result: &PermissionCheckResult,
167        _system_prompt: Option<&str>,
168    ) -> Result<Vec<SecurityResult>> {
169        let tool_requests: Vec<_> = permission_check_result
170            .approved
171            .iter()
172            .chain(permission_check_result.needs_approval.iter())
173            .cloned()
174            .collect();
175
176        self.analyze_tool_requests(&tool_requests, messages).await
177    }
178}
179
180impl Default for SecurityManager {
181    fn default() -> Self {
182        Self::new()
183    }
184}