rhai_process/
pipeline_executor.rs

1use crate::command_spec::CommandSpec;
2use crate::config::Config;
3use crate::util::{map_io_err, normalize_exit_codes, runtime_error};
4use crate::{RhaiArray, RhaiResult};
5use duct::{self, Expression};
6use os_pipe::PipeReader;
7use rhai::{Dynamic, FnPtr, ImmutableString, Map as RhaiMap, NativeCallContext, INT};
8use std::collections::HashSet;
9use std::io::{self, ErrorKind, Read, Write};
10use std::path::PathBuf;
11use std::sync::mpsc::{self, RecvTimeoutError, Sender};
12use std::sync::Arc;
13use std::thread;
14use std::time::{Duration, Instant};
15
16#[derive(Clone, Debug)]
17pub struct PipelineExecutor {
18    pub(crate) config: Arc<Config>,
19    pub(crate) commands: Vec<CommandSpec>,
20    pub(crate) timeout_override_ms: Option<u64>,
21    pub(crate) allowed_exit_codes: Option<HashSet<i64>>,
22    pub(crate) cwd: Option<PathBuf>,
23}
24
25impl PipelineExecutor {
26    pub(crate) fn new(config: Arc<Config>, commands: Vec<CommandSpec>) -> Self {
27        Self {
28            config,
29            commands,
30            timeout_override_ms: None,
31            allowed_exit_codes: None,
32            cwd: None,
33        }
34    }
35
36    pub fn cwd(mut self, path: String) -> RhaiResult<Self> {
37        if path.is_empty() {
38            self.cwd = None;
39        } else {
40            self.cwd = Some(PathBuf::from(path));
41        }
42        Ok(self)
43    }
44
45    pub fn timeout(mut self, timeout: INT) -> RhaiResult<Self> {
46        if timeout <= 0 {
47            return Err(runtime_error("timeout must be a positive integer"));
48        }
49        self.timeout_override_ms = Some(timeout as u64);
50        Ok(self)
51    }
52
53    pub fn allow_exit_codes(mut self, codes: RhaiArray) -> RhaiResult<Self> {
54        let mut set = HashSet::new();
55        for code in codes {
56            let value = code
57                .clone()
58                .try_cast::<INT>()
59                .ok_or_else(|| runtime_error("allow_exit_codes expects integers"))?;
60            set.insert(value);
61        }
62        self.allowed_exit_codes = normalize_exit_codes(set);
63        Ok(self)
64    }
65
66    pub fn run(self) -> RhaiResult<RhaiMap> {
67        let timeout = self.timeout_override_ms.or(self.config.default_timeout_ms);
68        let result = run_pipeline(
69            &self.commands,
70            timeout,
71            self.allowed_exit_codes.clone(),
72            self.cwd,
73        )?;
74        Ok(result.into_map())
75    }
76
77    pub fn run_stream(
78        self,
79        context: &NativeCallContext,
80        stdout_cb: Option<FnPtr>,
81        stderr_cb: Option<FnPtr>,
82    ) -> RhaiResult<RhaiMap> {
83        let timeout = self.timeout_override_ms.or(self.config.default_timeout_ms);
84        let result = run_pipeline_stream(
85            &self.commands,
86            timeout,
87            self.allowed_exit_codes.clone(),
88            self.cwd,
89            context,
90            stdout_cb,
91            stderr_cb,
92        )?;
93        Ok(result.into_map())
94    }
95}
96
97#[derive(Debug)]
98struct ProcessResult {
99    success: bool,
100    status: i64,
101    stdout: String,
102    stderr: String,
103    duration_ms: u64,
104}
105
106impl ProcessResult {
107    fn into_map(self) -> RhaiMap {
108        let mut map = RhaiMap::new();
109        map.insert("success".into(), Dynamic::from_bool(self.success));
110        map.insert("status".into(), Dynamic::from_int(self.status as INT));
111        map.insert("stdout".into(), Dynamic::from(self.stdout));
112        map.insert("stderr".into(), Dynamic::from(self.stderr));
113        let duration_int: INT = self.duration_ms.try_into().unwrap_or(i64::MAX);
114        map.insert("duration_ms".into(), Dynamic::from_int(duration_int));
115        map
116    }
117}
118
119fn run_pipeline(
120    commands: &[CommandSpec],
121    timeout_ms: Option<u64>,
122    allowed_exit_codes: Option<HashSet<i64>>,
123    cwd: Option<PathBuf>,
124) -> RhaiResult<ProcessResult> {
125    if commands.is_empty() {
126        return Err(runtime_error("no command specified"));
127    }
128    let mut expression = build_expression(commands, cwd.as_ref())?;
129    expression = expression.stdout_capture().stderr_capture().unchecked();
130    let start = Instant::now();
131    let output = match timeout_ms {
132        Some(ms) => run_with_timeout(expression, Duration::from_millis(ms)).map_err(map_io_err)?,
133        None => expression.run().map_err(map_io_err)?,
134    };
135    let duration = start.elapsed();
136    let exit_code = output.status.code().map(|c| c as i64).unwrap_or(-1);
137    let mut success = output.status.success();
138    if !success {
139        if let Some(allowed) = allowed_exit_codes.as_ref() {
140            if allowed.contains(&exit_code) {
141                success = true;
142            }
143        }
144    }
145
146    Ok(ProcessResult {
147        success,
148        status: exit_code,
149        stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
150        stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
151        duration_ms: duration.as_millis().try_into().unwrap_or(u64::MAX),
152    })
153}
154
155fn run_pipeline_stream(
156    commands: &[CommandSpec],
157    timeout_ms: Option<u64>,
158    allowed_exit_codes: Option<HashSet<i64>>,
159    cwd: Option<PathBuf>,
160    context: &NativeCallContext,
161    stdout_cb: Option<FnPtr>,
162    stderr_cb: Option<FnPtr>,
163) -> RhaiResult<ProcessResult> {
164    if commands.is_empty() {
165        return Err(runtime_error("no command specified"));
166    }
167
168    let mut expression = build_expression(commands, cwd.as_ref())?;
169    let (stdout_reader, stdout_writer) = os_pipe::pipe().map_err(map_io_err)?;
170    let (stderr_reader, stderr_writer) = os_pipe::pipe().map_err(map_io_err)?;
171    expression = expression
172        .stdout_file(stdout_writer)
173        .stderr_file(stderr_writer)
174        .unchecked();
175
176    let handle = expression.start().map_err(map_io_err)?;
177    drop(expression);
178    let start = Instant::now();
179    let (tx, rx) = mpsc::channel();
180    spawn_stream_reader(stdout_reader, tx.clone(), StreamKind::Stdout);
181    spawn_stream_reader(stderr_reader, tx, StreamKind::Stderr);
182
183    let mut stdout_open = true;
184    let mut stderr_open = true;
185    let mut process_finished = false;
186
187    while stdout_open || stderr_open {
188        if let Some(limit) = timeout_ms {
189            if start.elapsed() >= Duration::from_millis(limit) {
190                handle.kill().ok();
191                return Err(map_io_err(io::Error::new(
192                    ErrorKind::TimedOut,
193                    "process execution timed out",
194                )));
195            }
196        }
197
198        match rx.recv_timeout(Duration::from_millis(50)) {
199            Ok(StreamMessage::Data(kind, chunk)) => {
200                dispatch_stream_chunk(
201                    kind,
202                    &chunk,
203                    context,
204                    stdout_cb.as_ref(),
205                    stderr_cb.as_ref(),
206                )?;
207            }
208            Ok(StreamMessage::Eof(kind)) => match kind {
209                StreamKind::Stdout => stdout_open = false,
210                StreamKind::Stderr => stderr_open = false,
211            },
212            Ok(StreamMessage::Error(err)) => {
213                handle.kill().ok();
214                return Err(map_io_err(err));
215            }
216            Err(RecvTimeoutError::Timeout) => {
217                if !process_finished && handle.try_wait().map_err(map_io_err)?.is_some() {
218                    process_finished = true;
219                }
220                continue;
221            }
222            Err(RecvTimeoutError::Disconnected) => break,
223        }
224    }
225
226    let duration = start.elapsed();
227    let output = handle.wait().map_err(map_io_err)?;
228    let exit_code = output.status.code().map(|c| c as i64).unwrap_or(-1);
229    let mut success = output.status.success();
230    if !success {
231        if let Some(allowed) = allowed_exit_codes.as_ref() {
232            if allowed.contains(&exit_code) {
233                success = true;
234            }
235        }
236    }
237
238    Ok(ProcessResult {
239        success,
240        status: exit_code,
241        stdout: String::new(),
242        stderr: String::new(),
243        duration_ms: duration.as_millis().try_into().unwrap_or(u64::MAX),
244    })
245}
246
247fn build_expression(commands: &[CommandSpec], cwd: Option<&PathBuf>) -> RhaiResult<Expression> {
248    let mut iter = commands.iter();
249    let first = iter
250        .next()
251        .ok_or_else(|| runtime_error("no command specified"))?;
252    let mut expression = expression_from_spec(first, cwd);
253    for command in iter {
254        let next_expr = expression_from_spec(command, cwd);
255        expression = expression.pipe(next_expr);
256    }
257    Ok(expression)
258}
259
260fn expression_from_spec(spec: &CommandSpec, cwd: Option<&PathBuf>) -> Expression {
261    let mut expr = duct::cmd(spec.program.clone(), spec.args.clone());
262    if let Some(dir) = cwd {
263        expr = expr.dir(dir.clone());
264    }
265    for (key, value) in &spec.env {
266        expr = expr.env(key, value);
267    }
268    expr
269}
270
271fn run_with_timeout(expr: Expression, limit: Duration) -> io::Result<std::process::Output> {
272    let handle = Arc::new(expr.start()?);
273    drop(expr);
274
275    let wait_handle = Arc::clone(&handle);
276    let (tx, rx) = mpsc::channel();
277    thread::spawn(move || {
278        let result = wait_handle
279            .wait()
280            .map(|output| std::process::Output {
281                status: output.status,
282                stdout: output.stdout.clone(),
283                stderr: output.stderr.clone(),
284            });
285        let _ = tx.send(result);
286    });
287
288    match rx.recv_timeout(limit) {
289        Ok(result) => result,
290        Err(RecvTimeoutError::Timeout) => {
291            handle.kill()?;
292            Err(io::Error::new(
293                io::ErrorKind::TimedOut,
294                "process execution timed out",
295            ))
296        }
297        Err(RecvTimeoutError::Disconnected) => Err(io::Error::new(
298            io::ErrorKind::Other,
299            "process execution failed",
300        )),
301    }
302}
303
304#[derive(Copy, Clone)]
305enum StreamKind {
306    Stdout,
307    Stderr,
308}
309
310enum StreamMessage {
311    Data(StreamKind, Vec<u8>),
312    Eof(StreamKind),
313    Error(io::Error),
314}
315
316fn spawn_stream_reader(reader: PipeReader, sender: Sender<StreamMessage>, kind: StreamKind) {
317    thread::spawn(move || {
318        let mut reader = reader;
319        let mut buffer = [0u8; 8 * 1024];
320        loop {
321            match reader.read(&mut buffer) {
322                Ok(0) => {
323                    let _ = sender.send(StreamMessage::Eof(kind));
324                    break;
325                }
326                Ok(n) => {
327                    if sender
328                        .send(StreamMessage::Data(kind, buffer[..n].to_vec()))
329                        .is_err()
330                    {
331                        break;
332                    }
333                }
334                Err(ref err) if err.kind() == ErrorKind::Interrupted => continue,
335                Err(err) => {
336                    let _ = sender.send(StreamMessage::Error(err));
337                    break;
338                }
339            }
340        }
341    });
342}
343
344fn dispatch_stream_chunk(
345    kind: StreamKind,
346    chunk: &[u8],
347    context: &NativeCallContext,
348    stdout_cb: Option<&FnPtr>,
349    stderr_cb: Option<&FnPtr>,
350) -> RhaiResult<()> {
351    let text = String::from_utf8_lossy(chunk).to_string();
352    let value: ImmutableString = text.clone().into();
353
354    let target = match kind {
355        StreamKind::Stdout => stdout_cb,
356        StreamKind::Stderr => stderr_cb,
357    };
358
359    if let Some(callback) = target {
360        let _ = callback.call_within_context::<Dynamic>(context, (value,))?;
361    } else {
362        match kind {
363            StreamKind::Stdout => {
364                print!("{}", text);
365                let _ = io::stdout().flush();
366            }
367            StreamKind::Stderr => {
368                eprint!("{}", text);
369                let _ = io::stderr().flush();
370            }
371        }
372    }
373
374    Ok(())
375}