use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, info_span};
use crate::tool::{AgentTool, AgentToolResult, validate_tool_arguments, validation_error_result};
use crate::tool_execution_policy::{ToolCallSummary, ToolExecutionPolicy};
use crate::types::{AgentMessage, ToolResultMessage};
use crate::util::now_timestamp;
use super::shared::{emit_error_result, emit_tool_execution_start, forward_tool_updates};
use super::{AgentEvent, AgentLoopConfig, PreparedToolCall, ToolCallInfo, emit};
pub(super) enum DispatchResult {
Spawned(tokio::task::JoinHandle<()>),
Inline,
ChannelClosed,
}
pub(super) async fn compute_execution_groups(
policy: &ToolExecutionPolicy,
tool_calls: &[ToolCallInfo],
prepared: &[PreparedToolCall],
) -> Result<Vec<Vec<usize>>, String> {
if prepared.is_empty() {
return Ok(vec![]);
}
match policy {
ToolExecutionPolicy::Concurrent => Ok(vec![(0..prepared.len()).collect()]),
ToolExecutionPolicy::Sequential => Ok((0..prepared.len()).map(|i| vec![i]).collect()),
ToolExecutionPolicy::Priority(priority_fn) => {
let mut scored: Vec<(usize, i32)> = prepared
.iter()
.enumerate()
.map(|(prep_idx, prep)| {
let tc = &tool_calls[prep.idx];
let summary = ToolCallSummary {
id: &tc.id,
name: &tc.name,
arguments: &prep.effective_arguments,
};
(prep_idx, priority_fn(&summary))
})
.collect();
scored.sort_by_key(|b| std::cmp::Reverse(b.1));
let mut groups: Vec<Vec<usize>> = Vec::new();
let mut current_priority = None;
for (prep_idx, priority) in scored {
if current_priority == Some(priority) {
if let Some(last) = groups.last_mut() {
last.push(prep_idx);
}
} else {
current_priority = Some(priority);
groups.push(vec![prep_idx]);
}
}
Ok(groups)
}
ToolExecutionPolicy::Custom(strategy) => {
let summaries: Vec<ToolCallSummary<'_>> = prepared
.iter()
.map(|prep| {
let tc = &tool_calls[prep.idx];
ToolCallSummary {
id: &tc.id,
name: &tc.name,
arguments: &prep.effective_arguments,
}
})
.collect();
let groups = strategy.partition(&summaries).await;
validate_custom_execution_groups(&groups, prepared.len())?;
Ok(groups)
}
}
}
fn validate_custom_execution_groups(
groups: &[Vec<usize>],
prepared_len: usize,
) -> Result<(), String> {
let mut seen: Vec<Option<(usize, usize)>> = vec![None; prepared_len];
for (group_idx, group) in groups.iter().enumerate() {
for (position, &prepared_idx) in group.iter().enumerate() {
if prepared_idx >= prepared_len {
return Err(format!(
"group {group_idx} position {position} referenced prepared index \
{prepared_idx}, but the prepared tool-call slice has length {prepared_len}"
));
}
if let Some((previous_group, previous_position)) = seen[prepared_idx] {
return Err(format!(
"group {group_idx} position {position} repeated prepared index \
{prepared_idx}, which was already assigned at group {previous_group} \
position {previous_position}"
));
}
seen[prepared_idx] = Some((group_idx, position));
}
}
let missing: Vec<String> = seen
.iter()
.enumerate()
.filter_map(|(prepared_idx, slot)| slot.is_none().then_some(prepared_idx.to_string()))
.collect();
if !missing.is_empty() {
return Err(format!(
"the custom strategy omitted prepared indices {}",
missing.join(", ")
));
}
Ok(())
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
pub(super) async fn dispatch_single_tool(
tool_map: &HashMap<&str, &Arc<dyn AgentTool>>,
config: &Arc<AgentLoopConfig>,
tc: &ToolCallInfo,
effective_arguments: &serde_json::Value,
idx: usize,
batch_token: &CancellationToken,
results: &Arc<tokio::sync::Mutex<Vec<(usize, ToolResultMessage)>>>,
tool_timings: &Arc<tokio::sync::Mutex<Vec<crate::metrics::ToolExecMetrics>>>,
_steering_messages: &Arc<tokio::sync::Mutex<Vec<AgentMessage>>>,
steering_flag: &Arc<std::sync::atomic::AtomicBool>,
transfer_flag: &Arc<std::sync::atomic::AtomicBool>,
transfer_signal: &Arc<tokio::sync::Mutex<Option<crate::transfer::TransferSignal>>>,
tx: &mpsc::Sender<AgentEvent>,
) -> DispatchResult {
let tool = tool_map.get(tc.name.as_str()).copied();
let tool_call_id = tc.id.clone();
let tool_name = tc.name.clone();
let arguments = effective_arguments.clone();
let Some(tool) = tool else {
let error_result = crate::tool::unknown_tool_result(&tool_name);
emit_error_result(&tool_name, &tool_call_id, error_result, idx, results, tx).await;
return DispatchResult::Inline;
};
let tool = Arc::clone(tool);
let child_token = batch_token.child_token();
let results_clone = Arc::clone(results);
let timings_clone = Arc::clone(tool_timings);
let steering_clone = Arc::clone(steering_flag);
let transfer_flag_clone = Arc::clone(transfer_flag);
let transfer_clone = Arc::clone(transfer_signal);
let config_clone = Arc::clone(config);
let tx_clone = tx.clone();
let validation = validate_tool_arguments(tool.parameters_schema(), &arguments);
if validation.is_ok()
&& !emit_tool_execution_start(&tool_call_id, &tool_name, &arguments, tx).await
{
return DispatchResult::ChannelClosed;
}
let tool_span = info_span!(
"agent.tool",
agent.tool.name = %tool_name,
tool_call_id = %tool_call_id,
);
let handle = tokio::spawn(
async move {
debug!(tool = %tool_name, id = %tool_call_id, "tool execution starting");
let exec_start = std::time::Instant::now();
let (result, is_error) = if let Err(errors) = validation {
(validation_error_result(&errors), true)
} else {
match resolve_credential(&tool, &config_clone, &tool_call_id).await {
Err(cred_error) => (AgentToolResult::error(format!("{cred_error}")), true),
Ok(credential) => {
let (update_tx, update_rx) = mpsc::unbounded_channel();
let updates_tx = tx_clone.clone();
let updates_tool_call_id = tool_call_id.clone();
let updates_tool_name = tool_name.clone();
let update_forwarder = tokio::spawn(async move {
forward_tool_updates(
&updates_tool_call_id,
&updates_tool_name,
update_rx,
&updates_tx,
)
.await;
});
let result = {
let on_update = Box::new(move |partial: AgentToolResult| {
let _ = update_tx.send(partial);
});
tool.execute(
&tool_call_id,
arguments,
child_token,
Some(on_update),
config_clone.session_state.clone(),
credential,
)
.await
};
let _ = update_forwarder.await;
let is_error = result.is_error;
(result, is_error)
}
}
};
let exec_duration = exec_start.elapsed();
debug!(tool = %tool_name, id = %tool_call_id, is_error, "tool execution finished");
let event_tool_name = tool_name.clone();
timings_clone
.lock()
.await
.push(crate::metrics::ToolExecMetrics {
tool_name,
duration: exec_duration,
success: !is_error,
});
if result.is_transfer() {
let mut guard = transfer_clone.lock().await;
if guard.is_none() {
(*guard).clone_from(&result.transfer_signal);
}
drop(guard);
transfer_flag_clone.store(true, std::sync::atomic::Ordering::SeqCst);
}
let _ = emit(
&tx_clone,
AgentEvent::ToolExecutionEnd {
id: tool_call_id.clone(),
name: event_tool_name,
result: result.clone(),
is_error,
},
)
.await;
let tool_result_msg = ToolResultMessage {
tool_call_id,
content: result.content,
is_error,
timestamp: now_timestamp(),
details: result.details,
cache_hint: None,
};
results_clone.lock().await.push((idx, tool_result_msg));
if let Some(ref provider) = config_clone.message_provider
&& provider.has_steering()
{
steering_clone.store(true, std::sync::atomic::Ordering::SeqCst);
}
}
.instrument(tool_span),
);
DispatchResult::Spawned(handle)
}
async fn resolve_credential(
tool: &Arc<dyn AgentTool>,
config: &Arc<AgentLoopConfig>,
_tool_call_id: &str,
) -> Result<Option<crate::credential::ResolvedCredential>, crate::credential::CredentialError> {
let Some(auth_config) = tool.auth_config() else {
return Ok(None);
};
let cred_resolver = config.credential_resolver.as_ref().ok_or_else(|| {
crate::credential::CredentialError::NotFound {
key: auth_config.credential_key.clone(),
}
})?;
let resolve_future = cred_resolver.resolve(&auth_config.credential_key);
let credential = tokio::time::timeout(std::time::Duration::from_secs(30), resolve_future)
.await
.map_err(|_| crate::credential::CredentialError::Timeout {
key: auth_config.credential_key.clone(),
})??;
let actual_type = match &credential {
crate::credential::ResolvedCredential::ApiKey(_) => {
crate::credential::CredentialType::ApiKey
}
crate::credential::ResolvedCredential::Bearer(_) => {
crate::credential::CredentialType::Bearer
}
crate::credential::ResolvedCredential::OAuth2AccessToken(_) => {
crate::credential::CredentialType::OAuth2
}
};
if actual_type != auth_config.credential_type {
return Err(crate::credential::CredentialError::TypeMismatch {
key: auth_config.credential_key,
expected: auth_config.credential_type,
actual: actual_type,
});
}
Ok(Some(credential))
}