codetether_agent/tool/
batch.rs1use super::{Tool, ToolRegistry, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::sync::{Arc, RwLock, Weak};
9
10pub struct BatchTool {
13 registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl Default for BatchTool {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl BatchTool {
23 pub fn new() -> Self {
24 Self {
25 registry: Arc::new(RwLock::new(None)),
26 }
27 }
28
29 pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
31 let mut guard = self.registry.write().unwrap();
32 *guard = Some(registry);
33 }
34}
35
36#[derive(Deserialize)]
37struct Params {
38 calls: Vec<BatchCall>,
39}
40
41#[derive(Deserialize)]
42struct BatchCall {
43 #[serde(alias = "name")]
44 tool: String,
45 #[serde(default, alias = "arguments", alias = "params")]
46 args: Value,
47}
48
49#[async_trait]
50impl Tool for BatchTool {
51 fn id(&self) -> &str {
52 "batch"
53 }
54 fn name(&self) -> &str {
55 "Batch Execute"
56 }
57 fn description(&self) -> &str {
58 "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments."
59 }
60 fn parameters(&self) -> Value {
61 json!({
62 "type": "object",
63 "properties": {
64 "calls": {
65 "type": "array",
66 "description": "Array of tool calls to execute. Preferred keys are `tool` + `args`; aliases `name` + `arguments` are also accepted for compatibility.",
67 "items": {
68 "type": "object",
69 "properties": {
70 "tool": {"type": "string", "description": "Tool ID to call (alias: `name`)"},
71 "args": {"type": "object", "description": "Arguments for the tool (alias: `arguments`)"},
72 "name": {"type": "string", "description": "Alias for `tool`"},
73 "arguments": {"type": "object", "description": "Alias for `args`"}
74 },
75 "anyOf": [
76 { "required": ["tool", "args"] },
77 { "required": ["name", "arguments"] }
78 ]
79 }
80 }
81 },
82 "required": ["calls"]
83 })
84 }
85
86 async fn execute(&self, params: Value) -> Result<ToolResult> {
87 let p: Params = serde_json::from_value(params).context("Invalid params")?;
88
89 if p.calls.is_empty() {
90 return Ok(ToolResult::error("No calls provided"));
91 }
92
93 let registry = {
95 let guard = self.registry.read().unwrap();
96 match guard.as_ref() {
97 Some(weak) => match weak.upgrade() {
98 Some(arc) => arc,
99 None => return Ok(ToolResult::error("Registry no longer available")),
100 },
101 None => return Ok(ToolResult::error("Registry not initialized")),
102 }
103 };
104
105 let futures: Vec<_> = p
107 .calls
108 .iter()
109 .enumerate()
110 .map(|(i, call)| {
111 let tool_id = call.tool.clone();
112 let args = call.args.clone();
113 let registry = Arc::clone(®istry);
114
115 async move {
116 if tool_id == "batch" {
118 return (
119 i,
120 tool_id,
121 ToolResult::error("Cannot call batch from within batch"),
122 );
123 }
124
125 match registry.get(&tool_id) {
126 Some(tool) => match tool.execute(args).await {
127 Ok(result) => (i, tool_id, result),
128 Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
129 },
130 None => {
131 let available_tools =
133 registry.list().iter().map(|s| s.to_string()).collect();
134 let invalid_tool = super::invalid::InvalidTool::with_context(
135 tool_id.clone(),
136 available_tools,
137 );
138 let invalid_args = serde_json::json!({
139 "requested_tool": tool_id,
140 "args": args
141 });
142 match invalid_tool.execute(invalid_args).await {
143 Ok(result) => (i, tool_id.clone(), result),
144 Err(e) => (
145 i,
146 tool_id.clone(),
147 ToolResult::error(format!(
148 "Unknown tool: {}. Error: {}",
149 tool_id, e
150 )),
151 ),
152 }
153 }
154 }
155 }
156 })
157 .collect();
158
159 let results = futures::future::join_all(futures).await;
160
161 let mut output_parts = Vec::new();
162 let mut success_count = 0;
163 let mut error_count = 0;
164
165 for (idx, tool_id, result) in results {
166 if result.success {
167 success_count += 1;
168 output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
169 } else {
170 error_count += 1;
171 output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
172 }
173 }
174
175 let summary = format!(
176 "Batch complete: {} succeeded, {} failed\n\n{}",
177 success_count,
178 error_count,
179 output_parts.join("\n\n")
180 );
181
182 let overall_success = error_count == 0;
183 if overall_success {
184 Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
185 } else {
186 Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
187 }
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::Params;
194
195 #[test]
196 fn batch_call_accepts_name_arguments_aliases() {
197 let params: Params = serde_json::from_value(serde_json::json!({
198 "calls": [
199 {
200 "name": "read",
201 "arguments": { "path": "src/main.rs" }
202 }
203 ]
204 }))
205 .expect("should parse alias form");
206
207 assert_eq!(params.calls.len(), 1);
208 assert_eq!(params.calls[0].tool, "read");
209 assert_eq!(params.calls[0].args["path"], "src/main.rs");
210 }
211
212 #[test]
213 fn batch_call_accepts_tool_args_primary_form() {
214 let params: Params = serde_json::from_value(serde_json::json!({
215 "calls": [
216 {
217 "tool": "read",
218 "args": { "path": "src/main.rs" }
219 }
220 ]
221 }))
222 .expect("should parse primary form");
223
224 assert_eq!(params.calls.len(), 1);
225 assert_eq!(params.calls[0].tool, "read");
226 assert_eq!(params.calls[0].args["path"], "src/main.rs");
227 }
228}