Skip to main content

ai_agent/tools/
orchestration.rs

1// Source: /data/home/swei/claudecode/openclaudecode/src/services/tools/toolOrchestration.ts
2//! Tool orchestration module for running tools with concurrency control.
3//!
4//! Translated from TypeScript toolOrchestration.ts
5
6use crate::constants::env::ai;
7use crate::types::{
8    Message, MessageRole, ToolAnnotations, ToolCall, ToolDefinition, ToolInputSchema,
9};
10use crate::AgentError;
11use futures_util::stream::{self, StreamExt};
12
13/// Maximum number of concurrent tool executions (matches TypeScript default)
14pub const MAX_TOOL_USE_CONCURRENCY: usize = 10;
15
16/// Get max tool use concurrency from environment variable
17pub fn get_max_tool_use_concurrency() -> usize {
18    std::env::var(ai::MAX_TOOL_USE_CONCURRENCY)
19        .ok()
20        .and_then(|v| v.parse::<usize>().ok())
21        .unwrap_or(MAX_TOOL_USE_CONCURRENCY)
22}
23
24/// A batch of tool calls that can be executed together
25#[derive(Debug, Clone)]
26pub struct ToolBatch {
27    /// Whether this batch is concurrency-safe (can run in parallel)
28    pub is_concurrency_safe: bool,
29    /// The tool calls in this batch
30    pub blocks: Vec<ToolCall>,
31}
32
33/// Message update for tool orchestration
34#[derive(Debug, Clone)]
35pub struct ToolMessageUpdate {
36    /// The message to add to the conversation
37    pub message: Option<Message>,
38    /// Updated context after this tool (for serial execution)
39    pub new_context: Option<crate::types::ToolContext>,
40}
41
42/// Partition tool calls into batches where each batch is either:
43/// 1. A single non-concurrency-safe tool, or
44/// 2. Multiple consecutive concurrency-safe tools
45pub fn partition_tool_calls(tool_calls: &[ToolCall], tools: &[ToolDefinition]) -> Vec<ToolBatch> {
46    let mut batches: Vec<ToolBatch> = Vec::new();
47
48    for tool_use in tool_calls {
49        // Find the tool definition
50        let tool = tools.iter().find(|t| t.name == tool_use.name);
51
52        // Check concurrency safety
53        // Matches TypeScript: use the tool's isConcurrencySafe method
54        // If tool not found or isConcurrencySafe throws, treat as not concurrency-safe
55        let is_concurrency_safe = tool
56            .map(|t| t.is_concurrency_safe(&tool_use.arguments))
57            .unwrap_or(false);
58
59        // Check if we can add to the last batch
60        if is_concurrency_safe {
61            if let Some(last) = batches.last_mut() {
62                if last.is_concurrency_safe {
63                    // Add to existing concurrency-safe batch
64                    last.blocks.push(tool_use.clone());
65                    continue;
66                }
67            }
68        }
69
70        // Create new batch (either non-concurrency-safe or first in a concurrency-safe group)
71        batches.push(ToolBatch {
72            is_concurrency_safe,
73            blocks: vec![tool_use.clone()],
74        });
75    }
76
77    batches
78}
79
80/// Mark a tool use as complete (removes from in-progress set)
81pub fn mark_tool_use_as_complete(
82    in_progress_ids: &mut std::collections::HashSet<String>,
83    tool_use_id: &str,
84) {
85    in_progress_ids.remove(tool_use_id);
86}
87
88/// Run tools serially (for non-concurrency-safe tools)
89/// This matches TypeScript's runToolsSerially function
90pub async fn run_tools_serially<F, Fut>(
91    tool_calls: Vec<ToolCall>,
92    tool_context: crate::types::ToolContext,
93    mut executor: F,
94) -> Vec<ToolMessageUpdate>
95where
96    F: FnMut(String, serde_json::Value, String) -> Fut + Send,
97    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
98{
99    let mut updates = Vec::new();
100    let mut current_context = tool_context;
101    let mut in_progress_ids = std::collections::HashSet::new();
102
103    for tool_call in tool_calls {
104        let tool_name = tool_call.name.clone();
105        let tool_args = tool_call.arguments.clone();
106        let tool_call_id = tool_call.id.clone();
107
108        // Mark tool as in progress
109        in_progress_ids.insert(tool_call_id.clone());
110
111        // Execute the tool (pass tool_call_id)
112        match executor(tool_name.clone(), tool_args.clone(), tool_call_id.clone()).await {
113            Ok(result) => {
114                // Create tool result message
115                let message = Message {
116                    role: MessageRole::Tool,
117                    content: result.content,
118                    tool_call_id: Some(tool_call_id.clone()),
119                    ..Default::default()
120                };
121
122                updates.push(ToolMessageUpdate {
123                    message: Some(message),
124                    new_context: Some(current_context.clone()),
125                });
126            }
127            Err(e) => {
128                // Add error as tool result message
129                let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
130                let message = Message {
131                    role: MessageRole::Tool,
132                    content: error_content,
133                    tool_call_id: Some(tool_call_id.clone()),
134                    is_error: Some(true),
135                    ..Default::default()
136                };
137
138                updates.push(ToolMessageUpdate {
139                    message: Some(message),
140                    new_context: Some(current_context.clone()),
141                });
142            }
143        }
144
145        // Mark tool as complete
146        mark_tool_use_as_complete(&mut in_progress_ids, &tool_call_id);
147    }
148
149    updates
150}
151
152/// Run tools concurrently (for concurrency-safe tools)
153/// Uses the all() generator pattern from TypeScript with concurrency limit
154pub async fn run_tools_concurrently<F, Fut>(
155    tool_calls: Vec<ToolCall>,
156    tool_context: crate::types::ToolContext,
157    mut executor: F,
158) -> Vec<ToolMessageUpdate>
159where
160    F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
161    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
162{
163    let max_concurrency = get_max_tool_use_concurrency();
164    let mut updates = Vec::new();
165
166    // Create a stream of tool executions
167    let executions: Vec<_> = tool_calls
168        .into_iter()
169        .map(|tool_call| {
170            let mut exec = executor.clone();
171            let tool_name = tool_call.name.clone();
172            let tool_args = tool_call.arguments.clone();
173            let tool_call_id = tool_call.id.clone();
174
175            async move {
176                let result = exec(tool_name, tool_args, tool_call_id.clone()).await;
177                (tool_call_id, result)
178            }
179        })
180        .collect();
181
182    // Run with bounded concurrency using buffer_unordered
183    let mut stream = stream::iter(executions).buffer_unordered(max_concurrency);
184
185    while let Some((tool_call_id, result)) = stream.next().await {
186        match result {
187            Ok(tool_result) => {
188                let message = Message {
189                    role: MessageRole::Tool,
190                    content: tool_result.content,
191                    tool_call_id: Some(tool_call_id),
192                    ..Default::default()
193                };
194
195                updates.push(ToolMessageUpdate {
196                    message: Some(message),
197                    new_context: None,
198                });
199            }
200            Err(e) => {
201                let error_content = format!("<tool_use_error>Error: {}</tool_use_error>", e);
202                let message = Message {
203                    role: MessageRole::Tool,
204                    content: error_content,
205                    tool_call_id: Some(tool_call_id),
206                    is_error: Some(true),
207                    ..Default::default()
208                };
209
210                updates.push(ToolMessageUpdate {
211                    message: Some(message),
212                    new_context: None,
213                });
214            }
215        }
216    }
217
218    updates
219}
220
221/// Run all tools with proper partitioning and concurrency
222/// This is the main entry point that matches TypeScript's runTools()
223pub async fn run_tools<F, Fut>(
224    tool_calls: Vec<ToolCall>,
225    tools: Vec<ToolDefinition>,
226    tool_context: crate::types::ToolContext,
227    mut executor: F,
228) -> Vec<ToolMessageUpdate>
229where
230    F: FnMut(String, serde_json::Value, String) -> Fut + Send + Clone + 'static,
231    Fut: Future<Output = Result<crate::types::ToolResult, AgentError>> + Send,
232{
233    let batches = partition_tool_calls(&tool_calls, &tools);
234    let mut all_updates = Vec::new();
235    let mut current_context = tool_context;
236
237    for batch in batches {
238        if batch.is_concurrency_safe {
239            // Run concurrency-safe batch concurrently
240            let updates =
241                run_tools_concurrently(batch.blocks, current_context.clone(), executor.clone())
242                    .await;
243            all_updates.extend(updates);
244        } else {
245            // Run non-concurrency-safe batch serially
246            let updates =
247                run_tools_serially(batch.blocks, current_context.clone(), executor.clone()).await;
248
249            // Update context after serial execution
250            if let Some(last_update) = updates.last() {
251                if let Some(ctx) = &last_update.new_context {
252                    current_context = ctx.clone();
253                }
254            }
255
256            all_updates.extend(updates);
257        }
258    }
259
260    all_updates
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use crate::types::ToolInputSchema;
267
268    fn create_test_tool(name: &str, concurrency_safe: bool) -> ToolDefinition {
269        ToolDefinition {
270            name: name.to_string(),
271            description: format!("Test tool {}", name),
272            input_schema: ToolInputSchema {
273                schema_type: "object".to_string(),
274                properties: serde_json::json!({}),
275                required: None,
276            },
277            annotations: if concurrency_safe {
278                Some(ToolAnnotations {
279                    concurrency_safe: Some(true),
280                    ..Default::default()
281                })
282            } else {
283                None
284            },
285        }
286    }
287
288    #[test]
289    fn test_get_max_tool_use_concurrency_default() {
290        // Without env var, should return default
291        // Note: In real env with no var set, returns default
292        assert_eq!(get_max_tool_use_concurrency(), MAX_TOOL_USE_CONCURRENCY);
293    }
294
295    #[test]
296    fn test_get_max_tool_use_concurrency_value() {
297        // Just test that the function returns a value
298        let result = get_max_tool_use_concurrency();
299        assert!(result > 0);
300    }
301
302    #[test]
303    fn test_partition_tool_calls_all_non_safe() {
304        let tool_calls = vec![
305            ToolCall {
306                id: "1".to_string(),
307                name: "Bash".to_string(),
308                arguments: serde_json::json!({}),
309            },
310            ToolCall {
311                id: "2".to_string(),
312                name: "Edit".to_string(),
313                arguments: serde_json::json!({}),
314            },
315        ];
316        let tools = vec![
317            create_test_tool("Bash", false),
318            create_test_tool("Edit", false),
319        ];
320
321        let batches = partition_tool_calls(&tool_calls, &tools);
322        assert_eq!(batches.len(), 2);
323        assert!(!batches[0].is_concurrency_safe);
324        assert!(!batches[1].is_concurrency_safe);
325    }
326
327    #[test]
328    fn test_partition_tool_calls_mixed() {
329        let tool_calls = vec![
330            ToolCall {
331                id: "1".to_string(),
332                name: "Read".to_string(),
333                arguments: serde_json::json!({}),
334            },
335            ToolCall {
336                id: "2".to_string(),
337                name: "Glob".to_string(),
338                arguments: serde_json::json!({}),
339            },
340            ToolCall {
341                id: "3".to_string(),
342                name: "Bash".to_string(),
343                arguments: serde_json::json!({}),
344            },
345            ToolCall {
346                id: "4".to_string(),
347                name: "Grep".to_string(),
348                arguments: serde_json::json!({}),
349            },
350        ];
351        let tools = vec![
352            create_test_tool("Read", true),
353            create_test_tool("Glob", true),
354            create_test_tool("Bash", false),
355            create_test_tool("Grep", true),
356        ];
357
358        let batches = partition_tool_calls(&tool_calls, &tools);
359        // Should be: [Read,Glob] (concurrency safe), [Bash] (non-safe), [Grep] (concurrency safe)
360        assert_eq!(batches.len(), 3);
361        assert!(batches[0].is_concurrency_safe);
362        assert_eq!(batches[0].blocks.len(), 2);
363        assert!(!batches[1].is_concurrency_safe);
364        assert!(batches[2].is_concurrency_safe);
365    }
366
367    #[test]
368    fn test_partition_tool_calls_with_unknown_tool() {
369        let tool_calls = vec![ToolCall {
370            id: "1".to_string(),
371            name: "UnknownTool".to_string(),
372            arguments: serde_json::json!({}),
373        }];
374        let tools = vec![];
375
376        let batches = partition_tool_calls(&tool_calls, &tools);
377        assert_eq!(batches.len(), 1);
378        // Unknown tools should be treated as not concurrency-safe
379        assert!(!batches[0].is_concurrency_safe);
380    }
381
382    #[tokio::test]
383    async fn test_run_tools_serially() {
384        let tool_calls = vec![ToolCall {
385            id: "1".to_string(),
386            name: "test".to_string(),
387            arguments: serde_json::json!({}),
388        }];
389
390        let tool_context = crate::types::ToolContext::default();
391
392        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
393            Ok(crate::types::ToolResult {
394                result_type: "tool_result".to_string(),
395                tool_use_id: "1".to_string(),
396                content: "success".to_string(),
397                is_error: Some(false),
398            })
399        };
400
401        let updates = run_tools_serially(tool_calls, tool_context, executor).await;
402        assert_eq!(updates.len(), 1);
403        assert!(updates[0].message.is_some());
404    }
405
406    #[tokio::test]
407    async fn test_run_tools_concurrently() {
408        let tool_calls = vec![
409            ToolCall {
410                id: "1".to_string(),
411                name: "test1".to_string(),
412                arguments: serde_json::json!({}),
413            },
414            ToolCall {
415                id: "2".to_string(),
416                name: "test2".to_string(),
417                arguments: serde_json::json!({}),
418            },
419        ];
420
421        let tool_context = crate::types::ToolContext::default();
422
423        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
424            Ok(crate::types::ToolResult {
425                result_type: "tool_result".to_string(),
426                tool_use_id: "1".to_string(),
427                content: "success".to_string(),
428                is_error: Some(false),
429            })
430        };
431
432        let updates = run_tools_concurrently(tool_calls, tool_context, executor).await;
433        assert_eq!(updates.len(), 2);
434    }
435
436    #[tokio::test]
437    async fn test_run_tools_with_partitioning() {
438        let tool_calls = vec![
439            ToolCall {
440                id: "1".to_string(),
441                name: "Read".to_string(),
442                arguments: serde_json::json!({}),
443            },
444            ToolCall {
445                id: "2".to_string(),
446                name: "Glob".to_string(),
447                arguments: serde_json::json!({}),
448            },
449            ToolCall {
450                id: "3".to_string(),
451                name: "Bash".to_string(),
452                arguments: serde_json::json!({}),
453            },
454        ];
455        let tools = vec![
456            create_test_tool("Read", true),
457            create_test_tool("Glob", true),
458            create_test_tool("Bash", false),
459        ];
460
461        let tool_context = crate::types::ToolContext::default();
462
463        let executor = |_name: String, _args: serde_json::Value, _tool_call_id: String| async {
464            Ok(crate::types::ToolResult {
465                result_type: "tool_result".to_string(),
466                tool_use_id: "1".to_string(),
467                content: "success".to_string(),
468                is_error: Some(false),
469            })
470        };
471
472        let updates = run_tools(tool_calls, tools, tool_context, executor).await;
473        assert_eq!(updates.len(), 3);
474    }
475
476    #[test]
477    fn test_mark_tool_use_as_complete() {
478        let mut in_progress = std::collections::HashSet::new();
479        in_progress.insert("tool1".to_string());
480        in_progress.insert("tool2".to_string());
481
482        mark_tool_use_as_complete(&mut in_progress, "tool1");
483
484        assert!(!in_progress.contains("tool1"));
485        assert!(in_progress.contains("tool2"));
486    }
487}