Skip to main content

agentzero_tools/
process_tool.rs

1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::sync::Mutex;
6use tokio::io::AsyncReadExt;
7use tokio::process::Command;
8
9const MAX_CONCURRENT: usize = 8;
10const MAX_OUTPUT_BYTES: usize = 512 * 1024;
11
12#[derive(Debug, Deserialize)]
13#[serde(tag = "action")]
14#[serde(rename_all = "snake_case")]
15enum ProcessAction {
16    Spawn { command: String },
17    List,
18    Output { id: usize },
19    Kill { id: usize },
20}
21
22struct ProcessEntry {
23    id: usize,
24    command: String,
25    handle: Option<tokio::task::JoinHandle<ProcessOutput>>,
26    result: Option<ProcessOutput>,
27}
28
29struct ProcessOutput {
30    exit_code: Option<i32>,
31    stdout: String,
32    stderr: String,
33}
34
35pub struct ProcessTool {
36    entries: Mutex<Vec<ProcessEntry>>,
37}
38
39impl Default for ProcessTool {
40    fn default() -> Self {
41        Self {
42            entries: Mutex::new(Vec::new()),
43        }
44    }
45}
46
47impl ProcessTool {
48    /// Extract finished handles from the entries (under lock) so we can await them
49    /// outside the lock. Returns (index, handle) pairs for finished tasks.
50    fn take_finished_handles(
51        entries: &mut [ProcessEntry],
52    ) -> Vec<(usize, tokio::task::JoinHandle<ProcessOutput>)> {
53        let mut finished = Vec::new();
54        for entry in entries.iter_mut() {
55            if entry.result.is_some() {
56                continue;
57            }
58            let is_finished = entry.handle.as_ref().is_some_and(|h| h.is_finished());
59            if is_finished {
60                if let Some(handle) = entry.handle.take() {
61                    finished.push((entry.id, handle));
62                }
63            }
64        }
65        finished
66    }
67
68    /// Store results back into entries after awaiting handles.
69    fn store_results(entries: &mut [ProcessEntry], results: Vec<(usize, ProcessOutput)>) {
70        for (id, output) in results {
71            if let Some(entry) = entries.iter_mut().find(|e| e.id == id) {
72                entry.result = Some(output);
73            }
74        }
75    }
76
77    /// Collect all finished process outputs. Must be called from async context.
78    async fn collect_finished(&self) {
79        let finished = {
80            let mut entries = match self.entries.lock() {
81                Ok(e) => e,
82                Err(_) => return,
83            };
84            Self::take_finished_handles(&mut entries)
85        };
86
87        if finished.is_empty() {
88            return;
89        }
90
91        let mut results = Vec::new();
92        for (id, handle) in finished {
93            let output = match handle.await {
94                Ok(o) => o,
95                Err(_) => ProcessOutput {
96                    exit_code: None,
97                    stdout: String::new(),
98                    stderr: "(task panicked)".to_string(),
99                },
100            };
101            results.push((id, output));
102        }
103
104        if let Ok(mut entries) = self.entries.lock() {
105            Self::store_results(&mut entries, results);
106        }
107    }
108}
109
110#[async_trait]
111impl Tool for ProcessTool {
112    fn name(&self) -> &'static str {
113        "process"
114    }
115
116    fn description(&self) -> &'static str {
117        "Manage long-running background processes: start, stop, list, or read output."
118    }
119
120    fn input_schema(&self) -> Option<serde_json::Value> {
121        Some(serde_json::json!({
122            "type": "object",
123            "properties": {
124                "action": { "type": "string", "enum": ["spawn", "list", "output", "kill"], "description": "The process action to perform" },
125                "command": { "type": "string", "description": "Shell command to run (for spawn)" },
126                "id": { "type": "integer", "description": "Process ID (for output/kill)" }
127            },
128            "required": ["action"],
129            "additionalProperties": false
130        }))
131    }
132
133    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
134        let action: ProcessAction =
135            serde_json::from_str(input).context("process expects JSON with \"action\" field")?;
136
137        match action {
138            ProcessAction::Spawn { command } => {
139                if command.trim().is_empty() {
140                    return Err(anyhow!("command must not be empty"));
141                }
142                let mut entries = self.entries.lock().map_err(|_| anyhow!("lock poisoned"))?;
143                let active = entries.iter().filter(|e| e.result.is_none()).count();
144                if active >= MAX_CONCURRENT {
145                    return Err(anyhow!(
146                        "max concurrent processes reached ({MAX_CONCURRENT})"
147                    ));
148                }
149                let id = entries.len();
150                let workspace_root = ctx.workspace_root.clone();
151                let cmd = command.clone();
152
153                let handle = tokio::spawn(async move { run_process(&cmd, &workspace_root).await });
154
155                entries.push(ProcessEntry {
156                    id,
157                    command: command.clone(),
158                    handle: Some(handle),
159                    result: None,
160                });
161
162                Ok(ToolResult {
163                    output: format!("spawned process {id}: {command}"),
164                })
165            }
166
167            ProcessAction::List => {
168                self.collect_finished().await;
169                let entries = self.entries.lock().map_err(|_| anyhow!("lock poisoned"))?;
170
171                if entries.is_empty() {
172                    return Ok(ToolResult {
173                        output: "no processes".to_string(),
174                    });
175                }
176
177                let lines: Vec<String> = entries
178                    .iter()
179                    .map(|e| {
180                        let status = if e.result.is_some()
181                            || e.handle.as_ref().is_some_and(|h| h.is_finished())
182                        {
183                            "finished"
184                        } else {
185                            "running"
186                        };
187                        format!("id={} status={} command={}", e.id, status, e.command)
188                    })
189                    .collect();
190
191                Ok(ToolResult {
192                    output: lines.join("\n"),
193                })
194            }
195
196            ProcessAction::Output { id } => {
197                self.collect_finished().await;
198                let entries = self.entries.lock().map_err(|_| anyhow!("lock poisoned"))?;
199
200                let entry = entries
201                    .iter()
202                    .find(|e| e.id == id)
203                    .ok_or_else(|| anyhow!("process {id} not found"))?;
204
205                if let Some(ref result) = entry.result {
206                    let mut output = format!("exit={}\n", result.exit_code.unwrap_or(-1));
207                    if !result.stdout.is_empty() {
208                        output.push_str(&result.stdout);
209                    }
210                    if !result.stderr.is_empty() {
211                        output.push_str("\nstderr:\n");
212                        output.push_str(&result.stderr);
213                    }
214                    Ok(ToolResult { output })
215                } else {
216                    Ok(ToolResult {
217                        output: format!("process {id} is still running"),
218                    })
219                }
220            }
221
222            ProcessAction::Kill { id } => {
223                let mut entries = self.entries.lock().map_err(|_| anyhow!("lock poisoned"))?;
224                let entry = entries
225                    .iter_mut()
226                    .find(|e| e.id == id)
227                    .ok_or_else(|| anyhow!("process {id} not found"))?;
228
229                if let Some(handle) = entry.handle.take() {
230                    handle.abort();
231                    entry.result = Some(ProcessOutput {
232                        exit_code: None,
233                        stdout: String::new(),
234                        stderr: "(killed)".to_string(),
235                    });
236                }
237
238                Ok(ToolResult {
239                    output: format!("killed process {id}"),
240                })
241            }
242        }
243    }
244}
245
246async fn run_process(command: &str, workspace_root: &str) -> ProcessOutput {
247    let result = Command::new("sh")
248        .arg("-c")
249        .arg(command)
250        .current_dir(workspace_root)
251        .stdout(std::process::Stdio::piped())
252        .stderr(std::process::Stdio::piped())
253        .spawn();
254
255    let mut child = match result {
256        Ok(c) => c,
257        Err(e) => {
258            return ProcessOutput {
259                exit_code: None,
260                stdout: String::new(),
261                stderr: format!("failed to spawn: {e}"),
262            };
263        }
264    };
265
266    let stdout_handle = match child.stdout.take() {
267        Some(h) => h,
268        None => {
269            return ProcessOutput {
270                exit_code: None,
271                stdout: String::new(),
272                stderr: "stdout not piped on spawned child".to_string(),
273            };
274        }
275    };
276    let stderr_handle = match child.stderr.take() {
277        Some(h) => h,
278        None => {
279            return ProcessOutput {
280                exit_code: None,
281                stdout: String::new(),
282                stderr: "stderr not piped on spawned child".to_string(),
283            };
284        }
285    };
286
287    let stdout_task = tokio::spawn(read_limited(stdout_handle));
288    let stderr_task = tokio::spawn(read_limited(stderr_handle));
289
290    let status = child.wait().await;
291    let stdout = stdout_task
292        .await
293        .unwrap_or_else(|_| Ok(String::new()))
294        .unwrap_or_default();
295    let stderr = stderr_task
296        .await
297        .unwrap_or_else(|_| Ok(String::new()))
298        .unwrap_or_default();
299
300    ProcessOutput {
301        exit_code: status.ok().and_then(|s| s.code()),
302        stdout,
303        stderr,
304    }
305}
306
307async fn read_limited<R: tokio::io::AsyncRead + Unpin>(mut reader: R) -> anyhow::Result<String> {
308    let mut buf = Vec::new();
309    let mut limited = (&mut reader).take((MAX_OUTPUT_BYTES + 1) as u64);
310    limited.read_to_end(&mut buf).await?;
311    let truncated = buf.len() > MAX_OUTPUT_BYTES;
312    if truncated {
313        buf.truncate(MAX_OUTPUT_BYTES);
314    }
315    let mut s = String::from_utf8_lossy(&buf).to_string();
316    if truncated {
317        s.push_str(&format!("\n<truncated at {} bytes>", MAX_OUTPUT_BYTES));
318    }
319    Ok(s)
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[tokio::test]
327    async fn process_spawn_list_output() {
328        let tool = ProcessTool::default();
329        let ctx = ToolContext::new(".".to_string());
330
331        let result = tool
332            .execute(r#"{"action": "spawn", "command": "echo hello"}"#, &ctx)
333            .await
334            .expect("spawn should succeed");
335        assert!(result.output.contains("spawned process 0"));
336
337        // Give the process time to finish
338        tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
339
340        let result = tool
341            .execute(r#"{"action": "list"}"#, &ctx)
342            .await
343            .expect("list should succeed");
344        assert!(result.output.contains("echo hello"));
345
346        let result = tool
347            .execute(r#"{"action": "output", "id": 0}"#, &ctx)
348            .await
349            .expect("output should succeed");
350        assert!(result.output.contains("hello") || result.output.contains("still running"));
351    }
352
353    #[tokio::test]
354    async fn process_rejects_empty_command() {
355        let tool = ProcessTool::default();
356        let err = tool
357            .execute(
358                r#"{"action": "spawn", "command": ""}"#,
359                &ToolContext::new(".".to_string()),
360            )
361            .await
362            .expect_err("empty command should fail");
363        assert!(err.to_string().contains("command must not be empty"));
364    }
365
366    #[tokio::test]
367    async fn process_kill_running() {
368        let tool = ProcessTool::default();
369        let ctx = ToolContext::new(".".to_string());
370
371        tool.execute(r#"{"action": "spawn", "command": "sleep 60"}"#, &ctx)
372            .await
373            .expect("spawn should succeed");
374
375        let result = tool
376            .execute(r#"{"action": "kill", "id": 0}"#, &ctx)
377            .await
378            .expect("kill should succeed");
379        assert!(result.output.contains("killed process 0"));
380    }
381
382    #[tokio::test]
383    async fn process_nonexistent_id_fails() {
384        let tool = ProcessTool::default();
385        let err = tool
386            .execute(
387                r#"{"action": "output", "id": 99}"#,
388                &ToolContext::new(".".to_string()),
389            )
390            .await
391            .expect_err("nonexistent should fail");
392        assert!(err.to_string().contains("not found"));
393    }
394}