#![cfg(unix)]
#![cfg_attr(docsrs, doc(cfg(all(unix, feature = "signal"))))]
use crate::runtime::scheduler;
use crate::runtime::signal::Handle;
use crate::signal::registry::{globals, EventId, EventInfo, Globals, Storage};
use crate::signal::RxFuture;
use crate::sync::watch;
use mio::net::UnixStream;
use std::io::{self, Error, ErrorKind, Write};
use std::sync::OnceLock;
use std::task::{Context, Poll};
#[cfg(not(any(target_os = "linux", target_os = "illumos")))]
pub(crate) struct OsStorage([SignalInfo; 33]);
#[cfg(any(target_os = "linux", target_os = "illumos"))]
pub(crate) struct OsStorage(Box<[SignalInfo]>);
impl OsStorage {
fn get(&self, id: EventId) -> Option<&SignalInfo> {
self.0.get(id - 1)
}
}
impl Default for OsStorage {
fn default() -> Self {
#[cfg(not(any(target_os = "linux", target_os = "illumos")))]
let inner = std::array::from_fn(|_| SignalInfo::default());
#[cfg(any(target_os = "linux", target_os = "illumos"))]
let inner = std::iter::repeat_with(SignalInfo::default)
.take(libc::SIGRTMAX() as usize)
.collect();
Self(inner)
}
}
impl Storage for OsStorage {
fn event_info(&self, id: EventId) -> Option<&EventInfo> {
self.get(id).map(|si| &si.event_info)
}
fn for_each<'a, F>(&'a self, f: F)
where
F: FnMut(&'a EventInfo),
{
self.0.iter().map(|si| &si.event_info).for_each(f);
}
}
#[derive(Debug)]
pub(crate) struct OsExtraData {
sender: UnixStream,
pub(crate) receiver: UnixStream,
}
impl Default for OsExtraData {
fn default() -> Self {
let (receiver, sender) = UnixStream::pair().expect("failed to create UnixStream");
Self { sender, receiver }
}
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct SignalKind(libc::c_int);
impl SignalKind {
pub const fn from_raw(signum: std::os::raw::c_int) -> Self {
Self(signum as libc::c_int)
}
pub const fn as_raw_value(&self) -> std::os::raw::c_int {
self.0
}
pub const fn alarm() -> Self {
Self(libc::SIGALRM)
}
pub const fn child() -> Self {
Self(libc::SIGCHLD)
}
pub const fn hangup() -> Self {
Self(libc::SIGHUP)
}
#[cfg(any(
target_os = "dragonfly",
target_os = "freebsd",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "illumos"
))]
pub const fn info() -> Self {
Self(libc::SIGINFO)
}
pub const fn interrupt() -> Self {
Self(libc::SIGINT)
}
#[cfg(target_os = "haiku")]
pub const fn io() -> Self {
Self(libc::SIGPOLL)
}
#[cfg(not(target_os = "haiku"))]
pub const fn io() -> Self {
Self(libc::SIGIO)
}
pub const fn pipe() -> Self {
Self(libc::SIGPIPE)
}
pub const fn quit() -> Self {
Self(libc::SIGQUIT)
}
pub const fn terminate() -> Self {
Self(libc::SIGTERM)
}
pub const fn user_defined1() -> Self {
Self(libc::SIGUSR1)
}
pub const fn user_defined2() -> Self {
Self(libc::SIGUSR2)
}
pub const fn window_change() -> Self {
Self(libc::SIGWINCH)
}
}
impl From<std::os::raw::c_int> for SignalKind {
fn from(signum: std::os::raw::c_int) -> Self {
Self::from_raw(signum as libc::c_int)
}
}
impl From<SignalKind> for std::os::raw::c_int {
fn from(kind: SignalKind) -> Self {
kind.as_raw_value()
}
}
#[derive(Default)]
pub(crate) struct SignalInfo {
event_info: EventInfo,
init: OnceLock<Result<(), Option<i32>>>,
}
fn action(globals: &'static Globals, signal: libc::c_int) {
globals.record_event(signal as EventId);
let mut sender = &globals.sender;
drop(sender.write(&[1]));
}
fn signal_enable(signal: SignalKind, handle: &Handle) -> io::Result<()> {
let signal = signal.0;
if signal <= 0 || signal_hook_registry::FORBIDDEN.contains(&signal) {
return Err(Error::new(
ErrorKind::Other,
format!("Refusing to register signal {signal}"),
));
}
handle.check_inner()?;
let globals = globals();
let siginfo = match globals.storage().get(signal as EventId) {
Some(slot) => slot,
None => return Err(io::Error::new(io::ErrorKind::Other, "signal too large")),
};
siginfo
.init
.get_or_init(|| {
unsafe { signal_hook_registry::register(signal, move || action(globals, signal)) }
.map(|_| ())
.map_err(|e| e.raw_os_error())
})
.map_err(|e| {
e.map_or_else(
|| Error::new(ErrorKind::Other, "registering signal handler failed"),
Error::from_raw_os_error,
)
})
}
#[must_use = "streams do nothing unless polled"]
#[derive(Debug)]
pub struct Signal {
inner: RxFuture,
}
#[track_caller]
pub fn signal(kind: SignalKind) -> io::Result<Signal> {
let handle = scheduler::Handle::current();
let rx = signal_with_handle(kind, handle.driver().signal())?;
Ok(Signal {
inner: RxFuture::new(rx),
})
}
pub(crate) fn signal_with_handle(
kind: SignalKind,
handle: &Handle,
) -> io::Result<watch::Receiver<()>> {
signal_enable(kind, handle)?;
Ok(globals().register_listener(kind.0 as EventId))
}
impl Signal {
pub async fn recv(&mut self) -> Option<()> {
self.inner.recv().await;
Some(())
}
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
self.inner.poll_recv(cx).map(Some)
}
}
#[cfg(feature = "process")]
pub(crate) trait InternalStream {
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>>;
}
#[cfg(feature = "process")]
impl InternalStream for Signal {
fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
self.poll_recv(cx)
}
}
pub(crate) fn ctrl_c() -> io::Result<Signal> {
signal(SignalKind::interrupt())
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
#[test]
fn signal_enable_error_on_invalid_input() {
let inputs = [-1, 0];
for input in inputs {
assert_eq!(
signal_enable(SignalKind::from_raw(input), &Handle::default())
.unwrap_err()
.kind(),
ErrorKind::Other,
);
}
}
#[test]
fn signal_enable_error_on_forbidden_input() {
let inputs = signal_hook_registry::FORBIDDEN;
for &input in inputs {
assert_eq!(
signal_enable(SignalKind::from_raw(input), &Handle::default())
.unwrap_err()
.kind(),
ErrorKind::Other,
);
}
}
#[test]
fn from_c_int() {
assert_eq!(SignalKind::from(2), SignalKind::interrupt());
}
#[test]
fn into_c_int() {
let value: std::os::raw::c_int = SignalKind::interrupt().into();
assert_eq!(value, libc::SIGINT as _);
}
}