Skip to main content

oxi_agent/agent_loop/
tool_exec.rs

1/// Tool execution logic for agent loop
2use crate::{AgentEvent, AgentToolResult};
3use anyhow::Result;
4use oxi_ai::{progress_callback, AssistantMessage, Message, ToolCall, ToolResultMessage};
5use std::pin::Pin;
6use std::sync::Arc;
7
8use super::config::{AfterToolCallHook, ToolExecutionMode};
9use super::helpers::{create_tool_result_message, should_terminate_batch, FinalizedToolCall};
10use crate::tools::ToolContext;
11
12pub(crate) struct ExecutedToolCallBatch {
13    pub messages: Vec<ToolResultMessage>,
14    pub terminate: bool,
15}
16
17enum FinalizedToolCallEntry {
18    Immediate(FinalizedToolCall),
19    Future(Pin<Box<dyn futures::Future<Output = FinalizedToolCall> + Send>>),
20}
21
22pub(crate) struct ExecutedToolCallOutcome {
23    pub result: AgentToolResult,
24    pub is_error: bool,
25}
26
27enum PreparedToolCallKind {
28    Immediate,
29    Prepared,
30}
31
32struct PreparedToolCallOutcome {
33    _kind: PreparedToolCallKind,
34    immediate_result: Option<AgentToolResult>,
35    is_error: bool,
36    tool: Option<Arc<dyn crate::tools::AgentTool>>,
37    tool_call: ToolCall,
38    args: serde_json::Value,
39}
40
41pub(crate) async fn execute_tool_calls(
42    loop_ref: &super::AgentLoop,
43    messages: &mut Vec<Message>,
44    assistant_message: &AssistantMessage,
45    tool_calls: Vec<ToolCall>,
46    emit: &super::EmitFn,
47    ctx: &ToolContext,
48) -> Result<ExecutedToolCallBatch> {
49    if loop_ref.config.tool_execution == ToolExecutionMode::Sequential {
50        execute_tool_calls_sequential(loop_ref, messages, assistant_message, tool_calls, emit, ctx)
51            .await
52    } else {
53        execute_tool_calls_parallel(loop_ref, messages, assistant_message, tool_calls, emit, ctx)
54            .await
55    }
56}
57
58async fn execute_tool_calls_sequential(
59    loop_ref: &super::AgentLoop,
60    _messages: &mut Vec<Message>,
61    _assistant_message: &AssistantMessage,
62    tool_calls: Vec<ToolCall>,
63    emit: &super::EmitFn,
64    ctx: &ToolContext,
65) -> Result<ExecutedToolCallBatch> {
66    let mut finalized_calls = Vec::new();
67    let mut tool_result_messages = Vec::new();
68
69    for tool_call in tool_calls {
70        // Clone tool_call fields once upfront to avoid repeated clones.
71        let tc_id = tool_call.id.clone();
72        let tc_name = tool_call.name.clone();
73        let tc_args = tool_call.arguments.clone();
74
75        emit(AgentEvent::ToolExecutionStart {
76            tool_call_id: tc_id.clone(),
77            tool_name: tc_name.clone(),
78            args: tc_args,
79        });
80
81        let prepared = prepare_tool_call(loop_ref, &tool_call).await;
82
83        let finalized = if let Some(result) = prepared.immediate_result {
84            FinalizedToolCall {
85                tool_call,
86                result,
87                is_error: prepared.is_error,
88            }
89        } else {
90            let executed = execute_prepared_tool_call(loop_ref, &prepared, emit, ctx).await;
91
92            let mut result = executed.result;
93            let mut is_error = executed.is_error;
94
95            if let Some(ref hook) = loop_ref.after_tool_call {
96                if let Some(modified) = hook(&tc_name, &result).await.ok().flatten() {
97                    result = modified;
98                    is_error = !result.success;
99                }
100            }
101
102            FinalizedToolCall {
103                tool_call,
104                result,
105                is_error,
106            }
107        };
108
109        emit(AgentEvent::ToolExecutionEnd {
110            tool_call_id: finalized.tool_call.id.clone(),
111            tool_name: finalized.tool_call.name.clone(),
112            result: oxi_ai::ToolResult {
113                tool_call_id: finalized.tool_call.id.clone(),
114                content: finalized.result.output.clone(),
115                status: if finalized.is_error {
116                    String::from("error")
117                } else {
118                    String::from("success")
119                },
120            },
121            is_error: finalized.is_error,
122        });
123
124        let tool_result_message = create_tool_result_message(&finalized);
125        let msg = Message::ToolResult(tool_result_message.clone());
126        emit(AgentEvent::MessageStart {
127            message: msg.clone(),
128        });
129        emit(AgentEvent::MessageEnd { message: msg });
130
131        finalized_calls.push(finalized);
132        tool_result_messages.push(tool_result_message);
133    }
134
135    Ok(ExecutedToolCallBatch {
136        messages: tool_result_messages,
137        terminate: should_terminate_batch(&finalized_calls),
138    })
139}
140
141async fn execute_tool_calls_parallel(
142    loop_ref: &super::AgentLoop,
143    _messages: &mut Vec<Message>,
144    _assistant_message: &AssistantMessage,
145    tool_calls: Vec<ToolCall>,
146    emit: &super::EmitFn,
147    ctx: &ToolContext,
148) -> Result<ExecutedToolCallBatch> {
149    let mut finalized_calls: Vec<FinalizedToolCallEntry> = Vec::new();
150
151    for tool_call in tool_calls {
152        // Clone tool_call fields once upfront to avoid repeated clones.
153        let tc_id = tool_call.id.clone();
154        let tc_name = tool_call.name.clone();
155        let tc_args = tool_call.arguments.clone();
156
157        emit(AgentEvent::ToolExecutionStart {
158            tool_call_id: tc_id.clone(),
159            tool_name: tc_name.clone(),
160            args: tc_args,
161        });
162
163        let prepared = prepare_tool_call(loop_ref, &tool_call).await;
164
165        if let Some(result) = prepared.immediate_result {
166            let finalized = FinalizedToolCall {
167                tool_call,
168                result,
169                is_error: prepared.is_error,
170            };
171
172            emit(AgentEvent::ToolExecutionEnd {
173                tool_call_id: finalized.tool_call.id.clone(),
174                tool_name: finalized.tool_call.name.clone(),
175                result: oxi_ai::ToolResult {
176                    tool_call_id: finalized.tool_call.id.clone(),
177                    content: finalized.result.output.clone(),
178                    status: if finalized.is_error {
179                        String::from("error")
180                    } else {
181                        String::from("success")
182                    },
183                },
184                is_error: finalized.is_error,
185            });
186
187            finalized_calls.push(FinalizedToolCallEntry::Immediate(finalized));
188        } else {
189            let tool = prepared.tool.clone();
190            let args = prepared.args.clone();
191            let after_hook = loop_ref.after_tool_call.clone();
192            let emit_clone = emit.clone();
193            let ctx_clone = ctx.clone();
194
195            finalized_calls.push(FinalizedToolCallEntry::Future(Box::pin(async move {
196                let executed = execute_prepared_tool_call_static(
197                    tool_call.clone(),
198                    tool,
199                    args,
200                    after_hook.clone(),
201                    emit_clone.clone(),
202                    &ctx_clone,
203                )
204                .await;
205
206                FinalizedToolCall {
207                    tool_call,
208                    result: executed.result,
209                    is_error: executed.is_error,
210                }
211            })));
212        }
213    }
214
215    let mut slots: Vec<Option<FinalizedToolCall>> = Vec::with_capacity(finalized_calls.len());
216    #[allow(clippy::type_complexity)]
217    let mut pending_futures: Vec<(
218        usize,
219        Pin<Box<dyn futures::Future<Output = FinalizedToolCall> + Send>>,
220    )> = Vec::new();
221
222    for (i, entry) in finalized_calls.into_iter().enumerate() {
223        match entry {
224            FinalizedToolCallEntry::Immediate(f) => slots.push(Some(f)),
225            FinalizedToolCallEntry::Future(f) => {
226                slots.push(None);
227                pending_futures.push((i, f));
228            }
229        }
230    }
231
232    if !pending_futures.is_empty() {
233        let indexed_results: Vec<(usize, FinalizedToolCall)> = futures::future::join_all(
234            pending_futures
235                .into_iter()
236                .map(|(i, f)| async move { (i, f.await) }),
237        )
238        .await;
239
240        for (idx, finalized) in indexed_results {
241            slots[idx] = Some(finalized);
242        }
243    }
244
245    let ordered_finalized_calls: Vec<FinalizedToolCall> = slots
246        .into_iter()
247        .map(|s| s.expect("all slots should be filled after join_all"))
248        .collect();
249
250    let mut tool_result_messages = Vec::new();
251    for finalized in &ordered_finalized_calls {
252        let tool_result_message = create_tool_result_message(finalized);
253        let msg = Message::ToolResult(tool_result_message.clone());
254        emit(AgentEvent::MessageStart {
255            message: msg.clone(),
256        });
257        emit(AgentEvent::MessageEnd { message: msg });
258        tool_result_messages.push(tool_result_message);
259    }
260
261    Ok(ExecutedToolCallBatch {
262        messages: tool_result_messages,
263        terminate: should_terminate_batch(&ordered_finalized_calls),
264    })
265}
266
267pub(crate) async fn execute_prepared_tool_call_static(
268    tool_call: ToolCall,
269    tool: Option<Arc<dyn crate::tools::AgentTool>>,
270    args: serde_json::Value,
271    after_hook: Option<AfterToolCallHook>,
272    emit: Arc<dyn Fn(AgentEvent) + Send + Sync>,
273    ctx: &ToolContext,
274) -> ExecutedToolCallOutcome {
275    let tool_call_id = tool_call.id.clone();
276    let tool_name = tool_call.name.clone();
277
278    let mut result = AgentToolResult::success("");
279    let mut is_error = false;
280
281    if let Some(ref tool) = tool {
282        match tool.execute(&tool_call_id, args, None, ctx).await {
283            Ok(r) => result = r,
284            Err(e) => {
285                result = AgentToolResult::error(e);
286                is_error = true;
287            }
288        }
289    }
290
291    if let Some(ref hook) = after_hook {
292        if let Some(modified) = hook(&tool_call.name, &result).await.ok().flatten() {
293            result = modified;
294            is_error = !result.success;
295        }
296    }
297
298    emit(AgentEvent::ToolExecutionEnd {
299        tool_call_id: tool_call_id.clone(),
300        tool_name: tool_name.clone(),
301        result: oxi_ai::ToolResult {
302            tool_call_id,
303            content: result.output.clone(),
304            status: if is_error {
305                String::from("error")
306            } else {
307                String::from("success")
308            },
309        },
310        is_error,
311    });
312
313    ExecutedToolCallOutcome { result, is_error }
314}
315
316async fn prepare_tool_call(
317    loop_ref: &super::AgentLoop,
318    tool_call: &ToolCall,
319) -> PreparedToolCallOutcome {
320    let tool = match loop_ref.tools.get(&tool_call.name) {
321        Some(t) => t,
322        None => {
323            return PreparedToolCallOutcome {
324                _kind: PreparedToolCallKind::Immediate,
325                immediate_result: Some(AgentToolResult::error(format!(
326                    "Tool '{}' not found",
327                    tool_call.name
328                ))),
329                is_error: true,
330                tool: None,
331                tool_call: tool_call.clone(),
332                args: tool_call.arguments.clone(),
333            };
334        }
335    };
336
337    let validated_args = tool_call.arguments.clone();
338
339    if let Some(ref hook) = loop_ref.before_tool_call {
340        if let Some(blocked) = hook(&tool_call.name, &validated_args).await.ok().flatten() {
341            return PreparedToolCallOutcome {
342                _kind: PreparedToolCallKind::Immediate,
343                immediate_result: Some(blocked),
344                is_error: true,
345                tool: None,
346                tool_call: tool_call.clone(),
347                args: validated_args,
348            };
349        }
350    }
351
352    PreparedToolCallOutcome {
353        _kind: PreparedToolCallKind::Prepared,
354        immediate_result: None,
355        is_error: false,
356        tool: Some(Arc::clone(&tool)),
357        tool_call: tool_call.clone(),
358        args: validated_args,
359    }
360}
361
362async fn execute_prepared_tool_call(
363    _loop_ref: &super::AgentLoop,
364    prepared: &PreparedToolCallOutcome,
365    emit: &super::EmitFn,
366    ctx: &ToolContext,
367) -> ExecutedToolCallOutcome {
368    let tool_call_id = prepared.tool_call.id.clone();
369    let tool_name = prepared.tool_call.name.clone();
370
371    let mut result = AgentToolResult::success("");
372    let mut is_error = false;
373
374    if let Some(ref tool) = prepared.tool {
375        let tool_call_id_clone = tool_call_id.clone();
376        let emit_clone = emit.clone();
377
378        let progress_cb: Arc<dyn Fn(String) + Send + Sync> = Arc::new(move |msg: String| {
379            emit_clone(AgentEvent::ToolExecutionUpdate {
380                tool_call_id: tool_call_id_clone.clone(),
381                tool_name: tool_name.clone(),
382                partial_result: msg,
383            });
384        });
385
386        // Wire up progress callback BEFORE execute — pi-mono: tool's onUpdate
387        tool.on_progress(progress_callback(move |msg: String| {
388            progress_cb(msg);
389        }));
390
391        match tool
392            .execute(&tool_call_id, prepared.args.clone(), None, ctx)
393            .await
394        {
395            Ok(r) => result = r,
396            Err(e) => {
397                result = AgentToolResult::error(e);
398                is_error = true;
399            }
400        }
401    }
402
403    ExecutedToolCallOutcome { result, is_error }
404}