#[cfg(feature = "locktick")]
use locktick::parking_lot::{Mutex, RwLock};
#[cfg(not(feature = "locktick"))]
use parking_lot::{Mutex, RwLock};
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use tokio::sync::oneshot;
use tracing::{debug, error, trace};
pub trait Stoppable: Send + Sync {
fn stop(&self);
fn is_stopped(&self) -> bool;
}
pub struct SimpleStoppable {
state: AtomicBool,
}
impl SimpleStoppable {
pub fn new() -> Arc<Self> {
Arc::new(Self { state: AtomicBool::new(false) })
}
}
impl Stoppable for SimpleStoppable {
fn stop(&self) {
self.state.store(true, Ordering::SeqCst);
}
fn is_stopped(&self) -> bool {
self.state.load(Ordering::SeqCst)
}
}
pub struct SignalHandler {
stopped_sender: RwLock<Option<oneshot::Sender<()>>>,
stopped_receiver: Mutex<Option<oneshot::Receiver<()>>>,
}
impl SignalHandler {
pub fn new() -> Arc<Self> {
let (stopped_sender, stopped_receiver) = oneshot::channel();
let obj = Arc::new(Self {
stopped_sender: RwLock::new(Some(stopped_sender)),
stopped_receiver: Mutex::new(Some(stopped_receiver)),
});
{
let obj = obj.clone();
tokio::spawn(async move {
obj.handle_signals().await;
});
}
obj
}
async fn handle_signals(&self) {
#[cfg(target_family = "unix")]
let signal_listener = async move {
use tokio::signal::unix::{SignalKind, signal};
let mut s_int = signal(SignalKind::interrupt())?;
let mut s_term = signal(SignalKind::terminate())?;
let mut s_quit = signal(SignalKind::quit())?;
let mut s_hup = signal(SignalKind::hangup())?;
tokio::select!(
_ = s_int.recv() => trace!("Received SIGINT"),
_ = s_term.recv() => trace!("Received SIGTERM"),
_ = s_quit.recv() => trace!("Received SIGQUIT"),
_ = s_hup.recv() => trace!("Received SIGHUP"),
);
std::io::Result::<()>::Ok(())
};
#[cfg(not(target_family = "unix"))]
let signal_listener = async move {
tokio::signal::ctrl_c().await?;
std::io::Result::<()>::Ok(())
};
match signal_listener.await {
Ok(()) => debug!("Received signal, shutting down..."),
Err(error) => error!("tokio::signal encountered an error: {error}"),
}
self.stop();
}
pub async fn wait_for_signals(&self) {
let Some(receiver) = self.stopped_receiver.lock().take() else {
panic!("wait_for_signals must be called at most once");
};
if let Err(err) = receiver.await {
error!("wait_for_signals encountered an error: {err}");
}
}
}
impl Stoppable for SignalHandler {
fn stop(&self) {
if let Some(stopped_sender) = self.stopped_sender.write().take() {
let _ = stopped_sender.send(());
}
}
fn is_stopped(&self) -> bool {
self.stopped_sender.read().is_none()
}
}