use std::{
error::Error,
io::{self, BufRead, BufReader, Read, Write},
mem,
path::Path,
process::{Child, ChildStdin, ChildStdout, Command, Output, Stdio},
sync::{Arc, Condvar, Mutex},
thread::{self, JoinHandle},
time::{Duration, Instant},
};
use tempfile::TempPath;
use crate::helper::format_command;
const PRESERVE_TMP_ENV_VAR: &str = "CRITERION_PRESERVE_TMP";
pub fn run_command(
description: &str,
command: &mut Command,
temp_files: impl IntoIterator<Item = TempPath>,
) -> Output {
let output = command.output().unwrap();
if !output.status.success() {
let temp_files = temp_files.into_iter().filter_map(|f| f.keep().ok());
let err = io::Error::new(io::ErrorKind::Other, "command failed");
panic_with_output(description, command, &output, err, temp_files);
}
if std::env::var_os(PRESERVE_TMP_ENV_VAR).is_some() {
temp_files.into_iter().for_each(|f| { let _ = f.keep(); });
}
output
}
fn panic_with_output<P>(
description: &str,
command: &Command,
output: &Output,
error: impl Error + 'static,
temp_files: impl IntoIterator<Item = P>,
) -> !
where
P: AsRef<Path>,
{
eprintln!(
"\n\n# Command (failed with {})\n {}",
output.status,
format_command(command)
);
let mut temp_files = temp_files
.into_iter()
.map(|f| f.as_ref().to_string_lossy().into_owned())
.enumerate()
.peekable();
while let Some((i, file)) = temp_files.next() {
if i == 0 {
eprint!("\n# Temporary files preserved");
}
eprint!("\n - {file}");
if temp_files.peek().is_none() {
eprintln!();
}
}
if !output.stdout.is_empty() {
eprintln!(
"\n# Process Standard Output\n{}\n\n",
String::from_utf8_lossy(&output.stdout)
);
}
eprintln!(
"\n# Process Standard Error\n{}\n\n",
String::from_utf8_lossy(&output.stderr)
);
panic!(
"{}: {} benchmark failed with error \"{}\"",
env!("CARGO_PKG_NAME"),
description,
error
);
}
pub struct CommunicatingBenchmark {
child: Option<Child>,
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
stdout_buf: Vec<u8>,
stderr_task: Option<JoinHandle<Vec<u8>>>,
command: Command,
description: String,
temp_files: Vec<TempPath>,
}
impl CommunicatingBenchmark {
const TIME_SUFFIX: &str = " nsec";
pub fn start(
description: &str,
mut command: Command,
temp_files: impl IntoIterator<Item = TempPath>,
) -> Self {
let mut child = command
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("could not spawn benchmark child process");
let stdin = child.stdin.take().unwrap();
let stdout = BufReader::new(child.stdout.take().unwrap());
let mut stderr = child.stderr.take().unwrap();
let stderr_task = thread::spawn(move || {
let mut output = Vec::new();
if let Err(e) = stderr.read_to_end(&mut output) {
write!(output, "Failed reading stderr: {}", e).unwrap();
}
output
});
Self {
child: Some(child),
stdin,
stdout,
stdout_buf: Vec::new(),
stderr_task: Some(stderr_task),
command,
description: description.to_string(),
temp_files: temp_files.into_iter().collect(),
}
}
pub fn run(&mut self, n: u64) -> Duration {
self.request_run(n)
.unwrap_or_else(|e| self.panic_on_error(e))
}
fn request_run(&mut self, n: u64) -> io::Result<Duration> {
writeln!(self.stdin, "{}", n).and_then(|_| self.stdin.flush())?;
self.parse_elapsed_line()
}
pub fn terminate(mut self) -> Output {
const TIMEOUT: Duration = Duration::from_secs(5);
let mut child = self.child.take().unwrap();
let pair = Arc::new((Mutex::new(false), Condvar::new()));
let pair2 = Arc::clone(&pair);
let timeout_thread = thread::spawn(move || {
let (lock, cvar) = &*pair2;
let mut exited = lock.lock().unwrap();
let began = Instant::now();
let mut timeout_remaining = TIMEOUT;
loop {
if !*exited {
let result = cvar.wait_timeout(exited, timeout_remaining).unwrap();
exited = result.0;
}
let elapsed = began.elapsed();
if *exited {
if let Some(status) = child.try_wait().unwrap() {
break (child, Ok(status));
}
thread::yield_now();
} else if elapsed >= TIMEOUT {
child.kill().expect("child already killed");
let err = io::Error::new(
io::ErrorKind::Other,
"timed out waiting on benchmark to exit",
);
break (child, Err(err));
}
timeout_remaining = TIMEOUT - elapsed;
}
});
let run_result = self.request_run(0).and_then(|d| {
if d.is_zero() {
Ok(d)
} else {
Err(io::Error::new(
io::ErrorKind::Other,
"benchmark exited with a non-zero duration",
))
}
});
let (lock, cvar) = &*pair;
{
let mut exited = lock.lock().unwrap();
*exited = true;
}
cvar.notify_one();
let (child, timeout_result) = timeout_thread.join().unwrap();
self.child.replace(child);
let status = match timeout_result {
Ok(s) => s,
Err(e) => self.panic_on_error(e),
};
if let Err(e) = run_result {
self.panic_on_error(e);
}
if std::env::var_os(PRESERVE_TMP_ENV_VAR).is_some() {
self.temp_files.drain(..).for_each(|f| { let _ = f.keep(); });
}
let (stdout, stderr) = self.get_remaining_output();
Output {
status,
stdout,
stderr,
}
}
fn panic_on_error(&mut self, err: impl Error + 'static) -> ! {
let dyn_err = &err as &dyn Error;
match dyn_err.downcast_ref::<io::Error>() {
Some(e) if e.kind() == io::ErrorKind::BrokenPipe => (),
_ => eprintln!("\n\nExiting due to error: {}", err),
}
let _ = self.child.as_mut().unwrap().kill();
let (stdout, stderr) = self.get_remaining_output();
let status = self.child.as_mut().unwrap().wait().unwrap();
let output = Output {
status,
stdout,
stderr,
};
let temp_files = self.temp_files.drain(..).filter_map(|f| f.keep().ok());
panic_with_output(&self.description, &self.command, &output, err, temp_files);
}
fn get_remaining_output(&mut self) -> (Vec<u8>, Vec<u8>) {
let mut stdout = mem::take(&mut self.stdout_buf);
if let Err(e) = self.stdout.read_to_end(&mut stdout) {
write!(stdout, "Failed reading stdout: {}", e).unwrap();
}
let stderr = self.stderr_task.take().unwrap().join().unwrap();
(stdout, stderr)
}
fn parse_elapsed_line(&mut self) -> io::Result<Duration> {
let read_start = self.stdout_buf.len();
let read_len = self
.stdout
.read_until(b'\n', &mut self.stdout_buf) .and_then(|b| {
if b == 0 {
Err(io::Error::new(io::ErrorKind::BrokenPipe, "unexpected EOF"))
} else {
Ok(b)
}
})?;
let read_end = read_start + read_len;
let parse_error = || {
io::Error::new(
io::ErrorKind::Other,
format!(
"failed parsing time string: {:?}",
String::from_utf8_lossy(&self.stdout_buf[read_start..read_end])
),
)
};
let line = std::str::from_utf8(&self.stdout_buf[read_start..read_end])
.map_err(|_| parse_error())?;
let time_end = line.find(Self::TIME_SUFFIX).ok_or_else(parse_error)?;
let suffix_end = time_end + Self::TIME_SUFFIX.len();
line.get(suffix_end..)
.map(|s| {
if s.starts_with(char::is_whitespace) {
Ok(s)
} else {
Err(parse_error())
}
})
.transpose()?;
let time = line[..time_end]
.rsplit(|c| char::is_ascii_whitespace(&c))
.next()
.ok_or_else(parse_error)?;
let time = time.parse().map_err(|_| parse_error())?;
let duration = Duration::from_nanos(time);
Ok(duration)
}
}
#[cfg(test)]
mod test {
use super::*;
const ONCE: [u64; 1] = [1];
test_binary::build_test_binary_once!(rhai_expect, "test_binaries");
fn make_rhai_command(script: &str) -> Command {
let mut command = Command::new(path_to_rhai_expect());
command.args(["-c", script]);
command
}
fn run_rhai_benchmark<I>(name: &str, iterations: I, script: &str) -> (CommunicatingBenchmark, Vec<Duration>)
where
I: IntoIterator<Item = u64>,
{
let command = make_rhai_command(script);
let mut benchmark = CommunicatingBenchmark::start(name, command, None);
let responses: Vec<_> = iterations
.into_iter()
.map(|n| benchmark.run(n))
.collect();
(benchmark, responses)
}
fn rhai_static_response(response: &str) -> String {
let response = response.replace('"', "\\\"");
format!(r#"for line in stdin_lines() {{ print("{response}"); }}"#)
}
#[test]
fn terminate_normal() {
let script =
r#"
for line in stdin_lines() {
if line == "0" {
print("0 nsec");
return;
}
}
"#;
let (benchmark, _) = run_rhai_benchmark("terminate_normal", None, script);
let output = benchmark.terminate();
assert!(output.status.success());
assert_eq!(&output.stdout, b"0 nsec\n");
assert!(output.stderr.is_empty());
}
#[test]
fn run_normal() {
let script = r#"
for line in stdin_lines() {
let n = parse_int(line);
if n > 0 {
print(`${n} nsec`);
}
}
"#;
let iterations = [1, 54321];
let (benchmark, responses) = run_rhai_benchmark("run_normal", iterations, script);
assert_eq!(responses, iterations.map(Duration::from_nanos));
benchmark.child.unwrap().kill().unwrap();
}
#[test]
fn run_and_terminate_normal() {
use std::fmt::Write;
let command = make_rhai_command(
r#"
for line in stdin_lines() {
let n = parse_int(line);
if n == 0 {
print("0 nsec");
return;
} else if n > 0 {
print(`${n} nsec`);
} else {
throw;
}
}
"#,
);
let mut benchmark = CommunicatingBenchmark::start("terminate_normal", command, None);
let mut expected_stdout = String::new();
for i in 1..=100 {
benchmark.run(i);
writeln!(expected_stdout, "{} nsec", i).unwrap();
}
let output = benchmark.terminate();
writeln!(expected_stdout, "0 nsec").unwrap();
assert!(output.status.success());
assert_eq!(&output.stdout, expected_stdout.as_bytes());
assert!(output.stderr.is_empty());
}
#[test]
#[ignore]
#[should_panic(expected = "timed out waiting on benchmark to exit")]
fn benchmark_timeout() {
let script = r#"sleep(180)"#;
let (benchmark, _) = run_rhai_benchmark("benchmark_timeout", None, script);
benchmark.terminate();
}
#[test]
#[should_panic(expected = "benchmark exited with a non-zero duration")]
fn terminate_invalid_reply() {
let (benchmark, _) = run_rhai_benchmark("terminate_invalid_reply", None, r#"
for line in stdin_lines() {
if line == "0" {
print("555 nsec");
return;
}
}
"#);
benchmark.terminate();
}
#[test]
#[should_panic(expected = "unexpected EOF")]
fn premature_exit() {
run_rhai_benchmark("premature_exit", ONCE, r#"return"#);
}
#[test]
#[should_panic(expected = "unexpected EOF")]
fn exit_with_error() {
run_rhai_benchmark("exit_with_error", ONCE, r#"throw"#);
}
#[test]
#[should_panic(expected = "unexpected EOF")]
fn later_error() {
run_rhai_benchmark("later_error", 1..=5, r#"
let i = 1;
let fail_at = 4;
for line in stdin_lines() {
let n = parse_int(line);
if n < 1 || n >= fail_at {
throw;
} else {
print(`${n} nsec`);
}
i += 1;
}
"#);
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn garbage_stdout() {
run_rhai_benchmark("garbage_stdout", ONCE, &rhai_static_response("Ceci n'est pas une pipe"));
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn float_as_elapsed() {
run_rhai_benchmark("float_as_elapsed", ONCE, &rhai_static_response("0.123456789 nsec"));
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn wrong_units_elapsed() {
run_rhai_benchmark("wrong_units_elapsed", ONCE, &rhai_static_response("1 usec"));
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn integer_overflow_elapsed() {
run_rhai_benchmark("integer_overflow_elapsed", ONCE, &rhai_static_response("99999999999999999999999999999 nsec"));
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn no_space_elapsed() {
run_rhai_benchmark("no_space_elapsed", ONCE, &rhai_static_response("1nsec"));
}
#[test]
fn text_after_elapsed() {
let (benchmark, responses) = run_rhai_benchmark("text_after_elapsed", ONCE, &rhai_static_response("1 nsec commentary"));
assert_eq!(responses, ONCE.map(Duration::from_nanos));
benchmark.child.unwrap().kill().unwrap();
}
#[test]
#[should_panic(expected = "failed parsing time string")]
fn units_trailing_chars() {
run_rhai_benchmark("units_trailing_chars", ONCE, &rhai_static_response("1 nsectionalsofa"));
}
}