Skip to main content

agent_code_lib/tools/
executor.rs

1//! Tool executor: manages concurrent and serial tool execution.
2//!
3//! The executor partitions tool calls into batches:
4//! - Read-only (concurrency-safe) tools run in parallel
5//! - Mutation tools run serially
6//!
7//! This mirrors the streaming tool executor pattern where tools
8//! begin execution as soon as their input is fully parsed from
9//! the stream, maximizing throughput.
10
11use std::sync::Arc;
12
13use crate::llm::message::ContentBlock;
14use crate::permissions::{PermissionChecker, PermissionDecision};
15
16use super::{Tool, ToolContext, ToolResult};
17
18/// A pending tool call extracted from the model's response.
19#[derive(Debug, Clone)]
20pub struct PendingToolCall {
21    pub id: String,
22    pub name: String,
23    pub input: serde_json::Value,
24}
25
26/// Result of executing a tool call.
27#[derive(Debug)]
28pub struct ToolCallResult {
29    pub tool_use_id: String,
30    pub tool_name: String,
31    pub result: ToolResult,
32}
33
34impl ToolCallResult {
35    /// Convert to a content block for sending back to the API.
36    pub fn to_content_block(&self) -> ContentBlock {
37        ContentBlock::ToolResult {
38            tool_use_id: self.tool_use_id.clone(),
39            content: self.result.content.clone(),
40            is_error: self.result.is_error,
41            extra_content: vec![],
42        }
43    }
44}
45
46/// Extract pending tool calls from assistant content blocks.
47pub fn extract_tool_calls(content: &[ContentBlock]) -> Vec<PendingToolCall> {
48    content
49        .iter()
50        .filter_map(|block| {
51            if let ContentBlock::ToolUse { id, name, input } = block {
52                Some(PendingToolCall {
53                    id: id.clone(),
54                    name: name.clone(),
55                    input: input.clone(),
56                })
57            } else {
58                None
59            }
60        })
61        .collect()
62}
63
64/// Execute a batch of tool calls, respecting concurrency constraints.
65///
66/// Tools that are concurrency-safe run in parallel. Other tools run
67/// serially. Results are returned in the same order as the input.
68pub async fn execute_tool_calls(
69    calls: &[PendingToolCall],
70    tools: &[Arc<dyn Tool>],
71    ctx: &ToolContext,
72    permission_checker: &PermissionChecker,
73) -> Vec<ToolCallResult> {
74    // Partition into concurrent and serial batches.
75    let mut results = Vec::with_capacity(calls.len());
76
77    // Group consecutive concurrency-safe calls together.
78    let mut i = 0;
79    while i < calls.len() {
80        let call = &calls[i];
81        let tool = tools.iter().find(|t| t.name() == call.name);
82
83        match tool {
84            None => {
85                results.push(ToolCallResult {
86                    tool_use_id: call.id.clone(),
87                    tool_name: call.name.clone(),
88                    result: ToolResult::error(format!("Tool '{}' not found", call.name)),
89                });
90                i += 1;
91            }
92            Some(tool) => {
93                if tool.is_concurrency_safe() {
94                    // Collect consecutive concurrency-safe calls.
95                    let batch_start = i;
96                    while i < calls.len() {
97                        let t = tools.iter().find(|t| t.name() == calls[i].name);
98                        if t.is_some_and(|t| t.is_concurrency_safe()) {
99                            i += 1;
100                        } else {
101                            break;
102                        }
103                    }
104
105                    // Execute batch concurrently.
106                    let batch = &calls[batch_start..i];
107                    let mut handles = Vec::new();
108
109                    for call in batch {
110                        let tool = tools
111                            .iter()
112                            .find(|t| t.name() == call.name)
113                            .unwrap()
114                            .clone();
115                        let call = call.clone();
116                        let ctx_cwd = ctx.cwd.clone();
117                        let ctx_cancel = ctx.cancel.clone();
118                        let ctx_verbose = ctx.verbose;
119                        let perm_checker = ctx.permission_checker.clone();
120
121                        let ctx_plan_mode = ctx.plan_mode;
122                        let ctx_file_cache = ctx.file_cache.clone();
123                        // Read-only tools still go through permission checks.
124                        handles.push(tokio::spawn(async move {
125                            execute_single_tool(
126                                &call,
127                                &*tool,
128                                &ToolContext {
129                                    cwd: ctx_cwd,
130                                    cancel: ctx_cancel,
131                                    permission_checker: perm_checker.clone(),
132                                    verbose: ctx_verbose,
133                                    plan_mode: ctx_plan_mode,
134                                    file_cache: ctx_file_cache,
135                                    denial_tracker: None,
136                                    task_manager: None,
137                                    session_allows: None,
138                                    permission_prompter: None,
139                                    // Parallel branch only runs read-only, concurrency-safe
140                                    // tools — none of them spawn subprocesses, so the
141                                    // sandbox would be inert here anyway.
142                                    sandbox: None,
143                                },
144                                &perm_checker,
145                            )
146                            .await
147                        }));
148                    }
149
150                    for handle in handles {
151                        match handle.await {
152                            Ok(result) => results.push(result),
153                            Err(e) => {
154                                results.push(ToolCallResult {
155                                    tool_use_id: String::new(),
156                                    tool_name: String::new(),
157                                    result: ToolResult::error(format!("Task join error: {e}")),
158                                });
159                            }
160                        }
161                    }
162                } else {
163                    // Execute serially.
164                    let result = execute_single_tool(call, &**tool, ctx, permission_checker).await;
165                    results.push(result);
166                    i += 1;
167                }
168            }
169        }
170    }
171
172    results
173}
174
175/// Execute a single tool call with permission checking.
176async fn execute_single_tool(
177    call: &PendingToolCall,
178    tool: &dyn Tool,
179    ctx: &ToolContext,
180    permission_checker: &PermissionChecker,
181) -> ToolCallResult {
182    // Block non-read-only tools in plan mode.
183    if ctx.plan_mode && !tool.is_read_only() {
184        return ToolCallResult {
185            tool_use_id: call.id.clone(),
186            tool_name: call.name.clone(),
187            result: ToolResult::error(
188                "Plan mode active: only read-only tools are allowed. \
189                 Use ExitPlanMode to enable mutations."
190                    .to_string(),
191            ),
192        };
193    }
194
195    // Check permissions.
196    let decision = tool
197        .check_permissions(&call.input, permission_checker)
198        .await;
199    match decision {
200        PermissionDecision::Allow => {}
201        PermissionDecision::Deny(reason) => {
202            if let Some(ref tracker) = ctx.denial_tracker {
203                tracker
204                    .lock()
205                    .await
206                    .record(&call.name, &call.id, &reason, &call.input);
207            }
208            return ToolCallResult {
209                tool_use_id: call.id.clone(),
210                tool_name: call.name.clone(),
211                result: ToolResult::error(format!("Permission denied: {reason}")),
212            };
213        }
214        PermissionDecision::Ask(prompt) => {
215            // Check session-level allows first (user previously chose "Allow for session").
216            if let Some(ref allows) = ctx.session_allows
217                && allows.lock().await.contains(call.name.as_str())
218            {
219                // Already allowed for this session — skip prompt.
220            } else {
221                // Prompt the user for permission via the prompter trait.
222                let description = format!("{}: {}", call.name, prompt);
223                let input_preview = serde_json::to_string_pretty(&call.input).ok();
224
225                let response = if let Some(ref prompter) = ctx.permission_prompter {
226                    prompter.ask(&call.name, &description, input_preview.as_deref())
227                } else {
228                    // No prompter = auto-allow (non-interactive mode).
229                    super::PermissionResponse::AllowOnce
230                };
231
232                match response {
233                    super::PermissionResponse::AllowOnce => {
234                        // Continue to execution.
235                    }
236                    super::PermissionResponse::AllowSession => {
237                        // Record session-level allow so future calls skip the prompt.
238                        if let Some(ref allows) = ctx.session_allows {
239                            allows.lock().await.insert(call.name.clone());
240                        }
241                    }
242                    super::PermissionResponse::Deny => {
243                        if let Some(ref tracker) = ctx.denial_tracker {
244                            tracker.lock().await.record(
245                                &call.name,
246                                &call.id,
247                                "user denied",
248                                &call.input,
249                            );
250                        }
251                        return ToolCallResult {
252                            tool_use_id: call.id.clone(),
253                            tool_name: call.name.clone(),
254                            result: ToolResult::error("Permission denied by user".to_string()),
255                        };
256                    }
257                }
258            } // close else block
259        }
260    }
261
262    // Validate input.
263    if let Err(msg) = tool.validate_input(&call.input) {
264        return ToolCallResult {
265            tool_use_id: call.id.clone(),
266            tool_name: call.name.clone(),
267            result: ToolResult::error(format!("Invalid input: {msg}")),
268        };
269    }
270
271    // Execute.
272    match tool.call(call.input.clone(), ctx).await {
273        Ok(mut result) => {
274            // Persist large outputs to disk, replace with truncated + path reference.
275            result.content = crate::services::output_store::persist_if_large(
276                &result.content,
277                tool.name(),
278                &call.id,
279            );
280
281            // Additional truncation if still over the tool's limit.
282            let max = tool.max_result_size_chars();
283            if result.content.len() > max {
284                result.content.truncate(max);
285                result.content.push_str("\n\n(output truncated)");
286            }
287            ToolCallResult {
288                tool_use_id: call.id.clone(),
289                tool_name: call.name.clone(),
290                result,
291            }
292        }
293        Err(e) => ToolCallResult {
294            tool_use_id: call.id.clone(),
295            tool_name: call.name.clone(),
296            result: ToolResult::error(e.to_string()),
297        },
298    }
299}