Skip to main content

codetether_agent/tool/
batch.rs

1//! Batch Tool - Execute multiple tool calls in parallel.
2
3use super::{Tool, ToolRegistry, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::sync::{Arc, RwLock, Weak};
9
10/// BatchTool executes multiple tool calls in parallel.
11/// Uses lazy registry initialization to break circular dependency.
12pub struct BatchTool {
13    registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl Default for BatchTool {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl BatchTool {
23    pub fn new() -> Self {
24        Self {
25            registry: Arc::new(RwLock::new(None)),
26        }
27    }
28
29    /// Set the registry after construction to break circular dependency.
30    pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
31        let mut guard = self.registry.write().unwrap();
32        *guard = Some(registry);
33    }
34}
35
36#[derive(Deserialize)]
37struct Params {
38    calls: Vec<BatchCall>,
39}
40
41#[derive(Deserialize)]
42struct BatchCall {
43    #[serde(alias = "name")]
44    tool: String,
45    #[serde(default, alias = "arguments", alias = "params")]
46    args: Value,
47}
48
49#[async_trait]
50impl Tool for BatchTool {
51    fn id(&self) -> &str {
52        "batch"
53    }
54    fn name(&self) -> &str {
55        "Batch Execute"
56    }
57    fn description(&self) -> &str {
58        "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments."
59    }
60    fn parameters(&self) -> Value {
61        json!({
62            "type": "object",
63            "properties": {
64                "calls": {
65                    "type": "array",
66                    "description": "Array of tool calls to execute. Preferred keys are `tool` + `args`; aliases `name` + `arguments` are also accepted for compatibility.",
67                    "items": {
68                        "type": "object",
69                        "properties": {
70                            "tool": {"type": "string", "description": "Tool ID to call (alias: `name`)"},
71                            "args": {"type": "object", "description": "Arguments for the tool (alias: `arguments`)"},
72                            "name": {"type": "string", "description": "Alias for `tool`"},
73                            "arguments": {"type": "object", "description": "Alias for `args`"}
74                        },
75                        "anyOf": [
76                            { "required": ["tool", "args"] },
77                            { "required": ["name", "arguments"] }
78                        ]
79                    }
80                }
81            },
82            "required": ["calls"]
83        })
84    }
85
86    async fn execute(&self, params: Value) -> Result<ToolResult> {
87        let p: Params = serde_json::from_value(params).context("Invalid params")?;
88
89        if p.calls.is_empty() {
90            return Ok(ToolResult::error("No calls provided"));
91        }
92
93        // Get registry from weak reference
94        let registry = {
95            let guard = self.registry.read().unwrap();
96            match guard.as_ref() {
97                Some(weak) => match weak.upgrade() {
98                    Some(arc) => arc,
99                    None => return Ok(ToolResult::error("Registry no longer available")),
100                },
101                None => return Ok(ToolResult::error("Registry not initialized")),
102            }
103        };
104
105        // Execute all calls in parallel
106        let futures: Vec<_> = p
107            .calls
108            .iter()
109            .enumerate()
110            .map(|(i, call)| {
111                let tool_id = call.tool.clone();
112                let args = call.args.clone();
113                let registry = Arc::clone(&registry);
114
115                async move {
116                    // Prevent recursive batch calls
117                    if tool_id == "batch" {
118                        return (
119                            i,
120                            tool_id,
121                            ToolResult::error("Cannot call batch from within batch"),
122                        );
123                    }
124
125                    match registry.get(&tool_id) {
126                        Some(tool) => match tool.execute(args).await {
127                            Ok(result) => (i, tool_id, result),
128                            Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
129                        },
130                        None => {
131                            // Use the invalid tool handler for better error messages
132                            let available_tools =
133                                registry.list().iter().map(|s| s.to_string()).collect();
134                            let invalid_tool = super::invalid::InvalidTool::with_context(
135                                tool_id.clone(),
136                                available_tools,
137                            );
138                            let invalid_args = serde_json::json!({
139                                "requested_tool": tool_id,
140                                "args": args
141                            });
142                            match invalid_tool.execute(invalid_args).await {
143                                Ok(result) => (i, tool_id.clone(), result),
144                                Err(e) => (
145                                    i,
146                                    tool_id.clone(),
147                                    ToolResult::error(format!(
148                                        "Unknown tool: {}. Error: {}",
149                                        tool_id, e
150                                    )),
151                                ),
152                            }
153                        }
154                    }
155                }
156            })
157            .collect();
158
159        let results = futures::future::join_all(futures).await;
160
161        let mut output_parts = Vec::new();
162        let mut success_count = 0;
163        let mut error_count = 0;
164
165        for (idx, tool_id, result) in results {
166            if result.success {
167                success_count += 1;
168                output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
169            } else {
170                error_count += 1;
171                output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
172            }
173        }
174
175        let summary = format!(
176            "Batch complete: {} succeeded, {} failed\n\n{}",
177            success_count,
178            error_count,
179            output_parts.join("\n\n")
180        );
181
182        let overall_success = error_count == 0;
183        if overall_success {
184            Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
185        } else {
186            Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
187        }
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::Params;
194
195    #[test]
196    fn batch_call_accepts_name_arguments_aliases() {
197        let params: Params = serde_json::from_value(serde_json::json!({
198            "calls": [
199                {
200                    "name": "read",
201                    "arguments": { "path": "src/main.rs" }
202                }
203            ]
204        }))
205        .expect("should parse alias form");
206
207        assert_eq!(params.calls.len(), 1);
208        assert_eq!(params.calls[0].tool, "read");
209        assert_eq!(params.calls[0].args["path"], "src/main.rs");
210    }
211
212    #[test]
213    fn batch_call_accepts_tool_args_primary_form() {
214        let params: Params = serde_json::from_value(serde_json::json!({
215            "calls": [
216                {
217                    "tool": "read",
218                    "args": { "path": "src/main.rs" }
219                }
220            ]
221        }))
222        .expect("should parse primary form");
223
224        assert_eq!(params.calls.len(), 1);
225        assert_eq!(params.calls[0].tool, "read");
226        assert_eq!(params.calls[0].args["path"], "src/main.rs");
227    }
228}