use crate::constants::env::ai;
use crate::types::{
Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema,
};
use crate::AgentError;
use futures_util::stream::{self, StreamExt};
pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
pub fn get_max_tool_use_concurrency() -> usize {
std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
.ok()
.and_then(|v| v.parse::<usize>().ok())
.unwrap_or(MAX_TOOL_USE_CONCURRENCY)
}
#[derive(Debug, Clone)]
pub struct ToolBatch {
pub is_concurrency_safe: bool,
pub blocks: Vec<ToolCall>,
}
#[derive(Debug, Clone)]
pub struct ToolMessageUpdate {
pub message: Option<Message>,
pub new_context: Option<crate::types::ToolContext>,
}
pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
let mut batches: Vec<ToolBatch> = Vec::new();
for tool_use in tool_calls {
let tool = tools.iter().find(|t| t.name == tool_use.name);
let is_concurrency_safe = tool
.map(|t| t.is_concurrency_safe(&tool_use.arguments))
.unwrap_or(false);
if is_concurrency_safe {
if let Some(last) = batches.last_mut() {
if last.is_concurrency_safe {
last.blocks.push(tool_use.clone());
continue;
}
}
}
batches.push(ToolBatch {
is_concurrency_safe,
blocks: vec![tool_use.clone()],
});
}
batches
}
pub fn mark_tool_use_as_complete(
in_progress_ids: &mut std::collections::HashSet<String>,
tool_use_id: &str,
) {
in_progress_ids.remove(tool_use_id);
}
pub async fn run_tools_serially<F, Fut>(
tool_calls: Vec<ToolCall>,
tool_context: crate::types::ToolContext,
mut executor: F,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let mut updates = Vec::new();
let mut current_context = tool_context;
let mut in_progress_ids = std::collections::HashSet::new();
for tool_call in tool_calls {
let tool_name = tool_call.name.clone();
let tool_args = tool_call.arguments.clone();
let tool_call_id = tool_call.id.clone();
in_progress_ids.insert(tool_call_id.clone());
match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
Ok(result) => {
let message = Message {
role: MessageRole::Tool,
content: result.content,
tool_call_id: Some(tool_call_id.clone()),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: Some(current_context.clone()),
});
}
Err(e) => {
let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
let message = Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id.clone()),
is_error: Some(true),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: Some(current_context.clone()),
});
}
}
mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
}
updates
}
pub async fn run_tools_concurrently<F, Fut>(
tool_calls: Vec<ToolCall>,
tool_context: crate::types::ToolContext,
mut executor: F,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let max_concurrency = get_max_tool_use_concurrency();
let mut updates = Vec::new();
let executions: Vec<_> = tool_calls
.into_iter()
.map(|tool_call| {
let mut exec = executor.clone();
let tool_name = tool_call.name.clone();
let tool_args = tool_call.arguments.clone();
let tool_call_id = tool_call.id.clone();
async move {
let result = exec(tool_name, tool_args, tool_call_id.clone()).await;
(tool_call_id, result)
}
})
.collect();
let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
while let Some((tool_call_id, result)) = stream.next().await {
match result {
Ok(tool_result) => {
let message = Message {
role: MessageRole::Tool,
content: tool_result.content,
tool_call_id: Some(tool_call_id),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: None,
});
}
Err(e) => {
let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
let message = Message {
role: MessageRole::Tool,
content: error_content,
tool_call_id: Some(tool_call_id),
is_error: Some(true),
..Default::default()
};
updates.push(ToolMessageUpdate {
message: Some(message),
new_context: None,
});
}
}
}
updates
}
pub async fn run_tools<F, Fut>(
tool_calls: Vec<ToolCall>,
tools: Vec<ToolDefinition>,
tool_context: crate::types::ToolContext,
mut executor: F,
) -> Vec<ToolMessageUpdate>
where
F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
{
let batches = partition_tool_calls(&tool_calls, &tools);
let mut all_updates = Vec::new();
let mut current_context = tool_context;
for batch in batches {
if batch.is_concurrency_safe {
let updates =
run_tools_concurrently(batch.blocks, current_context.clone(), executor.clone())
.await;
all_updates.extend(updates);
} else {
let updates =
run_tools_serially(batch.blocks, current_context.clone(), executor.clone()).await;
if let Some(last_update) = updates.last() {
if let Some(ctx) = &last_update.new_context {
current_context = ctx.clone();
}
}
all_updates.extend(updates);
}
}
all_updates
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ToolInputSchema;
fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
ToolDefinition {
name: name.to_string(),
description: format!("Test tool {}", name),
input_schema: ToolInputSchema {
schema_type: "object".to_string(),
properties: serde_json::json!({}),
required: None,
},
annotations: if concurrency_safe {
Some(ToolAnnotations {
concurrency_safe: Some(true),
..Default::default()
})
} else {
None
},
}
}
#[test]
fn test_get_max_tool_use_concurrency_default() {
assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
}
#[test]
fn test_get_max_tool_use_concurrency_value() {
let result = get_max_tool_use_concurrency();
assert!(result > 0);
}
#[test]
fn test_partition_tool_calls_all_non_safe() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
name: "Edit".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Bash", false),
create_test_tool("Edit", false),
];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 2);
assert!(!batches[0].is_concurrency_safe);
assert!(!batches[1].is_concurrency_safe);
}
#[test]
fn test_partition_tool_calls_mixed() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
name: "Read".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
name: "Glob".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "3".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "4".to_string(),
name: "Grep".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Read", true),
create_test_tool("Glob", true),
create_test_tool("Bash", false),
create_test_tool("Grep", true),
];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 3);
assert!(batches[0].is_concurrency_safe);
assert_eq!(batches[0].blocks.len(), 2);
assert!(!batches[1].is_concurrency_safe);
assert!(batches[2].is_concurrency_safe);
}
#[test]
fn test_partition_tool_calls_with_unknown_tool() {
let tool_calls = vec![ToolCall {
id: "1".to_string(),
name: "UnknownTool".to_string(),
arguments: serde_json::json!({}),
}];
let tools = vec![];
let batches = partition_tool_calls(&tool_calls, &tools);
assert_eq!(batches.len(), 1);
assert!(!batches[0].is_concurrency_safe);
}
#[tokio::test]
async fn test_run_tools_serially() {
let tool_calls = vec![ToolCall {
id: "1".to_string(),
name: "test".to_string(),
arguments: serde_json::json!({}),
}];
let tool_context = crate::types::ToolContext::default();
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
})
};
let updates = run_tools_serially(tool_calls, tool_context, executor).await;
assert_eq!(updates.len(), 1);
assert!(updates[0].message.is_some());
}
#[tokio::test]
async fn test_run_tools_concurrently() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
name: "test1".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
name: "test2".to_string(),
arguments: serde_json::json!({}),
},
];
let tool_context = crate::types::ToolContext::default();
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
})
};
let updates = run_tools_concurrently(tool_calls, tool_context, executor).await;
assert_eq!(updates.len(), 2);
}
#[tokio::test]
async fn test_run_tools_with_partitioning() {
let tool_calls = vec![
ToolCall {
id: "1".to_string(),
name: "Read".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "2".to_string(),
name: "Glob".to_string(),
arguments: serde_json::json!({}),
},
ToolCall {
id: "3".to_string(),
name: "Bash".to_string(),
arguments: serde_json::json!({}),
},
];
let tools = vec![
create_test_tool("Read", true),
create_test_tool("Glob", true),
create_test_tool("Bash", false),
];
let tool_context = crate::types::ToolContext::default();
let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
Ok(crate::types::ToolResult {
result_type: "tool_result".to_string(),
tool_use_id: "1".to_string(),
content: "success".to_string(),
is_error: Some(false),
})
};
let updates = run_tools(tool_calls, tools, tool_context, executor).await;
assert_eq!(updates.len(), 3);
}
#[test]
fn test_mark_tool_use_as_complete() {
let mut in_progress = std::collections::HashSet::new();
in_progress.insert("tool1".to_string());
in_progress.insert("tool2".to_string());
mark_tool_use_as_complete(&mut in_progress, "tool1");
assert!(!in_progress.contains("tool1"));
assert!(in_progress.contains("tool2"));
}
}