codetether_agent/tool/
batch.rs1use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use super::{Tool, ToolResult, ToolRegistry};
8use std::sync::{Arc, RwLock, Weak};
9
10pub struct BatchTool {
13 registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl BatchTool {
17 pub fn new() -> Self {
18 Self {
19 registry: Arc::new(RwLock::new(None)),
20 }
21 }
22
23 pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
25 let mut guard = self.registry.write().unwrap();
26 *guard = Some(registry);
27 }
28}
29
30#[derive(Deserialize)]
31struct Params {
32 calls: Vec<BatchCall>,
33}
34
35#[derive(Deserialize)]
36struct BatchCall {
37 tool: String,
38 args: Value,
39}
40
41#[async_trait]
42impl Tool for BatchTool {
43 fn id(&self) -> &str { "batch" }
44 fn name(&self) -> &str { "Batch Execute" }
45 fn description(&self) -> &str { "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments." }
46 fn parameters(&self) -> Value {
47 json!({
48 "type": "object",
49 "properties": {
50 "calls": {
51 "type": "array",
52 "description": "Array of tool calls to execute",
53 "items": {
54 "type": "object",
55 "properties": {
56 "tool": {"type": "string", "description": "Tool ID to call"},
57 "args": {"type": "object", "description": "Arguments for the tool"}
58 },
59 "required": ["tool", "args"]
60 }
61 }
62 },
63 "required": ["calls"]
64 })
65 }
66
67 async fn execute(&self, params: Value) -> Result<ToolResult> {
68 let p: Params = serde_json::from_value(params).context("Invalid params")?;
69
70 if p.calls.is_empty() {
71 return Ok(ToolResult::error("No calls provided"));
72 }
73
74 let registry = {
76 let guard = self.registry.read().unwrap();
77 match guard.as_ref() {
78 Some(weak) => match weak.upgrade() {
79 Some(arc) => arc,
80 None => return Ok(ToolResult::error("Registry no longer available")),
81 },
82 None => return Ok(ToolResult::error("Registry not initialized")),
83 }
84 };
85
86 let futures: Vec<_> = p.calls.iter().enumerate().map(|(i, call)| {
88 let tool_id = call.tool.clone();
89 let args = call.args.clone();
90 let registry = Arc::clone(®istry);
91
92 async move {
93 if tool_id == "batch" {
95 return (i, tool_id, ToolResult::error("Cannot call batch from within batch"));
96 }
97
98 match registry.get(&tool_id) {
99 Some(tool) => {
100 match tool.execute(args).await {
101 Ok(result) => (i, tool_id, result),
102 Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
103 }
104 }
105 None => {
106 let available_tools = registry.list().iter().map(|s| s.to_string()).collect();
108 let invalid_tool = super::invalid::InvalidTool::with_context(tool_id.clone(), available_tools);
109 let invalid_args = serde_json::json!({
110 "requested_tool": tool_id,
111 "args": args
112 });
113 match invalid_tool.execute(invalid_args).await {
114 Ok(result) => (i, tool_id.clone(), result),
115 Err(e) => (i, tool_id.clone(), ToolResult::error(format!("Unknown tool: {}. Error: {}", tool_id, e))),
116 }
117 }
118 }
119 }
120 }).collect();
121
122 let results = futures::future::join_all(futures).await;
123
124 let mut output_parts = Vec::new();
125 let mut success_count = 0;
126 let mut error_count = 0;
127
128 for (idx, tool_id, result) in results {
129 if result.success {
130 success_count += 1;
131 output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
132 } else {
133 error_count += 1;
134 output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
135 }
136 }
137
138 let summary = format!("Batch complete: {} succeeded, {} failed\n\n{}",
139 success_count, error_count, output_parts.join("\n\n"));
140
141 let overall_success = error_count == 0;
142 if overall_success {
143 Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
144 } else {
145 Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
146 }
147 }
148}