Skip to main content

codetether_agent/tool/
batch.rs

1//! Batch Tool - Execute multiple tool calls in parallel.
2
3use super::{Tool, ToolRegistry, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::sync::{Arc, RwLock, Weak};
9
10/// BatchTool executes multiple tool calls in parallel.
11/// Uses lazy registry initialization to break circular dependency.
12pub struct BatchTool {
13    registry: Arc<RwLock<Option<Weak<ToolRegistry>>>>,
14}
15
16impl BatchTool {
17    pub fn new() -> Self {
18        Self {
19            registry: Arc::new(RwLock::new(None)),
20        }
21    }
22
23    /// Set the registry after construction to break circular dependency.
24    pub fn set_registry(&self, registry: Weak<ToolRegistry>) {
25        let mut guard = self.registry.write().unwrap();
26        *guard = Some(registry);
27    }
28}
29
30#[derive(Deserialize)]
31struct Params {
32    calls: Vec<BatchCall>,
33}
34
35#[derive(Deserialize)]
36struct BatchCall {
37    tool: String,
38    args: Value,
39}
40
41#[async_trait]
42impl Tool for BatchTool {
43    fn id(&self) -> &str {
44        "batch"
45    }
46    fn name(&self) -> &str {
47        "Batch Execute"
48    }
49    fn description(&self) -> &str {
50        "Execute multiple tool calls in parallel. Each call specifies a tool name and arguments."
51    }
52    fn parameters(&self) -> Value {
53        json!({
54            "type": "object",
55            "properties": {
56                "calls": {
57                    "type": "array",
58                    "description": "Array of tool calls to execute",
59                    "items": {
60                        "type": "object",
61                        "properties": {
62                            "tool": {"type": "string", "description": "Tool ID to call"},
63                            "args": {"type": "object", "description": "Arguments for the tool"}
64                        },
65                        "required": ["tool", "args"]
66                    }
67                }
68            },
69            "required": ["calls"]
70        })
71    }
72
73    async fn execute(&self, params: Value) -> Result<ToolResult> {
74        let p: Params = serde_json::from_value(params).context("Invalid params")?;
75
76        if p.calls.is_empty() {
77            return Ok(ToolResult::error("No calls provided"));
78        }
79
80        // Get registry from weak reference
81        let registry = {
82            let guard = self.registry.read().unwrap();
83            match guard.as_ref() {
84                Some(weak) => match weak.upgrade() {
85                    Some(arc) => arc,
86                    None => return Ok(ToolResult::error("Registry no longer available")),
87                },
88                None => return Ok(ToolResult::error("Registry not initialized")),
89            }
90        };
91
92        // Execute all calls in parallel
93        let futures: Vec<_> = p
94            .calls
95            .iter()
96            .enumerate()
97            .map(|(i, call)| {
98                let tool_id = call.tool.clone();
99                let args = call.args.clone();
100                let registry = Arc::clone(&registry);
101
102                async move {
103                    // Prevent recursive batch calls
104                    if tool_id == "batch" {
105                        return (
106                            i,
107                            tool_id,
108                            ToolResult::error("Cannot call batch from within batch"),
109                        );
110                    }
111
112                    match registry.get(&tool_id) {
113                        Some(tool) => match tool.execute(args).await {
114                            Ok(result) => (i, tool_id, result),
115                            Err(e) => (i, tool_id, ToolResult::error(format!("Error: {}", e))),
116                        },
117                        None => {
118                            // Use the invalid tool handler for better error messages
119                            let available_tools =
120                                registry.list().iter().map(|s| s.to_string()).collect();
121                            let invalid_tool = super::invalid::InvalidTool::with_context(
122                                tool_id.clone(),
123                                available_tools,
124                            );
125                            let invalid_args = serde_json::json!({
126                                "requested_tool": tool_id,
127                                "args": args
128                            });
129                            match invalid_tool.execute(invalid_args).await {
130                                Ok(result) => (i, tool_id.clone(), result),
131                                Err(e) => (
132                                    i,
133                                    tool_id.clone(),
134                                    ToolResult::error(format!(
135                                        "Unknown tool: {}. Error: {}",
136                                        tool_id, e
137                                    )),
138                                ),
139                            }
140                        }
141                    }
142                }
143            })
144            .collect();
145
146        let results = futures::future::join_all(futures).await;
147
148        let mut output_parts = Vec::new();
149        let mut success_count = 0;
150        let mut error_count = 0;
151
152        for (idx, tool_id, result) in results {
153            if result.success {
154                success_count += 1;
155                output_parts.push(format!("[{}] ✓ {}:\n{}", idx + 1, tool_id, result.output));
156            } else {
157                error_count += 1;
158                output_parts.push(format!("[{}] ✗ {}:\n{}", idx + 1, tool_id, result.output));
159            }
160        }
161
162        let summary = format!(
163            "Batch complete: {} succeeded, {} failed\n\n{}",
164            success_count,
165            error_count,
166            output_parts.join("\n\n")
167        );
168
169        let overall_success = error_count == 0;
170        if overall_success {
171            Ok(ToolResult::success(summary).with_metadata("success_count", json!(success_count)))
172        } else {
173            Ok(ToolResult::error(summary).with_metadata("error_count", json!(error_count)))
174        }
175    }
176}