1use crate::command_spec::CommandSpec;
2use crate::config::Config;
3use crate::util::{map_io_err, normalize_exit_codes, runtime_error};
4use crate::{RhaiArray, RhaiResult};
5use duct::{self, Expression};
6use os_pipe::PipeReader;
7use rhai::{Dynamic, FnPtr, ImmutableString, Map as RhaiMap, NativeCallContext, INT};
8use std::collections::HashSet;
9use std::io::{self, ErrorKind, Read, Write};
10use std::path::PathBuf;
11use std::sync::mpsc::{self, RecvTimeoutError, Sender};
12use std::sync::Arc;
13use std::thread;
14use std::time::{Duration, Instant};
15
16#[derive(Clone, Debug)]
17pub struct PipelineExecutor {
18 pub(crate) config: Arc<Config>,
19 pub(crate) commands: Vec<CommandSpec>,
20 pub(crate) timeout_override_ms: Option<u64>,
21 pub(crate) allowed_exit_codes: Option<HashSet<i64>>,
22 pub(crate) cwd: Option<PathBuf>,
23}
24
25impl PipelineExecutor {
26 pub(crate) fn new(config: Arc<Config>, commands: Vec<CommandSpec>) -> Self {
27 Self {
28 config,
29 commands,
30 timeout_override_ms: None,
31 allowed_exit_codes: None,
32 cwd: None,
33 }
34 }
35
36 pub fn cwd(mut self, path: String) -> RhaiResult<Self> {
37 if path.is_empty() {
38 self.cwd = None;
39 } else {
40 self.cwd = Some(PathBuf::from(path));
41 }
42 Ok(self)
43 }
44
45 pub fn timeout(mut self, timeout: INT) -> RhaiResult<Self> {
46 if timeout <= 0 {
47 return Err(runtime_error("timeout must be a positive integer"));
48 }
49 self.timeout_override_ms = Some(timeout as u64);
50 Ok(self)
51 }
52
53 pub fn allow_exit_codes(mut self, codes: RhaiArray) -> RhaiResult<Self> {
54 let mut set = HashSet::new();
55 for code in codes {
56 let value = code
57 .clone()
58 .try_cast::<INT>()
59 .ok_or_else(|| runtime_error("allow_exit_codes expects integers"))?;
60 set.insert(value);
61 }
62 self.allowed_exit_codes = normalize_exit_codes(set);
63 Ok(self)
64 }
65
66 pub fn run(self) -> RhaiResult<RhaiMap> {
67 let timeout = self.timeout_override_ms.or(self.config.default_timeout_ms);
68 let result = run_pipeline(
69 &self.commands,
70 timeout,
71 self.allowed_exit_codes.clone(),
72 self.cwd,
73 )?;
74 Ok(result.into_map())
75 }
76
77 pub fn run_stream(
78 self,
79 context: &NativeCallContext,
80 stdout_cb: Option<FnPtr>,
81 stderr_cb: Option<FnPtr>,
82 ) -> RhaiResult<RhaiMap> {
83 let timeout = self.timeout_override_ms.or(self.config.default_timeout_ms);
84 let result = run_pipeline_stream(
85 &self.commands,
86 timeout,
87 self.allowed_exit_codes.clone(),
88 self.cwd,
89 context,
90 stdout_cb,
91 stderr_cb,
92 )?;
93 Ok(result.into_map())
94 }
95}
96
97#[derive(Debug)]
98struct ProcessResult {
99 success: bool,
100 status: i64,
101 stdout: String,
102 stderr: String,
103 duration_ms: u64,
104}
105
106impl ProcessResult {
107 fn into_map(self) -> RhaiMap {
108 let mut map = RhaiMap::new();
109 map.insert("success".into(), Dynamic::from_bool(self.success));
110 map.insert("status".into(), Dynamic::from_int(self.status as INT));
111 map.insert("stdout".into(), Dynamic::from(self.stdout));
112 map.insert("stderr".into(), Dynamic::from(self.stderr));
113 let duration_int: INT = self.duration_ms.try_into().unwrap_or(i64::MAX);
114 map.insert("duration_ms".into(), Dynamic::from_int(duration_int));
115 map
116 }
117}
118
119fn run_pipeline(
120 commands: &[CommandSpec],
121 timeout_ms: Option<u64>,
122 allowed_exit_codes: Option<HashSet<i64>>,
123 cwd: Option<PathBuf>,
124) -> RhaiResult<ProcessResult> {
125 if commands.is_empty() {
126 return Err(runtime_error("no command specified"));
127 }
128 let mut expression = build_expression(commands, cwd.as_ref())?;
129 expression = expression.stdout_capture().stderr_capture().unchecked();
130 let start = Instant::now();
131 let output = match timeout_ms {
132 Some(ms) => run_with_timeout(expression, Duration::from_millis(ms)).map_err(map_io_err)?,
133 None => expression.run().map_err(map_io_err)?,
134 };
135 let duration = start.elapsed();
136 let exit_code = output.status.code().map(|c| c as i64).unwrap_or(-1);
137 let mut success = output.status.success();
138 if !success {
139 if let Some(allowed) = allowed_exit_codes.as_ref() {
140 if allowed.contains(&exit_code) {
141 success = true;
142 }
143 }
144 }
145
146 Ok(ProcessResult {
147 success,
148 status: exit_code,
149 stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
150 stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
151 duration_ms: duration.as_millis().try_into().unwrap_or(u64::MAX),
152 })
153}
154
155fn run_pipeline_stream(
156 commands: &[CommandSpec],
157 timeout_ms: Option<u64>,
158 allowed_exit_codes: Option<HashSet<i64>>,
159 cwd: Option<PathBuf>,
160 context: &NativeCallContext,
161 stdout_cb: Option<FnPtr>,
162 stderr_cb: Option<FnPtr>,
163) -> RhaiResult<ProcessResult> {
164 if commands.is_empty() {
165 return Err(runtime_error("no command specified"));
166 }
167
168 let mut expression = build_expression(commands, cwd.as_ref())?;
169 let (stdout_reader, stdout_writer) = os_pipe::pipe().map_err(map_io_err)?;
170 let (stderr_reader, stderr_writer) = os_pipe::pipe().map_err(map_io_err)?;
171 expression = expression
172 .stdout_file(stdout_writer)
173 .stderr_file(stderr_writer)
174 .unchecked();
175
176 let handle = expression.start().map_err(map_io_err)?;
177 drop(expression);
178 let start = Instant::now();
179 let (tx, rx) = mpsc::channel();
180 spawn_stream_reader(stdout_reader, tx.clone(), StreamKind::Stdout);
181 spawn_stream_reader(stderr_reader, tx, StreamKind::Stderr);
182
183 let mut stdout_open = true;
184 let mut stderr_open = true;
185 let mut process_finished = false;
186
187 while stdout_open || stderr_open {
188 if let Some(limit) = timeout_ms {
189 if start.elapsed() >= Duration::from_millis(limit) {
190 handle.kill().ok();
191 return Err(map_io_err(io::Error::new(
192 ErrorKind::TimedOut,
193 "process execution timed out",
194 )));
195 }
196 }
197
198 match rx.recv_timeout(Duration::from_millis(50)) {
199 Ok(StreamMessage::Data(kind, chunk)) => {
200 dispatch_stream_chunk(
201 kind,
202 &chunk,
203 context,
204 stdout_cb.as_ref(),
205 stderr_cb.as_ref(),
206 )?;
207 }
208 Ok(StreamMessage::Eof(kind)) => match kind {
209 StreamKind::Stdout => stdout_open = false,
210 StreamKind::Stderr => stderr_open = false,
211 },
212 Ok(StreamMessage::Error(err)) => {
213 handle.kill().ok();
214 return Err(map_io_err(err));
215 }
216 Err(RecvTimeoutError::Timeout) => {
217 if !process_finished && handle.try_wait().map_err(map_io_err)?.is_some() {
218 process_finished = true;
219 }
220 continue;
221 }
222 Err(RecvTimeoutError::Disconnected) => break,
223 }
224 }
225
226 let duration = start.elapsed();
227 let output = handle.wait().map_err(map_io_err)?;
228 let exit_code = output.status.code().map(|c| c as i64).unwrap_or(-1);
229 let mut success = output.status.success();
230 if !success {
231 if let Some(allowed) = allowed_exit_codes.as_ref() {
232 if allowed.contains(&exit_code) {
233 success = true;
234 }
235 }
236 }
237
238 Ok(ProcessResult {
239 success,
240 status: exit_code,
241 stdout: String::new(),
242 stderr: String::new(),
243 duration_ms: duration.as_millis().try_into().unwrap_or(u64::MAX),
244 })
245}
246
247fn build_expression(commands: &[CommandSpec], cwd: Option<&PathBuf>) -> RhaiResult<Expression> {
248 let mut iter = commands.iter();
249 let first = iter
250 .next()
251 .ok_or_else(|| runtime_error("no command specified"))?;
252 let mut expression = expression_from_spec(first, cwd);
253 for command in iter {
254 let next_expr = expression_from_spec(command, cwd);
255 expression = expression.pipe(next_expr);
256 }
257 Ok(expression)
258}
259
260fn expression_from_spec(spec: &CommandSpec, cwd: Option<&PathBuf>) -> Expression {
261 let mut expr = duct::cmd(spec.program.clone(), spec.args.clone());
262 if let Some(dir) = cwd {
263 expr = expr.dir(dir.clone());
264 }
265 for (key, value) in &spec.env {
266 expr = expr.env(key, value);
267 }
268 expr
269}
270
271fn run_with_timeout(expr: Expression, limit: Duration) -> io::Result<std::process::Output> {
272 let handle = Arc::new(expr.start()?);
273 drop(expr);
274
275 let wait_handle = Arc::clone(&handle);
276 let (tx, rx) = mpsc::channel();
277 thread::spawn(move || {
278 let result = wait_handle
279 .wait()
280 .map(|output| std::process::Output {
281 status: output.status,
282 stdout: output.stdout.clone(),
283 stderr: output.stderr.clone(),
284 });
285 let _ = tx.send(result);
286 });
287
288 match rx.recv_timeout(limit) {
289 Ok(result) => result,
290 Err(RecvTimeoutError::Timeout) => {
291 handle.kill()?;
292 Err(io::Error::new(
293 io::ErrorKind::TimedOut,
294 "process execution timed out",
295 ))
296 }
297 Err(RecvTimeoutError::Disconnected) => Err(io::Error::new(
298 io::ErrorKind::Other,
299 "process execution failed",
300 )),
301 }
302}
303
304#[derive(Copy, Clone)]
305enum StreamKind {
306 Stdout,
307 Stderr,
308}
309
310enum StreamMessage {
311 Data(StreamKind, Vec<u8>),
312 Eof(StreamKind),
313 Error(io::Error),
314}
315
316fn spawn_stream_reader(reader: PipeReader, sender: Sender<StreamMessage>, kind: StreamKind) {
317 thread::spawn(move || {
318 let mut reader = reader;
319 let mut buffer = [0u8; 8 * 1024];
320 loop {
321 match reader.read(&mut buffer) {
322 Ok(0) => {
323 let _ = sender.send(StreamMessage::Eof(kind));
324 break;
325 }
326 Ok(n) => {
327 if sender
328 .send(StreamMessage::Data(kind, buffer[..n].to_vec()))
329 .is_err()
330 {
331 break;
332 }
333 }
334 Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
335 Err(err) => {
336 let _ = sender.send(StreamMessage::Error(err));
337 break;
338 }
339 }
340 }
341 });
342}
343
344fn dispatch_stream_chunk(
345 kind: StreamKind,
346 chunk: &[u8],
347 context: &NativeCallContext,
348 stdout_cb: Option<&FnPtr>,
349 stderr_cb: Option<&FnPtr>,
350) -> RhaiResult<()> {
351 let text = String::from_utf8_lossy(chunk).to_string();
352 let value: ImmutableString = text.clone().into();
353
354 let target = match kind {
355 StreamKind::Stdout => stdout_cb,
356 StreamKind::Stderr => stderr_cb,
357 };
358
359 if let Some(callback) = target {
360 let _ = callback.call_within_context::<Dynamic>(context, (value,))?;
361 } else {
362 match kind {
363 StreamKind::Stdout => {
364 print!("{}", text);
365 let _ = io::stdout().flush();
366 }
367 StreamKind::Stderr => {
368 eprint!("{}", text);
369 let _ = io::stderr().flush();
370 }
371 }
372 }
373
374 Ok(())
375}