Skip to main content

aster/
tool_inspection.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use std::collections::HashMap;
4
5use crate::config::AsterMode;
6use crate::conversation::message::{Message, ToolRequest};
7use crate::permission::permission_inspector::PermissionInspector;
8use crate::permission::permission_judge::PermissionCheckResult;
9
10/// Result of inspecting a tool call
11#[derive(Debug, Clone)]
12pub struct InspectionResult {
13    pub tool_request_id: String,
14    pub action: InspectionAction,
15    pub reason: String,
16    pub confidence: f32,
17    pub inspector_name: String,
18    pub finding_id: Option<String>,
19}
20
21/// Action to take based on inspection result
22#[derive(Debug, Clone, PartialEq)]
23pub enum InspectionAction {
24    /// Allow the tool to execute without user intervention
25    Allow,
26    /// Deny the tool execution completely
27    Deny,
28    /// Require user approval before execution (with optional warning message)
29    RequireApproval(Option<String>),
30}
31
32/// Trait for all tool inspectors
33#[async_trait]
34pub trait ToolInspector: Send + Sync {
35    /// Name of this inspector (for logging/debugging)
36    fn name(&self) -> &'static str;
37
38    /// Inspect tool requests and return results
39    async fn inspect(
40        &self,
41        tool_requests: &[ToolRequest],
42        messages: &[Message],
43    ) -> Result<Vec<InspectionResult>>;
44
45    /// Whether this inspector is enabled
46    fn is_enabled(&self) -> bool {
47        true
48    }
49
50    /// Allow downcasting to concrete types
51    fn as_any(&self) -> &dyn std::any::Any;
52}
53
54/// Manages all tool inspectors and coordinates their results
55pub struct ToolInspectionManager {
56    inspectors: Vec<Box<dyn ToolInspector>>,
57}
58
59impl ToolInspectionManager {
60    pub fn new() -> Self {
61        Self {
62            inspectors: Vec::new(),
63        }
64    }
65
66    /// Add an inspector to the manager
67    /// Inspectors run in the order they are added
68    pub fn add_inspector(&mut self, inspector: Box<dyn ToolInspector>) {
69        self.inspectors.push(inspector);
70    }
71
72    /// Run all inspectors on the tool requests
73    pub async fn inspect_tools(
74        &self,
75        tool_requests: &[ToolRequest],
76        messages: &[Message],
77    ) -> Result<Vec<InspectionResult>> {
78        let mut all_results = Vec::new();
79
80        for inspector in &self.inspectors {
81            if !inspector.is_enabled() {
82                continue;
83            }
84
85            tracing::debug!(
86                inspector_name = inspector.name(),
87                tool_count = tool_requests.len(),
88                "Running tool inspector"
89            );
90
91            match inspector.inspect(tool_requests, messages).await {
92                Ok(results) => {
93                    tracing::debug!(
94                        inspector_name = inspector.name(),
95                        result_count = results.len(),
96                        "Tool inspector completed"
97                    );
98                    all_results.extend(results);
99                }
100                Err(e) => {
101                    tracing::error!(
102                        inspector_name = inspector.name(),
103                        error = %e,
104                        "Tool inspector failed"
105                    );
106                    // Continue with other inspectors even if one fails
107                }
108            }
109        }
110
111        Ok(all_results)
112    }
113
114    /// Get list of registered inspector names
115    pub fn inspector_names(&self) -> Vec<&'static str> {
116        self.inspectors.iter().map(|i| i.name()).collect()
117    }
118
119    /// Update the permission inspector's mode
120    pub async fn update_permission_inspector_mode(&self, mode: AsterMode) {
121        for inspector in &self.inspectors {
122            if inspector.name() == "permission" {
123                // Downcast to PermissionInspector to access update_mode method
124                if let Some(permission_inspector) =
125                    inspector.as_any().downcast_ref::<PermissionInspector>()
126                {
127                    permission_inspector.update_mode(mode).await;
128                    return;
129                }
130            }
131        }
132        tracing::warn!("Permission inspector not found for mode update");
133    }
134
135    /// Update the permission manager for a specific tool
136    pub async fn update_permission_manager(
137        &self,
138        tool_name: &str,
139        permission_level: crate::config::permission::PermissionLevel,
140    ) {
141        for inspector in &self.inspectors {
142            if inspector.name() == "permission" {
143                // Downcast to PermissionInspector to access permission manager
144                if let Some(permission_inspector) =
145                    inspector.as_any().downcast_ref::<PermissionInspector>()
146                {
147                    let mut permission_manager =
148                        permission_inspector.permission_manager.lock().await;
149                    permission_manager.update_user_permission(tool_name, permission_level);
150                    return;
151                }
152            }
153        }
154        tracing::warn!("Permission inspector not found for permission manager update");
155    }
156
157    /// Get the integrated permission manager from the permission inspector
158    ///
159    /// Returns the integrated permission manager if the permission inspector
160    /// has been configured with one.
161    ///
162    /// Requirements: 11.1, 11.4
163    pub fn get_integrated_permission_manager(
164        &self,
165    ) -> Option<
166        std::sync::Arc<
167            tokio::sync::Mutex<crate::permission::integration::IntegratedPermissionManager>,
168        >,
169    > {
170        for inspector in &self.inspectors {
171            if inspector.name() == "permission" {
172                if let Some(permission_inspector) =
173                    inspector.as_any().downcast_ref::<PermissionInspector>()
174                {
175                    return permission_inspector.integrated_manager().cloned();
176                }
177            }
178        }
179        None
180    }
181
182    /// Process inspection results using the permission inspector
183    /// This delegates to the permission inspector's process_inspection_results method
184    pub fn process_inspection_results_with_permission_inspector(
185        &self,
186        remaining_requests: &[ToolRequest],
187        inspection_results: &[InspectionResult],
188    ) -> Option<PermissionCheckResult> {
189        for inspector in &self.inspectors {
190            if inspector.name() == "permission" {
191                if let Some(permission_inspector) =
192                    inspector.as_any().downcast_ref::<PermissionInspector>()
193                {
194                    return Some(
195                        permission_inspector
196                            .process_inspection_results(remaining_requests, inspection_results),
197                    );
198                }
199            }
200        }
201        tracing::warn!("Permission inspector not found for processing inspection results");
202        None
203    }
204}
205
206impl Default for ToolInspectionManager {
207    fn default() -> Self {
208        Self::new()
209    }
210}
211
212/// Apply inspection results to permission check results
213/// This is the generic permission-mixing logic that works for all inspector types
214pub fn apply_inspection_results_to_permissions(
215    mut permission_result: PermissionCheckResult,
216    inspection_results: &[InspectionResult],
217) -> PermissionCheckResult {
218    if inspection_results.is_empty() {
219        return permission_result;
220    }
221
222    // Create a map of tool requests by ID for easy lookup
223    let mut all_requests: HashMap<String, ToolRequest> = HashMap::new();
224
225    // Collect all tool requests
226    for req in &permission_result.approved {
227        all_requests.insert(req.id.clone(), req.clone());
228    }
229    for req in &permission_result.needs_approval {
230        all_requests.insert(req.id.clone(), req.clone());
231    }
232    for req in &permission_result.denied {
233        all_requests.insert(req.id.clone(), req.clone());
234    }
235
236    // Process inspection results
237    for result in inspection_results {
238        let request_id = &result.tool_request_id;
239
240        tracing::info!(
241            inspector_name = result.inspector_name,
242            tool_request_id = %request_id,
243            action = ?result.action,
244            confidence = result.confidence,
245            reason = %result.reason,
246            finding_id = ?result.finding_id,
247            "Applying inspection result"
248        );
249
250        match result.action {
251            InspectionAction::Deny => {
252                // Remove from approved and needs_approval, add to denied
253                permission_result
254                    .approved
255                    .retain(|req| req.id != *request_id);
256                permission_result
257                    .needs_approval
258                    .retain(|req| req.id != *request_id);
259
260                if let Some(request) = all_requests.get(request_id) {
261                    if !permission_result
262                        .denied
263                        .iter()
264                        .any(|req| req.id == *request_id)
265                    {
266                        permission_result.denied.push(request.clone());
267                    }
268                }
269            }
270            InspectionAction::RequireApproval(_) => {
271                // Remove from approved, add to needs_approval if not already there
272                permission_result
273                    .approved
274                    .retain(|req| req.id != *request_id);
275
276                if let Some(request) = all_requests.get(request_id) {
277                    if !permission_result
278                        .needs_approval
279                        .iter()
280                        .any(|req| req.id == *request_id)
281                    {
282                        permission_result.needs_approval.push(request.clone());
283                    }
284                }
285            }
286            InspectionAction::Allow => {
287                // This inspector allows it, but don't override other inspectors' decisions
288                // If it's already denied or needs approval, leave it that way
289            }
290        }
291    }
292
293    permission_result
294}
295
296pub fn get_security_finding_id_from_results(
297    tool_request_id: &str,
298    inspection_results: &[InspectionResult],
299) -> Option<String> {
300    inspection_results
301        .iter()
302        .find(|result| {
303            result.tool_request_id == tool_request_id && result.inspector_name == "security"
304        })
305        .and_then(|result| result.finding_id.clone())
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311    use crate::conversation::message::ToolRequest;
312    use rmcp::model::CallToolRequestParam;
313    use rmcp::object;
314
315    #[test]
316    fn test_apply_inspection_results() {
317        let tool_request = ToolRequest {
318            id: "req_1".to_string(),
319            tool_call: Ok(CallToolRequestParam {
320                name: "test_tool".into(),
321                arguments: Some(object!({})),
322            }),
323            metadata: None,
324            tool_meta: None,
325        };
326
327        let permission_result = PermissionCheckResult {
328            approved: vec![tool_request.clone()],
329            needs_approval: vec![],
330            denied: vec![],
331        };
332
333        let inspection_results = vec![InspectionResult {
334            tool_request_id: "req_1".to_string(),
335            action: InspectionAction::Deny,
336            reason: "Test denial".to_string(),
337            confidence: 0.9,
338            inspector_name: "test_inspector".to_string(),
339            finding_id: Some("TEST-001".to_string()),
340        }];
341
342        let updated_result =
343            apply_inspection_results_to_permissions(permission_result, &inspection_results);
344
345        assert_eq!(updated_result.approved.len(), 0);
346        assert_eq!(updated_result.denied.len(), 1);
347        assert_eq!(updated_result.denied[0].id, "req_1");
348    }
349}