Skip to main content

aster/agents/
tool_execution.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Instant;
5
6use async_stream::try_stream;
7use futures::stream::{self, BoxStream};
8use futures::{Stream, StreamExt};
9use tokio::sync::Mutex;
10use tokio_util::sync::CancellationToken;
11
12use crate::config::permission::PermissionLevel;
13use crate::mcp_utils::ToolResult;
14use crate::permission::{
15    AuditLogEntry, AuditLogLevel, AuditLogger, Permission, PermissionContext, ToolPermissionManager,
16};
17use crate::tools::{ToolContext, ToolRegistry};
18use rmcp::model::{Content, ServerNotification};
19
20// ToolCallResult combines the result of a tool call with an optional notification stream that
21// can be used to receive notifications from the tool.
22pub struct ToolCallResult {
23    pub result: Box<dyn Future<Output = ToolResult<rmcp::model::CallToolResult>> + Send + Unpin>,
24    pub notification_stream: Option<Box<dyn Stream<Item = ServerNotification> + Send + Unpin>>,
25}
26
27impl From<ToolResult<rmcp::model::CallToolResult>> for ToolCallResult {
28    fn from(result: ToolResult<rmcp::model::CallToolResult>) -> Self {
29        Self {
30            result: Box::new(futures::future::ready(result)),
31            notification_stream: None,
32        }
33    }
34}
35
36use super::agent::{tool_stream, ToolStream};
37use crate::agents::Agent;
38use crate::conversation::message::{Message, ToolRequest};
39use crate::session::Session;
40use crate::tool_inspection::get_security_finding_id_from_results;
41
42pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
43    DO NOT attempt to call this tool again. \
44    If there are no alternative methods to proceed, clearly explain the situation and STOP.";
45
46pub const CHAT_MODE_TOOL_SKIPPED_RESPONSE: &str = "Let the user know the tool call was skipped in aster chat mode. \
47                                        DO NOT apologize for skipping the tool call. DO NOT say sorry. \
48                                        Provide an explanation of what the tool call would do, structured as a \
49                                        plan for the user. Again, DO NOT apologize. \
50                                        **Example Plan:**\n \
51                                        1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
52                                        2. **Outline Steps** - Break down the steps.\n \
53                                        If needed, adjust the explanation based on user preferences or questions.";
54
55impl Agent {
56    pub(crate) fn handle_approval_tool_requests<'a>(
57        &'a self,
58        tool_requests: &'a [ToolRequest],
59        tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
60        request_to_response_map: &'a HashMap<String, Arc<Mutex<Message>>>,
61        cancellation_token: Option<CancellationToken>,
62        session: &'a Session,
63        inspection_results: &'a [crate::tool_inspection::InspectionResult],
64    ) -> BoxStream<'a, anyhow::Result<Message>> {
65        try_stream! {
66        for request in tool_requests.iter() {
67            if let Ok(tool_call) = request.tool_call.clone() {
68                // Find the corresponding inspection result for this tool request
69                let security_message = inspection_results.iter()
70                    .find(|result| result.tool_request_id == request.id)
71                    .and_then(|result| {
72                        if let crate::tool_inspection::InspectionAction::RequireApproval(Some(message)) = &result.action {
73                            Some(message.clone())
74                        } else {
75                            None
76                        }
77                    });
78
79                let confirmation = Message::assistant()
80                    .with_action_required(
81                        request.id.clone(),
82                        tool_call.name.to_string().clone(),
83                        tool_call.arguments.clone().unwrap_or_default(),
84                        security_message,
85                    )
86                    .user_only();
87                yield confirmation;
88
89                let mut rx = self.confirmation_rx.lock().await;
90                while let Some((req_id, confirmation)) = rx.recv().await {
91                    if req_id == request.id {
92                        // Log user decision if this was a security alert
93                        if let Some(finding_id) = get_security_finding_id_from_results(&request.id, inspection_results) {
94                            tracing::info!(
95                                counter.aster.prompt_injection_user_decisions = 1,
96                                decision = ?confirmation.permission,
97                                finding_id = %finding_id,
98                                "User security decision"
99                            );
100                        }
101
102                        if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
103                            let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone(), session).await;
104                            let mut futures = tool_futures.lock().await;
105
106                            futures.push((req_id, match tool_result {
107                                Ok(result) => tool_stream(
108                                    result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
109                                    result.result,
110                                ),
111                                Err(e) => tool_stream(
112                                    Box::new(stream::empty()),
113                                    futures::future::ready(Err(e)),
114                                ),
115                            }));
116
117                            // Update the shared permission manager when user selects "Always Allow"
118                            if confirmation.permission == Permission::AlwaysAllow {
119                                self.tool_inspection_manager
120                                    .update_permission_manager(&tool_call.name, PermissionLevel::AlwaysAllow)
121                                    .await;
122                            }
123                        } else {
124                            // User declined - update the specific response message for this request
125                            if let Some(response_msg) = request_to_response_map.get(&request.id) {
126                                let mut response = response_msg.lock().await;
127                                *response = response.clone().with_tool_response_with_metadata(
128                                    request.id.clone(),
129                                    Ok(rmcp::model::CallToolResult {
130                                        content: vec![Content::text(DECLINED_RESPONSE)],
131                                        structured_content: None,
132                                        is_error: Some(true),
133                                        meta: None,
134                                    }),
135                                    request.metadata.as_ref(),
136                                );
137                            }
138                        }
139                        break; // Exit the loop once the matching `req_id` is found
140                    }
141                }
142            }
143        }
144    }.boxed()
145    }
146
147    pub(crate) fn handle_frontend_tool_request<'a>(
148        &'a self,
149        tool_request: &'a ToolRequest,
150        message_tool_response: Arc<Mutex<Message>>,
151    ) -> BoxStream<'a, anyhow::Result<Message>> {
152        try_stream! {
153                if let Ok(tool_call) = tool_request.tool_call.clone() {
154                    if self.is_frontend_tool(&tool_call.name).await {
155                        // Send frontend tool request and wait for response
156                        yield Message::assistant().with_frontend_tool_request(
157                            tool_request.id.clone(),
158                            Ok(tool_call.clone())
159                        );
160
161                        if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
162                            let mut response = message_tool_response.lock().await;
163                            *response = response.clone().with_tool_response_with_metadata(
164                                id,
165                                result,
166                                tool_request.metadata.as_ref(),
167                            );
168                        }
169                    }
170            }
171        }
172        .boxed()
173    }
174
175    // =============================================================================
176    // ToolRegistry Integration (Requirements: 8.1, 8.2, 8.3, 8.4, 8.5)
177    // =============================================================================
178
179    /// Create a ToolContext from a Session
180    ///
181    /// This helper function creates a ToolContext suitable for use with the
182    /// ToolRegistry from the current session information.
183    ///
184    /// Requirements: 8.4
185    pub fn create_tool_context(
186        session: &Session,
187        cancellation_token: Option<CancellationToken>,
188    ) -> ToolContext {
189        let mut ctx = ToolContext::new(session.working_dir.clone()).with_session_id(&session.id);
190
191        if let Some(token) = cancellation_token {
192            ctx = ctx.with_cancellation_token(token);
193        }
194
195        ctx
196    }
197
198    /// Create a PermissionContext from a Session
199    ///
200    /// This helper function creates a PermissionContext suitable for use with
201    /// the ToolPermissionManager from the current session information.
202    ///
203    /// Requirements: 8.1, 8.2
204    pub fn create_permission_context(session: &Session) -> PermissionContext {
205        PermissionContext {
206            working_directory: session.working_dir.clone(),
207            session_id: session.id.clone(),
208            timestamp: chrono::Utc::now().timestamp(),
209            user: None,
210            environment: HashMap::new(),
211            metadata: HashMap::new(),
212        }
213    }
214
215    /// Execute a tool through the ToolRegistry with permission checking and audit logging
216    ///
217    /// This method provides a unified interface for executing tools through the
218    /// ToolRegistry, integrating:
219    /// - Permission checking via ToolPermissionManager
220    /// - Audit logging via AuditLogger
221    /// - User confirmation handling for 'ask' permission behavior
222    ///
223    /// # Arguments
224    /// * `registry` - The ToolRegistry containing registered tools
225    /// * `tool_name` - Name of the tool to execute
226    /// * `params` - Tool parameters as JSON
227    /// * `session` - Current session
228    /// * `cancellation_token` - Optional cancellation token
229    /// * `on_permission_request` - Optional callback for permission requests
230    ///
231    /// # Returns
232    /// * `Ok(ToolResult)` - The tool execution result
233    /// * `Err(ToolError)` - If permission denied or execution fails
234    ///
235    /// Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
236    pub async fn execute_tool_with_registry(
237        registry: &ToolRegistry,
238        tool_name: &str,
239        params: serde_json::Value,
240        session: &Session,
241        cancellation_token: Option<CancellationToken>,
242        on_permission_request: Option<crate::tools::PermissionRequestCallback>,
243    ) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
244        let context = Self::create_tool_context(session, cancellation_token);
245        registry
246            .execute(tool_name, params, &context, on_permission_request)
247            .await
248    }
249
250    /// Log a tool execution to the audit logger
251    ///
252    /// This helper function logs tool execution events to the audit logger,
253    /// including success/failure status, duration, and relevant metadata.
254    ///
255    /// Requirements: 8.5
256    pub fn log_tool_execution(
257        audit_logger: &AuditLogger,
258        tool_name: &str,
259        params: &serde_json::Value,
260        session: &Session,
261        success: bool,
262        duration: std::time::Duration,
263        error_message: Option<&str>,
264    ) {
265        let level = if success {
266            AuditLogLevel::Info
267        } else {
268            AuditLogLevel::Warn
269        };
270
271        let perm_context = Self::create_permission_context(session);
272        let params_map = Self::params_to_hashmap(params);
273
274        let mut entry = AuditLogEntry::new("tool_execution", tool_name)
275            .with_level(level)
276            .with_parameters(params_map)
277            .with_context(perm_context)
278            .with_duration_ms(duration.as_millis() as u64)
279            .add_metadata("success", serde_json::json!(success));
280
281        if let Some(err) = error_message {
282            entry = entry.add_metadata("error", serde_json::json!(err));
283        }
284
285        audit_logger.log_tool_execution(entry);
286    }
287
288    /// Log a permission denial to the audit logger
289    ///
290    /// This helper function logs permission denial events to the audit logger.
291    ///
292    /// Requirements: 8.5
293    pub fn log_permission_denied(
294        audit_logger: &AuditLogger,
295        tool_name: &str,
296        params: &serde_json::Value,
297        session: &Session,
298        reason: &str,
299    ) {
300        let perm_context = Self::create_permission_context(session);
301        let params_map = Self::params_to_hashmap(params);
302
303        let entry = AuditLogEntry::new("permission_denied", tool_name)
304            .with_level(AuditLogLevel::Warn)
305            .with_parameters(params_map)
306            .with_context(perm_context)
307            .add_metadata("reason", serde_json::json!(reason));
308
309        audit_logger.log(entry);
310    }
311
312    /// Convert JSON params to HashMap for permission checking
313    fn params_to_hashmap(params: &serde_json::Value) -> HashMap<String, serde_json::Value> {
314        match params {
315            serde_json::Value::Object(map) => {
316                map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
317            }
318            _ => HashMap::new(),
319        }
320    }
321
322    /// Check tool permissions using ToolPermissionManager
323    ///
324    /// This method checks if a tool execution is allowed based on the
325    /// configured permission rules.
326    ///
327    /// # Arguments
328    /// * `permission_manager` - The ToolPermissionManager to use
329    /// * `tool_name` - Name of the tool to check
330    /// * `params` - Tool parameters as JSON
331    /// * `session` - Current session
332    ///
333    /// # Returns
334    /// * `Ok(())` - If permission is granted
335    /// * `Err(reason)` - If permission is denied, with the denial reason
336    ///
337    /// Requirements: 8.1, 8.2, 8.3
338    pub fn check_tool_permission(
339        permission_manager: &ToolPermissionManager,
340        tool_name: &str,
341        params: &serde_json::Value,
342        session: &Session,
343    ) -> Result<(), String> {
344        let perm_context = Self::create_permission_context(session);
345        let params_map = Self::params_to_hashmap(params);
346
347        let result = permission_manager.is_allowed(tool_name, &params_map, &perm_context);
348
349        if result.allowed {
350            Ok(())
351        } else {
352            Err(result
353                .reason
354                .unwrap_or_else(|| format!("Permission denied for tool '{}'", tool_name)))
355        }
356    }
357
358    /// Execute a tool call with integrated permission checking and audit logging
359    ///
360    /// This is a higher-level wrapper that combines permission checking,
361    /// tool execution, and audit logging into a single operation.
362    ///
363    /// # Arguments
364    /// * `registry` - The ToolRegistry containing registered tools
365    /// * `permission_manager` - Optional ToolPermissionManager for permission checks
366    /// * `audit_logger` - Optional AuditLogger for logging
367    /// * `tool_name` - Name of the tool to execute
368    /// * `params` - Tool parameters as JSON
369    /// * `session` - Current session
370    /// * `cancellation_token` - Optional cancellation token
371    ///
372    /// # Returns
373    /// * `Ok(ToolResult)` - The tool execution result
374    /// * `Err(String)` - Error message if permission denied or execution fails
375    ///
376    /// Requirements: 8.1, 8.2, 8.3, 8.4, 8.5
377    pub async fn execute_tool_with_checks(
378        registry: &ToolRegistry,
379        permission_manager: Option<&ToolPermissionManager>,
380        audit_logger: Option<&AuditLogger>,
381        tool_name: &str,
382        params: serde_json::Value,
383        session: &Session,
384        cancellation_token: Option<CancellationToken>,
385    ) -> Result<crate::tools::ToolResult, String> {
386        let start_time = Instant::now();
387
388        // Step 1: Check permissions if permission manager is provided
389        if let Some(pm) = permission_manager {
390            if let Err(reason) = Self::check_tool_permission(pm, tool_name, &params, session) {
391                // Log permission denial
392                if let Some(logger) = audit_logger {
393                    Self::log_permission_denied(logger, tool_name, &params, session, &reason);
394                }
395                return Err(reason);
396            }
397        }
398
399        // Step 2: Execute the tool
400        let context = Self::create_tool_context(session, cancellation_token);
401        let result = registry
402            .execute(tool_name, params.clone(), &context, None)
403            .await;
404
405        // Step 3: Log the execution
406        let duration = start_time.elapsed();
407        if let Some(logger) = audit_logger {
408            match &result {
409                Ok(tool_result) => {
410                    Self::log_tool_execution(
411                        logger,
412                        tool_name,
413                        &params,
414                        session,
415                        tool_result.is_success(),
416                        duration,
417                        tool_result.error.as_deref(),
418                    );
419                }
420                Err(err) => {
421                    Self::log_tool_execution(
422                        logger,
423                        tool_name,
424                        &params,
425                        session,
426                        false,
427                        duration,
428                        Some(&err.to_string()),
429                    );
430                }
431            }
432        }
433
434        result.map_err(|e| e.to_string())
435    }
436
437    /// Execute a tool call with user confirmation support for 'ask' permission behavior
438    ///
439    /// This method extends `execute_tool_with_checks` to support the 'ask' permission
440    /// behavior, where the user is prompted to confirm tool execution.
441    ///
442    /// # Arguments
443    /// * `registry` - The ToolRegistry containing registered tools
444    /// * `permission_manager` - Optional ToolPermissionManager for permission checks
445    /// * `audit_logger` - Optional AuditLogger for logging
446    /// * `tool_name` - Name of the tool to execute
447    /// * `params` - Tool parameters as JSON
448    /// * `session` - Current session
449    /// * `cancellation_token` - Optional cancellation token
450    /// * `on_permission_request` - Callback for handling 'ask' permission behavior
451    ///
452    /// # Returns
453    /// * `Ok(ToolResult)` - The tool execution result
454    /// * `Err(String)` - Error message if permission denied or execution fails
455    ///
456    /// Requirements: 8.1, 8.2, 8.3, 8.4
457    #[allow(clippy::too_many_arguments)]
458    pub async fn execute_tool_with_user_confirmation(
459        registry: &ToolRegistry,
460        permission_manager: Option<&ToolPermissionManager>,
461        audit_logger: Option<&AuditLogger>,
462        tool_name: &str,
463        params: serde_json::Value,
464        session: &Session,
465        cancellation_token: Option<CancellationToken>,
466        on_permission_request: Option<crate::tools::PermissionRequestCallback>,
467    ) -> Result<crate::tools::ToolResult, String> {
468        let start_time = Instant::now();
469
470        // Step 1: Check permissions if permission manager is provided
471        if let Some(pm) = permission_manager {
472            if let Err(reason) = Self::check_tool_permission(pm, tool_name, &params, session) {
473                // Log permission denial
474                if let Some(logger) = audit_logger {
475                    Self::log_permission_denied(logger, tool_name, &params, session, &reason);
476                }
477                return Err(reason);
478            }
479        }
480
481        // Step 2: Execute the tool with permission request callback
482        let context = Self::create_tool_context(session, cancellation_token);
483        let result = registry
484            .execute(tool_name, params.clone(), &context, on_permission_request)
485            .await;
486
487        // Step 3: Log the execution
488        let duration = start_time.elapsed();
489        if let Some(logger) = audit_logger {
490            match &result {
491                Ok(tool_result) => {
492                    Self::log_tool_execution(
493                        logger,
494                        tool_name,
495                        &params,
496                        session,
497                        tool_result.is_success(),
498                        duration,
499                        tool_result.error.as_deref(),
500                    );
501                }
502                Err(err) => {
503                    Self::log_tool_execution(
504                        logger,
505                        tool_name,
506                        &params,
507                        session,
508                        false,
509                        duration,
510                        Some(&err.to_string()),
511                    );
512                }
513            }
514        }
515
516        result.map_err(|e| e.to_string())
517    }
518
519    /// Create a permission request callback that uses the Agent's confirmation channel
520    ///
521    /// This method creates a callback that can be used with `execute_tool_with_user_confirmation`
522    /// to handle 'ask' permission behavior by sending confirmation requests through the
523    /// Agent's existing confirmation channel.
524    ///
525    /// # Arguments
526    /// * `request_id` - The tool request ID for tracking
527    /// * `confirmation_tx` - The confirmation sender channel
528    ///
529    /// # Returns
530    /// A callback that sends permission requests and waits for user confirmation
531    ///
532    /// Requirements: 8.2, 8.3
533    pub fn create_permission_callback(
534        request_id: String,
535        _confirmation_tx: tokio::sync::mpsc::Sender<(
536            String,
537            crate::permission::PermissionConfirmation,
538        )>,
539    ) -> crate::tools::PermissionRequestCallback {
540        Box::new(move |tool_name: String, message: String| {
541            let req_id = request_id.clone();
542            Box::pin(async move {
543                // Log the permission request
544                tracing::info!(
545                    tool_name = %tool_name,
546                    message = %message,
547                    request_id = %req_id,
548                    "Permission request for tool execution"
549                );
550
551                // For now, we return false (deny) as the actual confirmation
552                // would need to be handled through the UI flow
553                // The existing handle_approval_tool_requests handles this flow
554                false
555            })
556        })
557    }
558
559    /// Log a permission check result to the audit logger
560    ///
561    /// This helper function logs permission check events to the audit logger,
562    /// including the result and any relevant metadata.
563    ///
564    /// Requirements: 8.5
565    pub fn log_permission_check(
566        audit_logger: &AuditLogger,
567        tool_name: &str,
568        params: &serde_json::Value,
569        session: &Session,
570        allowed: bool,
571        reason: Option<&str>,
572    ) {
573        let level = if allowed {
574            AuditLogLevel::Debug
575        } else {
576            AuditLogLevel::Warn
577        };
578
579        let perm_context = Self::create_permission_context(session);
580        let params_map = Self::params_to_hashmap(params);
581
582        let mut entry = AuditLogEntry::new("permission_check", tool_name)
583            .with_level(level)
584            .with_parameters(params_map)
585            .with_context(perm_context)
586            .add_metadata("allowed", serde_json::json!(allowed));
587
588        if let Some(r) = reason {
589            entry = entry.add_metadata("reason", serde_json::json!(r));
590        }
591
592        audit_logger.log_permission_check(entry);
593    }
594}