Skip to main content

codetether_agent/tool/
batch.rs

1//! Batch Tool - Execute multiple tool calls in parallel.
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use super::{Tool, ToolResult, ToolRegistry};
8use std::sync::{Arc, RwLock, Weak};
9
10/// BatchTool executes multiple tool calls in parallel.
11/// Uses lazy registry initialization to break circular dependency.
12pub struct BatchTool {
13    registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl BatchTool {
17    pub fn new() -> Self {
18        Self {
19            registry: Arc::new(RwLock::new(None)),
20        }
21    }
22
23    /// Set the registry after construction to break circular dependency.
24    pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
25        let mut guard = self.registry.write().unwrap();
26        *guard = Some(registry);
27    }
28}
29
30#[derive(Deserialize)]
31struct Params {
32    calls: Vec<BatchCall>,
33}
34
35#[derive(Deserialize)]
36struct BatchCall {
37    tool: String,
38    args: Value,
39}
40
41#[async_trait]
42impl Tool for BatchTool {
43    fn id(&self) -> &str { "batch" }
44    fn name(&self) -> &str { "Batch Execute" }
45    fn description(&self) -> &str { "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments." }
46    fn parameters(&self) -> Value {
47        json!({
48            "type": "object",
49            "properties": {
50                "calls": {
51                    "type": "array",
52                    "description": "Array of tool calls to execute",
53                    "items": {
54                        "type": "object",
55                        "properties": {
56                            "tool": {"type": "string", "description": "Tool ID to call"},
57                            "args": {"type": "object", "description": "Arguments for the tool"}
58                        },
59                        "required": ["tool", "args"]
60                    }
61                }
62            },
63            "required": ["calls"]
64        })
65    }
66
67    async fn execute(&self, params: Value) -> Result<ToolResult> {
68        let p: Params = serde_json::from_value(params).context("Invalid params")?;
69        
70        if p.calls.is_empty() {
71            return Ok(ToolResult::error("No calls provided"));
72        }
73
74        // Get registry from weak reference
75        let registry = {
76            let guard = self.registry.read().unwrap();
77            match guard.as_ref() {
78                Some(weak) => match weak.upgrade() {
79                    Some(arc) => arc,
80                    None => return Ok(ToolResult::error("Registry no longer available")),
81                },
82                None => return Ok(ToolResult::error("Registry not initialized")),
83            }
84        };
85        
86        // Execute all calls in parallel
87        let futures: Vec<_> = p.calls.iter().enumerate().map(|(i, call)| {
88            let tool_id = call.tool.clone();
89            let args = call.args.clone();
90            let registry = Arc::clone(&registry);
91            
92            async move {
93                // Prevent recursive batch calls
94                if tool_id == "batch" {
95                    return (i, tool_id, ToolResult::error("Cannot call batch from within batch"));
96                }
97                
98                match registry.get(&tool_id) {
99                    Some(tool) => {
100                        match tool.execute(args).await {
101                            Ok(result) => (i, tool_id, result),
102                            Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
103                        }
104                    }
105                    None => {
106                        // Use the invalid tool handler for better error messages
107                        let available_tools = registry.list().iter().map(|s| s.to_string()).collect();
108                        let invalid_tool = super::invalid::InvalidTool::with_context(tool_id.clone(), available_tools);
109                        let invalid_args = serde_json::json!({
110                            "requested_tool": tool_id,
111                            "args": args
112                        });
113                        match invalid_tool.execute(invalid_args).await {
114                            Ok(result) => (i, tool_id.clone(), result),
115                            Err(e) => (i, tool_id.clone(), ToolResult::error(format!("Unknown tool: {}. Error: {}", tool_id, e))),
116                        }
117                    }
118                }
119            }
120        }).collect();
121        
122        let results = futures::future::join_all(futures).await;
123        
124        let mut output_parts = Vec::new();
125        let mut success_count = 0;
126        let mut error_count = 0;
127        
128        for (idx, tool_id, result) in results {
129            if result.success {
130                success_count += 1;
131                output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
132            } else {
133                error_count += 1;
134                output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
135            }
136        }
137        
138        let summary = format!("Batch complete: {} succeeded, {} failed\n\n{}", 
139            success_count, error_count, output_parts.join("\n\n"));
140        
141        let overall_success = error_count == 0;
142        if overall_success {
143            Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
144        } else {
145            Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
146        }
147    }
148}