#![allow(unused_imports)]
use std::{
io::{Read, Seek, Write},
process::{Child, Command, ExitStatus, Stdio},
sync::{
Arc, Condvar, Mutex, MutexGuard,
mpsc::{
Receiver, RecvTimeoutError, Sender, SyncSender, TryRecvError, channel, sync_channel,
},
},
thread::{JoinHandle, sleep, spawn},
time::{Duration, Instant},
};
#[cfg(unix)]
use std::os::unix::process::CommandExt;
use tempfile::tempfile;
use crate::logger;
const KIB: usize = 1024;
const MIB: usize = 1024 * KIB;
const MAX_OUTPUT_BYTES: usize = 10 * MIB;
pub struct TimeoutCommand {
max_duration: Duration,
stdin_data: Vec<u8>,
stdout: RealtimeLines,
}
impl TimeoutCommand {
pub fn new(max_duration: Duration) -> Self {
Self {
max_duration,
stdin_data: vec![],
stdout: RealtimeLines::new(max_duration),
}
}
pub fn feed_stdin(&mut self, data: &[u8]) {
self.stdin_data = data.to_vec();
}
pub fn stdout(&self) -> RealtimeLines {
self.stdout.clone()
}
pub fn spawn(&self, command: Command) -> Result<ChildProcess, TimeoutError> {
ChildProcess::new(command, &self.stdin_data, self.stdout(), self.max_duration)
}
}
#[derive(Debug)]
pub struct FinishedProcess {
exit: ExitStatus,
stderr: Vec<u8>,
}
impl FinishedProcess {
pub fn exit_code(&self) -> ExitStatus {
self.exit
}
pub fn stderr(&self) -> &[u8] {
&self.stderr
}
}
pub struct ChildProcess {
deadline: Instant,
child: Child,
stdout: RealtimeLines,
stdout_rx: Receiver<()>,
stdout_thread: JoinHandle<Result<(), TimeoutError>>,
stderr_thread: JoinHandle<Result<Vec<u8>, TimeoutError>>,
arc: Arc<Mutex<RealtimeLines>>,
#[allow(dead_code)]
timeout_thread: JoinHandle<()>,
}
impl ChildProcess {
pub fn new(
mut cmd: Command,
stdin: &[u8],
stdout_lines: RealtimeLines,
timeout: Duration,
) -> Result<Self, TimeoutError> {
let mut file = tempfile::tempfile()?;
file.write_all(stdin)?;
file.rewind()?;
let mut child = cmd
.stdin(file)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.process_group(0)
.spawn()
.map_err(|err| TimeoutError::Spawn(cmd, err))?;
let (stdout_tx, stdout_rx) = channel::<()>();
let arc = Arc::new(Mutex::new(stdout_lines.clone()));
let stdout = child
.stdout
.take()
.ok_or(TimeoutError::TakeHandle("stdout"))?;
let stdout_thread = {
let arc = arc.clone();
spawn(move || Self::line_reader(arc, stdout_tx, stdout))
};
let stderr = child
.stderr
.take()
.ok_or(TimeoutError::TakeHandle("stderr"))?;
let stderr_thread = Self::capture(stderr)?;
let timeout_thread = {
let mut lines = stdout_lines.clone();
spawn(move || {
sleep(timeout);
lines.finish();
})
};
Ok(Self {
deadline: Instant::now()
.checked_add(timeout)
.ok_or(TimeoutError::Deadline)?,
child,
stdout: stdout_lines,
stdout_rx,
stdout_thread,
stderr_thread,
arc,
timeout_thread,
})
}
pub fn id(&self) -> u32 {
self.child.id()
}
pub fn kill(mut self) -> Result<FinishedProcess, TimeoutError> {
self.kill_helper();
self.wait()
}
fn kill_helper(&mut self) {
unsafe {
libc::killpg(self.child.id() as i32, libc::SIGKILL);
}
self.stdout.finish();
}
pub fn wait(mut self) -> Result<FinishedProcess, TimeoutError> {
self.stdout.finish();
let max_wait = self.deadline - Instant::now();
let _guardlock = &*self.arc;
match self.stdout_rx.recv_timeout(max_wait) {
Ok(_) | Err(RecvTimeoutError::Disconnected) => {}
Err(RecvTimeoutError::Timeout) => {
self.kill_helper();
return Err(TimeoutError::TimedOut);
}
}
self.stdout_thread
.join()
.map_err(|_| TimeoutError::Thread)??;
let stderr = self
.stderr_thread
.join()
.map_err(|_| TimeoutError::Thread)?;
let stderr = stderr?;
match self.child.wait() {
Ok(exit) => Ok(FinishedProcess { exit, stderr }),
Err(err) => Err(TimeoutError::Wait(err)),
}
}
fn line_reader(
arc: Arc<Mutex<RealtimeLines>>,
tx: Sender<()>,
mut stream: impl Read,
) -> Result<(), TimeoutError> {
let mutex = &*arc;
loop {
let mut bytes = vec![0; 1024];
let n = stream.read(&mut bytes)?;
if n == 0 {
break;
} else {
let mut buf = mutex.lock().map_err(|_| TimeoutError::Lock)?;
buf.push(bytes[..n].to_vec());
}
}
let mut buf = mutex.lock().map_err(|_| TimeoutError::Lock)?;
buf.finish();
tx.send(()).ok();
Ok(())
}
fn capture(
mut stream: impl Read + Send + 'static,
) -> Result<JoinHandle<Result<Vec<u8>, TimeoutError>>, TimeoutError> {
let thread = spawn(move || {
let mut buf = vec![];
loop {
let mut chunk = vec![0; MIB];
let n = stream.read(&mut chunk).map_err(TimeoutError::Io)?;
if n == 0 {
return Ok(buf);
} else {
buf.append(&mut chunk[..n].to_vec());
if buf.len() > MAX_OUTPUT_BYTES {
return Err(TimeoutError::TooMuch);
}
}
}
});
Ok(thread)
}
}
#[derive(Clone)]
pub struct RealtimeLines {
data: Arc<(Mutex<UnlockedBuf>, Condvar)>,
started: Instant,
max_duration: Duration,
}
impl RealtimeLines {
fn new(max_duration: Duration) -> Self {
Self {
data: Arc::new((Mutex::new(UnlockedBuf::default()), Condvar::default())),
started: Instant::now(),
max_duration,
}
}
pub fn push(&mut self, more_data: Vec<u8>) {
let (mutex, var) = &*self.data;
let mut buf = mutex.lock().expect("lock for push");
buf.push(more_data);
var.notify_all();
}
pub fn finish(&mut self) {
let (mutex, var) = &*self.data;
let mut buf = mutex.lock().expect("lock for push");
buf.finish();
var.notify_all();
}
pub fn line(&mut self) -> Option<String> {
let (mutex, var) = &*self.data;
let mut buf = mutex.lock().expect("lock to wait for line");
loop {
let remaining = self
.max_duration
.checked_sub(self.started.elapsed())
.unwrap_or_default();
if remaining.as_millis() == 0 {
return None;
}
let line = buf.line();
match line {
None if buf.is_finished() => {
let line = buf.line();
return line;
}
None => {
let result = var.wait_timeout(buf, remaining).expect("wait for line");
buf = result.0;
}
Some(line) => {
return Some(line);
}
}
}
}
}
#[derive(Default, Debug)]
struct UnlockedBuf {
data: Vec<u8>,
finished: bool,
}
impl UnlockedBuf {
fn finish(&mut self) {
self.finished = true;
}
fn is_finished(&self) -> bool {
self.finished
}
fn push(&mut self, mut more_data: Vec<u8>) {
self.data.append(&mut more_data);
}
fn line(&mut self) -> Option<String> {
for (i, byte) in self.data.iter().enumerate() {
if *byte == b'\n' {
let range = 0..i + 1;
let line = String::from_utf8_lossy(&self.data[range.clone()]).to_string();
self.data.drain(range);
return Some(line);
}
}
if self.finished && !self.data.is_empty() {
let line = String::from_utf8_lossy(&self.data).to_string();
self.data.clear();
return Some(line);
}
None
}
}
#[derive(Debug, thiserror::Error)]
pub enum TimeoutError {
#[error(transparent)]
Io(#[from] std::io::Error),
#[error("failed to spawn command: {0:?}")]
Spawn(Command, #[source] std::io::Error),
#[error("thread join failed")]
Thread,
#[error("failed to lock mutex")]
Lock,
#[error("failed to lock condition variable")]
LockVar,
#[error("timed out waiting for child")]
TimedOut,
#[error("failed to kill child process")]
Kill,
#[error("failed waiting for child process to terminate")]
Wait(#[source] std::io::Error),
#[error("child exit code is not known")]
ExitCode,
#[error("failed to take child {0} file handle")]
TakeHandle(&'static str),
#[error("programming error: failed to get thread to wait on")]
TakeThread,
#[error("failed to compute deadline")]
Deadline,
#[error("child process produced too much output")]
TooMuch,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
const LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES: Duration = Duration::from_secs(10);
const SHORT_TIMEOUT: Duration = Duration::from_secs(3);
fn setup(
script: &str,
timeout: Duration,
stdin: Option<&'static str>,
) -> Result<(ChildProcess, RealtimeLines), TimeoutError> {
let mut cmd = Command::new("sh");
cmd.arg("-c").arg(script);
let mut to = TimeoutCommand::new(timeout);
if let Some(stdin) = stdin {
to.feed_stdin(stdin.as_bytes());
}
Ok((to.spawn(cmd)?, to.stdout()))
}
#[test]
fn bin_true() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup("exec true", LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES, None)?;
let finished = running.wait()?;
assert_eq!(finished.exit_code().code(), Some(0));
Ok(())
}
#[test]
fn bin_false() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup("exec false", LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES, None)?;
let result = running.wait();
assert!(matches!(result, Ok(FinishedProcess { exit, .. }) if exit.code() == Some(1)));
Ok(())
}
#[test]
fn sleep_1() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup(
"exec sleep 1",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
None,
)?;
let result = running.wait();
assert!(matches!(result, Ok(FinishedProcess { exit, .. }) if exit.code() == Some(0)));
Ok(())
}
#[test]
fn sleep_for_too_long() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup("exec sleep 1000", SHORT_TIMEOUT, None)?;
let result = running.wait();
assert!(matches!(result, Err(TimeoutError::TimedOut)));
Ok(())
}
#[test]
fn hello_world() -> Result<(), Box<dyn std::error::Error>> {
let (running, mut stdout) = setup(
"exec echo hello, world",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
None,
)?;
assert_eq!(stdout.line(), Some("hello, world\n".into()));
assert_eq!(stdout.line(), None);
let result = running.wait();
eprintln!("result={result:#?}");
let finished = result.unwrap();
assert_eq!(finished.exit_code().code(), Some(0));
assert_eq!(finished.stderr(), b"");
Ok(())
}
#[test]
fn hello_world_to_stderr() -> Result<(), Box<dyn std::error::Error>> {
let (running, mut stdout) = setup(
"exec echo hello, world 1>&2",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
None,
)?;
assert_eq!(stdout.line(), None);
let finished = running.wait().unwrap();
assert_eq!(finished.exit_code().code(), Some(0));
assert_eq!(finished.stderr(), b"hello, world\n");
Ok(())
}
#[test]
fn pipe_through_cat() -> Result<(), Box<dyn std::error::Error>> {
let (running, mut stdout) = setup(
"exec cat",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
Some("hello, world"),
)?;
assert_eq!(stdout.line(), Some("hello, world".into()));
assert_eq!(stdout.line(), None);
let finished = running.wait().unwrap();
assert_eq!(finished.exit_code().code(), Some(0));
assert_eq!(finished.stderr(), b"");
Ok(())
}
#[test]
fn yes_to_stdout() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup("exec yes", SHORT_TIMEOUT, None)?;
assert!(matches!(running.wait(), Err(TimeoutError::TimedOut)));
Ok(())
}
#[test]
fn yes_to_stderr() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup("exec yes 1>&2", SHORT_TIMEOUT, None)?;
assert!(matches!(
running.wait(),
Err(TimeoutError::TimedOut) | Err(TimeoutError::TooMuch)
));
Ok(())
}
#[test]
fn yes_to_stdout_while_reading_with_realtimelines() -> Result<(), Box<dyn std::error::Error>> {
let (running, mut stdout) = setup("exec yes", SHORT_TIMEOUT, None)?;
while stdout.line().is_some() {}
let result = running.wait();
eprintln!("result: {result:#?}");
assert!(matches!(result, Err(TimeoutError::TimedOut)));
Ok(())
}
#[test]
fn sleep_for_too_long_while_reading_with_realtimelines()
-> Result<(), Box<dyn std::error::Error>> {
let (running, mut stdout) = setup("exec sleep 1000", SHORT_TIMEOUT, None)?;
while stdout.line().is_some() {}
let result = running.wait();
eprintln!("result: {result:#?}");
assert!(matches!(result, Err(TimeoutError::TimedOut)));
Ok(())
}
#[test]
fn kill() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup(
"exec sleep 1000",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
None,
)?;
sleep(Duration::from_millis(100));
let finished = running.kill()?;
assert_eq!(finished.exit_code().code(), None);
Ok(())
}
#[test]
fn kill_stderr() -> Result<(), Box<dyn std::error::Error>> {
let (running, _) = setup(
"exec sleep 1000 1>&2",
LONG_ENOUGH_THAT_SCRIPT_SURELY_FINISHES,
None,
)?;
sleep(Duration::from_millis(100));
let finished = running.kill()?;
assert_eq!(finished.exit_code().code(), None);
Ok(())
}
}