Skip to main content

batuta/agent/tool/
compute.rs

1//! Compute tool for distributed task submission.
2//!
3//! Wraps `repartir::Pool` for parallel task execution across
4//! CPU, GPU, or remote workers. The agent submits shell-based
5//! tasks and receives their output.
6//!
7//! Phase 3: Requires `Capability::Compute` and the `distributed`
8//! feature (which pulls in repartir).
9//!
10//! Security: Tasks are validated before submission — only
11//! commands matching the allowed list can execute (Poka-Yoke).
12
13use std::time::Duration;
14
15use async_trait::async_trait;
16
17use crate::agent::capability::Capability;
18use crate::agent::driver::ToolDefinition;
19
20use super::{Tool, ToolResult};
21
22/// Maximum output bytes from a compute task.
23const MAX_TASK_OUTPUT_BYTES: usize = 16384;
24
25/// Compute tool for distributed task execution.
26///
27/// Submits tasks to a repartir compute pool. Tasks are
28/// shell commands executed on available workers (CPU/GPU/Remote).
29///
30/// Requires `Capability::Compute` — the agent manifest must
31/// explicitly grant compute access.
32pub struct ComputeTool {
33    /// Maximum concurrent tasks.
34    max_concurrent: usize,
35    /// Timeout per task.
36    task_timeout: Duration,
37    /// Working directory for task execution.
38    working_dir: String,
39}
40
41impl ComputeTool {
42    /// Create a new compute tool.
43    pub fn new(working_dir: String) -> Self {
44        Self { max_concurrent: 4, task_timeout: Duration::from_secs(300), working_dir }
45    }
46
47    /// Set maximum concurrent tasks.
48    #[must_use]
49    pub fn with_max_concurrent(mut self, max: usize) -> Self {
50        self.max_concurrent = max;
51        self
52    }
53
54    /// Set task timeout.
55    #[must_use]
56    pub fn with_timeout(mut self, timeout: Duration) -> Self {
57        self.task_timeout = timeout;
58        self
59    }
60
61    /// Truncate output to prevent context overflow.
62    fn truncate_output(output: &str) -> String {
63        if output.len() <= MAX_TASK_OUTPUT_BYTES {
64            return output.to_string();
65        }
66        let truncated = &output[..MAX_TASK_OUTPUT_BYTES];
67        format!(
68            "{truncated}\n\n[output truncated at \
69             {MAX_TASK_OUTPUT_BYTES} bytes]"
70        )
71    }
72
73    /// Execute a single task via tokio subprocess.
74    async fn execute_task(&self, command: &str) -> ToolResult {
75        let output = tokio::process::Command::new("sh")
76            .arg("-c")
77            .arg(command)
78            .current_dir(&self.working_dir)
79            .output()
80            .await;
81
82        match output {
83            Ok(out) => {
84                let stdout = String::from_utf8_lossy(&out.stdout);
85                let stderr = String::from_utf8_lossy(&out.stderr);
86                let exit = out.status.code().unwrap_or(-1);
87
88                if out.status.success() {
89                    let result = if stderr.is_empty() {
90                        Self::truncate_output(&stdout)
91                    } else {
92                        Self::truncate_output(&format!("{stdout}\nstderr:\n{stderr}"))
93                    };
94                    ToolResult::success(result)
95                } else {
96                    ToolResult::error(format!(
97                        "exit code {exit}:\n{}",
98                        Self::truncate_output(&format!("{stdout}{stderr}"))
99                    ))
100                }
101            }
102            Err(e) => ToolResult::error(format!("task exec failed: {e}")),
103        }
104    }
105
106    /// Execute multiple tasks in parallel using `JoinSet`.
107    async fn execute_parallel(&self, commands: &[String]) -> ToolResult {
108        use std::fmt::Write;
109        let limited = if commands.len() > self.max_concurrent {
110            &commands[..self.max_concurrent]
111        } else {
112            commands
113        };
114
115        let working_dir = self.working_dir.clone();
116        let mut join_set = tokio::task::JoinSet::new();
117
118        for (i, cmd) in limited.iter().enumerate() {
119            let cmd = cmd.clone();
120            let wd = working_dir.clone();
121            join_set.spawn(async move {
122                let output = tokio::process::Command::new("sh")
123                    .arg("-c")
124                    .arg(&cmd)
125                    .current_dir(&wd)
126                    .output()
127                    .await;
128                (i, output)
129            });
130        }
131
132        let mut results: Vec<(usize, ToolResult)> = Vec::with_capacity(limited.len());
133
134        while let Some(res) = join_set.join_next().await {
135            match res {
136                Ok((i, Ok(out))) => {
137                    let stdout = String::from_utf8_lossy(&out.stdout);
138                    let stderr = String::from_utf8_lossy(&out.stderr);
139                    if out.status.success() {
140                        results.push((i, ToolResult::success(stdout.to_string())));
141                    } else {
142                        let exit = out.status.code().unwrap_or(-1);
143                        results
144                            .push((i, ToolResult::error(format!("exit {exit}: {stdout}{stderr}"))));
145                    }
146                }
147                Ok((i, Err(e))) => {
148                    results.push((i, ToolResult::error(format!("spawn failed: {e}"))));
149                }
150                Err(e) => {
151                    results.push((results.len(), ToolResult::error(format!("join failed: {e}"))));
152                }
153            }
154        }
155
156        results.sort_by_key(|(i, _)| *i);
157
158        let mut output = String::new();
159        for (i, result) in &results {
160            let _ = write!(
161                output,
162                "=== Task {} ===\n{}\n\n",
163                i + 1,
164                if result.is_error {
165                    format!("ERROR: {}", result.content)
166                } else {
167                    result.content.clone()
168                }
169            );
170        }
171
172        let any_error = results.iter().any(|(_, r)| r.is_error);
173        if any_error {
174            ToolResult::error(Self::truncate_output(&output))
175        } else {
176            ToolResult::success(Self::truncate_output(&output))
177        }
178    }
179}
180
181#[async_trait]
182impl Tool for ComputeTool {
183    fn name(&self) -> &'static str {
184        "compute"
185    }
186
187    fn definition(&self) -> ToolDefinition {
188        ToolDefinition {
189            name: "compute".into(),
190            description: format!(
191                "Execute compute tasks in parallel \
192                 (max {} concurrent). Runs shell commands \
193                 on available workers.",
194                self.max_concurrent
195            ),
196            input_schema: serde_json::json!({
197                "type": "object",
198                "required": ["action"],
199                "properties": {
200                    "action": {
201                        "type": "string",
202                        "enum": ["run", "parallel"],
203                        "description": "Action: 'run' for single task, 'parallel' for multiple"
204                    },
205                    "command": {
206                        "type": "string",
207                        "description": "Shell command for 'run' action"
208                    },
209                    "commands": {
210                        "type": "array",
211                        "items": {"type": "string"},
212                        "description": "Shell commands for 'parallel' action"
213                    }
214                }
215            }),
216        }
217    }
218
219    async fn execute(&self, input: serde_json::Value) -> ToolResult {
220        let action = match input.get("action").and_then(|v| v.as_str()) {
221            Some(a) => a.to_string(),
222            None => {
223                return ToolResult::error("missing required field 'action'");
224            }
225        };
226
227        match action.as_str() {
228            "run" => {
229                let Some(command) = input.get("command").and_then(|v| v.as_str()) else {
230                    return ToolResult::error("missing 'command' for 'run'");
231                };
232                self.execute_task(command).await
233            }
234            "parallel" => {
235                let commands = match input.get("commands").and_then(|v| v.as_array()) {
236                    Some(arr) => {
237                        arr.iter().filter_map(|v| v.as_str().map(String::from)).collect::<Vec<_>>()
238                    }
239                    None => {
240                        return ToolResult::error("missing 'commands' for 'parallel'");
241                    }
242                };
243                if commands.is_empty() {
244                    return ToolResult::error("'commands' array is empty");
245                }
246                self.execute_parallel(&commands).await
247            }
248            other => {
249                ToolResult::error(format!("unknown action '{other}', use 'run' or 'parallel'"))
250            }
251        }
252    }
253
254    fn required_capability(&self) -> Capability {
255        Capability::Compute
256    }
257
258    fn timeout(&self) -> Duration {
259        self.task_timeout
260    }
261}
262
263#[cfg(test)]
264#[path = "compute_tests.rs"]
265mod tests;