1use crate::tools::{Tool, ToolRegistry};
11use async_trait::async_trait;
12use futures::future::join_all;
13use serde::de::{self, MapAccess, Visitor};
14use serde::{Deserialize, Deserializer};
15use serde_json::{json, Value};
16use std::path::PathBuf;
17use std::sync::{
18 atomic::{AtomicUsize, Ordering},
19 Arc,
20};
21
22const MAX_BATCH_SIZE: usize = 25;
23
24#[derive(Debug, Clone)]
25struct BatchCall {
26 tool: String,
27 input: Value,
28}
29
30impl<'de> Deserialize<'de> for BatchCall {
38 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
39 struct V;
40
41 impl<'de> Visitor<'de> for V {
42 type Value = BatchCall;
43
44 fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45 f.write_str("a batch call with 'tool' and either 'input'/'parameters' or flat args")
46 }
47
48 fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<BatchCall, M::Error> {
49 let mut tool: Option<String> = None;
50 let mut nested: Option<Value> = None;
51 let mut rest = serde_json::Map::new();
52
53 while let Some(key) = map.next_key::<String>()? {
54 match key.as_str() {
55 "tool" => tool = Some(map.next_value()?),
56 "input" | "parameters" => nested = Some(map.next_value()?),
57 _ => {
58 rest.insert(key, map.next_value()?);
59 }
60 }
61 }
62
63 let tool = tool.ok_or_else(|| de::Error::missing_field("tool"))?;
64 let input = match nested {
65 Some(v) if rest.is_empty() => v,
66 Some(Value::Object(mut obj)) => {
67 for (k, v) in rest {
68 if obj.contains_key(&k) {
69 return Err(de::Error::custom(format_args!(
70 "duplicate parameter '{k}' in both nested input and flat fields"
71 )));
72 }
73 obj.insert(k, v);
74 }
75 Value::Object(obj)
76 }
77 Some(_) => {
78 return Err(de::Error::custom(
79 "'input'/'parameters' must be an object when flat fields are also present",
80 ));
81 }
82 None if !rest.is_empty() => Value::Object(rest),
83 None => return Err(de::Error::missing_field("input")),
84 };
85
86 Ok(BatchCall { tool, input })
87 }
88 }
89
90 deserializer.deserialize_map(V)
91 }
92}
93
94#[derive(Debug, Clone, Deserialize)]
95struct BatchArgs {
96 calls: Vec<BatchCall>,
97}
98
99#[derive(Clone)]
100pub struct BatchTool {
101 workspace_root: PathBuf,
102}
103
104impl BatchTool {
105 pub fn new(workspace_root: PathBuf) -> Self {
106 Self { workspace_root }
107 }
108}
109
110#[async_trait]
111impl Tool for BatchTool {
112 fn name(&self) -> &str {
113 "batch"
114 }
115
116 fn description(&self) -> &str {
117 "Execute up to 25 tool calls concurrently and return an array of results."
118 }
119
120 fn mutating(&self) -> bool {
121 false
122 }
123
124 fn parameters_schema(&self) -> Value {
125 json!({
126 "type": "object",
127 "properties": {
128 "calls": {
129 "type": "array",
130 "description": "Array of tool calls: {tool: string, input: object} or flat {tool: string, ...args}",
131 "items": { "type": "object" }
132 }
133 },
134 "required": ["calls"]
135 })
136 }
137
138 async fn execute(&self, args: Value) -> crate::Result<Value> {
139 let parsed: BatchArgs = serde_json::from_value(args)
140 .map_err(|e| crate::PawanError::Tool(format!("invalid batch args: {e}")))?;
141
142 if parsed.calls.is_empty() {
143 return Ok(Value::Array(vec![]));
144 }
145
146 let active_len = parsed.calls.len().min(MAX_BATCH_SIZE);
147 if parsed.calls.len() > MAX_BATCH_SIZE {
148 tracing::warn!(
149 total = parsed.calls.len(),
150 used = active_len,
151 limit = MAX_BATCH_SIZE,
152 "batch: truncating calls over limit"
153 );
154 }
155
156 let calls = parsed
157 .calls
158 .into_iter()
159 .take(active_len)
160 .collect::<Vec<_>>();
161 let total = calls.len();
162 let completed = Arc::new(AtomicUsize::new(0));
163
164 let registry = Arc::new(ToolRegistry::with_defaults(self.workspace_root.clone()));
165
166 let futs = calls.into_iter().map(|call| {
167 let registry = Arc::clone(®istry);
168 let completed = Arc::clone(&completed);
169 async move {
170 let out = if call.tool == "batch" {
171 json!({"error": "cannot nest batch inside batch"})
172 } else {
173 match registry.execute(&call.tool, call.input).await {
174 Ok(v) => v,
175 Err(e) => json!({"error": e.to_string()}),
176 }
177 };
178
179 let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
180 tracing::info!(completed = done, total, "BatchProgress");
181
182 out
183 }
184 });
185
186 let results = join_all(futs).await;
187 Ok(Value::Array(results))
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use serde_json::json;
195
196 #[tokio::test]
197 async fn batch_three_reads_returns_all_contents() {
198 let dir = tempfile::tempdir().unwrap();
199 std::fs::write(dir.path().join("a.txt"), "A").unwrap();
200 std::fs::write(dir.path().join("b.txt"), "B").unwrap();
201 std::fs::write(dir.path().join("c.txt"), "C").unwrap();
202
203 let tool = BatchTool::new(dir.path().to_path_buf());
204 let out = tool
205 .execute(json!({
206 "calls": [
207 {"tool": "read_file", "input": {"path": "a.txt"}},
208 {"tool": "read_file", "input": {"path": "b.txt"}},
209 {"tool": "read_file", "input": {"path": "c.txt"}}
210 ]
211 }))
212 .await
213 .unwrap();
214
215 let arr = out.as_array().unwrap();
216 assert_eq!(arr.len(), 3);
217 assert!(arr[0]
218 .get("content")
219 .and_then(|v| v.as_str())
220 .unwrap()
221 .contains("A"));
222 assert!(arr[1]
223 .get("content")
224 .and_then(|v| v.as_str())
225 .unwrap()
226 .contains("B"));
227 assert!(arr[2]
228 .get("content")
229 .and_then(|v| v.as_str())
230 .unwrap()
231 .contains("C"));
232 }
233
234 #[tokio::test]
235 async fn batch_unknown_tool_is_partial_success() {
236 let dir = tempfile::tempdir().unwrap();
237 std::fs::write(dir.path().join("ok.txt"), "OK").unwrap();
238
239 let tool = BatchTool::new(dir.path().to_path_buf());
240 let out = tool
241 .execute(json!({
242 "calls": [
243 {"tool": "read_file", "input": {"path": "ok.txt"}},
244 {"tool": "no_such_tool", "input": {}}
245 ]
246 }))
247 .await
248 .unwrap();
249
250 let arr = out.as_array().unwrap();
251 assert_eq!(arr.len(), 2);
252 assert!(arr[0].get("content").is_some());
253 let err = arr[1].get("error").and_then(|v| v.as_str()).unwrap();
254 assert!(!err.is_empty());
255 }
256
257 #[tokio::test]
258 async fn nested_batch_is_rejected() {
259 let dir = tempfile::tempdir().unwrap();
260 let tool = BatchTool::new(dir.path().to_path_buf());
261
262 let out = tool
263 .execute(json!({
264 "calls": [
265 {"tool": "batch", "input": {"calls": []}}
266 ]
267 }))
268 .await
269 .unwrap();
270
271 let arr = out.as_array().unwrap();
272 assert_eq!(arr.len(), 1);
273 assert_eq!(
274 arr[0].get("error").and_then(|v| v.as_str()).unwrap(),
275 "cannot nest batch inside batch"
276 );
277 }
278
279 #[tokio::test]
280 async fn accepts_flat_call_format() {
281 let dir = tempfile::tempdir().unwrap();
282 std::fs::write(dir.path().join("x.txt"), "X").unwrap();
283
284 let tool = BatchTool::new(dir.path().to_path_buf());
285 let out = tool
286 .execute(json!({
287 "calls": [
288 {"tool": "read_file", "path": "x.txt"}
289 ]
290 }))
291 .await
292 .unwrap();
293
294 let arr = out.as_array().unwrap();
295 assert_eq!(arr.len(), 1);
296 assert!(arr[0]
297 .get("content")
298 .and_then(|v| v.as_str())
299 .unwrap()
300 .contains("X"));
301 }
302}