Skip to main content

pawan/tools/
batch.rs

1//! Batch tool: execute multiple tool calls concurrently.
2//!
3//! Input: {"calls":[{"tool":"read_file","input":{...}}, ...]}
4//!
5//! Notes:
6//! - Limits to 25 calls.
7//! - Rejects nested batch calls.
8//! - Accepts both nested (input/parameters) and flat parameter formats.
9
10use crate::tools::{Tool, ToolRegistry};
11use async_trait::async_trait;
12use futures::future::join_all;
13use serde::de::{self, MapAccess, Visitor};
14use serde::{Deserialize, Deserializer};
15use serde_json::{json, Value};
16use std::path::PathBuf;
17use std::sync::{
18    atomic::{AtomicUsize, Ordering},
19    Arc,
20};
21
22const MAX_BATCH_SIZE: usize = 25;
23
24#[derive(Debug, Clone)]
25struct BatchCall {
26    tool: String,
27    input: Value,
28}
29
30/// Flexible entry deserializer.
31///
32/// Accepts either:
33/// - { tool: "read_file", input: { path: "..." } }
34/// - { tool: "read_file", parameters: { path: "..." } }   (alias)
35/// - { tool: "read_file", path: "..." }                  (flat)
36/// - { tool: "read_file", input: {...}, path: "..." }    (merge; duplicates rejected)
37impl<'de> Deserialize<'de> for BatchCall {
38    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
39        struct V;
40
41        impl<'de> Visitor<'de> for V {
42            type Value = BatchCall;
43
44            fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45                f.write_str("a batch call with 'tool' and either 'input'/'parameters' or flat args")
46            }
47
48            fn visit_map<M: MapAccess<'de>>(self, mut map: M) -> Result<BatchCall, M::Error> {
49                let mut tool: Option<String> = None;
50                let mut nested: Option<Value> = None;
51                let mut rest = serde_json::Map::new();
52
53                while let Some(key) = map.next_key::<String>()? {
54                    match key.as_str() {
55                        "tool" => tool = Some(map.next_value()?),
56                        "input" | "parameters" => nested = Some(map.next_value()?),
57                        _ => {
58                            rest.insert(key, map.next_value()?);
59                        }
60                    }
61                }
62
63                let tool = tool.ok_or_else(|| de::Error::missing_field("tool"))?;
64                let input = match nested {
65                    Some(v) if rest.is_empty() => v,
66                    Some(Value::Object(mut obj)) => {
67                        for (k, v) in rest {
68                            if obj.contains_key(&k) {
69                                return Err(de::Error::custom(format_args!(
70                                    "duplicate parameter '{k}' in both nested input and flat fields"
71                                )));
72                            }
73                            obj.insert(k, v);
74                        }
75                        Value::Object(obj)
76                    }
77                    Some(_) => {
78                        return Err(de::Error::custom(
79                            "'input'/'parameters' must be an object when flat fields are also present",
80                        ));
81                    }
82                    None if !rest.is_empty() => Value::Object(rest),
83                    None => return Err(de::Error::missing_field("input")),
84                };
85
86                Ok(BatchCall { tool, input })
87            }
88        }
89
90        deserializer.deserialize_map(V)
91    }
92}
93
94#[derive(Debug, Clone, Deserialize)]
95struct BatchArgs {
96    calls: Vec<BatchCall>,
97}
98
99#[derive(Clone)]
100pub struct BatchTool {
101    workspace_root: PathBuf,
102}
103
104impl BatchTool {
105    pub fn new(workspace_root: PathBuf) -> Self {
106        Self { workspace_root }
107    }
108}
109
110#[async_trait]
111impl Tool for BatchTool {
112    fn name(&self) -> &str {
113        "batch"
114    }
115
116    fn description(&self) -> &str {
117        "Execute up to 25 tool calls concurrently and return an array of results."
118    }
119
120    fn mutating(&self) -> bool {
121        false
122    }
123
124    fn parameters_schema(&self) -> Value {
125        json!({
126            "type": "object",
127            "properties": {
128                "calls": {
129                    "type": "array",
130                    "description": "Array of tool calls: {tool: string, input: object} or flat {tool: string, ...args}",
131                    "items": { "type": "object" }
132                }
133            },
134            "required": ["calls"]
135        })
136    }
137
138    async fn execute(&self, args: Value) -> crate::Result<Value> {
139        let parsed: BatchArgs = serde_json::from_value(args)
140            .map_err(|e| crate::PawanError::Tool(format!("invalid batch args: {e}")))?;
141
142        if parsed.calls.is_empty() {
143            return Ok(Value::Array(vec![]));
144        }
145
146        let active_len = parsed.calls.len().min(MAX_BATCH_SIZE);
147        if parsed.calls.len() > MAX_BATCH_SIZE {
148            tracing::warn!(
149                total = parsed.calls.len(),
150                used = active_len,
151                limit = MAX_BATCH_SIZE,
152                "batch: truncating calls over limit"
153            );
154        }
155
156        let calls = parsed
157            .calls
158            .into_iter()
159            .take(active_len)
160            .collect::<Vec<_>>();
161        let total = calls.len();
162        let completed = Arc::new(AtomicUsize::new(0));
163
164        let registry = Arc::new(ToolRegistry::with_defaults(self.workspace_root.clone()));
165
166        let futs = calls.into_iter().map(|call| {
167            let registry = Arc::clone(&registry);
168            let completed = Arc::clone(&completed);
169            async move {
170                let out = if call.tool == "batch" {
171                    json!({"error": "cannot nest batch inside batch"})
172                } else {
173                    match registry.execute(&call.tool, call.input).await {
174                        Ok(v) => v,
175                        Err(e) => json!({"error": e.to_string()}),
176                    }
177                };
178
179                let done = completed.fetch_add(1, Ordering::Relaxed) + 1;
180                tracing::info!(completed = done, total, "BatchProgress");
181
182                out
183            }
184        });
185
186        let results = join_all(futs).await;
187        Ok(Value::Array(results))
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use serde_json::json;
195
196    #[tokio::test]
197    async fn batch_three_reads_returns_all_contents() {
198        let dir = tempfile::tempdir().unwrap();
199        std::fs::write(dir.path().join("a.txt"), "A").unwrap();
200        std::fs::write(dir.path().join("b.txt"), "B").unwrap();
201        std::fs::write(dir.path().join("c.txt"), "C").unwrap();
202
203        let tool = BatchTool::new(dir.path().to_path_buf());
204        let out = tool
205            .execute(json!({
206                "calls": [
207                    {"tool": "read_file", "input": {"path": "a.txt"}},
208                    {"tool": "read_file", "input": {"path": "b.txt"}},
209                    {"tool": "read_file", "input": {"path": "c.txt"}}
210                ]
211            }))
212            .await
213            .unwrap();
214
215        let arr = out.as_array().unwrap();
216        assert_eq!(arr.len(), 3);
217        assert!(arr[0]
218            .get("content")
219            .and_then(|v| v.as_str())
220            .unwrap()
221            .contains("A"));
222        assert!(arr[1]
223            .get("content")
224            .and_then(|v| v.as_str())
225            .unwrap()
226            .contains("B"));
227        assert!(arr[2]
228            .get("content")
229            .and_then(|v| v.as_str())
230            .unwrap()
231            .contains("C"));
232    }
233
234    #[tokio::test]
235    async fn batch_unknown_tool_is_partial_success() {
236        let dir = tempfile::tempdir().unwrap();
237        std::fs::write(dir.path().join("ok.txt"), "OK").unwrap();
238
239        let tool = BatchTool::new(dir.path().to_path_buf());
240        let out = tool
241            .execute(json!({
242                "calls": [
243                    {"tool": "read_file", "input": {"path": "ok.txt"}},
244                    {"tool": "no_such_tool", "input": {}}
245                ]
246            }))
247            .await
248            .unwrap();
249
250        let arr = out.as_array().unwrap();
251        assert_eq!(arr.len(), 2);
252        assert!(arr[0].get("content").is_some());
253        let err = arr[1].get("error").and_then(|v| v.as_str()).unwrap();
254        assert!(!err.is_empty());
255    }
256
257    #[tokio::test]
258    async fn nested_batch_is_rejected() {
259        let dir = tempfile::tempdir().unwrap();
260        let tool = BatchTool::new(dir.path().to_path_buf());
261
262        let out = tool
263            .execute(json!({
264                "calls": [
265                    {"tool": "batch", "input": {"calls": []}}
266                ]
267            }))
268            .await
269            .unwrap();
270
271        let arr = out.as_array().unwrap();
272        assert_eq!(arr.len(), 1);
273        assert_eq!(
274            arr[0].get("error").and_then(|v| v.as_str()).unwrap(),
275            "cannot nest batch inside batch"
276        );
277    }
278
279    #[tokio::test]
280    async fn accepts_flat_call_format() {
281        let dir = tempfile::tempdir().unwrap();
282        std::fs::write(dir.path().join("x.txt"), "X").unwrap();
283
284        let tool = BatchTool::new(dir.path().to_path_buf());
285        let out = tool
286            .execute(json!({
287                "calls": [
288                    {"tool": "read_file", "path": "x.txt"}
289                ]
290            }))
291            .await
292            .unwrap();
293
294        let arr = out.as_array().unwrap();
295        assert_eq!(arr.len(), 1);
296        assert!(arr[0]
297            .get("content")
298            .and_then(|v| v.as_str())
299            .unwrap()
300            .contains("X"));
301    }
302}