#![allow(rustdoc::private_intra_doc_links)]
use super::TaskId;
use crate::{
future::IntoFutureWithArgs,
sync::once::{OnceTrigger, once_event},
};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use std::{
marker::PhantomData,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use tokio::{
sync::watch::{Receiver, channel},
task::{JoinError, JoinHandle},
};
#[cfg(feature = "signal")]
use tokio::signal::ctrl_c;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum GracefulKind {
CtrlC,
Explicit,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum FinishMode {
Complete,
Shutdown(GracefulKind),
}
#[derive(Debug)]
pub struct TaskOutput<T> {
pub finish_mode: FinishMode,
pub join_result: Result<T, JoinError>,
}
#[derive(Debug, Clone)]
pub struct ShutdownTrigger(Arc<Mutex<Option<OnceTrigger>>>);
impl ShutdownTrigger {
pub fn trigger(&self) -> bool {
match self.0.try_lock() {
Err(_) => false,
Ok(mut guard) => match guard.take() {
None => false,
Some(trigger) => trigger.trigger(),
},
}
}
}
#[derive(Debug, Clone)]
pub struct ShutdownReceiver(RecvInner);
#[derive(Debug, Clone)]
enum RecvInner {
Pending(Receiver<Option<GracefulKind>>),
Shutdown(GracefulKind),
}
impl ShutdownReceiver {
pub async fn recv(&mut self) -> GracefulKind {
match &mut self.0 {
RecvInner::Pending(receiver) => {
let init = *receiver.borrow_and_update();
let kind = match init {
Some(kind) => kind,
None => {
receiver.changed().await.ok();
receiver.borrow_and_update().unwrap()
}
};
self.0 = RecvInner::Shutdown(kind);
kind
}
RecvInner::Shutdown(kind) => *kind,
}
}
}
impl IntoFuture for ShutdownReceiver {
type Output = GracefulKind;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send>>;
fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move { self.recv().await })
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct GracefulTaskBuilder<T> {
ctrlc_shutdown: bool,
_phantom: PhantomData<T>,
}
impl<T> Default for GracefulTaskBuilder<T> {
fn default() -> Self {
Self {
ctrlc_shutdown: false,
_phantom: PhantomData,
}
}
}
impl<T> GracefulTaskBuilder<T> {
#[cfg(feature = "signal")]
pub fn ctrlc_shutdown(self) -> Self {
Self {
ctrlc_shutdown: true,
..self
}
}
pub fn spawn<I, F>(self, ifwa: I) -> GracefulTask<T>
where
I: IntoFutureWithArgs<ShutdownReceiver, F>,
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.spawn_ctrlc_mocked(ifwa, async move {
#[cfg(feature = "signal")]
ctrl_c().await.ok();
})
}
fn spawn_ctrlc_mocked<I, F, C>(self, ifwa: I, ctrlc: C) -> GracefulTask<T>
where
I: IntoFutureWithArgs<ShutdownReceiver, F>,
F: Future<Output = T> + Send + 'static,
C: Future<Output = ()> + Send + 'static,
T: Send + 'static,
{
let ctrlc_shutdown = self.ctrlc_shutdown;
let ctrlc = if ctrlc_shutdown { Some(ctrlc) } else { None };
let (sender, recver) = channel(None);
let (trigger, waiter) = once_event();
let trigger = ShutdownTrigger(Arc::new(Mutex::new(Some(trigger))));
let mut inner_task =
tokio::spawn(ifwa.into_future_with_args(ShutdownReceiver(RecvInner::Pending(recver))));
let inner = inner_task.id().into();
let graceful = trigger.clone();
let task = tokio::spawn(async move {
let ctrlc = async move { ctrlc.unwrap().await };
let (finish_mode, join_result) = tokio::select! {
_ = ctrlc, if ctrlc_shutdown => {
trigger.trigger();
let kind = GracefulKind::CtrlC;
sender.send(Some(kind)).ok();
(FinishMode::Shutdown(kind), inner_task.await)
},
_ = waiter => {
let kind = GracefulKind::Explicit;
sender.send(Some(kind)).ok();
(FinishMode::Shutdown(kind), inner_task.await)
},
join_result = &mut inner_task => (FinishMode::Complete, join_result),
};
TaskOutput {
finish_mode,
join_result,
}
});
let outer = task.id().into();
GracefulTask {
inner,
outer,
graceful,
task,
}
}
}
#[derive(Debug)]
pub struct GracefulTask<T> {
inner: TaskId,
outer: TaskId,
graceful: ShutdownTrigger,
task: JoinHandle<TaskOutput<T>>,
}
impl<T> GracefulTask<T> {
pub fn builder_default() -> GracefulTaskBuilder<T> {
GracefulTaskBuilder::default()
}
pub fn ids(&self) -> (TaskId, TaskId) {
(self.outer, self.inner)
}
pub fn trigger_graceful_shutdown(&self) -> bool {
self.graceful.trigger()
}
pub async fn graceful_shutdown(self) -> TaskOutput<T> {
self.trigger_graceful_shutdown();
self.await
}
pub fn shutdown_handle(&self) -> ShutdownTrigger {
self.graceful.clone()
}
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
}
impl<T> Future for GracefulTask<T> {
type Output = TaskOutput<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.task).poll(cx).map(Result::unwrap)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
fn ctrlc_mocked() -> (OnceTrigger, impl Future<Output = ()> + Send + 'static) {
let (trigger, waiter) = once_event();
(trigger, async move {
waiter.await;
})
}
async fn sleep() {
tokio::time::sleep(Duration::from_millis(100)).await;
}
async fn sleep_double() {
tokio::time::sleep(Duration::from_millis(200)).await;
}
#[tokio::test(flavor = "multi_thread")]
async fn graceful_shutdown() {
let task_output = GracefulTask::builder_default()
.spawn(async |shutdown| shutdown.await)
.graceful_shutdown()
.await;
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
}
#[tokio::test(flavor = "multi_thread")]
async fn trigger_graceful_shutdown() {
let mut graceful_task =
GracefulTask::builder_default().spawn(async |shutdown| shutdown.await);
assert!(!graceful_task.is_finished());
assert!(graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let mut graceful_task =
GracefulTask::builder_default().spawn(async |shutdown| shutdown.await);
assert!(!graceful_task.is_finished());
assert!(graceful_task.trigger_graceful_shutdown());
assert!(!graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
#[cfg(feature = "signal")]
{
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.ctrlc_shutdown()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(graceful_task.trigger_graceful_shutdown());
sleep().await;
assert!(!trigger.trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.ctrlc_shutdown()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(trigger.trigger());
sleep().await;
assert!(!graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::CtrlC)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::CtrlC);
}
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(!trigger.trigger());
sleep().await;
assert!(graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let mut graceful_task = GracefulTask::builder_default().spawn(async |_| 42);
sleep().await;
assert!(!graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(task_output.finish_mode, FinishMode::Complete);
assert_eq!(task_output.join_result.unwrap(), 42);
let mut graceful_task =
GracefulTask::builder_default().spawn(async |_| sleep_double().await);
assert!(!graceful_task.is_finished());
sleep().await;
assert!(graceful_task.trigger_graceful_shutdown());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), ());
}
#[tokio::test(flavor = "multi_thread")]
async fn shutdown_trigger() {
let mut graceful_task =
GracefulTask::builder_default().spawn(async |shutdown| shutdown.await);
assert!(!graceful_task.is_finished());
assert!(graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let mut graceful_task =
GracefulTask::builder_default().spawn(async |shutdown| shutdown.await);
assert!(!graceful_task.is_finished());
let trigger1 = graceful_task.shutdown_handle();
let trigger2 = graceful_task.shutdown_handle();
let trigger1 = tokio::spawn(async move { trigger1.trigger() });
let trigger2 = tokio::spawn(async move { trigger2.trigger() });
assert!(trigger1.await.unwrap() ^ trigger2.await.unwrap());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let mut graceful_task =
GracefulTask::builder_default().spawn(async |shutdown| shutdown.await);
assert!(!graceful_task.is_finished());
assert!(graceful_task.shutdown_handle().trigger());
assert!(!graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
#[cfg(feature = "signal")]
{
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.ctrlc_shutdown()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(graceful_task.shutdown_handle().trigger());
sleep().await;
assert!(!trigger.trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.ctrlc_shutdown()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(trigger.trigger());
sleep().await;
assert!(!graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::CtrlC)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::CtrlC);
}
let (trigger, ctrlc) = ctrlc_mocked();
let mut graceful_task = GracefulTask::builder_default()
.spawn_ctrlc_mocked(async |shutdown| shutdown.await, ctrlc);
assert!(!graceful_task.is_finished());
assert!(!trigger.trigger());
sleep().await;
assert!(graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), GracefulKind::Explicit);
let mut graceful_task = GracefulTask::builder_default().spawn(async |_| 42);
sleep().await;
assert!(!graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(task_output.finish_mode, FinishMode::Complete);
assert_eq!(task_output.join_result.unwrap(), 42);
let mut graceful_task =
GracefulTask::builder_default().spawn(async |_| sleep_double().await);
assert!(!graceful_task.is_finished());
sleep().await;
assert!(graceful_task.shutdown_handle().trigger());
let task_output = (&mut graceful_task).await;
assert!(graceful_task.is_finished());
assert_eq!(
task_output.finish_mode,
FinishMode::Shutdown(GracefulKind::Explicit)
);
assert_eq!(task_output.join_result.unwrap(), ());
}
}