#[cfg(not(unix))]
compile_error!(
"This crate is Unix-only. Put it in [target.'cfg(unix)'.dependencies] in Cargo.toml."
);
use std::io::{self, ErrorKind, Read};
use std::os::raw::c_int;
use std::os::unix::io::AsRawFd;
use std::time::{Duration, Instant};
#[cfg(feature = "os_pipe")]
use os_pipe::{pipe, PipeReader};
#[cfg(not(feature = "os_pipe"))]
use std::io::{pipe, PipeReader};
#[cfg(test)]
type Result<T> = anyhow::Result<T>;
#[cfg(not(test))]
type Result<T> = io::Result<T>;
#[derive(Debug)]
pub struct Waiter {
reader: PipeReader,
sig_id: signal_hook::SigId,
}
impl Waiter {
pub fn new() -> Result<Self> {
let (reader, writer) = pipe()?;
set_nonblocking(&reader)?;
set_nonblocking(&writer)?;
let sig_id = signal_hook::low_level::pipe::register(libc::SIGCHLD, writer)?;
Ok(Self { reader, sig_id })
}
pub fn wait(&mut self) -> Result<()> {
let signaled = self.wait_inner(None)?;
debug_assert!(signaled, "timeout shouldn't be possible");
Ok(())
}
pub fn wait_timeout(&mut self, timeout: Duration) -> Result<bool> {
let deadline = Instant::now() + timeout;
self.wait_inner(Some(deadline))
}
pub fn wait_deadline(&mut self, deadline: Instant) -> Result<bool> {
self.wait_inner(Some(deadline))
}
fn wait_inner(&mut self, maybe_deadline: Option<Instant>) -> Result<bool> {
loop {
let mut buf = [0u8; 1024];
let mut signaled = false;
loop {
match self.reader.read(&mut buf) {
Ok(0) => unreachable!("this pipe should never close"),
Ok(_) => signaled = true,
Err(e) if e.kind() == ErrorKind::WouldBlock => break,
#[allow(clippy::useless_conversion)]
Err(e) => return Err(e.into()),
}
}
if signaled {
return Ok(true);
}
if let Some(deadline) = maybe_deadline {
if Instant::now() > deadline {
return Ok(false);
}
}
let mut poll_fd = libc::pollfd {
fd: self.reader.as_raw_fd(),
events: libc::POLLIN,
revents: 0,
};
let timeout_ms: c_int = if let Some(deadline) = maybe_deadline {
let timeout = deadline.saturating_duration_since(Instant::now());
(timeout.as_nanos().saturating_add(999_999) / 1_000_000)
.try_into()
.unwrap_or(c_int::MAX)
} else {
-1 };
let poll_return_code = unsafe {
libc::poll(
&mut poll_fd, 1, timeout_ms,
)
};
if poll_return_code < 0 {
let last_error = io::Error::last_os_error();
if last_error.kind() != ErrorKind::Interrupted {
#[allow(clippy::useless_conversion)]
return Err(last_error.into());
}
}
}
}
}
impl Drop for Waiter {
fn drop(&mut self) {
let existed = signal_hook::low_level::unregister(self.sig_id);
debug_assert!(existed, "should've existed");
}
}
fn set_nonblocking(fd: &impl AsRawFd) -> Result<()> {
unsafe {
let return_code = libc::fcntl(fd.as_raw_fd(), libc::F_SETFL, libc::O_NONBLOCK);
if return_code == -1 {
#[allow(clippy::useless_conversion)]
Err(io::Error::last_os_error().into())
} else {
Ok(())
}
}
}
#[cfg(test)]
mod test {
use super::*;
use duct::cmd;
use std::sync::{Arc, Mutex, MutexGuard};
use std::time::{Duration, Instant};
static ONE_TEST_AT_A_TIME: Mutex<()> = Mutex::new(());
fn lock_no_poison<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
match mutex.lock() {
Ok(guard) => guard,
Err(e) => e.into_inner(),
}
}
#[track_caller]
fn assert_approx_eq(dur1: Duration, dur2: Duration) {
const CLOSE_ENOUGH: f64 = 0.1; let lower_bound = 1.0 - CLOSE_ENOUGH;
let upper_bound = 1.0 + CLOSE_ENOUGH;
let ratio = dur1.as_secs_f64() / dur2.as_secs_f64();
assert!(
lower_bound < ratio && ratio < upper_bound,
"{dur1:?} and {dur2:?} are not close enough",
);
}
#[test]
fn test_wait() -> Result<()> {
let _test_guard = lock_no_poison(&ONE_TEST_AT_A_TIME); let start = Instant::now();
let mut waiter = Waiter::new()?;
cmd!("sleep", "0.25").start()?;
waiter.wait()?;
let dur = Instant::now() - start;
assert_approx_eq(Duration::from_millis(250), dur);
Ok(())
}
#[test]
fn test_wait_deadline() -> Result<()> {
let _test_guard = lock_no_poison(&ONE_TEST_AT_A_TIME); let start = Instant::now();
let timeout = Duration::from_millis(500);
let mut waiter = Waiter::new()?;
cmd!("sleep", "0.25").start()?;
let signaled = waiter.wait_deadline(Instant::now() + timeout)?;
let dur = Instant::now() - start;
assert_approx_eq(Duration::from_millis(250), dur);
assert!(signaled);
let mut waiter2 = Waiter::new()?;
let signaled2 = waiter2.wait_deadline(Instant::now() + timeout)?;
let dur2 = Instant::now() - start;
assert_approx_eq(Duration::from_millis(750), dur2);
assert!(!signaled2);
Ok(())
}
#[test]
fn test_wait_timeout() -> Result<()> {
let _test_guard = lock_no_poison(&ONE_TEST_AT_A_TIME); let start = Instant::now();
let timeout = Duration::from_millis(500);
let mut waiter = Waiter::new()?;
cmd!("sleep", "0.25").start()?;
let signaled = waiter.wait_timeout(timeout)?;
let dur = Instant::now() - start;
assert_approx_eq(Duration::from_millis(250), dur);
assert!(signaled);
let mut waiter2 = Waiter::new()?;
let signaled2 = waiter2.wait_timeout(timeout)?;
let dur2 = Instant::now() - start;
assert_approx_eq(Duration::from_millis(750), dur2);
assert!(!signaled2);
Ok(())
}
#[test]
fn test_wait_many_threads() -> Result<()> {
let _test_guard = lock_no_poison(&ONE_TEST_AT_A_TIME); let start = Instant::now();
let handle = Arc::new(cmd!("sleep", "1").start()?);
let mut wait_threads = Vec::new();
let mut short_timeout_threads = Vec::new();
let mut long_timeout_threads = Vec::new();
for _ in 0..3 {
let handle_clone = handle.clone();
let mut waiter = Waiter::new()?;
wait_threads.push(std::thread::spawn(move || -> Result<Duration> {
waiter.wait()?;
let dur = Instant::now() - start;
assert!(handle_clone.try_wait()?.is_some(), "should've exited");
Ok(dur)
}));
let handle_clone = handle.clone();
let mut waiter = Waiter::new()?;
short_timeout_threads.push(std::thread::spawn(move || -> Result<bool> {
let signaled = waiter.wait_timeout(Duration::from_millis(500))?;
assert!(handle_clone.try_wait()?.is_none(), "shouldn't have exited");
Ok(signaled)
}));
let handle_clone = handle.clone();
let mut waiter = Waiter::new()?;
long_timeout_threads.push(std::thread::spawn(move || -> Result<bool> {
let signaled = waiter.wait_timeout(Duration::from_millis(1500))?;
assert!(handle_clone.try_wait()?.is_some(), "should've exited");
Ok(signaled)
}));
}
for thread in wait_threads {
let dur = thread.join().unwrap()?;
assert_approx_eq(Duration::from_millis(1000), dur);
}
for thread in short_timeout_threads {
assert!(!thread.join().unwrap()?, "should not be signaled");
}
for thread in long_timeout_threads {
assert!(thread.join().unwrap()?, "should be signaled");
}
Ok(())
}
}