codetether_agent/tool/
batch.rs1use super::{Tool, ToolRegistry, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::sync::{Arc, RwLock, Weak};
9
10pub struct BatchTool {
13 registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl BatchTool {
17 pub fn new() -> Self {
18 Self {
19 registry: Arc::new(RwLock::new(None)),
20 }
21 }
22
23 pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
25 let mut guard = self.registry.write().unwrap();
26 *guard = Some(registry);
27 }
28}
29
30#[derive(Deserialize)]
31struct Params {
32 calls: Vec<BatchCall>,
33}
34
35#[derive(Deserialize)]
36struct BatchCall {
37 tool: String,
38 args: Value,
39}
40
41#[async_trait]
42impl Tool for BatchTool {
43 fn id(&self) -> &str {
44 "batch"
45 }
46 fn name(&self) -> &str {
47 "Batch Execute"
48 }
49 fn description(&self) -> &str {
50 "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments."
51 }
52 fn parameters(&self) -> Value {
53 json!({
54 "type": "object",
55 "properties": {
56 "calls": {
57 "type": "array",
58 "description": "Array of tool calls to execute",
59 "items": {
60 "type": "object",
61 "properties": {
62 "tool": {"type": "string", "description": "Tool ID to call"},
63 "args": {"type": "object", "description": "Arguments for the tool"}
64 },
65 "required": ["tool", "args"]
66 }
67 }
68 },
69 "required": ["calls"]
70 })
71 }
72
73 async fn execute(&self, params: Value) -> Result<ToolResult> {
74 let p: Params = serde_json::from_value(params).context("Invalid params")?;
75
76 if p.calls.is_empty() {
77 return Ok(ToolResult::error("No calls provided"));
78 }
79
80 let registry = {
82 let guard = self.registry.read().unwrap();
83 match guard.as_ref() {
84 Some(weak) => match weak.upgrade() {
85 Some(arc) => arc,
86 None => return Ok(ToolResult::error("Registry no longer available")),
87 },
88 None => return Ok(ToolResult::error("Registry not initialized")),
89 }
90 };
91
92 let futures: Vec<_> = p
94 .calls
95 .iter()
96 .enumerate()
97 .map(|(i, call)| {
98 let tool_id = call.tool.clone();
99 let args = call.args.clone();
100 let registry = Arc::clone(®istry);
101
102 async move {
103 if tool_id == "batch" {
105 return (
106 i,
107 tool_id,
108 ToolResult::error("Cannot call batch from within batch"),
109 );
110 }
111
112 match registry.get(&tool_id) {
113 Some(tool) => match tool.execute(args).await {
114 Ok(result) => (i, tool_id, result),
115 Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
116 },
117 None => {
118 let available_tools =
120 registry.list().iter().map(|s| s.to_string()).collect();
121 let invalid_tool = super::invalid::InvalidTool::with_context(
122 tool_id.clone(),
123 available_tools,
124 );
125 let invalid_args = serde_json::json!({
126 "requested_tool": tool_id,
127 "args": args
128 });
129 match invalid_tool.execute(invalid_args).await {
130 Ok(result) => (i, tool_id.clone(), result),
131 Err(e) => (
132 i,
133 tool_id.clone(),
134 ToolResult::error(format!(
135 "Unknown tool: {}. Error: {}",
136 tool_id, e
137 )),
138 ),
139 }
140 }
141 }
142 }
143 })
144 .collect();
145
146 let results = futures::future::join_all(futures).await;
147
148 let mut output_parts = Vec::new();
149 let mut success_count = 0;
150 let mut error_count = 0;
151
152 for (idx, tool_id, result) in results {
153 if result.success {
154 success_count += 1;
155 output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
156 } else {
157 error_count += 1;
158 output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
159 }
160 }
161
162 let summary = format!(
163 "Batch complete: {} succeeded, {} failed\n\n{}",
164 success_count,
165 error_count,
166 output_parts.join("\n\n")
167 );
168
169 let overall_success = error_count == 0;
170 if overall_success {
171 Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
172 } else {
173 Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
174 }
175 }
176}