use std::{
pin::Pin,
sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll, Waker},
};
use crate::{Error, set_handler};
#[derive(Debug)]
pub struct AsyncCtrlC {
waker: Arc<Mutex<Option<Waker>>>,
active: Arc<AtomicBool>,
}
impl Future for AsyncCtrlC {
type Output = std::io::Result<()>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.active.load(Ordering::SeqCst) {
log::trace!("AsyncCtrlC: signal already activated, returning ready");
Poll::Ready(Ok(()))
} else {
{
let mut waker_guard = self.waker.lock().map_err(|e| std::io::Error::other(format!("acquire lock: {e}")))?;
*waker_guard = Some(cx.waker().clone());
}
if self.active.load(Ordering::SeqCst) {
log::trace!("AsyncCtrlC: signal activated while setting waker, returning ready");
Poll::Ready(Ok(()))
} else {
log::trace!("AsyncCtrlC: signal not activated, returning pending");
Poll::Pending
}
}
}
}
static INSTANCE_CREATED: AtomicBool = AtomicBool::new(false);
impl AsyncCtrlC {
pub fn new<F>(mut user_handler: F) -> std::io::Result<Self>
where
F: FnMut() -> bool + 'static + Send,
{
if INSTANCE_CREATED.load(Ordering::SeqCst) {
return Err(Error::MultipleHandlers.into());
}
INSTANCE_CREATED.store(true, Ordering::SeqCst);
let waker: Arc<Mutex<Option<Waker>>> = Arc::new(Mutex::new(None));
let active = Arc::new(AtomicBool::new(false));
let waker_clone = waker.clone();
let active_clone = active.clone();
set_handler(move || {
let handled = user_handler();
if handled {
log::trace!("AsyncCtrlC: user handler returned true, waking up waker");
active_clone.store(true, Ordering::SeqCst);
let mut woken = false;
if let Ok(mut waker_guard) = waker_clone.lock() {
if let Some(waker) = waker_guard.take() {
waker.wake();
log::trace!("AsyncCtrlC: waker has been woken up");
woken = true;
}
}
if !woken {
log::debug!("AsyncCtrlC: waker was not set, cannot wake up");
}
}
handled
})?;
Ok(AsyncCtrlC { waker, active })
}
}
impl Drop for AsyncCtrlC {
fn drop(&mut self) {
INSTANCE_CREATED.store(false, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
#[tokio::test]
async fn test_async_ctrlc() {
if cfg!(windows) && std::env::var("CI").is_ok() {
println!("Skipping test_async_ctrlc in CI environment on Windows");
return;
}
let cancel_token = tokio_util::sync::CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
let ctrlc_future = crate::AsyncCtrlC::new(move || {
println!("Ctrl+C received, cancelling...");
cancel_token_clone.cancel();
true
})
.unwrap();
let fire_signal = tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
#[cfg(unix)]
nix::sys::signal::kill(nix::unistd::Pid::this(), nix::sys::signal::Signal::SIGINT).unwrap();
#[cfg(windows)]
{
use windows_sys::Win32::System::Console::{CTRL_C_EVENT, GenerateConsoleCtrlEvent};
unsafe { GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0) };
}
println!("Ctrl+C signal sent");
});
let main_worker = tokio::spawn(async move {
println!("[Main worker] started, till the cancellation token is received...");
cancel_token.cancelled().await;
println!("[Main worker] cancelled, exiting...");
});
ctrlc_future.await.unwrap();
fire_signal.await.unwrap();
main_worker.await.unwrap();
println!("Test completed successfully.");
}
}