1use crate::{AgentEvent, AgentToolResult};
3use anyhow::Result;
4use oxi_ai::{progress_callback, AssistantMessage, Message, ToolCall, ToolResultMessage};
5use std::pin::Pin;
6use std::sync::Arc;
7
8use super::config::{AfterToolCallHook, ToolExecutionMode};
9use super::helpers::{create_tool_result_message, should_terminate_batch, FinalizedToolCall};
10use crate::tools::ToolContext;
11
12pub(crate) struct ExecutedToolCallBatch {
13 pub messages: Vec<ToolResultMessage>,
14 pub terminate: bool,
15}
16
17enum FinalizedToolCallEntry {
18 Immediate(FinalizedToolCall),
19 Future(Pin<Box<dyn futures::Future<Output = FinalizedToolCall> + Send>>),
20}
21
22pub(crate) struct ExecutedToolCallOutcome {
23 pub result: AgentToolResult,
24 pub is_error: bool,
25}
26
27enum PreparedToolCallKind {
28 Immediate,
29 Prepared,
30}
31
32struct PreparedToolCallOutcome {
33 _kind: PreparedToolCallKind,
34 immediate_result: Option<AgentToolResult>,
35 is_error: bool,
36 tool: Option<Arc<dyn crate::tools::AgentTool>>,
37 tool_call: ToolCall,
38 args: serde_json::Value,
39}
40
41pub(crate) async fn execute_tool_calls(
42 loop_ref: &super::AgentLoop,
43 messages: &mut Vec<Message>,
44 assistant_message: &AssistantMessage,
45 tool_calls: Vec<ToolCall>,
46 emit: &super::EmitFn,
47 ctx: &ToolContext,
48) -> Result<ExecutedToolCallBatch> {
49 if loop_ref.config.tool_execution == ToolExecutionMode::Sequential {
50 execute_tool_calls_sequential(loop_ref, messages, assistant_message, tool_calls, emit, ctx)
51 .await
52 } else {
53 execute_tool_calls_parallel(loop_ref, messages, assistant_message, tool_calls, emit, ctx)
54 .await
55 }
56}
57
58async fn execute_tool_calls_sequential(
59 loop_ref: &super::AgentLoop,
60 _messages: &mut Vec<Message>,
61 _assistant_message: &AssistantMessage,
62 tool_calls: Vec<ToolCall>,
63 emit: &super::EmitFn,
64 ctx: &ToolContext,
65) -> Result<ExecutedToolCallBatch> {
66 let mut finalized_calls = Vec::new();
67 let mut tool_result_messages = Vec::new();
68
69 for tool_call in tool_calls {
70 let tc_id = tool_call.id.clone();
72 let tc_name = tool_call.name.clone();
73 let tc_args = tool_call.arguments.clone();
74
75 emit(AgentEvent::ToolExecutionStart {
76 tool_call_id: tc_id.clone(),
77 tool_name: tc_name.clone(),
78 args: tc_args,
79 });
80
81 let prepared = prepare_tool_call(loop_ref, &tool_call).await;
82
83 let finalized = if let Some(result) = prepared.immediate_result {
84 FinalizedToolCall {
85 tool_call,
86 result,
87 is_error: prepared.is_error,
88 }
89 } else {
90 let executed = execute_prepared_tool_call(loop_ref, &prepared, emit, ctx).await;
91
92 let mut result = executed.result;
93 let mut is_error = executed.is_error;
94
95 if let Some(ref hook) = loop_ref.after_tool_call {
96 if let Some(modified) = hook(&tc_name, &result).await.ok().flatten() {
97 result = modified;
98 is_error = !result.success;
99 }
100 }
101
102 FinalizedToolCall {
103 tool_call,
104 result,
105 is_error,
106 }
107 };
108
109 emit(AgentEvent::ToolExecutionEnd {
110 tool_call_id: finalized.tool_call.id.clone(),
111 tool_name: finalized.tool_call.name.clone(),
112 result: oxi_ai::ToolResult {
113 tool_call_id: finalized.tool_call.id.clone(),
114 content: finalized.result.output.clone(),
115 status: if finalized.is_error {
116 String::from("error")
117 } else {
118 String::from("success")
119 },
120 },
121 is_error: finalized.is_error,
122 });
123
124 let tool_result_message = create_tool_result_message(&finalized);
125 let msg = Message::ToolResult(tool_result_message.clone());
126 emit(AgentEvent::MessageStart {
127 message: msg.clone(),
128 });
129 emit(AgentEvent::MessageEnd { message: msg });
130
131 finalized_calls.push(finalized);
132 tool_result_messages.push(tool_result_message);
133 }
134
135 Ok(ExecutedToolCallBatch {
136 messages: tool_result_messages,
137 terminate: should_terminate_batch(&finalized_calls),
138 })
139}
140
141async fn execute_tool_calls_parallel(
142 loop_ref: &super::AgentLoop,
143 _messages: &mut Vec<Message>,
144 _assistant_message: &AssistantMessage,
145 tool_calls: Vec<ToolCall>,
146 emit: &super::EmitFn,
147 ctx: &ToolContext,
148) -> Result<ExecutedToolCallBatch> {
149 let mut finalized_calls: Vec<FinalizedToolCallEntry> = Vec::new();
150
151 for tool_call in tool_calls {
152 let tc_id = tool_call.id.clone();
154 let tc_name = tool_call.name.clone();
155 let tc_args = tool_call.arguments.clone();
156
157 emit(AgentEvent::ToolExecutionStart {
158 tool_call_id: tc_id.clone(),
159 tool_name: tc_name.clone(),
160 args: tc_args,
161 });
162
163 let prepared = prepare_tool_call(loop_ref, &tool_call).await;
164
165 if let Some(result) = prepared.immediate_result {
166 let finalized = FinalizedToolCall {
167 tool_call,
168 result,
169 is_error: prepared.is_error,
170 };
171
172 emit(AgentEvent::ToolExecutionEnd {
173 tool_call_id: finalized.tool_call.id.clone(),
174 tool_name: finalized.tool_call.name.clone(),
175 result: oxi_ai::ToolResult {
176 tool_call_id: finalized.tool_call.id.clone(),
177 content: finalized.result.output.clone(),
178 status: if finalized.is_error {
179 String::from("error")
180 } else {
181 String::from("success")
182 },
183 },
184 is_error: finalized.is_error,
185 });
186
187 finalized_calls.push(FinalizedToolCallEntry::Immediate(finalized));
188 } else {
189 let tool = prepared.tool.clone();
190 let args = prepared.args.clone();
191 let after_hook = loop_ref.after_tool_call.clone();
192 let emit_clone = emit.clone();
193 let ctx_clone = ctx.clone();
194
195 finalized_calls.push(FinalizedToolCallEntry::Future(Box::pin(async move {
196 let executed = execute_prepared_tool_call_static(
197 tool_call.clone(),
198 tool,
199 args,
200 after_hook.clone(),
201 emit_clone.clone(),
202 &ctx_clone,
203 )
204 .await;
205
206 FinalizedToolCall {
207 tool_call,
208 result: executed.result,
209 is_error: executed.is_error,
210 }
211 })));
212 }
213 }
214
215 let mut slots: Vec<Option<FinalizedToolCall>> = Vec::with_capacity(finalized_calls.len());
216 #[allow(clippy::type_complexity)]
217 let mut pending_futures: Vec<(
218 usize,
219 Pin<Box<dyn futures::Future<Output = FinalizedToolCall> + Send>>,
220 )> = Vec::new();
221
222 for (i, entry) in finalized_calls.into_iter().enumerate() {
223 match entry {
224 FinalizedToolCallEntry::Immediate(f) => slots.push(Some(f)),
225 FinalizedToolCallEntry::Future(f) => {
226 slots.push(None);
227 pending_futures.push((i, f));
228 }
229 }
230 }
231
232 if !pending_futures.is_empty() {
233 let indexed_results: Vec<(usize, FinalizedToolCall)> = futures::future::join_all(
234 pending_futures
235 .into_iter()
236 .map(|(i, f)| async move { (i, f.await) }),
237 )
238 .await;
239
240 for (idx, finalized) in indexed_results {
241 slots[idx] = Some(finalized);
242 }
243 }
244
245 let ordered_finalized_calls: Vec<FinalizedToolCall> = slots
246 .into_iter()
247 .map(|s| s.expect("all slots should be filled after join_all"))
248 .collect();
249
250 let mut tool_result_messages = Vec::new();
251 for finalized in &ordered_finalized_calls {
252 let tool_result_message = create_tool_result_message(finalized);
253 let msg = Message::ToolResult(tool_result_message.clone());
254 emit(AgentEvent::MessageStart {
255 message: msg.clone(),
256 });
257 emit(AgentEvent::MessageEnd { message: msg });
258 tool_result_messages.push(tool_result_message);
259 }
260
261 Ok(ExecutedToolCallBatch {
262 messages: tool_result_messages,
263 terminate: should_terminate_batch(&ordered_finalized_calls),
264 })
265}
266
267pub(crate) async fn execute_prepared_tool_call_static(
268 tool_call: ToolCall,
269 tool: Option<Arc<dyn crate::tools::AgentTool>>,
270 args: serde_json::Value,
271 after_hook: Option<AfterToolCallHook>,
272 emit: Arc<dyn Fn(AgentEvent) + Send + Sync>,
273 ctx: &ToolContext,
274) -> ExecutedToolCallOutcome {
275 let tool_call_id = tool_call.id.clone();
276 let tool_name = tool_call.name.clone();
277
278 let mut result = AgentToolResult::success("");
279 let mut is_error = false;
280
281 if let Some(ref tool) = tool {
282 match tool.execute(&tool_call_id, args, None, ctx).await {
283 Ok(r) => result = r,
284 Err(e) => {
285 result = AgentToolResult::error(e);
286 is_error = true;
287 }
288 }
289 }
290
291 if let Some(ref hook) = after_hook {
292 if let Some(modified) = hook(&tool_call.name, &result).await.ok().flatten() {
293 result = modified;
294 is_error = !result.success;
295 }
296 }
297
298 emit(AgentEvent::ToolExecutionEnd {
299 tool_call_id: tool_call_id.clone(),
300 tool_name: tool_name.clone(),
301 result: oxi_ai::ToolResult {
302 tool_call_id,
303 content: result.output.clone(),
304 status: if is_error {
305 String::from("error")
306 } else {
307 String::from("success")
308 },
309 },
310 is_error,
311 });
312
313 ExecutedToolCallOutcome { result, is_error }
314}
315
316async fn prepare_tool_call(
317 loop_ref: &super::AgentLoop,
318 tool_call: &ToolCall,
319) -> PreparedToolCallOutcome {
320 let tool = match loop_ref.tools.get(&tool_call.name) {
321 Some(t) => t,
322 None => {
323 return PreparedToolCallOutcome {
324 _kind: PreparedToolCallKind::Immediate,
325 immediate_result: Some(AgentToolResult::error(format!(
326 "Tool '{}' not found",
327 tool_call.name
328 ))),
329 is_error: true,
330 tool: None,
331 tool_call: tool_call.clone(),
332 args: tool_call.arguments.clone(),
333 };
334 }
335 };
336
337 let validated_args = tool_call.arguments.clone();
338
339 if let Some(ref hook) = loop_ref.before_tool_call {
340 if let Some(blocked) = hook(&tool_call.name, &validated_args).await.ok().flatten() {
341 return PreparedToolCallOutcome {
342 _kind: PreparedToolCallKind::Immediate,
343 immediate_result: Some(blocked),
344 is_error: true,
345 tool: None,
346 tool_call: tool_call.clone(),
347 args: validated_args,
348 };
349 }
350 }
351
352 PreparedToolCallOutcome {
353 _kind: PreparedToolCallKind::Prepared,
354 immediate_result: None,
355 is_error: false,
356 tool: Some(Arc::clone(&tool)),
357 tool_call: tool_call.clone(),
358 args: validated_args,
359 }
360}
361
362async fn execute_prepared_tool_call(
363 _loop_ref: &super::AgentLoop,
364 prepared: &PreparedToolCallOutcome,
365 emit: &super::EmitFn,
366 ctx: &ToolContext,
367) -> ExecutedToolCallOutcome {
368 let tool_call_id = prepared.tool_call.id.clone();
369 let tool_name = prepared.tool_call.name.clone();
370
371 let mut result = AgentToolResult::success("");
372 let mut is_error = false;
373
374 if let Some(ref tool) = prepared.tool {
375 let tool_call_id_clone = tool_call_id.clone();
376 let emit_clone = emit.clone();
377
378 let progress_cb: Arc<dyn Fn(String) + Send + Sync> = Arc::new(move |msg: String| {
379 emit_clone(AgentEvent::ToolExecutionUpdate {
380 tool_call_id: tool_call_id_clone.clone(),
381 tool_name: tool_name.clone(),
382 partial_result: msg,
383 });
384 });
385
386 tool.on_progress(progress_callback(move |msg: String| {
388 progress_cb(msg);
389 }));
390
391 match tool
392 .execute(&tool_call_id, prepared.args.clone(), None, ctx)
393 .await
394 {
395 Ok(r) => result = r,
396 Err(e) => {
397 result = AgentToolResult::error(e);
398 is_error = true;
399 }
400 }
401 }
402
403 ExecutedToolCallOutcome { result, is_error }
404}