use std::io;
use std::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command, ExitStatus};
use std::sync::{Condvar, Mutex, MutexGuard};
#[cfg(feature = "timeout")]
use std::time::{Duration, Instant};
mod sys;
#[cfg(unix)]
pub mod unix;
#[derive(Debug)]
enum ChildState {
NotWaiting,
Waiting,
Exited(ExitStatus),
}
use crate::ChildState::{Exited, NotWaiting, Waiting};
#[derive(Debug)]
struct SharedChildInner {
child: Child,
state: ChildState,
}
#[derive(Debug)]
pub struct SharedChild {
inner: Mutex<SharedChildInner>,
condvar: Condvar,
}
impl SharedChild {
pub fn spawn(command: &mut Command) -> io::Result<Self> {
Ok(SharedChild {
inner: Mutex::new(SharedChildInner {
child: command.spawn()?,
state: NotWaiting,
}),
condvar: Condvar::new(),
})
}
pub fn new(mut child: Child) -> io::Result<Self> {
let state = if let Some(exit_status) = child.try_wait()? {
Exited(exit_status)
} else {
NotWaiting
};
Ok(SharedChild {
inner: Mutex::new(SharedChildInner { child, state }),
condvar: Condvar::new(),
})
}
pub fn id(&self) -> u32 {
self.inner.lock().unwrap().child.id()
}
pub fn wait(&self) -> io::Result<ExitStatus> {
let mut inner_guard = self.inner.lock().unwrap();
loop {
match inner_guard.state {
Exited(exit_status) => return Ok(exit_status),
Waiting => inner_guard = self.condvar.wait(inner_guard).unwrap(),
NotWaiting => break,
}
}
inner_guard.state = Waiting;
let handle = sys::get_handle(&inner_guard.child);
drop(inner_guard);
let wait_result = sys::wait_noreap(handle);
inner_guard = self.inner.lock().unwrap();
inner_guard.state = NotWaiting;
self.condvar.notify_all();
wait_result?;
let exit_status = inner_guard.child.wait()?;
inner_guard.state = Exited(exit_status);
Ok(exit_status)
}
#[cfg(feature = "timeout")]
pub fn wait_timeout(&self, timeout: Duration) -> io::Result<Option<ExitStatus>> {
let deadline = std::time::Instant::now() + timeout;
self.wait_deadline(deadline)
}
#[cfg(feature = "timeout")]
pub fn wait_deadline(&self, deadline: Instant) -> io::Result<Option<ExitStatus>> {
let mut inner_guard = self.inner.lock().unwrap();
loop {
match inner_guard.state {
Exited(exit_status) => return Ok(Some(exit_status)),
_ if deadline < Instant::now() => {
return self.try_wait_inner(inner_guard);
}
Waiting => {
let timeout = deadline.saturating_duration_since(Instant::now());
inner_guard = self.condvar.wait_timeout(inner_guard, timeout).unwrap().0;
}
NotWaiting => break,
}
}
inner_guard.state = Waiting;
let handle = sys::get_handle(&inner_guard.child);
drop(inner_guard);
let wait_result = sys::wait_deadline_noreap(handle, deadline);
inner_guard = self.inner.lock().unwrap();
inner_guard.state = NotWaiting;
self.condvar.notify_all();
let exited = wait_result?;
if exited {
let exit_status = inner_guard.child.wait()?;
inner_guard.state = Exited(exit_status);
Ok(Some(exit_status))
} else {
Ok(None)
}
}
pub fn try_wait(&self) -> io::Result<Option<ExitStatus>> {
let inner_guard = self.inner.lock().unwrap();
self.try_wait_inner(inner_guard)
}
fn try_wait_inner(
&self,
mut inner_guard: MutexGuard<SharedChildInner>,
) -> io::Result<Option<ExitStatus>> {
match inner_guard.state {
Exited(exit_status) => Ok(Some(exit_status)),
NotWaiting => {
if let Some(status) = inner_guard.child.try_wait()? {
inner_guard.state = Exited(status);
Ok(Some(status))
} else {
Ok(None)
}
}
Waiting => {
if sys::try_wait_noreap(sys::get_handle(&inner_guard.child))? {
drop(inner_guard);
let exit_status = self.wait()?;
Ok(Some(exit_status))
} else {
Ok(None)
}
}
}
}
pub fn kill(&self) -> io::Result<()> {
self.inner.lock().unwrap().child.kill()
}
pub fn into_inner(self) -> Child {
self.inner.into_inner().unwrap().child
}
pub fn take_stdin(&self) -> Option<ChildStdin> {
self.inner.lock().unwrap().child.stdin.take()
}
pub fn take_stdout(&self) -> Option<ChildStdout> {
self.inner.lock().unwrap().child.stdout.take()
}
pub fn take_stderr(&self) -> Option<ChildStderr> {
self.inner.lock().unwrap().child.stderr.take()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::error::Error;
use std::process::{Command, Stdio};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[cfg(unix)]
pub fn true_cmd() -> Command {
Command::new("true")
}
#[cfg(not(unix))]
pub fn true_cmd() -> Command {
let mut cmd = Command::new("python");
cmd.arg("-c").arg("");
cmd
}
#[cfg(unix)]
pub fn sleep_cmd(duration: Duration) -> Command {
let mut cmd = Command::new("sleep");
cmd.arg(format!("{}", duration.as_secs_f32()));
cmd
}
#[cfg(not(unix))]
pub fn sleep_cmd(duration: Duration) -> Command {
let mut cmd = Command::new("python");
cmd.arg("-c").arg(format!(
"import time; time.sleep({})",
duration.as_secs_f32()
));
cmd
}
pub fn sleep_forever_cmd() -> Command {
sleep_cmd(Duration::from_secs(1000000))
}
#[cfg(unix)]
pub fn cat_cmd() -> Command {
Command::new("cat")
}
#[cfg(not(unix))]
pub fn cat_cmd() -> Command {
let mut cmd = Command::new("python");
cmd.arg("-c").arg("");
cmd
}
#[test]
fn test_wait() {
let child = SharedChild::spawn(&mut true_cmd()).unwrap();
let id = child.id();
assert!(id > 0);
let status = child.wait().unwrap();
assert_eq!(status.code().unwrap(), 0);
}
#[cfg(feature = "timeout")]
fn exited_but_unawaited_child() -> SharedChild {
let child = SharedChild::spawn(&mut true_cmd()).unwrap();
let handle = sys::get_handle(&child.inner.lock().unwrap().child);
sys::wait_noreap(handle).unwrap();
child
}
#[test]
#[cfg(feature = "timeout")]
fn test_wait_timeout() {
let exited_child = exited_but_unawaited_child();
assert!(exited_child
.wait_timeout(Duration::from_secs(0))
.expect("no IO error")
.expect("did not time out")
.success());
assert!(exited_child
.wait_timeout(Duration::from_secs(0))
.expect("no IO error")
.expect("did not time out")
.success());
let long_child = Arc::new(SharedChild::spawn(&mut sleep_forever_cmd()).unwrap());
let status = long_child
.wait_timeout(Duration::from_millis(10))
.expect("no IO error");
assert!(status.is_none(), "timed out");
let long_child_clone = Arc::clone(&long_child);
std::thread::spawn(move || long_child_clone.wait().unwrap());
std::thread::sleep(Duration::from_millis(10));
let status = long_child
.wait_timeout(Duration::from_millis(10))
.expect("no IO error");
assert!(status.is_none(), "timed out");
long_child.kill().unwrap();
long_child
.wait_timeout(Duration::from_millis(100))
.expect("no IO error")
.expect("did not time out");
}
#[test]
#[cfg(feature = "timeout")]
fn test_wait_deadline() {
let exited_child = exited_but_unawaited_child();
assert!(exited_child
.wait_deadline(Instant::now() + Duration::from_secs(0))
.expect("no IO error")
.expect("did not time out")
.success());
assert!(exited_child
.wait_deadline(Instant::now() + Duration::from_secs(0))
.expect("no IO error")
.expect("did not time out")
.success());
let long_child = Arc::new(SharedChild::spawn(&mut sleep_forever_cmd()).unwrap());
let status = long_child
.wait_deadline(Instant::now() + Duration::from_millis(10))
.expect("no IO error");
assert!(status.is_none(), "timed out");
let long_child_clone = Arc::clone(&long_child);
std::thread::spawn(move || long_child_clone.wait().unwrap());
std::thread::sleep(Duration::from_millis(10));
let status = long_child
.wait_deadline(Instant::now() + Duration::from_millis(10))
.expect("no IO error");
assert!(status.is_none(), "timed out");
long_child.kill().unwrap();
long_child
.wait_deadline(Instant::now() + Duration::from_millis(100))
.expect("no IO error")
.expect("did not time out");
}
#[test]
fn test_kill() {
let child = SharedChild::spawn(&mut sleep_forever_cmd()).unwrap();
child.kill().unwrap();
let status = child.wait().unwrap();
assert!(!status.success());
}
#[test]
fn test_try_wait() {
let child = SharedChild::spawn(&mut sleep_forever_cmd()).unwrap();
let maybe_status = child.try_wait().unwrap();
assert_eq!(maybe_status, None);
child.kill().unwrap();
let mut maybe_status = None;
while maybe_status.is_none() {
maybe_status = child.try_wait().unwrap();
}
assert!(maybe_status.is_some());
assert!(!maybe_status.unwrap().success());
}
#[test]
fn test_many_waiters() {
let child = Arc::new(SharedChild::spawn(&mut sleep_forever_cmd()).unwrap());
let mut threads = Vec::new();
for _ in 0..10 {
let clone = child.clone();
threads.push(std::thread::spawn(move || clone.wait()));
}
child.kill().unwrap();
for thread in threads {
thread.join().unwrap().unwrap();
}
}
#[test]
fn test_waitid_after_exit_doesnt_hang() {
let mut child = true_cmd().spawn().unwrap();
sys::wait_noreap(sys::get_handle(&child)).unwrap();
sys::wait_noreap(sys::get_handle(&child)).unwrap();
child.wait().unwrap();
}
#[test]
fn test_into_inner_before_wait() {
let shared_child = SharedChild::spawn(&mut sleep_forever_cmd()).unwrap();
let mut child = shared_child.into_inner();
child.kill().unwrap();
child.wait().unwrap();
}
#[test]
fn test_into_inner_after_wait() {
let shared_child = SharedChild::spawn(&mut sleep_forever_cmd()).unwrap();
shared_child.kill().unwrap();
shared_child.wait().unwrap();
let mut child = shared_child.into_inner();
child.wait().unwrap();
}
#[test]
fn test_new() -> Result<(), Box<dyn Error>> {
let mut command = cat_cmd();
command.stdin(Stdio::piped());
command.stdout(Stdio::null());
let mut child = command.spawn()?;
let child_stdin = child.stdin.take().unwrap();
let mut shared_child = SharedChild::new(child).unwrap();
assert!(matches!(
shared_child.inner.lock().unwrap().state,
NotWaiting,
));
drop(child_stdin);
loop {
shared_child = SharedChild::new(shared_child.into_inner())?;
if let Exited(status) = shared_child.inner.lock().unwrap().state {
assert!(status.success());
return Ok(());
}
}
}
#[test]
fn test_takes() -> Result<(), Box<dyn Error>> {
let mut command = true_cmd();
command.stdin(Stdio::piped());
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
let shared_child = SharedChild::spawn(&mut command)?;
assert!(shared_child.take_stdin().is_some());
assert!(shared_child.take_stdout().is_some());
assert!(shared_child.take_stderr().is_some());
assert!(shared_child.take_stdin().is_none());
assert!(shared_child.take_stdout().is_none());
assert!(shared_child.take_stderr().is_none());
shared_child.wait()?;
Ok(())
}
#[test]
fn test_wait_try_wait_race() -> Result<(), Box<dyn Error>> {
let mut test_duration_secs: u64 = 1;
if let Ok(test_duration_secs_str) = std::env::var("SHARED_CHILD_RACE_TEST_SECONDS") {
dbg!(&test_duration_secs_str);
test_duration_secs = test_duration_secs_str.parse().expect("invalid u64");
}
let test_duration = Duration::from_secs(test_duration_secs);
let test_start = Instant::now();
let mut iterations = 1u64;
loop {
let child = SharedChild::spawn(&mut true_cmd())?;
let handle = sys::get_handle(&child.inner.lock().unwrap().child);
sys::wait_noreap(handle)?;
let barrier = std::sync::Barrier::new(2);
let try_wait_ret = std::thread::scope(|scope| {
scope.spawn(|| {
barrier.wait();
child.wait().unwrap();
});
scope
.spawn(|| {
barrier.wait();
child.try_wait().unwrap()
})
.join()
.unwrap()
});
let test_time_so_far = Instant::now().saturating_duration_since(test_start);
assert!(
try_wait_ret.is_some(),
"encountered the race condition after {test_time_so_far:?} ({iterations} iterations)",
);
iterations += 1;
if test_time_so_far >= test_duration {
return Ok(());
}
}
}
}