1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Instant;
5
6use async_stream::try_stream;
7use futures::stream::{self, BoxStream};
8use futures::{Stream, StreamExt};
9use tokio::sync::Mutex;
10use tokio_util::sync::CancellationToken;
11
12use crate::config::permission::PermissionLevel;
13use crate::mcp_utils::ToolResult;
14use crate::permission::{
15 AuditLogEntry, AuditLogLevel, AuditLogger, Permission, PermissionContext, ToolPermissionManager,
16};
17use crate::tools::{ToolContext, ToolRegistry};
18use rmcp::model::{Content, ServerNotification};
19
20pub struct ToolCallResult {
23 pub result: Box<dyn Future<Output = ToolResult<rmcp::model::CallToolResult>> + Send + Unpin>,
24 pub notification_stream: Option<Box<dyn Stream<Item = ServerNotification> + Send + Unpin>>,
25}
26
27impl From<ToolResult<rmcp::model::CallToolResult>> for ToolCallResult {
28 fn from(result: ToolResult<rmcp::model::CallToolResult>) -> Self {
29 Self {
30 result: Box::new(futures::future::ready(result)),
31 notification_stream: None,
32 }
33 }
34}
35
36use super::agent::{tool_stream, ToolStream};
37use crate::agents::Agent;
38use crate::conversation::message::{Message, ToolRequest};
39use crate::session::Session;
40use crate::tool_inspection::get_security_finding_id_from_results;
41
42pub const DECLINED_RESPONSE: &str = "The user has declined to run this tool. \
43 DO NOT attempt to call this tool again. \
44 If there are no alternative methods to proceed, clearly explain the situation and STOP.";
45
46pub const CHAT_MODE_TOOL_SKIPPED_RESPONSE: &str = "Let the user know the tool call was skipped in aster chat mode. \
47 DO NOT apologize for skipping the tool call. DO NOT say sorry. \
48 Provide an explanation of what the tool call would do, structured as a \
49 plan for the user. Again, DO NOT apologize. \
50 **Example Plan:**\n \
51 1. **Identify Task Scope** - Determine the purpose and expected outcome.\n \
52 2. **Outline Steps** - Break down the steps.\n \
53 If needed, adjust the explanation based on user preferences or questions.";
54
55impl Agent {
56 pub(crate) fn handle_approval_tool_requests<'a>(
57 &'a self,
58 tool_requests: &'a [ToolRequest],
59 tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
60 request_to_response_map: &'a HashMap<String, Arc<Mutex<Message>>>,
61 cancellation_token: Option<CancellationToken>,
62 session: &'a Session,
63 inspection_results: &'a [crate::tool_inspection::InspectionResult],
64 ) -> BoxStream<'a, anyhow::Result<Message>> {
65 try_stream! {
66 for request in tool_requests.iter() {
67 if let Ok(tool_call) = request.tool_call.clone() {
68 let security_message = inspection_results.iter()
70 .find(|result| result.tool_request_id == request.id)
71 .and_then(|result| {
72 if let crate::tool_inspection::InspectionAction::RequireApproval(Some(message)) = &result.action {
73 Some(message.clone())
74 } else {
75 None
76 }
77 });
78
79 let confirmation = Message::assistant()
80 .with_action_required(
81 request.id.clone(),
82 tool_call.name.to_string().clone(),
83 tool_call.arguments.clone().unwrap_or_default(),
84 security_message,
85 )
86 .user_only();
87 yield confirmation;
88
89 let mut rx = self.confirmation_rx.lock().await;
90 while let Some((req_id, confirmation)) = rx.recv().await {
91 if req_id == request.id {
92 if let Some(finding_id) = get_security_finding_id_from_results(&request.id, inspection_results) {
94 tracing::info!(
95 counter.aster.prompt_injection_user_decisions = 1,
96 decision = ?confirmation.permission,
97 finding_id = %finding_id,
98 "User security decision"
99 );
100 }
101
102 if confirmation.permission == Permission::AllowOnce || confirmation.permission == Permission::AlwaysAllow {
103 let (req_id, tool_result) = self.dispatch_tool_call(tool_call.clone(), request.id.clone(), cancellation_token.clone(), session).await;
104 let mut futures = tool_futures.lock().await;
105
106 futures.push((req_id, match tool_result {
107 Ok(result) => tool_stream(
108 result.notification_stream.unwrap_or_else(|| Box::new(stream::empty())),
109 result.result,
110 ),
111 Err(e) => tool_stream(
112 Box::new(stream::empty()),
113 futures::future::ready(Err(e)),
114 ),
115 }));
116
117 if confirmation.permission == Permission::AlwaysAllow {
119 self.tool_inspection_manager
120 .update_permission_manager(&tool_call.name, PermissionLevel::AlwaysAllow)
121 .await;
122 }
123 } else {
124 if let Some(response_msg) = request_to_response_map.get(&request.id) {
126 let mut response = response_msg.lock().await;
127 *response = response.clone().with_tool_response_with_metadata(
128 request.id.clone(),
129 Ok(rmcp::model::CallToolResult {
130 content: vec![Content::text(DECLINED_RESPONSE)],
131 structured_content: None,
132 is_error: Some(true),
133 meta: None,
134 }),
135 request.metadata.as_ref(),
136 );
137 }
138 }
139 break; }
141 }
142 }
143 }
144 }.boxed()
145 }
146
147 pub(crate) fn handle_frontend_tool_request<'a>(
148 &'a self,
149 tool_request: &'a ToolRequest,
150 message_tool_response: Arc<Mutex<Message>>,
151 ) -> BoxStream<'a, anyhow::Result<Message>> {
152 try_stream! {
153 if let Ok(tool_call) = tool_request.tool_call.clone() {
154 if self.is_frontend_tool(&tool_call.name).await {
155 yield Message::assistant().with_frontend_tool_request(
157 tool_request.id.clone(),
158 Ok(tool_call.clone())
159 );
160
161 if let Some((id, result)) = self.tool_result_rx.lock().await.recv().await {
162 let mut response = message_tool_response.lock().await;
163 *response = response.clone().with_tool_response_with_metadata(
164 id,
165 result,
166 tool_request.metadata.as_ref(),
167 );
168 }
169 }
170 }
171 }
172 .boxed()
173 }
174
175 pub fn create_tool_context(
186 session: &Session,
187 cancellation_token: Option<CancellationToken>,
188 ) -> ToolContext {
189 let mut ctx = ToolContext::new(session.working_dir.clone()).with_session_id(&session.id);
190
191 if let Some(token) = cancellation_token {
192 ctx = ctx.with_cancellation_token(token);
193 }
194
195 ctx
196 }
197
198 pub fn create_permission_context(session: &Session) -> PermissionContext {
205 PermissionContext {
206 working_directory: session.working_dir.clone(),
207 session_id: session.id.clone(),
208 timestamp: chrono::Utc::now().timestamp(),
209 user: None,
210 environment: HashMap::new(),
211 metadata: HashMap::new(),
212 }
213 }
214
215 pub async fn execute_tool_with_registry(
237 registry: &ToolRegistry,
238 tool_name: &str,
239 params: serde_json::Value,
240 session: &Session,
241 cancellation_token: Option<CancellationToken>,
242 on_permission_request: Option<crate::tools::PermissionRequestCallback>,
243 ) -> Result<crate::tools::ToolResult, crate::tools::ToolError> {
244 let context = Self::create_tool_context(session, cancellation_token);
245 registry
246 .execute(tool_name, params, &context, on_permission_request)
247 .await
248 }
249
250 pub fn log_tool_execution(
257 audit_logger: &AuditLogger,
258 tool_name: &str,
259 params: &serde_json::Value,
260 session: &Session,
261 success: bool,
262 duration: std::time::Duration,
263 error_message: Option<&str>,
264 ) {
265 let level = if success {
266 AuditLogLevel::Info
267 } else {
268 AuditLogLevel::Warn
269 };
270
271 let perm_context = Self::create_permission_context(session);
272 let params_map = Self::params_to_hashmap(params);
273
274 let mut entry = AuditLogEntry::new("tool_execution", tool_name)
275 .with_level(level)
276 .with_parameters(params_map)
277 .with_context(perm_context)
278 .with_duration_ms(duration.as_millis() as u64)
279 .add_metadata("success", serde_json::json!(success));
280
281 if let Some(err) = error_message {
282 entry = entry.add_metadata("error", serde_json::json!(err));
283 }
284
285 audit_logger.log_tool_execution(entry);
286 }
287
288 pub fn log_permission_denied(
294 audit_logger: &AuditLogger,
295 tool_name: &str,
296 params: &serde_json::Value,
297 session: &Session,
298 reason: &str,
299 ) {
300 let perm_context = Self::create_permission_context(session);
301 let params_map = Self::params_to_hashmap(params);
302
303 let entry = AuditLogEntry::new("permission_denied", tool_name)
304 .with_level(AuditLogLevel::Warn)
305 .with_parameters(params_map)
306 .with_context(perm_context)
307 .add_metadata("reason", serde_json::json!(reason));
308
309 audit_logger.log(entry);
310 }
311
312 fn params_to_hashmap(params: &serde_json::Value) -> HashMap<String, serde_json::Value> {
314 match params {
315 serde_json::Value::Object(map) => {
316 map.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
317 }
318 _ => HashMap::new(),
319 }
320 }
321
322 pub fn check_tool_permission(
339 permission_manager: &ToolPermissionManager,
340 tool_name: &str,
341 params: &serde_json::Value,
342 session: &Session,
343 ) -> Result<(), String> {
344 let perm_context = Self::create_permission_context(session);
345 let params_map = Self::params_to_hashmap(params);
346
347 let result = permission_manager.is_allowed(tool_name, ¶ms_map, &perm_context);
348
349 if result.allowed {
350 Ok(())
351 } else {
352 Err(result
353 .reason
354 .unwrap_or_else(|| format!("Permission denied for tool '{}'", tool_name)))
355 }
356 }
357
358 pub async fn execute_tool_with_checks(
378 registry: &ToolRegistry,
379 permission_manager: Option<&ToolPermissionManager>,
380 audit_logger: Option<&AuditLogger>,
381 tool_name: &str,
382 params: serde_json::Value,
383 session: &Session,
384 cancellation_token: Option<CancellationToken>,
385 ) -> Result<crate::tools::ToolResult, String> {
386 let start_time = Instant::now();
387
388 if let Some(pm) = permission_manager {
390 if let Err(reason) = Self::check_tool_permission(pm, tool_name, ¶ms, session) {
391 if let Some(logger) = audit_logger {
393 Self::log_permission_denied(logger, tool_name, ¶ms, session, &reason);
394 }
395 return Err(reason);
396 }
397 }
398
399 let context = Self::create_tool_context(session, cancellation_token);
401 let result = registry
402 .execute(tool_name, params.clone(), &context, None)
403 .await;
404
405 let duration = start_time.elapsed();
407 if let Some(logger) = audit_logger {
408 match &result {
409 Ok(tool_result) => {
410 Self::log_tool_execution(
411 logger,
412 tool_name,
413 ¶ms,
414 session,
415 tool_result.is_success(),
416 duration,
417 tool_result.error.as_deref(),
418 );
419 }
420 Err(err) => {
421 Self::log_tool_execution(
422 logger,
423 tool_name,
424 ¶ms,
425 session,
426 false,
427 duration,
428 Some(&err.to_string()),
429 );
430 }
431 }
432 }
433
434 result.map_err(|e| e.to_string())
435 }
436
437 #[allow(clippy::too_many_arguments)]
458 pub async fn execute_tool_with_user_confirmation(
459 registry: &ToolRegistry,
460 permission_manager: Option<&ToolPermissionManager>,
461 audit_logger: Option<&AuditLogger>,
462 tool_name: &str,
463 params: serde_json::Value,
464 session: &Session,
465 cancellation_token: Option<CancellationToken>,
466 on_permission_request: Option<crate::tools::PermissionRequestCallback>,
467 ) -> Result<crate::tools::ToolResult, String> {
468 let start_time = Instant::now();
469
470 if let Some(pm) = permission_manager {
472 if let Err(reason) = Self::check_tool_permission(pm, tool_name, ¶ms, session) {
473 if let Some(logger) = audit_logger {
475 Self::log_permission_denied(logger, tool_name, ¶ms, session, &reason);
476 }
477 return Err(reason);
478 }
479 }
480
481 let context = Self::create_tool_context(session, cancellation_token);
483 let result = registry
484 .execute(tool_name, params.clone(), &context, on_permission_request)
485 .await;
486
487 let duration = start_time.elapsed();
489 if let Some(logger) = audit_logger {
490 match &result {
491 Ok(tool_result) => {
492 Self::log_tool_execution(
493 logger,
494 tool_name,
495 ¶ms,
496 session,
497 tool_result.is_success(),
498 duration,
499 tool_result.error.as_deref(),
500 );
501 }
502 Err(err) => {
503 Self::log_tool_execution(
504 logger,
505 tool_name,
506 ¶ms,
507 session,
508 false,
509 duration,
510 Some(&err.to_string()),
511 );
512 }
513 }
514 }
515
516 result.map_err(|e| e.to_string())
517 }
518
519 pub fn create_permission_callback(
534 request_id: String,
535 _confirmation_tx: tokio::sync::mpsc::Sender<(
536 String,
537 crate::permission::PermissionConfirmation,
538 )>,
539 ) -> crate::tools::PermissionRequestCallback {
540 Box::new(move |tool_name: String, message: String| {
541 let req_id = request_id.clone();
542 Box::pin(async move {
543 tracing::info!(
545 tool_name = %tool_name,
546 message = %message,
547 request_id = %req_id,
548 "Permission request for tool execution"
549 );
550
551 false
555 })
556 })
557 }
558
559 pub fn log_permission_check(
566 audit_logger: &AuditLogger,
567 tool_name: &str,
568 params: &serde_json::Value,
569 session: &Session,
570 allowed: bool,
571 reason: Option<&str>,
572 ) {
573 let level = if allowed {
574 AuditLogLevel::Debug
575 } else {
576 AuditLogLevel::Warn
577 };
578
579 let perm_context = Self::create_permission_context(session);
580 let params_map = Self::params_to_hashmap(params);
581
582 let mut entry = AuditLogEntry::new("permission_check", tool_name)
583 .with_level(level)
584 .with_parameters(params_map)
585 .with_context(perm_context)
586 .add_metadata("allowed", serde_json::json!(allowed));
587
588 if let Some(r) = reason {
589 entry = entry.add_metadata("reason", serde_json::json!(r));
590 }
591
592 audit_logger.log_permission_check(entry);
593 }
594}