use crate::AgentError;
use crate::constants::env::ai;
use crate::types::{
Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema, ToolResult,
};
use futures_util::stream::{self, StreamExt};
use serde::Serialize;
use crate::tool_errors::format_tool_error;
use crate::tool_result_storage::process_tool_result;
use crate::tool_validation::validate_tool_input;
pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
pub fn get_max_tool_use_concurrency() -> usize {
std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(MAX_TOOL_USE_CONCURRENCY)
}
#[derive(Debug, Clone)]
pub struct ToolBatch {
pub is_concurrency_safe: bool,
pub blocks: Vec<ToolCall>,
}
#[derive(Debug, Clone)]
pub struct ContextModifier {
pub tool_use_id: String,
pub modify_context: fn(crate::types::ToolContext) -> crate::types::ToolContext,
}
#[derive(Debug, Clone)]
pub struct ToolMessageUpdate {
pub message: Option<Message>,
pub new_context: Option<crate::types::ToolContext>,
pub context_modifier: Option<ContextModifier>,
}
pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
let mut batches: Vec<ToolBatch> = Vec::new();
for tool_use in tool_calls {
let tool = tools.iter().find(|t| t.name == tool_use.name);
let is_concurrency_safe = tool
.map(|t| t.is_concurrency_safe(&tool_use.arguments))
.unwrap_or(false);
if is_concurrency_safe {
if let Some(last) = batches.last_mut() {
if last.is_concurrency_safe {
last.blocks.push(tool_use.clone());
continue;
}
}
}
batches.push(ToolBatch {
is_concurrency_safe,
blocks: vec![tool_use.clone()],
});
}
batches
}
pub fn mark_tool_use_as_complete(
in_progress_ids: &mut std::collections::HashSet<String>,
tool_use_id: &str,
) {
in_progress_ids.remove(tool_use_id);
}
pub async fn run_tools_serially<F, Fut>(
tool_calls: Vec<ToolCall>,
tool_context: crate::types::ToolContext,
tools: Vec<ToolDefinition>,
mut executor: F,
project_dir: Option<String>,
session_id: Option<String>,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let mut updates = Vec::new();
let mut current_context = tool_context;
let mut in_progress_ids = std::collections::HashSet::new();
for tool_call in tool_calls {
let tool_name = tool_call.name.clone();
let tool_args = tool_call.arguments.clone();
let tool_call_id = tool_call.id.clone();
in_progress_ids.insert(tool_call_id.clone());
let tool_def = tools.iter().find(|t| t.name == tool_name);
let interrupt_behavior = tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
&& current_context.abort_signal.is_aborted()
{
let error_content =
"<tool_use_error>Tool execution aborted by user interrupt</tool_use_error>"
.to_string();
updates.push(ToolMessageUpdate {
message: Some(Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id.clone()),
is_error: Some(true),
..Default::default()
}),
new_context: Some(current_context.clone()),
context_modifier: None,
});
mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
continue;
}
if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
let error_content = format!(
"<tool_use_error>InputValidationError: {}</tool_use_error>",
validation_err
);
updates.push(ToolMessageUpdate {
message: Some(Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id.clone()),
is_error: Some(true),
..Default::default()
}),
new_context: Some(current_context.clone()),
context_modifier: None,
});
mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
continue;
}
match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
Ok(mut result) => {
let persisted = process_tool_result(
&result.content,
&tool_name,
&tool_call_id,
project_dir.as_deref(),
session_id.as_deref(),
None, );
result.content = persisted.0;
result.was_persisted = Some(persisted.1);
let message = Message {
role: MessageRole::Tool,
content: result.content,
tool_call_id: Some(tool_call_id.clone()),
is_error: result.is_error,
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: Some(current_context.clone()),
context_modifier: None,
});
}
Err(e) => {
let error_content = format!(
"<tool_use_error>Error: {}</tool_use_error>",
format_tool_error(&e)
);
let message = Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id.clone()),
is_error: Some(true),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: Some(current_context.clone()),
context_modifier: None,
});
}
}
mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
}
updates
}
pub async fn run_tools_concurrently<F, Fut>(
tool_calls: Vec<ToolCall>,
tool_context: crate::types::ToolContext,
tools: Vec<ToolDefinition>,
mut executor: F,
project_dir: Option<String>,
session_id: Option<String>,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let max_concurrency = get_max_tool_use_concurrency();
let mut updates = Vec::new();
let executions: Vec<_> = tool_calls
.into_iter()
.map(|tool_call| {
let mut exec = executor.clone();
let tool_name = tool_call.name.clone();
let tool_args = tool_call.arguments.clone();
let tool_call_id = tool_call.id.clone();
let tools = tools.clone();
let project_dir = project_dir.clone();
let session_id = session_id.clone();
let abort_signal = tool_context.abort_signal.clone();
async move {
let tool_def = tools.iter().find(|t| t.name == tool_name);
let interrupt_behavior =
tool_def.map(|t| t.interrupt_behavior()).unwrap_or_default();
if !matches!(interrupt_behavior, crate::tools::types::InterruptBehavior::Block)
&& abort_signal.is_aborted()
{
return (
tool_call_id,
Err(AgentError::Tool("Tool execution aborted by user interrupt".to_string())),
);
}
if let Err(validation_err) = validate_tool_input(&tool_name, &tool_args, &tools) {
let error_content = format!(
"<tool_use_error>InputValidationError: {}</tool_use_error>",
validation_err
);
return (
tool_call_id,
Err(AgentError::Tool(format!(
"InputValidationError: {}",
validation_err
))),
);
}
let result = exec(tool_name.clone(), tool_args, tool_call_id.clone()).await;
(tool_call_id, result)
}
})
.collect();
let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
while let Some((tool_call_id, result)) = stream.next().await {
match result {
Ok(tool_result) => {
let (content, _) = process_tool_result(
&tool_result.content,
"", &tool_call_id,
project_dir.as_deref(),
session_id.as_deref(),
None,
);
let message = Message {
role: MessageRole::Tool,
content,
tool_call_id: Some(tool_call_id),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: None,
context_modifier: None,
});
}
Err(e) => {
let error_content = format!(
"<tool_use_error>Error: {}</tool_use_error>",
format_tool_error(&e)
);
let message = Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id),
is_error: Some(true),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: None,
context_modifier: None,
});
}
}
}
updates
}
pub async fn run_tools<F, Fut>(
tool_calls: Vec<ToolCall>,
tools: Vec<ToolDefinition>,
tool_context: crate::types::ToolContext,
executor: F,
project_dir: Option<String>,
session_id: Option<String>,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let batches = partition_tool_calls(&tool_calls, &tools);
let mut all_updates = Vec::new();
let mut current_context = tool_context;
for batch in batches {
let tools_clone = tools.clone();
let project_dir_clone = project_dir.clone();
let session_id_clone = session_id.clone();
if batch.is_concurrency_safe {
let updates = run_tools_concurrently(
batch.blocks,
current_context.clone(),
tools_clone,
executor.clone(),
project_dir_clone,
session_id_clone,
)
.await;
all_updates.extend(updates);
} else {
let updates = run_tools_serially(
batch.blocks,
current_context.clone(),
tools_clone,
executor.clone(),
project_dir_clone,
session_id_clone,
)
.await;
if let Some(last_update) = updates.last() {
if let Some(ctx) = &last_update.new_context {
current_context = ctx.clone();
}
}
all_updates.extend(updates);
}
}
all_updates
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ToolInputSchema;
fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("Test tool {}", name),
input_schema: ToolInputSchema {
schema_type: "object".to_string(),
properties: serde_json::json!({}),
required: None,
},
annotations: if concurrency_safe {
Some(ToolAnnotations {
concurrency_safe: Some(true),
..Default::default()
})
} else {
None
},
should_defer: None,
always_load: None,
is_mcp: None,
search_hint: None,
aliases: None,
user_facing_name: None,
interrupt_behavior: None,
}
}
#[test]
fn test_get_max_tool_use_concurrency_default() {
assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
}
#[test]
fn test_get_max_tool_use_concurrency_value() {
let result = get_max_tool_use_concurrency();
assert!(result > 0);
}
#[test]
fn test_partition_tool_calls_all_non_safe() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
r#type: "function".to_string(),
name: "Edit".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Bash", false),
create_test_tool("Edit", false),
];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 2);
assert!(!batches[0].is_concurrency_safe);
assert!(!batches[1].is_concurrency_safe);
}
#[test]
fn test_partition_tool_calls_mixed() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "Read".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
r#type: "function".to_string(),
name: "Glob".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "3".to_string(),
r#type: "function".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "4".to_string(),
r#type: "function".to_string(),
name: "Grep".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Read", true),
create_test_tool("Glob", true),
create_test_tool("Bash", false),
create_test_tool("Grep", true),
];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 3);
assert!(batches[0].is_concurrency_safe);
assert_eq!(batches[0].blocks.len(), 2);
assert!(!batches[1].is_concurrency_safe);
assert!(batches[2].is_concurrency_safe);
}
#[test]
fn test_partition_tool_calls_with_unknown_tool() {
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "UnknownTool".to_string(),
arguments: serde_json::json!({}),
}];
let tools = vec![];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 1);
assert!(!batches[0].is_concurrency_safe);
}
#[tokio::test]
async fn test_run_tools_serially() {
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
let tool_context = crate::types::ToolContext::default();
let tools = vec![create_test_tool("test", false)];
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
assert!(updates[0].message.is_some());
}
#[tokio::test]
async fn test_run_tools_concurrently() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "test1".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
r#type: "function".to_string(),
name: "test2".to_string(),
arguments: serde_json::json!({}),
},
];
let tool_context = crate::types::ToolContext::default();
let tools = vec![
create_test_tool("test1", true),
create_test_tool("test2", true),
];
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 2);
}
#[tokio::test]
async fn test_run_tools_with_partitioning() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "Read".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
r#type: "function".to_string(),
name: "Glob".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "3".to_string(),
r#type: "function".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Read", true),
create_test_tool("Glob", true),
create_test_tool("Bash", false),
];
let tool_context = crate::types::ToolContext::default();
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates = run_tools(tool_calls, tools, tool_context, executor, None, None).await;
assert_eq!(updates.len(), 3);
}
#[test]
fn test_mark_tool_use_as_complete() {
let mut in_progress = std::collections::HashSet::new();
in_progress.insert("tool1".to_string());
in_progress.insert("tool2".to_string());
mark_tool_use_as_complete(&mut in_progress, "tool1");
assert!(!in_progress.contains("tool1"));
assert!(in_progress.contains("tool2"));
}
#[tokio::test]
async fn test_run_tools_serially_aborted() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None); let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools = vec![create_tool_with_interrupt("test", Some("cancel".into()))];
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "should not reach".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.is_error == Some(true));
assert!(msg.content.contains("aborted"));
}
#[tokio::test]
async fn test_run_tools_concurrently_aborted() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "Read".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None); let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools = vec![create_tool_with_interrupt("Read", Some("cancel".into()))];
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "should not reach".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates = run_tools_concurrently(
tool_calls, tool_context, tools, executor, None, None,
)
.await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.is_error == Some(true));
}
fn create_tool_with_interrupt(
name: &str,
interrupt: Option<String>,
) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("Test tool {}", name),
input_schema: ToolInputSchema {
schema_type: "object".to_string(),
properties: serde_json::json!({}),
required: None,
},
annotations: None,
should_defer: None,
always_load: None,
is_mcp: None,
search_hint: None,
aliases: None,
user_facing_name: None,
interrupt_behavior: interrupt,
}
}
#[tokio::test]
async fn test_interrupt_cancel_tool_aborted() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "CancelTool".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None);
let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools = vec![create_tool_with_interrupt("CancelTool", Some("cancel".into()))];
let executor = |_name: String, _args: serde_json::Value, _id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "should not reach".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.is_error == Some(true));
assert!(msg.content.contains("aborted"));
}
#[tokio::test]
async fn test_interrupt_block_tool_ignores_abort() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "BlockTool".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None); let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools = vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
let executor = |_name: String, _args: serde_json::Value, _id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "block tool completed".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.is_error != Some(true));
assert!(msg.content.contains("block tool completed"));
}
#[tokio::test]
async fn test_interrupt_default_treated_as_block() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "DefaultTool".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None);
let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools = vec![create_tool_with_interrupt("DefaultTool", None)];
let executor = |_name: String, _args: serde_json::Value, _id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "default completed".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_serially(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.is_error != Some(true));
}
#[tokio::test]
async fn test_interrupt_concurrently_block_ignores_abort() {
use crate::utils::abort_controller::create_abort_controller_default;
let tool_calls = vec![ToolCall {
id: "1".to_string(),
r#type: "function".to_string(),
name: "BlockTool".to_string(),
arguments: serde_json::json!({}),
}];
let controller = create_abort_controller_default();
controller.abort(None);
let abort_signal = controller.signal().clone();
let tool_context = crate::types::ToolContext {
cwd: "/tmp".to_string(),
abort_signal,
};
let tools =
vec![create_tool_with_interrupt("BlockTool", Some("block".into()))];
let executor = |_name: String, _args: serde_json::Value, _id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "concurrent block done".to_string(),
is_error: Some(false),
was_persisted: Some(false),
})
};
let updates =
run_tools_concurrently(tool_calls, tool_context, tools, executor, None, None).await;
assert_eq!(updates.len(), 1);
let msg = updates[0].message.as_ref().unwrap();
assert!(msg.content.contains("concurrent block done"));
}
}