batuta/agent/tool/
compute.rs1use std::time::Duration;
14
15use async_trait::async_trait;
16
17use crate::agent::capability::Capability;
18use crate::agent::driver::ToolDefinition;
19
20use super::{Tool, ToolResult};
21
22const MAX_TASK_OUTPUT_BYTES: usize = 16384;
24
25pub struct ComputeTool {
33 max_concurrent: usize,
35 task_timeout: Duration,
37 working_dir: String,
39}
40
41impl ComputeTool {
42 pub fn new(working_dir: String) -> Self {
44 Self { max_concurrent: 4, task_timeout: Duration::from_secs(300), working_dir }
45 }
46
47 #[must_use]
49 pub fn with_max_concurrent(mut self, max: usize) -> Self {
50 self.max_concurrent = max;
51 self
52 }
53
54 #[must_use]
56 pub fn with_timeout(mut self, timeout: Duration) -> Self {
57 self.task_timeout = timeout;
58 self
59 }
60
61 fn truncate_output(output: &str) -> String {
63 if output.len() <= MAX_TASK_OUTPUT_BYTES {
64 return output.to_string();
65 }
66 let truncated = &output[..MAX_TASK_OUTPUT_BYTES];
67 format!(
68 "{truncated}\n\n[output truncated at \
69 {MAX_TASK_OUTPUT_BYTES} bytes]"
70 )
71 }
72
73 async fn execute_task(&self, command: &str) -> ToolResult {
75 let output = tokio::process::Command::new("sh")
76 .arg("-c")
77 .arg(command)
78 .current_dir(&self.working_dir)
79 .output()
80 .await;
81
82 match output {
83 Ok(out) => {
84 let stdout = String::from_utf8_lossy(&out.stdout);
85 let stderr = String::from_utf8_lossy(&out.stderr);
86 let exit = out.status.code().unwrap_or(-1);
87
88 if out.status.success() {
89 let result = if stderr.is_empty() {
90 Self::truncate_output(&stdout)
91 } else {
92 Self::truncate_output(&format!("{stdout}\nstderr:\n{stderr}"))
93 };
94 ToolResult::success(result)
95 } else {
96 ToolResult::error(format!(
97 "exit code {exit}:\n{}",
98 Self::truncate_output(&format!("{stdout}{stderr}"))
99 ))
100 }
101 }
102 Err(e) => ToolResult::error(format!("task exec failed: {e}")),
103 }
104 }
105
106 async fn execute_parallel(&self, commands: &[String]) -> ToolResult {
108 use std::fmt::Write;
109 let limited = if commands.len() > self.max_concurrent {
110 &commands[..self.max_concurrent]
111 } else {
112 commands
113 };
114
115 let working_dir = self.working_dir.clone();
116 let mut join_set = tokio::task::JoinSet::new();
117
118 for (i, cmd) in limited.iter().enumerate() {
119 let cmd = cmd.clone();
120 let wd = working_dir.clone();
121 join_set.spawn(async move {
122 let output = tokio::process::Command::new("sh")
123 .arg("-c")
124 .arg(&cmd)
125 .current_dir(&wd)
126 .output()
127 .await;
128 (i, output)
129 });
130 }
131
132 let mut results: Vec<(usize, ToolResult)> = Vec::with_capacity(limited.len());
133
134 while let Some(res) = join_set.join_next().await {
135 match res {
136 Ok((i, Ok(out))) => {
137 let stdout = String::from_utf8_lossy(&out.stdout);
138 let stderr = String::from_utf8_lossy(&out.stderr);
139 if out.status.success() {
140 results.push((i, ToolResult::success(stdout.to_string())));
141 } else {
142 let exit = out.status.code().unwrap_or(-1);
143 results
144 .push((i, ToolResult::error(format!("exit {exit}: {stdout}{stderr}"))));
145 }
146 }
147 Ok((i, Err(e))) => {
148 results.push((i, ToolResult::error(format!("spawn failed: {e}"))));
149 }
150 Err(e) => {
151 results.push((results.len(), ToolResult::error(format!("join failed: {e}"))));
152 }
153 }
154 }
155
156 results.sort_by_key(|(i, _)| *i);
157
158 let mut output = String::new();
159 for (i, result) in &results {
160 let _ = write!(
161 output,
162 "=== Task {} ===\n{}\n\n",
163 i + 1,
164 if result.is_error {
165 format!("ERROR: {}", result.content)
166 } else {
167 result.content.clone()
168 }
169 );
170 }
171
172 let any_error = results.iter().any(|(_, r)| r.is_error);
173 if any_error {
174 ToolResult::error(Self::truncate_output(&output))
175 } else {
176 ToolResult::success(Self::truncate_output(&output))
177 }
178 }
179}
180
181#[async_trait]
182impl Tool for ComputeTool {
183 fn name(&self) -> &'static str {
184 "compute"
185 }
186
187 fn definition(&self) -> ToolDefinition {
188 ToolDefinition {
189 name: "compute".into(),
190 description: format!(
191 "Execute compute tasks in parallel \
192 (max {} concurrent). Runs shell commands \
193 on available workers.",
194 self.max_concurrent
195 ),
196 input_schema: serde_json::json!({
197 "type": "object",
198 "required": ["action"],
199 "properties": {
200 "action": {
201 "type": "string",
202 "enum": ["run", "parallel"],
203 "description": "Action: 'run' for single task, 'parallel' for multiple"
204 },
205 "command": {
206 "type": "string",
207 "description": "Shell command for 'run' action"
208 },
209 "commands": {
210 "type": "array",
211 "items": {"type": "string"},
212 "description": "Shell commands for 'parallel' action"
213 }
214 }
215 }),
216 }
217 }
218
219 async fn execute(&self, input: serde_json::Value) -> ToolResult {
220 let action = match input.get("action").and_then(|v| v.as_str()) {
221 Some(a) => a.to_string(),
222 None => {
223 return ToolResult::error("missing required field 'action'");
224 }
225 };
226
227 match action.as_str() {
228 "run" => {
229 let Some(command) = input.get("command").and_then(|v| v.as_str()) else {
230 return ToolResult::error("missing 'command' for 'run'");
231 };
232 self.execute_task(command).await
233 }
234 "parallel" => {
235 let commands = match input.get("commands").and_then(|v| v.as_array()) {
236 Some(arr) => {
237 arr.iter().filter_map(|v| v.as_str().map(String::from)).collect::<Vec<_>>()
238 }
239 None => {
240 return ToolResult::error("missing 'commands' for 'parallel'");
241 }
242 };
243 if commands.is_empty() {
244 return ToolResult::error("'commands' array is empty");
245 }
246 self.execute_parallel(&commands).await
247 }
248 other => {
249 ToolResult::error(format!("unknown action '{other}', use 'run' or 'parallel'"))
250 }
251 }
252 }
253
254 fn required_capability(&self) -> Capability {
255 Capability::Compute
256 }
257
258 fn timeout(&self) -> Duration {
259 self.task_timeout
260 }
261}
262
263#[cfg(test)]
264#[path = "compute_tests.rs"]
265mod tests;