Skip to main content

cersei_tools/tool_primitives/
process.rs

1//! Async command execution primitives.
2//!
3//! Stateless — no persistent cwd or env. Each call is independent.
4//! Shell state persistence (for coding agents) is a higher-level concern.
5
6use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio::sync::mpsc;
11
12/// Result of executing a command.
13#[derive(Debug, Clone)]
14pub struct ExecOutput {
15    pub stdout: String,
16    pub stderr: String,
17    pub exit_code: i32,
18    pub timed_out: bool,
19}
20
21/// Options for command execution.
22#[derive(Debug, Clone)]
23pub struct ExecOptions {
24    pub cwd: Option<PathBuf>,
25    pub env: HashMap<String, String>,
26    pub timeout: Option<Duration>,
27    pub shell: Shell,
28}
29
30impl Default for ExecOptions {
31    fn default() -> Self {
32        Self {
33            cwd: None,
34            env: HashMap::new(),
35            timeout: Some(Duration::from_secs(120)),
36            shell: Shell::Sh,
37        }
38    }
39}
40
41/// Which shell to use for command execution.
42#[derive(Debug, Clone)]
43pub enum Shell {
44    Sh,
45    Bash,
46    Zsh,
47    PowerShell,
48    Cmd,
49    Custom { program: String, args: Vec<String> },
50}
51
52/// A line of output from a streaming command.
53#[derive(Debug, Clone)]
54pub enum OutputLine {
55    Stdout(String),
56    Stderr(String),
57}
58
59/// Execute a command through a shell. Returns when the command completes or times out.
60pub async fn exec(command: &str, opts: ExecOptions) -> Result<ExecOutput, std::io::Error> {
61    let (program, args) = shell_args(&opts.shell, command);
62
63    let mut cmd = tokio::process::Command::new(&program);
64    cmd.args(&args)
65        .stdout(std::process::Stdio::piped())
66        .stderr(std::process::Stdio::piped());
67
68    if let Some(cwd) = &opts.cwd {
69        cmd.current_dir(cwd);
70    }
71
72    for (k, v) in &opts.env {
73        cmd.env(k, v);
74    }
75
76    let child = cmd.spawn()?;
77
78    let timeout = opts.timeout.unwrap_or(Duration::from_secs(120));
79
80    match tokio::time::timeout(timeout, child.wait_with_output()).await {
81        Ok(Ok(output)) => Ok(ExecOutput {
82            stdout: String::from_utf8_lossy(&output.stdout).to_string(),
83            stderr: String::from_utf8_lossy(&output.stderr).to_string(),
84            exit_code: output.status.code().unwrap_or(-1),
85            timed_out: false,
86        }),
87        Ok(Err(e)) => Err(e),
88        Err(_) => {
89            // Timeout — process is dropped (killed automatically)
90            Ok(ExecOutput {
91                stdout: String::new(),
92                stderr: format!("Command timed out after {}s", timeout.as_secs()),
93                exit_code: -1,
94                timed_out: true,
95            })
96        }
97    }
98}
99
100/// Execute a command and stream output lines through a channel.
101///
102/// Returns a receiver for output lines and a join handle that resolves
103/// to the final `ExecOutput` when the command completes.
104pub fn exec_streaming(
105    command: &str,
106    opts: ExecOptions,
107) -> Result<
108    (
109        mpsc::Receiver<OutputLine>,
110        tokio::task::JoinHandle<ExecOutput>,
111    ),
112    std::io::Error,
113> {
114    let (program, args) = shell_args(&opts.shell, command);
115
116    let mut cmd = tokio::process::Command::new(&program);
117    cmd.args(&args)
118        .stdout(std::process::Stdio::piped())
119        .stderr(std::process::Stdio::piped());
120
121    if let Some(cwd) = &opts.cwd {
122        cmd.current_dir(cwd);
123    }
124
125    for (k, v) in &opts.env {
126        cmd.env(k, v);
127    }
128
129    let mut child = cmd.spawn()?;
130    let (tx, rx) = mpsc::channel(256);
131
132    let stdout = child.stdout.take();
133    let stderr = child.stderr.take();
134    let timeout = opts.timeout.unwrap_or(Duration::from_secs(120));
135
136    let handle = tokio::spawn(async move {
137        let mut full_stdout = String::new();
138        let mut full_stderr = String::new();
139
140        let tx_out = tx.clone();
141        let stdout_task = tokio::spawn(async move {
142            let mut collected = String::new();
143            if let Some(stdout) = stdout {
144                let mut lines = BufReader::new(stdout).lines();
145                while let Ok(Some(line)) = lines.next_line().await {
146                    collected.push_str(&line);
147                    collected.push('\n');
148                    let _ = tx_out.send(OutputLine::Stdout(line)).await;
149                }
150            }
151            collected
152        });
153
154        let tx_err = tx;
155        let stderr_task = tokio::spawn(async move {
156            let mut collected = String::new();
157            if let Some(stderr) = stderr {
158                let mut lines = BufReader::new(stderr).lines();
159                while let Ok(Some(line)) = lines.next_line().await {
160                    collected.push_str(&line);
161                    collected.push('\n');
162                    let _ = tx_err.send(OutputLine::Stderr(line)).await;
163                }
164            }
165            collected
166        });
167
168        let result = tokio::time::timeout(timeout, child.wait()).await;
169
170        full_stdout = stdout_task.await.unwrap_or_default();
171        full_stderr = stderr_task.await.unwrap_or_default();
172
173        match result {
174            Ok(Ok(status)) => ExecOutput {
175                stdout: full_stdout,
176                stderr: full_stderr,
177                exit_code: status.code().unwrap_or(-1),
178                timed_out: false,
179            },
180            _ => {
181                let _ = child.kill().await;
182                ExecOutput {
183                    stdout: full_stdout,
184                    stderr: full_stderr,
185                    exit_code: -1,
186                    timed_out: true,
187                }
188            }
189        }
190    });
191
192    Ok((rx, handle))
193}
194
195fn shell_args(shell: &Shell, command: &str) -> (String, Vec<String>) {
196    match shell {
197        Shell::Sh => ("sh".into(), vec!["-c".into(), command.into()]),
198        Shell::Bash => ("bash".into(), vec!["-c".into(), command.into()]),
199        Shell::Zsh => ("zsh".into(), vec!["-c".into(), command.into()]),
200        Shell::PowerShell => (
201            "pwsh".into(),
202            vec![
203                "-NoProfile".into(),
204                "-NonInteractive".into(),
205                "-Command".into(),
206                command.into(),
207            ],
208        ),
209        Shell::Cmd => ("cmd".into(), vec!["/C".into(), command.into()]),
210        Shell::Custom { program, args } => {
211            let mut a = args.clone();
212            a.push(command.into());
213            (program.clone(), a)
214        }
215    }
216}
217
218// ─── Tests ─────────────────────────────────────────────────────────────────
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[tokio::test]
225    async fn test_exec_echo() {
226        let out = exec("echo hello", ExecOptions::default()).await.unwrap();
227        assert_eq!(out.exit_code, 0);
228        assert!(out.stdout.trim() == "hello");
229        assert!(!out.timed_out);
230    }
231
232    #[tokio::test]
233    async fn test_exec_exit_code() {
234        let out = exec("exit 42", ExecOptions::default()).await.unwrap();
235        assert_eq!(out.exit_code, 42);
236    }
237
238    #[tokio::test]
239    async fn test_exec_with_cwd() {
240        let out = exec(
241            "pwd",
242            ExecOptions {
243                cwd: Some("/tmp".into()),
244                ..Default::default()
245            },
246        )
247        .await
248        .unwrap();
249        assert!(out.stdout.contains("tmp"));
250    }
251
252    #[tokio::test]
253    async fn test_exec_with_env() {
254        let mut env = HashMap::new();
255        env.insert("MY_VAR".into(), "hello_world".into());
256        let out = exec(
257            "echo $MY_VAR",
258            ExecOptions {
259                env,
260                ..Default::default()
261            },
262        )
263        .await
264        .unwrap();
265        assert!(out.stdout.contains("hello_world"));
266    }
267
268    #[tokio::test]
269    async fn test_exec_timeout() {
270        let out = exec(
271            "sleep 10",
272            ExecOptions {
273                timeout: Some(Duration::from_millis(100)),
274                ..Default::default()
275            },
276        )
277        .await
278        .unwrap();
279        assert!(out.timed_out);
280    }
281}