use tokio::{process::Child, task::JoinSet};
use tokio_util::sync::CancellationToken;
pub struct InProcess {
pub task_set: CancelTaskSet,
}
impl InProcess {
pub fn new(task_set: CancelTaskSet) -> Self {
Self { task_set }
}
}
impl Drop for InProcess {
fn drop(&mut self) {
self.task_set.cancellation_token().cancel();
}
}
pub struct Subprocess {
pub child: Child,
}
impl Subprocess {
pub fn new(child: Child) -> Self {
Self { child }
}
}
impl Drop for Subprocess {
fn drop(&mut self) {
let _ = self.child.start_kill();
let _ = self.child.try_wait();
}
}
pub struct CancelTaskSet {
pub join_set: JoinSet<Result<(), std::io::Error>>,
cancellation_token: CancellationToken,
}
impl CancelTaskSet {
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
let cancellation_token = CancellationToken::new();
Self::from_cancel_token(false, cancellation_token)
}
pub fn new_with_signal_handler() -> Self {
let cancellation_token = CancellationToken::new();
Self::from_cancel_token(true, cancellation_token)
}
pub fn from_cancel_token(
register_signal_handler: bool,
cancellation_token: CancellationToken,
) -> Self {
let mut join_set = JoinSet::new();
if register_signal_handler {
Self::spawn_shutdown_handler(&mut join_set, cancellation_token.clone());
}
CancelTaskSet {
join_set,
cancellation_token,
}
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
fn spawn_shutdown_handler(
join_set: &mut JoinSet<Result<(), std::io::Error>>,
cancellation_token: CancellationToken,
) {
join_set.spawn(async move {
#[cfg(target_family = "unix")]
{
use tokio::signal::unix::{SignalKind, signal};
let mut sigint =
signal(SignalKind::interrupt()).expect("failed to register SIGINT handler");
let mut sigterm =
signal(SignalKind::terminate()).expect("failed to register SIGTERM handler");
tokio::select! {
_ = sigint.recv() => {
tracing::debug!("Received SIGINT, cancelling token");
cancellation_token.cancel();
},
_ = sigterm.recv() => {
tracing::debug!("Received SIGTERM, cancelling token");
cancellation_token.cancel();
},
_ = cancellation_token.cancelled() => {
tracing::debug!("Cancellation token cancelled, exiting shutdown handler");
},
}
}
#[cfg(target_family = "windows")]
{
use tokio::signal::windows;
let mut ctrl_c = windows::ctrl_c().expect("failed to register CTRL-C handler");
let mut ctrl_break =
windows::ctrl_break().expect("failed to register CTRL-BREAK handler");
tokio::select! {
_ = ctrl_c.recv() => {
tracing::debug!("Received CTRL-C, cancelling token");
cancellation_token.cancel();
},
_ = ctrl_break.recv() => {
tracing::debug!("Received CTRL-BREAK, cancelling token");
cancellation_token.cancel();
},
_ = cancellation_token.cancelled() => {
tracing::debug!("Cancellation token cancelled, exiting shutdown handler");
},
}
}
Ok(())
});
}
pub fn spawn_cancellable_task<Fut>(&mut self, task: Fut)
where
Fut: Future<Output = Result<(), std::io::Error>> + Send + 'static,
{
let token = self.cancellation_token();
self.join_set.spawn(async move {
match token.run_until_cancelled(task).await {
Some(Ok(_)) => Ok(()), Some(Err(e)) => Err(e), None => Ok(()), }
});
}
pub async fn join_all(&mut self) {
while let Some(result) = self.join_set.join_next().await {
match result {
Ok(Ok(())) => {} Ok(Err(e)) => {
tracing::error!(error=%e, "Task failed");
self.cancellation_token.cancel();
}
Err(e) => {
tracing::error!(error=%e, "Task join failed");
self.cancellation_token.cancel();
}
}
}
}
}
impl Drop for CancelTaskSet {
fn drop(&mut self) {
self.cancellation_token.cancel();
self.join_set.abort_all();
}
}