1use crate::constants::env::ai;
7use crate::types::{
8 Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema,
9};
10use crate::AgentError;
11use futures_util::stream::{self, StreamExt};
12
13pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
15
16pub fn get_max_tool_use_concurrency() -> usize {
18 std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
19 .ok()
20 .and_then(|v| v.parse::<usize>().ok())
21 .unwrap_or(MAX_TOOL_USE_CONCURRENCY)
22}
23
24#[derive(Debug, Clone)]
26pub struct ToolBatch {
27 pub is_concurrency_safe: bool,
29 pub blocks: Vec<ToolCall>,
31}
32
33#[derive(Debug, Clone)]
35pub struct ToolMessageUpdate {
36 pub message: Option<Message>,
38 pub new_context: Option<crate::types::ToolContext>,
40}
41
42pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
46 let mut batches: Vec<ToolBatch> = Vec::new();
47
48 for tool_use in tool_calls {
49 let tool = tools.iter().find(|t| t.name == tool_use.name);
51
52 let is_concurrency_safe = tool
56 .map(|t| t.is_concurrency_safe(&tool_use.arguments))
57 .unwrap_or(false);
58
59 if is_concurrency_safe {
61 if let Some(last) = batches.last_mut() {
62 if last.is_concurrency_safe {
63 last.blocks.push(tool_use.clone());
65 continue;
66 }
67 }
68 }
69
70 batches.push(ToolBatch {
72 is_concurrency_safe,
73 blocks: vec![tool_use.clone()],
74 });
75 }
76
77 batches
78}
79
80pub fn mark_tool_use_as_complete(
82 in_progress_ids: &mut std::collections::HashSet<String>,
83 tool_use_id: &str,
84) {
85 in_progress_ids.remove(tool_use_id);
86}
87
88pub async fn run_tools_serially<F, Fut>(
91 tool_calls: Vec<ToolCall>,
92 tool_context: crate::types::ToolContext,
93 mut executor: F,
94) -> Vec<ToolMessageUpdate>
95where
96 F: FnMut(String, serde_json::Value, String) -> Fut + Send,
97 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
98{
99 let mut updates = Vec::new();
100 let mut current_context = tool_context;
101 let mut in_progress_ids = std::collections::HashSet::new();
102
103 for tool_call in tool_calls {
104 let tool_name = tool_call.name.clone();
105 let tool_args = tool_call.arguments.clone();
106 let tool_call_id = tool_call.id.clone();
107
108 in_progress_ids.insert(tool_call_id.clone());
110
111 match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
113 Ok(result) => {
114 let message = Message {
116 role: MessageRole::Tool,
117 content: result.content,
118 tool_call_id: Some(tool_call_id.clone()),
119 ..Default::default()
120 };
121
122 updates.push(ToolMessageUpdate {
123 message: Some(message),
124 new_context: Some(current_context.clone()),
125 });
126 }
127 Err(e) => {
128 let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
130 let message = Message {
131 role: MessageRole::Tool,
132 content: error_content,
133 tool_call_id: Some(tool_call_id.clone()),
134 is_error: Some(true),
135 ..Default::default()
136 };
137
138 updates.push(ToolMessageUpdate {
139 message: Some(message),
140 new_context: Some(current_context.clone()),
141 });
142 }
143 }
144
145 mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
147 }
148
149 updates
150}
151
152pub async fn run_tools_concurrently<F, Fut>(
155 tool_calls: Vec<ToolCall>,
156 tool_context: crate::types::ToolContext,
157 mut executor: F,
158) -> Vec<ToolMessageUpdate>
159where
160 F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
161 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
162{
163 let max_concurrency = get_max_tool_use_concurrency();
164 let mut updates = Vec::new();
165
166 let executions: Vec<_> = tool_calls
168 .into_iter()
169 .map(|tool_call| {
170 let mut exec = executor.clone();
171 let tool_name = tool_call.name.clone();
172 let tool_args = tool_call.arguments.clone();
173 let tool_call_id = tool_call.id.clone();
174
175 async move {
176 let result = exec(tool_name, tool_args, tool_call_id.clone()).await;
177 (tool_call_id, result)
178 }
179 })
180 .collect();
181
182 let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
184
185 while let Some((tool_call_id, result)) = stream.next().await {
186 match result {
187 Ok(tool_result) => {
188 let message = Message {
189 role: MessageRole::Tool,
190 content: tool_result.content,
191 tool_call_id: Some(tool_call_id),
192 ..Default::default()
193 };
194
195 updates.push(ToolMessageUpdate {
196 message: Some(message),
197 new_context: None,
198 });
199 }
200 Err(e) => {
201 let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
202 let message = Message {
203 role: MessageRole::Tool,
204 content: error_content,
205 tool_call_id: Some(tool_call_id),
206 is_error: Some(true),
207 ..Default::default()
208 };
209
210 updates.push(ToolMessageUpdate {
211 message: Some(message),
212 new_context: None,
213 });
214 }
215 }
216 }
217
218 updates
219}
220
221pub async fn run_tools<F, Fut>(
224 tool_calls: Vec<ToolCall>,
225 tools: Vec<ToolDefinition>,
226 tool_context: crate::types::ToolContext,
227 mut executor: F,
228) -> Vec<ToolMessageUpdate>
229where
230 F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
231 Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
232{
233 let batches = partition_tool_calls(&tool_calls, &tools);
234 let mut all_updates = Vec::new();
235 let mut current_context = tool_context;
236
237 for batch in batches {
238 if batch.is_concurrency_safe {
239 let updates =
241 run_tools_concurrently(batch.blocks, current_context.clone(), executor.clone())
242 .await;
243 all_updates.extend(updates);
244 } else {
245 let updates =
247 run_tools_serially(batch.blocks, current_context.clone(), executor.clone()).await;
248
249 if let Some(last_update) = updates.last() {
251 if let Some(ctx) = &last_update.new_context {
252 current_context = ctx.clone();
253 }
254 }
255
256 all_updates.extend(updates);
257 }
258 }
259
260 all_updates
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::types::ToolInputSchema;
267
268 fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
269 ToolDefinition {
270 name: name.to_string(),
271 description: format!("Test tool {}", name),
272 input_schema: ToolInputSchema {
273 schema_type: "object".to_string(),
274 properties: serde_json::json!({}),
275 required: None,
276 },
277 annotations: if concurrency_safe {
278 Some(ToolAnnotations {
279 concurrency_safe: Some(true),
280 ..Default::default()
281 })
282 } else {
283 None
284 },
285 }
286 }
287
288 #[test]
289 fn test_get_max_tool_use_concurrency_default() {
290 assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
293 }
294
295 #[test]
296 fn test_get_max_tool_use_concurrency_value() {
297 let result = get_max_tool_use_concurrency();
299 assert!(result > 0);
300 }
301
302 #[test]
303 fn test_partition_tool_calls_all_non_safe() {
304 let tool_calls = vec![
305 ToolCall {
306 id: "1".to_string(),
307 name: "Bash".to_string(),
308 arguments: serde_json::json!({}),
309 },
310 ToolCall {
311 id: "2".to_string(),
312 name: "Edit".to_string(),
313 arguments: serde_json::json!({}),
314 },
315 ];
316 let tools = vec![
317 create_test_tool("Bash", false),
318 create_test_tool("Edit", false),
319 ];
320
321 let batches = partition_tool_calls(&tool_calls, &tools);
322 assert_eq!(batches.len(), 2);
323 assert!(!batches[0].is_concurrency_safe);
324 assert!(!batches[1].is_concurrency_safe);
325 }
326
327 #[test]
328 fn test_partition_tool_calls_mixed() {
329 let tool_calls = vec![
330 ToolCall {
331 id: "1".to_string(),
332 name: "Read".to_string(),
333 arguments: serde_json::json!({}),
334 },
335 ToolCall {
336 id: "2".to_string(),
337 name: "Glob".to_string(),
338 arguments: serde_json::json!({}),
339 },
340 ToolCall {
341 id: "3".to_string(),
342 name: "Bash".to_string(),
343 arguments: serde_json::json!({}),
344 },
345 ToolCall {
346 id: "4".to_string(),
347 name: "Grep".to_string(),
348 arguments: serde_json::json!({}),
349 },
350 ];
351 let tools = vec![
352 create_test_tool("Read", true),
353 create_test_tool("Glob", true),
354 create_test_tool("Bash", false),
355 create_test_tool("Grep", true),
356 ];
357
358 let batches = partition_tool_calls(&tool_calls, &tools);
359 assert_eq!(batches.len(), 3);
361 assert!(batches[0].is_concurrency_safe);
362 assert_eq!(batches[0].blocks.len(), 2);
363 assert!(!batches[1].is_concurrency_safe);
364 assert!(batches[2].is_concurrency_safe);
365 }
366
367 #[test]
368 fn test_partition_tool_calls_with_unknown_tool() {
369 let tool_calls = vec![ToolCall {
370 id: "1".to_string(),
371 name: "UnknownTool".to_string(),
372 arguments: serde_json::json!({}),
373 }];
374 let tools = vec![];
375
376 let batches = partition_tool_calls(&tool_calls, &tools);
377 assert_eq!(batches.len(), 1);
378 assert!(!batches[0].is_concurrency_safe);
380 }
381
382 #[tokio::test]
383 async fn test_run_tools_serially() {
384 let tool_calls = vec![ToolCall {
385 id: "1".to_string(),
386 name: "test".to_string(),
387 arguments: serde_json::json!({}),
388 }];
389
390 let tool_context = crate::types::ToolContext::default();
391
392 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
393 Ok(crate::types::ToolResult {
394 result_type: "tool_result".to_string(),
395 tool_use_id: "1".to_string(),
396 content: "success".to_string(),
397 is_error: Some(false),
398 })
399 };
400
401 let updates = run_tools_serially(tool_calls, tool_context, executor).await;
402 assert_eq!(updates.len(), 1);
403 assert!(updates[0].message.is_some());
404 }
405
406 #[tokio::test]
407 async fn test_run_tools_concurrently() {
408 let tool_calls = vec![
409 ToolCall {
410 id: "1".to_string(),
411 name: "test1".to_string(),
412 arguments: serde_json::json!({}),
413 },
414 ToolCall {
415 id: "2".to_string(),
416 name: "test2".to_string(),
417 arguments: serde_json::json!({}),
418 },
419 ];
420
421 let tool_context = crate::types::ToolContext::default();
422
423 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
424 Ok(crate::types::ToolResult {
425 result_type: "tool_result".to_string(),
426 tool_use_id: "1".to_string(),
427 content: "success".to_string(),
428 is_error: Some(false),
429 })
430 };
431
432 let updates = run_tools_concurrently(tool_calls, tool_context, executor).await;
433 assert_eq!(updates.len(), 2);
434 }
435
436 #[tokio::test]
437 async fn test_run_tools_with_partitioning() {
438 let tool_calls = vec![
439 ToolCall {
440 id: "1".to_string(),
441 name: "Read".to_string(),
442 arguments: serde_json::json!({}),
443 },
444 ToolCall {
445 id: "2".to_string(),
446 name: "Glob".to_string(),
447 arguments: serde_json::json!({}),
448 },
449 ToolCall {
450 id: "3".to_string(),
451 name: "Bash".to_string(),
452 arguments: serde_json::json!({}),
453 },
454 ];
455 let tools = vec![
456 create_test_tool("Read", true),
457 create_test_tool("Glob", true),
458 create_test_tool("Bash", false),
459 ];
460
461 let tool_context = crate::types::ToolContext::default();
462
463 let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
464 Ok(crate::types::ToolResult {
465 result_type: "tool_result".to_string(),
466 tool_use_id: "1".to_string(),
467 content: "success".to_string(),
468 is_error: Some(false),
469 })
470 };
471
472 let updates = run_tools(tool_calls, tools, tool_context, executor).await;
473 assert_eq!(updates.len(), 3);
474 }
475
476 #[test]
477 fn test_mark_tool_use_as_complete() {
478 let mut in_progress = std::collections::HashSet::new();
479 in_progress.insert("tool1".to_string());
480 in_progress.insert("tool2".to_string());
481
482 mark_tool_use_as_complete(&mut in_progress, "tool1");
483
484 assert!(!in_progress.contains("tool1"));
485 assert!(in_progress.contains("tool2"));
486 }
487}