1use std::collections::HashMap;
46use std::path::{Path, PathBuf};
47use std::process::Stdio;
48use std::time::Duration;
49
50use async_trait::async_trait;
51use serde::{Deserialize, Serialize};
52use tokio::io::AsyncWriteExt;
53use tokio::process::Command;
54
55use super::handler::{JobCtx, JobHandler, JobOutcome};
56
57pub const CMD_EXEC_KIND: &str = "cmd_exec";
58
59const DEFAULT_TIMEOUT_SECS: u64 = 600;
62
63const OUTPUT_TAIL_BYTES: usize = 256;
67
68#[derive(Debug, Deserialize, Serialize)]
69pub struct CmdExecPayload {
70 pub argv: Vec<String>,
71 #[serde(default)]
72 pub env: HashMap<String, String>,
73 #[serde(default)]
74 pub cwd: Option<String>,
75 #[serde(default)]
76 pub stdin: Option<String>,
77 #[serde(default)]
78 pub timeout_secs: Option<u64>,
79}
80
81pub struct CmdExecHandler {
84 cwd_root: Option<PathBuf>,
85}
86
87impl std::fmt::Debug for CmdExecHandler {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("CmdExecHandler")
90 .field("cwd_root", &self.cwd_root)
91 .finish()
92 }
93}
94
95impl CmdExecHandler {
96 #[must_use]
100 pub const fn new_unrestricted() -> Self {
101 Self { cwd_root: None }
102 }
103
104 pub fn with_cwd_root(root: impl Into<PathBuf>) -> Self {
108 Self {
109 cwd_root: Some(root.into()),
110 }
111 }
112
113 fn validate_cwd(&self, cwd: &str) -> Result<PathBuf, String> {
114 let target = Path::new(cwd);
115 let resolved = target
116 .canonicalize()
117 .map_err(|e| format!("cwd canonicalize {cwd}: {e}"))?;
118 if let Some(root) = &self.cwd_root {
119 let root_resolved = root
120 .canonicalize()
121 .map_err(|e| format!("cwd_root canonicalize {}: {e}", root.display()))?;
122 if !resolved.starts_with(&root_resolved) {
123 return Err(format!(
124 "cwd {} escapes configured cwd_root {}",
125 resolved.display(),
126 root_resolved.display()
127 ));
128 }
129 }
130 Ok(resolved)
131 }
132}
133
134#[async_trait]
135impl JobHandler for CmdExecHandler {
136 fn kind(&self) -> &'static str {
137 CMD_EXEC_KIND
138 }
139
140 #[tracing::instrument(skip(self, ctx, payload), fields(kind = CMD_EXEC_KIND))]
141 async fn run(&self, ctx: JobCtx<'_>, payload: serde_json::Value) -> JobOutcome {
142 execute(payload, ctx.cancel.clone(), ctx.job_id.as_str(), self).await
143 }
144}
145
146async fn execute(
152 payload: serde_json::Value,
153 cancel: tokio_util::sync::CancellationToken,
154 job_id_label: &str,
155 handler: &CmdExecHandler,
156) -> JobOutcome {
157 let parsed: CmdExecPayload = match serde_json::from_value(payload) {
158 Ok(p) => p,
159 Err(e) => return JobOutcome::Failed(format!("payload: {e}")),
160 };
161 let Some((program, args)) = parsed.argv.split_first() else {
162 return JobOutcome::Failed("payload.argv is empty".into());
163 };
164
165 let mut cmd = Command::new(program);
166 cmd.args(args)
167 .envs(&parsed.env)
168 .stdout(Stdio::piped())
169 .stderr(Stdio::piped())
170 .stdin(if parsed.stdin.is_some() {
171 Stdio::piped()
172 } else {
173 Stdio::null()
174 });
175
176 if let Some(cwd_raw) = parsed.cwd.as_deref() {
177 match handler.validate_cwd(cwd_raw) {
178 Ok(resolved) => {
179 cmd.current_dir(resolved);
180 }
181 Err(e) => return JobOutcome::Failed(format!("cwd rejected: {e}")),
182 }
183 }
184
185 let timeout = Duration::from_secs(parsed.timeout_secs.unwrap_or(DEFAULT_TIMEOUT_SECS).max(1));
186
187 tracing::info!(
188 program,
189 arg_count = args.len(),
190 env_count = parsed.env.len(),
191 cwd = ?parsed.cwd,
192 timeout_secs = timeout.as_secs(),
193 job_id = %job_id_label,
194 "cmd_exec: spawning"
195 );
196
197 let mut child = match cmd.spawn() {
198 Ok(c) => c,
199 Err(e) => return JobOutcome::Failed(format!("spawn {program:?}: {e}")),
200 };
201
202 if let Some(stdin_data) = parsed.stdin.as_deref()
203 && let Some(mut stdin) = child.stdin.take()
204 {
205 if let Err(e) = stdin.write_all(stdin_data.as_bytes()).await {
206 tracing::warn!(?e, "cmd_exec: failed writing stdin (continuing)");
207 }
208 drop(stdin);
209 }
210
211 let wait = child.wait_with_output();
212 let output = tokio::select! {
213 () = cancel.cancelled() => {
214 tracing::info!(job_id = %job_id_label, "cmd_exec: cancelled; child orphaned");
215 return JobOutcome::Failed("cancelled by supervisor".into());
216 }
217 res = tokio::time::timeout(timeout, wait) => match res {
218 Ok(Ok(out)) => out,
219 Ok(Err(e)) => return JobOutcome::Failed(format!("wait child: {e}")),
220 Err(_) => return JobOutcome::Failed(format!("timeout after {}s", timeout.as_secs())),
221 }
222 };
223
224 let stdout = String::from_utf8_lossy(&output.stdout);
225 let stderr = String::from_utf8_lossy(&output.stderr);
226 if !stdout.is_empty() {
227 tracing::info!(stream = "stdout", body = %stdout, "cmd_exec: child output");
228 }
229 if !stderr.is_empty() {
230 tracing::info!(stream = "stderr", body = %stderr, "cmd_exec: child output");
231 }
232
233 match output.status.code() {
234 Some(0) => JobOutcome::Done,
235 Some(code) => {
236 let tail = if stderr.is_empty() { stdout } else { stderr };
237 let summary = tail_chars(&tail, OUTPUT_TAIL_BYTES);
238 JobOutcome::Failed(format!("exit {code}: {summary}"))
239 }
240 None => JobOutcome::Failed("killed by signal".into()),
241 }
242}
243
244fn tail_chars(s: &str, max_bytes: usize) -> String {
248 if s.len() <= max_bytes {
249 return s.to_owned();
250 }
251 let mut start = s.len() - max_bytes;
252 while start > 0 && !s.is_char_boundary(start) {
253 start -= 1;
254 }
255 format!("…{}", &s[start..])
256}
257
258#[cfg(test)]
259#[allow(
260 clippy::unwrap_used,
261 clippy::expect_used,
262 clippy::panic,
263 reason = "tests crash loudly on setup or assertion failure; that's the point"
264)]
265mod tests {
266 use super::*;
267 use tokio_util::sync::CancellationToken;
268
269 fn payload(argv: &[&str]) -> serde_json::Value {
270 serde_json::json!({ "argv": argv })
271 }
272
273 fn unrestricted() -> CmdExecHandler {
274 CmdExecHandler::new_unrestricted()
275 }
276
277 async fn run(handler: &CmdExecHandler, payload: serde_json::Value) -> JobOutcome {
278 execute(payload, CancellationToken::new(), "test-job", handler).await
279 }
280
281 #[tokio::test]
282 async fn empty_argv_rejected() {
283 let out = run(&unrestricted(), serde_json::json!({ "argv": [] })).await;
284 match out {
285 JobOutcome::Failed(msg) => assert!(msg.contains("argv"), "got: {msg}"),
286 other => panic!("expected Failed, got {other:?}"),
287 }
288 }
289
290 #[tokio::test]
291 async fn successful_command_returns_done() {
292 let out = run(&unrestricted(), payload(&["true"])).await;
293 assert!(matches!(out, JobOutcome::Done), "got: {out:?}");
294 }
295
296 #[tokio::test]
297 async fn non_zero_exit_returns_failed_with_code() {
298 let out = run(&unrestricted(), payload(&["false"])).await;
299 match out {
300 JobOutcome::Failed(msg) => assert!(msg.contains("exit 1"), "got: {msg}"),
301 other => panic!("expected Failed, got {other:?}"),
302 }
303 }
304
305 #[tokio::test]
306 async fn nonexistent_program_returns_failed() {
307 let out = run(
308 &unrestricted(),
309 payload(&["this-program-does-not-exist-xyz"]),
310 )
311 .await;
312 match out {
313 JobOutcome::Failed(msg) => assert!(msg.contains("spawn"), "got: {msg}"),
314 other => panic!("expected Failed, got {other:?}"),
315 }
316 }
317
318 #[tokio::test]
319 async fn timeout_exceeded_returns_failed() {
320 let out = run(
321 &unrestricted(),
322 serde_json::json!({ "argv": ["sleep", "5"], "timeout_secs": 1 }),
323 )
324 .await;
325 match out {
326 JobOutcome::Failed(msg) => assert!(msg.contains("timeout"), "got: {msg}"),
327 other => panic!("expected Failed, got {other:?}"),
328 }
329 }
330
331 #[tokio::test]
332 async fn cwd_outside_root_rejected() {
333 let handler = CmdExecHandler::with_cwd_root("/tmp");
334 let out = run(
335 &handler,
336 serde_json::json!({ "argv": ["true"], "cwd": "/etc" }),
337 )
338 .await;
339 match out {
340 JobOutcome::Failed(msg) => assert!(msg.contains("cwd"), "got: {msg}"),
341 other => panic!("expected Failed, got {other:?}"),
342 }
343 }
344
345 #[tokio::test]
346 async fn cancel_orphans_child_and_returns_failed() {
347 let cancel = CancellationToken::new();
348 let cancel_inner = cancel.clone();
349 let handler = unrestricted();
350 let join = tokio::spawn(async move {
351 execute(
352 serde_json::json!({ "argv": ["sleep", "5"] }),
353 cancel_inner,
354 "cancel-test",
355 &handler,
356 )
357 .await
358 });
359 tokio::time::sleep(Duration::from_millis(50)).await;
360 cancel.cancel();
361 let out = join.await.unwrap();
362 match out {
363 JobOutcome::Failed(msg) => assert!(msg.contains("cancelled"), "got: {msg}"),
364 other => panic!("expected Failed, got {other:?}"),
365 }
366 }
367}