clitest_lib/
command.rs

1use std::{
2    collections::HashMap,
3    io::{BufRead, BufReader},
4    process::{Command, ExitStatus, Stdio},
5    thread,
6    time::Duration,
7};
8
9use serde::Serialize;
10use shellish_parse::ParseOptions;
11use termcolor::Color;
12
13use crate::{
14    cwrite, cwriteln,
15    output::Lines,
16    script::{ScriptKillReceiver, ScriptKillSender, ScriptLocation},
17};
18
19#[derive(Clone, Debug, Serialize)]
20#[serde(transparent)]
21pub struct CommandLine {
22    pub command: String,
23    #[serde(skip)]
24    pub location: ScriptLocation,
25    #[serde(skip)]
26    pub line_count: usize,
27}
28
29impl CommandLine {
30    pub fn new(command: String, location: ScriptLocation, line_count: usize) -> Self {
31        Self {
32            command,
33            location,
34            line_count,
35        }
36    }
37
38    pub fn run(
39        &self,
40        writer: &mut dyn termcolor::WriteColor,
41        show_line_numbers: bool,
42        runner: Option<String>,
43        timeout: Duration,
44        envs: &HashMap<String, String>,
45        kill_receiver: &ScriptKillReceiver,
46        kill_sender: &ScriptKillSender,
47    ) -> Result<(Lines, ExitStatus), std::io::Error> {
48        // This fails to exit if the command hangs....
49        thread::scope(|s| {
50            let mut command = if let Some(runner) = runner {
51                let bits = shellish_parse::parse(&runner, ParseOptions::default())
52                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e))?;
53                let mut cmd = Command::new(&bits[0]);
54                cmd.args(&bits[1..]);
55                cmd
56            } else {
57                let mut cmd = Command::new("sh");
58                cmd.arg("-c");
59                cmd
60            };
61            command.arg(&self.command);
62            command.envs(envs);
63            if let Some(pwd) = envs.get("PWD") {
64                command.current_dir(pwd);
65            }
66            #[cfg(unix)]
67            {
68                use std::os::unix::process::CommandExt;
69                command.process_group(0);
70            }
71            #[cfg(windows)]
72            {
73                use std::os::windows::process::CommandExt;
74                const CREATE_SUSPENDED: u32 = 0x00000004;
75                command.creation_flags(CREATE_SUSPENDED);
76            }
77            command.stdout(Stdio::piped());
78            command.stderr(Stdio::piped());
79            command.stdin(Stdio::null());
80            let mut output = command.spawn().map_err(|e| {
81                std::io::Error::new(
82                    e.kind(),
83                    format!("failed to spawn command {command:?}: {e}"),
84                )
85            })?;
86            let (tx, rx) = std::sync::mpsc::channel();
87
88            // Spawn a thread for stdout and stderr and collect each line we read into a buffer
89            let stdout_lines = tx.clone();
90            let stdout = output.stdout.take().unwrap();
91            let stdout = s.spawn(move || {
92                let mut reader = BufReader::new(stdout);
93                let mut line = String::new();
94                while reader.read_line(&mut line).unwrap() > 0 {
95                    if line.is_empty() {
96                        continue;
97                    }
98                    if line.ends_with('\n') {
99                        line.pop();
100                    }
101                    _ = stdout_lines.send((true, std::mem::take(&mut line)));
102                }
103            });
104
105            let stderr_lines = tx;
106            let stderr = output.stderr.take().unwrap();
107            let stderr = s.spawn(move || {
108                let mut reader = BufReader::new(stderr);
109                let mut line = String::new();
110                while reader.read_line(&mut line).unwrap() > 0 {
111                    if line.is_empty() {
112                        continue;
113                    }
114                    if line.ends_with('\n') {
115                        line.pop();
116                    }
117                    _ = stderr_lines.send((false, std::mem::take(&mut line)));
118                }
119            });
120
121            let runner = s.spawn(move || kill_receiver.run_cmd(output));
122
123            let mut line_number = 1;
124            let mut output_lines = vec![];
125
126            while let Ok((is_stdout, line)) = rx.recv_timeout(timeout) {
127                if show_line_numbers {
128                    cwrite!(
129                        writer,
130                        fg = Color::White,
131                        dimmed = true,
132                        "{line_number:>3} "
133                    );
134                }
135                if is_stdout {
136                    cwriteln!(writer, fg = Color::White, "{line}");
137                } else {
138                    cwriteln!(writer, fg = Color::Yellow, "{line}");
139                }
140
141                output_lines.push(line);
142                line_number += 1;
143            }
144
145            let join_start = std::time::Instant::now();
146            let mut handles = vec![stdout, stderr];
147            while !handles.is_empty() {
148                if join_start.elapsed() > timeout {
149                    cwriteln!(writer, fg = Color::Yellow, "Process took too long!");
150                    kill_sender.kill();
151
152                    return Err(std::io::Error::new(
153                        std::io::ErrorKind::TimedOut,
154                        "process took too long to join",
155                    ));
156                }
157
158                let mut new_handles = vec![];
159                for handle in handles.drain(..) {
160                    if handle.is_finished() {
161                        handle
162                            .join()
163                            .map_err(|_| std::io::Error::other("thread panicked"))?;
164                    } else {
165                        new_handles.push(handle);
166                    }
167                }
168                handles = new_handles;
169                std::thread::sleep(std::time::Duration::from_millis(10));
170            }
171
172            Ok((Lines::new(output_lines), runner.join().unwrap()?))
173        })
174    }
175}