use super::config::*;
use crate::types::*;
use std::sync::Arc;
use tokio::sync::mpsc;
pub(super) struct ToolExecutionResult {
pub(super) tool_results: Vec<Message>,
pub(super) steering_messages: Option<Vec<AgentMessage>>,
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_tool_calls(
tools: &[Arc<dyn AgentTool>], tool_calls: &[(String, String, serde_json::Value)], tx: &mpsc::UnboundedSender<AgentEvent>, cancel: &tokio_util::sync::CancellationToken, get_steering: Option<&GetMessagesFn>, strategy: &ToolExecutionStrategy, config: &AgentLoopConfig,
loop_id: &str,
) -> ToolExecutionResult {
match strategy {
ToolExecutionStrategy::Sequential => {
execute_sequential(tools, tool_calls, tx, cancel, get_steering, config, loop_id).await
}
ToolExecutionStrategy::Parallel => {
execute_batch(tools, tool_calls, tx, cancel, get_steering, config, loop_id).await
}
ToolExecutionStrategy::Batched { size } => {
let mut results: Vec<Message> = Vec::new();
let mut steering_messages: Option<Vec<AgentMessage>> = None;
for (batch_idx, batch) in tool_calls.chunks(*size).enumerate() {
let batch_result =
execute_batch(tools, batch, tx, cancel, None, config, loop_id).await;
results.extend(batch_result.tool_results);
if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if !steering.is_empty() {
steering_messages = Some(steering);
let executed = (batch_idx + 1) * *size;
if executed < tool_calls.len() {
for (skip_id, skip_name, _) in &tool_calls[executed..] {
results.push(skip_tool_call(skip_id, skip_name, tx, loop_id));
}
}
break;
}
}
}
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
}
}
pub(super) async fn execute_sequential(
tools: &[Arc<dyn AgentTool>], tool_calls: &[(String, String, serde_json::Value)], tx: &mpsc::UnboundedSender<AgentEvent>, cancel: &tokio_util::sync::CancellationToken, get_steering: Option<&GetMessagesFn>, config: &AgentLoopConfig, loop_id: &str,
) -> ToolExecutionResult {
let mut results: Vec<Message> = Vec::new();
let mut steering_messages: Option<Vec<AgentMessage>> = None;
for (index, (id, name, args)) in tool_calls.iter().enumerate() {
let (result_msg, _is_error) =
execute_single_tool(tools, id, name, args, tx, cancel, config, loop_id).await;
results.push(result_msg);
if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if !steering.is_empty() {
steering_messages = Some(steering);
for (skip_id, skip_name, _) in &tool_calls[index + 1..] {
results.push(skip_tool_call(skip_id, skip_name, tx, loop_id));
}
break;
}
}
}
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
pub(super) async fn execute_batch(
tools: &[Arc<dyn AgentTool>], tool_calls: &[(String, String, serde_json::Value)], tx: &mpsc::UnboundedSender<AgentEvent>, cancel: &tokio_util::sync::CancellationToken, get_steering: Option<&GetMessagesFn>, config: &AgentLoopConfig, loop_id: &str,
) -> ToolExecutionResult {
use futures::future::join_all;
let futures: Vec<_> = tool_calls
.iter()
.map(|(id, name, args)| {
execute_single_tool(tools, id, name, args, tx, cancel, config, loop_id)
})
.collect();
let batch_results = join_all(futures).await;
let results: Vec<Message> = batch_results.into_iter().map(|(msg, _)| msg).collect();
let steering_messages = if let Some(get_steering_fn) = get_steering {
let steering = get_steering_fn();
if steering.is_empty() {
None
} else {
Some(steering)
}
} else {
None
};
ToolExecutionResult {
tool_results: results,
steering_messages,
}
}
#[allow(clippy::too_many_arguments)]
pub(super) async fn execute_single_tool(
tools: &[Arc<dyn AgentTool>], id: &str, name: &str, args: &serde_json::Value, tx: &mpsc::UnboundedSender<AgentEvent>, cancel: &tokio_util::sync::CancellationToken, config: &AgentLoopConfig, loop_id: &str,
) -> (Message, bool) {
let tool = tools.iter().find(|t| t.name() == name);
if let Some(ref hook) = config.before_tool_execution {
if !hook(name, id, args).await {
let skipped_result = ToolResult {
content: vec![Content::Text {
text: "Tool execution skipped by before_tool_execution hook.".to_string(),
}],
details: serde_json::Value::Null,
child_loop_id: None,
};
let tool_result_msg = Message::ToolResult {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
content: skipped_result.content,
is_error: true,
timestamp: now_ms(),
};
tx.send(AgentEvent::MessageStart {
loop_id: loop_id.to_string(),
message: tool_result_msg.clone().into(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
loop_id: loop_id.to_string(),
message: tool_result_msg.clone().into(),
})
.ok();
return (tool_result_msg, true);
}
}
tx.send(AgentEvent::ToolExecutionStart {
loop_id: loop_id.to_string(),
tool_call_id: id.to_string(),
tool_name: name.to_string(),
args: args.clone(),
})
.ok();
let on_update: Option<ToolUpdateFn> = {
let tx = tx.clone();
let id = id.to_string();
let name = name.to_string();
let loop_id_owned = loop_id.to_string();
let before_update = config.before_tool_execution_update.clone();
let after_update = config.after_tool_execution_update.clone();
Some(Arc::new(move |partial: ToolResult| {
let content_str: String = partial
.content
.iter()
.filter_map(|c| {
if let Content::Text { text } = c {
Some(text.as_str())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n");
let emit = before_update.as_ref().map_or(true, |h| {
futures::executor::block_on(h(&name, &id, &content_str))
});
if emit {
tx.send(AgentEvent::ToolExecutionUpdate {
loop_id: loop_id_owned.clone(),
tool_call_id: id.clone(),
tool_name: name.clone(),
partial_result: partial,
})
.ok();
if let Some(ref hook) = after_update {
futures::executor::block_on(hook(&name, &id, &content_str));
}
}
}))
};
let on_progress: Option<ProgressFn> = {
let tx = tx.clone();
let id = id.to_string();
let name = name.to_string();
let loop_id_owned = loop_id.to_string();
Some(Arc::new(move |text: String| {
tx.send(AgentEvent::ProgressMessage {
loop_id: loop_id_owned.clone(),
tool_call_id: id.clone(),
tool_name: name.clone(),
text,
})
.ok();
}))
};
let ctx = ToolContext {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
cancel: cancel.child_token(),
on_update,
on_progress,
};
let (result, is_error) = match tool {
Some(tool) => {
let effective_timeout = tool.timeout().or(config.tool_timeout);
let tool_cancel = ctx.cancel.clone();
if let Some(ref slot) = config.current_tool {
if let Ok(mut guard) = slot.lock() {
*guard = Some(crate::context::CurrentToolExecution {
name: name.to_string(),
timeout: effective_timeout,
});
}
}
let exec_result = match effective_timeout {
None => tool.execute(args.clone(), ctx).await,
Some(d) => match tokio::time::timeout(d, tool.execute(args.clone(), ctx)).await {
Ok(r) => r,
Err(_) => {
tool_cancel.cancel();
Err(ToolError::Timeout { duration: d })
}
},
};
if let Some(ref slot) = config.current_tool {
if let Ok(mut guard) = slot.lock() {
*guard = None;
}
}
match exec_result {
Ok(r) => (r, false),
Err(e) => (
ToolResult {
content: vec![Content::Text {
text: e.to_string(), }],
details: serde_json::Value::Null,
child_loop_id: None,
},
true,
),
}
}
None => (
ToolResult {
content: vec![Content::Text {
text: format!("Tool {} not found", name),
}],
details: serde_json::Value::Null,
child_loop_id: None,
},
true,
),
};
tx.send(AgentEvent::ToolExecutionEnd {
loop_id: loop_id.to_string(),
tool_call_id: id.to_string(),
tool_name: name.to_string(),
result: result.clone(),
is_error,
child_loop_id: result.child_loop_id.clone(), })
.ok();
if let Some(ref hook) = config.after_tool_execution {
hook(name, id, is_error).await;
}
let tool_result_msg = Message::ToolResult {
tool_call_id: id.to_string(),
tool_name: name.to_string(),
content: result.content,
is_error,
timestamp: now_ms(),
};
tx.send(AgentEvent::MessageStart {
loop_id: loop_id.to_string(),
message: tool_result_msg.clone().into(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
loop_id: loop_id.to_string(),
message: tool_result_msg.clone().into(),
})
.ok();
(tool_result_msg, is_error)
}
pub(super) fn skip_tool_call(
tool_call_id: &str, tool_name: &str, tx: &mpsc::UnboundedSender<AgentEvent>, loop_id: &str,
) -> Message {
let result = ToolResult {
content: vec![Content::Text {
text: "Skipped due to queued user message.".into(),
}],
details: serde_json::Value::Null,
child_loop_id: None,
};
tx.send(AgentEvent::ToolExecutionStart {
loop_id: loop_id.to_string(),
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
args: serde_json::Value::Null,
})
.ok();
tx.send(AgentEvent::ToolExecutionEnd {
loop_id: loop_id.to_string(),
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
result: result.clone(),
is_error: true,
child_loop_id: None,
})
.ok();
let msg = Message::ToolResult {
tool_call_id: tool_call_id.into(),
tool_name: tool_name.into(),
content: result.content,
is_error: true,
timestamp: now_ms(),
};
tx.send(AgentEvent::MessageStart {
loop_id: loop_id.to_string(),
message: msg.clone().into(),
})
.ok();
tx.send(AgentEvent::MessageEnd {
loop_id: loop_id.to_string(),
message: msg.clone().into(),
})
.ok();
msg
}