use super::Disposition;
use super::Errno;
use super::Result;
use super::SigmaskOp;
use super::signal;
#[cfg(doc)]
use super::{Concurrent, SharedSystem};
use crate::io::Fd;
use crate::system::{CaughtSignals, Clock, Sigaction, Sigmask, Signals};
use std::cell::RefCell;
use std::cmp::Ordering;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::collections::binary_heap::PeekMut;
use std::ffi::c_int;
use std::future::Future;
use std::ops::Deref;
use std::ops::DerefMut;
use std::rc::Rc;
use std::rc::Weak;
use std::task::Waker;
use std::time::Duration;
use std::time::Instant;
pub trait Select: Signals {
fn select<'a>(
&self,
readers: &'a mut Vec<Fd>,
writers: &'a mut Vec<Fd>,
timeout: Option<Duration>,
signal_mask: Option<&[signal::Number]>,
) -> impl Future<Output = Result<c_int>> + use<'a, Self>;
}
impl<S: Select> Select for Rc<S> {
#[inline]
fn select<'a>(
&self,
readers: &'a mut Vec<Fd>,
writers: &'a mut Vec<Fd>,
timeout: Option<Duration>,
signal_mask: Option<&[signal::Number]>,
) -> impl Future<Output = Result<c_int>> + use<'a, S> {
(self as &S).select(readers, writers, timeout, signal_mask)
}
}
#[derive(Debug)]
pub struct SelectSystem<S> {
system: S,
io: AsyncIo,
time: AsyncTime,
signal: AsyncSignal,
wait_mask: Option<Vec<signal::Number>>,
}
impl<S> Deref for SelectSystem<S> {
type Target = S;
fn deref(&self) -> &S {
&self.system
}
}
impl<S> DerefMut for SelectSystem<S> {
fn deref_mut(&mut self) -> &mut S {
&mut self.system
}
}
impl<S> SelectSystem<S> {
pub fn new(system: S) -> Self {
SelectSystem {
system,
io: AsyncIo::new(),
time: AsyncTime::new(),
signal: AsyncSignal::new(),
wait_mask: None,
}
}
async fn sigmask_async(
this: &RefCell<SelectSystem<S>>,
op: SigmaskOp,
signal: signal::Number,
) -> Result<()>
where
S: Sigmask,
{
let is_first = this.borrow().wait_mask.is_none();
if is_first {
let mut mask = Vec::new();
let future = this
.borrow_mut()
.system
.sigmask(Some((op, &[signal])), Some(&mut mask));
future.await?;
mask.retain(|&s| s != signal);
this.borrow_mut().wait_mask = Some(mask);
} else {
let future = this
.borrow_mut()
.system
.sigmask(Some((op, &[signal])), None);
future.await?;
let mut borrow = this.borrow_mut();
borrow.wait_mask.as_mut().unwrap().retain(|&s| s != signal);
}
Ok(())
}
#[inline]
pub fn get_disposition(&self, signal: signal::Number) -> Result<Disposition>
where
S: Sigaction,
{
self.system.get_sigaction(signal)
}
pub async fn set_disposition(
this: &RefCell<SelectSystem<S>>,
signal: signal::Number,
handling: Disposition,
) -> Result<Disposition>
where
S: Sigaction + Sigmask,
{
match handling {
Disposition::Default | Disposition::Ignore => {
let old_handling = this.borrow_mut().system.sigaction(signal, handling)?;
Self::sigmask_async(this, SigmaskOp::Remove, signal).await?;
Ok(old_handling)
}
Disposition::Catch => {
Self::sigmask_async(this, SigmaskOp::Add, signal).await?;
this.borrow_mut().system.sigaction(signal, handling)
}
}
}
pub fn add_reader(&mut self, fd: Fd, waker: Weak<RefCell<Option<Waker>>>) {
self.io.wait_for_reading(fd, waker)
}
pub fn add_writer(&mut self, fd: Fd, waker: Weak<RefCell<Option<Waker>>>) {
self.io.wait_for_writing(fd, waker)
}
pub fn add_timeout(&mut self, target: Instant, waker: Weak<RefCell<Option<Waker>>>) {
self.time.push(Timeout { target, waker })
}
pub fn add_signal_waker(&mut self) -> Rc<RefCell<SignalStatus>> {
self.signal.wait_for_signals()
}
fn wake_timeouts(&mut self)
where
S: Clock,
{
if !self.time.is_empty() {
let now = self.now();
self.time.wake_if_passed(now);
}
self.time.gc();
}
fn wake_on_signals(&mut self)
where
S: CaughtSignals,
{
let signals = self.system.caught_signals();
if signals.is_empty() {
self.signal.gc()
} else {
self.signal.wake(signals)
}
}
#[allow(clippy::await_holding_refcell_ref)] pub async fn select(this: &RefCell<SelectSystem<S>>, poll: bool) -> Result<()>
where
S: Select + CaughtSignals + Clock,
{
let me = this.borrow();
let mut readers = me.io.readers();
let mut writers = me.io.writers();
let timeout = if poll {
Some(Duration::ZERO)
} else {
me.time
.first_target()
.map(|instant| instant.saturating_duration_since(me.now()))
};
let future = me
.system
.select(&mut readers, &mut writers, timeout, me.wait_mask.as_deref());
drop(me);
let inner_result = future.await;
let mut me = this.borrow_mut();
let final_result = match inner_result {
Ok(_) => {
me.io.wake(&readers, &writers);
Ok(())
}
Err(Errno::EBADF) => {
me.io.wake_all();
Err(Errno::EBADF)
}
Err(Errno::EINTR) => Ok(()),
Err(error) => Err(error),
};
me.io.gc();
me.wake_timeouts();
me.wake_on_signals();
final_result
}
}
#[derive(Clone, Debug, Default)]
struct AsyncIo {
readers: Vec<FdAwaiter>,
writers: Vec<FdAwaiter>,
}
#[derive(Clone, Debug)]
struct FdAwaiter {
fd: Fd,
waker: Weak<RefCell<Option<Waker>>>,
}
impl Drop for FdAwaiter {
fn drop(&mut self) {
if let Some(waker) = self.waker.upgrade() {
if let Some(waker) = waker.borrow_mut().take() {
waker.wake();
}
}
}
}
impl AsyncIo {
pub fn new() -> Self {
Self::default()
}
pub fn readers(&self) -> Vec<Fd> {
self.readers.iter().map(|awaiter| awaiter.fd).collect()
}
pub fn writers(&self) -> Vec<Fd> {
self.writers.iter().map(|awaiter| awaiter.fd).collect()
}
pub fn wait_for_reading(&mut self, fd: Fd, waker: Weak<RefCell<Option<Waker>>>) {
self.readers.push(FdAwaiter { fd, waker });
}
pub fn wait_for_writing(&mut self, fd: Fd, waker: Weak<RefCell<Option<Waker>>>) {
self.writers.push(FdAwaiter { fd, waker });
}
pub fn wake(&mut self, readers: &[Fd], writers: &[Fd]) {
self.readers
.retain(|awaiter| !readers.contains(&awaiter.fd));
self.writers
.retain(|awaiter| !writers.contains(&awaiter.fd));
}
pub fn wake_all(&mut self) {
self.readers.clear();
self.writers.clear();
}
pub fn gc(&mut self) {
let is_alive = |awaiter: &FdAwaiter| awaiter.waker.strong_count() > 0;
self.readers.retain(is_alive);
self.writers.retain(is_alive);
}
}
#[derive(Clone, Debug, Default)]
struct AsyncTime {
timeouts: BinaryHeap<Reverse<Timeout>>,
}
#[derive(Clone, Debug)]
struct Timeout {
target: Instant,
waker: Weak<RefCell<Option<Waker>>>,
}
impl PartialEq for Timeout {
fn eq(&self, rhs: &Self) -> bool {
self.target == rhs.target
}
}
impl Eq for Timeout {}
impl PartialOrd for Timeout {
fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
Some(self.cmp(rhs))
}
}
impl Ord for Timeout {
fn cmp(&self, rhs: &Self) -> Ordering {
self.target.cmp(&rhs.target)
}
}
impl Drop for Timeout {
fn drop(&mut self) {
if let Some(waker) = self.waker.upgrade() {
if let Some(waker) = waker.borrow_mut().take() {
waker.wake();
}
}
}
}
impl AsyncTime {
#[must_use]
fn new() -> Self {
Self::default()
}
#[must_use]
fn is_empty(&self) -> bool {
self.timeouts.is_empty()
}
fn push(&mut self, timeout: Timeout) {
self.timeouts.push(Reverse(timeout))
}
#[must_use]
fn first_target(&self) -> Option<Instant> {
self.timeouts.peek().map(|timeout| timeout.0.target)
}
fn wake_if_passed(&mut self, now: Instant) {
while let Some(timeout) = self.timeouts.peek_mut() {
if !timeout.0.passed(now) {
break;
}
PeekMut::pop(timeout);
}
}
fn gc(&mut self) {
self.timeouts.retain(|t| t.0.waker.strong_count() > 0);
}
}
impl Timeout {
fn passed(&self, now: Instant) -> bool {
self.target <= now
}
}
#[derive(Clone, Debug)]
enum AsyncSignal {
Awaiting(Vec<Weak<RefCell<SignalStatus>>>),
Caught(Vec<signal::Number>),
}
#[derive(Clone, Debug)]
pub enum SignalStatus {
Expected(Option<Waker>),
Caught(Rc<[signal::Number]>),
}
impl AsyncSignal {
pub fn new() -> Self {
Self::Awaiting(Vec::new())
}
pub fn gc(&mut self) {
match self {
Self::Awaiting(awaiters) => awaiters.retain(|awaiter| awaiter.strong_count() > 0),
Self::Caught(_) => {}
}
}
pub fn wait_for_signals(&mut self) -> Rc<RefCell<SignalStatus>> {
match std::mem::replace(self, AsyncSignal::Awaiting(Vec::new())) {
AsyncSignal::Awaiting(mut awaiters) => {
let status = Rc::new(RefCell::new(SignalStatus::Expected(None)));
awaiters.push(Rc::downgrade(&status));
*self = AsyncSignal::Awaiting(awaiters);
status
}
AsyncSignal::Caught(signals) => {
debug_assert!(!signals.is_empty());
Rc::new(RefCell::new(SignalStatus::Caught(signals.into())))
}
}
}
pub fn wake(&mut self, signals: Vec<signal::Number>) {
if signals.is_empty() {
return;
}
match self {
AsyncSignal::Caught(accumulated_signals) => accumulated_signals.extend(signals),
AsyncSignal::Awaiting(awaiters) => {
enum Woke {
None(Vec<signal::Number>),
Some(Rc<[signal::Number]>),
}
let mut woke = Woke::None(signals);
for status in awaiters.drain(..) {
let Some(status) = status.upgrade() else {
continue;
};
let signals = match woke {
Woke::None(signals) => Rc::from(signals),
Woke::Some(signals) => signals,
};
woke = Woke::Some(Rc::clone(&signals));
let mut status_ref = status.borrow_mut();
let new_status = SignalStatus::Caught(signals);
let old_status = std::mem::replace(&mut *status_ref, new_status);
drop(status_ref);
if let SignalStatus::Expected(Some(waker)) = old_status {
waker.wake();
}
}
if let Woke::None(signals) = woke {
*self = AsyncSignal::Caught(signals);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::super::r#virtual::{SIGCHLD, SIGINT, SIGUSR1};
use super::*;
use crate::test_helper::WakeFlag;
use assert_matches::assert_matches;
use std::sync::Arc;
use std::sync::atomic::Ordering;
#[test]
fn async_io_has_no_default_readers_or_writers() {
let async_io = AsyncIo::new();
assert_eq!(async_io.readers(), []);
assert_eq!(async_io.writers(), []);
}
#[test]
fn async_io_non_empty_readers_and_writers() {
let mut async_io = AsyncIo::new();
let waker = Rc::new(RefCell::new(Some(Waker::noop().clone())));
async_io.wait_for_reading(Fd::STDIN, Rc::downgrade(&waker));
async_io.wait_for_writing(Fd::STDOUT, Rc::downgrade(&waker));
async_io.wait_for_writing(Fd::STDERR, Rc::downgrade(&waker));
assert_eq!(async_io.readers(), [Fd::STDIN]);
let mut writers = async_io.writers();
writers.sort();
assert_eq!(writers, [Fd::STDOUT, Fd::STDERR]);
}
#[test]
fn async_io_wake() {
let mut async_io = AsyncIo::new();
let waker = Rc::new(RefCell::new(Some(Waker::noop().clone())));
async_io.wait_for_reading(Fd(3), Rc::downgrade(&waker));
async_io.wait_for_reading(Fd(4), Rc::downgrade(&waker));
async_io.wait_for_writing(Fd(4), Rc::downgrade(&waker));
async_io.wake(&[Fd(4)], &[Fd(4)]);
assert_eq!(async_io.readers(), [Fd(3)]);
assert_eq!(async_io.writers(), []);
}
#[test]
fn async_io_wake_all() {
let mut async_io = AsyncIo::new();
let waker = Rc::new(RefCell::new(Some(Waker::noop().clone())));
async_io.wait_for_reading(Fd::STDIN, Rc::downgrade(&waker));
async_io.wait_for_writing(Fd::STDOUT, Rc::downgrade(&waker));
async_io.wait_for_writing(Fd::STDERR, Rc::downgrade(&waker));
async_io.wake_all();
assert_eq!(async_io.readers(), []);
assert_eq!(async_io.writers(), []);
}
#[test]
fn async_time_first_target() {
let mut async_time = AsyncTime::new();
let now = Instant::now();
assert_eq!(async_time.first_target(), None);
async_time.push(Timeout {
target: now + Duration::from_secs(2),
waker: Weak::default(),
});
async_time.push(Timeout {
target: now + Duration::from_secs(1),
waker: Weak::default(),
});
async_time.push(Timeout {
target: now + Duration::from_secs(3),
waker: Weak::default(),
});
assert_eq!(
async_time.first_target(),
Some(now + Duration::from_secs(1))
);
}
#[test]
fn async_time_wake_if_passed() {
let mut async_time = AsyncTime::new();
let now = Instant::now();
let waker = Rc::new(RefCell::new(Some(Waker::noop().clone())));
async_time.push(Timeout {
target: now,
waker: Rc::downgrade(&waker),
});
async_time.push(Timeout {
target: now + Duration::new(1, 0),
waker: Rc::downgrade(&waker),
});
async_time.push(Timeout {
target: now + Duration::new(1, 1),
waker: Rc::downgrade(&waker),
});
async_time.push(Timeout {
target: now + Duration::new(2, 0),
waker: Rc::downgrade(&waker),
});
assert_eq!(async_time.timeouts.len(), 4);
async_time.wake_if_passed(now + Duration::new(1, 0));
assert_eq!(
async_time.timeouts.pop().unwrap().0.target,
now + Duration::new(1, 1)
);
assert_eq!(
async_time.timeouts.pop().unwrap().0.target,
now + Duration::new(2, 0)
);
assert!(async_time.timeouts.is_empty(), "{:?}", async_time.timeouts);
}
#[test]
fn async_signal_wait_and_wake() {
let mut async_signal = AsyncSignal::new();
let status_1 = async_signal.wait_for_signals();
let status_2 = async_signal.wait_for_signals();
let wake_flag_1 = Arc::new(WakeFlag::new());
let wake_flag_2 = Arc::new(WakeFlag::new());
assert_matches!(&mut *status_1.borrow_mut(), SignalStatus::Expected(waker) => {
assert!(waker.is_none());
*waker = Some(wake_flag_1.clone().into());
});
assert_matches!(&mut *status_2.borrow_mut(), SignalStatus::Expected(waker) => {
assert!(waker.is_none());
*waker = Some(wake_flag_2.clone().into());
});
async_signal.wake(vec![SIGCHLD, SIGUSR1]);
assert!(wake_flag_1.0.load(Ordering::Relaxed));
assert!(wake_flag_2.0.load(Ordering::Relaxed));
assert_matches!(&*status_1.borrow(), SignalStatus::Caught(signals) => {
assert_eq!(**signals, [SIGCHLD, SIGUSR1]);
});
assert_matches!(&*status_2.borrow(), SignalStatus::Caught(signals) => {
assert_eq!(**signals, [SIGCHLD, SIGUSR1]);
});
}
#[test]
fn async_signal_wake_and_wait() {
let mut async_signal = AsyncSignal::new();
async_signal.wake(vec![SIGINT, SIGCHLD]);
let status = async_signal.wait_for_signals();
assert_matches!(&*status.borrow(), SignalStatus::Caught(signals) => {
assert_eq!(**signals, [SIGINT, SIGCHLD]);
});
}
#[test]
fn async_signal_wake_twice_and_wait() {
let mut async_signal = AsyncSignal::new();
async_signal.wake(vec![SIGINT]);
async_signal.wake(vec![SIGUSR1]);
let status = async_signal.wait_for_signals();
assert_matches!(&*status.borrow(), SignalStatus::Caught(signals) => {
assert_eq!(**signals, [SIGINT, SIGUSR1]);
});
}
#[test]
fn async_signal_empty_wake() {
let mut async_signal = AsyncSignal::new();
let status = async_signal.wait_for_signals();
let wake_flag = Arc::new(WakeFlag::new());
assert_matches!(&mut *status.borrow_mut(), SignalStatus::Expected(waker) => {
assert!(waker.is_none());
*waker = Some(wake_flag.clone().into());
});
async_signal.wake(vec![]);
assert!(!wake_flag.is_woken());
assert_matches!(&*status.borrow(), SignalStatus::Expected(Some(waker)) => {
waker.wake_by_ref();
});
assert!(wake_flag.is_woken());
}
#[test]
fn async_signal_phantom_wake() {
let mut async_signal = AsyncSignal::new();
let status_1 = async_signal.wait_for_signals();
drop(status_1);
async_signal.wake(vec![SIGINT]);
let status_2 = async_signal.wait_for_signals();
assert_matches!(&*status_2.borrow(), SignalStatus::Caught(signals) => {
assert_eq!(**signals, [SIGINT]);
});
}
}