cersei_tools/tool_primitives/
process.rs1use std::collections::HashMap;
7use std::path::PathBuf;
8use std::time::Duration;
9use tokio::io::{AsyncBufReadExt, BufReader};
10use tokio::sync::mpsc;
11
12#[derive(Debug, Clone)]
14pub struct ExecOutput {
15 pub stdout: String,
16 pub stderr: String,
17 pub exit_code: i32,
18 pub timed_out: bool,
19}
20
21#[derive(Debug, Clone)]
23pub struct ExecOptions {
24 pub cwd: Option<PathBuf>,
25 pub env: HashMap<String, String>,
26 pub timeout: Option<Duration>,
27 pub shell: Shell,
28}
29
30impl Default for ExecOptions {
31 fn default() -> Self {
32 Self {
33 cwd: None,
34 env: HashMap::new(),
35 timeout: Some(Duration::from_secs(120)),
36 shell: Shell::Sh,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub enum Shell {
44 Sh,
45 Bash,
46 Zsh,
47 PowerShell,
48 Cmd,
49 Custom { program: String, args: Vec<String> },
50}
51
52#[derive(Debug, Clone)]
54pub enum OutputLine {
55 Stdout(String),
56 Stderr(String),
57}
58
59pub async fn exec(command: &str, opts: ExecOptions) -> Result<ExecOutput, std::io::Error> {
61 let (program, args) = shell_args(&opts.shell, command);
62
63 let mut cmd = tokio::process::Command::new(&program);
64 cmd.args(&args)
65 .stdout(std::process::Stdio::piped())
66 .stderr(std::process::Stdio::piped());
67
68 if let Some(cwd) = &opts.cwd {
69 cmd.current_dir(cwd);
70 }
71
72 for (k, v) in &opts.env {
73 cmd.env(k, v);
74 }
75
76 let child = cmd.spawn()?;
77
78 let timeout = opts.timeout.unwrap_or(Duration::from_secs(120));
79
80 match tokio::time::timeout(timeout, child.wait_with_output()).await {
81 Ok(Ok(output)) => Ok(ExecOutput {
82 stdout: String::from_utf8_lossy(&output.stdout).to_string(),
83 stderr: String::from_utf8_lossy(&output.stderr).to_string(),
84 exit_code: output.status.code().unwrap_or(-1),
85 timed_out: false,
86 }),
87 Ok(Err(e)) => Err(e),
88 Err(_) => {
89 Ok(ExecOutput {
91 stdout: String::new(),
92 stderr: format!("Command timed out after {}s", timeout.as_secs()),
93 exit_code: -1,
94 timed_out: true,
95 })
96 }
97 }
98}
99
100pub fn exec_streaming(
105 command: &str,
106 opts: ExecOptions,
107) -> Result<
108 (
109 mpsc::Receiver<OutputLine>,
110 tokio::task::JoinHandle<ExecOutput>,
111 ),
112 std::io::Error,
113> {
114 let (program, args) = shell_args(&opts.shell, command);
115
116 let mut cmd = tokio::process::Command::new(&program);
117 cmd.args(&args)
118 .stdout(std::process::Stdio::piped())
119 .stderr(std::process::Stdio::piped());
120
121 if let Some(cwd) = &opts.cwd {
122 cmd.current_dir(cwd);
123 }
124
125 for (k, v) in &opts.env {
126 cmd.env(k, v);
127 }
128
129 let mut child = cmd.spawn()?;
130 let (tx, rx) = mpsc::channel(256);
131
132 let stdout = child.stdout.take();
133 let stderr = child.stderr.take();
134 let timeout = opts.timeout.unwrap_or(Duration::from_secs(120));
135
136 let handle = tokio::spawn(async move {
137 let mut full_stdout = String::new();
138 let mut full_stderr = String::new();
139
140 let tx_out = tx.clone();
141 let stdout_task = tokio::spawn(async move {
142 let mut collected = String::new();
143 if let Some(stdout) = stdout {
144 let mut lines = BufReader::new(stdout).lines();
145 while let Ok(Some(line)) = lines.next_line().await {
146 collected.push_str(&line);
147 collected.push('\n');
148 let _ = tx_out.send(OutputLine::Stdout(line)).await;
149 }
150 }
151 collected
152 });
153
154 let tx_err = tx;
155 let stderr_task = tokio::spawn(async move {
156 let mut collected = String::new();
157 if let Some(stderr) = stderr {
158 let mut lines = BufReader::new(stderr).lines();
159 while let Ok(Some(line)) = lines.next_line().await {
160 collected.push_str(&line);
161 collected.push('\n');
162 let _ = tx_err.send(OutputLine::Stderr(line)).await;
163 }
164 }
165 collected
166 });
167
168 let result = tokio::time::timeout(timeout, child.wait()).await;
169
170 full_stdout = stdout_task.await.unwrap_or_default();
171 full_stderr = stderr_task.await.unwrap_or_default();
172
173 match result {
174 Ok(Ok(status)) => ExecOutput {
175 stdout: full_stdout,
176 stderr: full_stderr,
177 exit_code: status.code().unwrap_or(-1),
178 timed_out: false,
179 },
180 _ => {
181 let _ = child.kill().await;
182 ExecOutput {
183 stdout: full_stdout,
184 stderr: full_stderr,
185 exit_code: -1,
186 timed_out: true,
187 }
188 }
189 }
190 });
191
192 Ok((rx, handle))
193}
194
195fn shell_args(shell: &Shell, command: &str) -> (String, Vec<String>) {
196 match shell {
197 Shell::Sh => ("sh".into(), vec!["-c".into(), command.into()]),
198 Shell::Bash => ("bash".into(), vec!["-c".into(), command.into()]),
199 Shell::Zsh => ("zsh".into(), vec!["-c".into(), command.into()]),
200 Shell::PowerShell => (
201 "pwsh".into(),
202 vec![
203 "-NoProfile".into(),
204 "-NonInteractive".into(),
205 "-Command".into(),
206 command.into(),
207 ],
208 ),
209 Shell::Cmd => ("cmd".into(), vec!["/C".into(), command.into()]),
210 Shell::Custom { program, args } => {
211 let mut a = args.clone();
212 a.push(command.into());
213 (program.clone(), a)
214 }
215 }
216}
217
218#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[tokio::test]
225 async fn test_exec_echo() {
226 let out = exec("echo hello", ExecOptions::default()).await.unwrap();
227 assert_eq!(out.exit_code, 0);
228 assert!(out.stdout.trim() == "hello");
229 assert!(!out.timed_out);
230 }
231
232 #[tokio::test]
233 async fn test_exec_exit_code() {
234 let out = exec("exit 42", ExecOptions::default()).await.unwrap();
235 assert_eq!(out.exit_code, 42);
236 }
237
238 #[tokio::test]
239 async fn test_exec_with_cwd() {
240 let out = exec(
241 "pwd",
242 ExecOptions {
243 cwd: Some("/tmp".into()),
244 ..Default::default()
245 },
246 )
247 .await
248 .unwrap();
249 assert!(out.stdout.contains("tmp"));
250 }
251
252 #[tokio::test]
253 async fn test_exec_with_env() {
254 let mut env = HashMap::new();
255 env.insert("MY_VAR".into(), "hello_world".into());
256 let out = exec(
257 "echo $MY_VAR",
258 ExecOptions {
259 env,
260 ..Default::default()
261 },
262 )
263 .await
264 .unwrap();
265 assert!(out.stdout.contains("hello_world"));
266 }
267
268 #[tokio::test]
269 async fn test_exec_timeout() {
270 let out = exec(
271 "sleep 10",
272 ExecOptions {
273 timeout: Some(Duration::from_millis(100)),
274 ..Default::default()
275 },
276 )
277 .await
278 .unwrap();
279 assert!(out.timed_out);
280 }
281}