use anyhow::{Context, Result};
use std::future::Future;
use tokio::runtime::Handle;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
pub type CriticalTaskHandler<Fut> = dyn FnOnce(CancellationToken) -> Fut + Send + 'static;
pub struct CriticalTaskExecutionHandle {
monitor_task: JoinHandle<()>,
graceful_shutdown_token: CancellationToken,
result_receiver: Option<oneshot::Receiver<Result<()>>>,
detached: bool,
}
impl CriticalTaskExecutionHandle {
pub fn new<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken,
description: &str,
) -> Result<Self>
where
Fut: Future<Output = Result<()>> + Send + 'static,
{
Self::new_with_runtime(task_fn, parent_token, description, &Handle::try_current()?)
}
pub fn new_with_runtime<Fut>(
task_fn: impl FnOnce(CancellationToken) -> Fut + Send + 'static,
parent_token: CancellationToken,
description: &str,
runtime: &Handle,
) -> Result<Self>
where
Fut: Future<Output = Result<()>> + Send + 'static,
{
let graceful_shutdown_token = parent_token.child_token();
let description = description.to_string();
let parent_token_clone = parent_token.clone();
let (result_sender, result_receiver) = oneshot::channel();
let graceful_shutdown_token_clone = graceful_shutdown_token.clone();
let description_clone = description.to_string();
let task = runtime.spawn(async move {
let future = task_fn(graceful_shutdown_token_clone);
match future.await {
Ok(()) => {
tracing::debug!(
"Critical task '{}' completed successfully",
description_clone
);
Ok(())
}
Err(e) => {
tracing::error!("Critical task '{}' failed: {:#}", description_clone, e);
Err(e.context(format!("Critical task '{}' failed", description_clone)))
}
}
});
let monitor_task = {
let main_task_handle = task;
let parent_token_monitor = parent_token_clone.clone();
let description_monitor = description.clone();
runtime.spawn(async move {
let result = match main_task_handle.await {
Ok(task_result) => {
if task_result.is_err() {
parent_token_monitor.cancel();
}
task_result
}
Err(join_error) => {
if join_error.is_panic() {
let panic_msg = if let Ok(reason) = join_error.try_into_panic() {
if let Some(s) = reason.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = reason.downcast_ref::<&str>() {
s.to_string()
} else {
"Unknown panic".to_string()
}
} else {
"Panic occurred but reason unavailable".to_string()
};
tracing::error!(
"Critical task '{}' panicked: {}",
description_monitor,
panic_msg
);
parent_token_monitor.cancel(); Err(anyhow::anyhow!(
"Critical task '{}' panicked: {}",
description_monitor,
panic_msg
))
} else {
parent_token_monitor.cancel();
Err(anyhow::anyhow!(
"Failed to join critical task '{}': {}",
description_monitor,
join_error
))
}
}
};
let _ = result_sender.send(result);
})
};
Ok(Self {
monitor_task,
graceful_shutdown_token,
result_receiver: Some(result_receiver),
detached: false,
})
}
pub fn is_finished(&self) -> bool {
self.monitor_task.is_finished()
}
pub fn is_cancelled(&self) -> bool {
self.graceful_shutdown_token.is_cancelled()
}
pub fn cancel(&self) {
self.graceful_shutdown_token.cancel();
}
pub async fn join(mut self) -> Result<()> {
self.detached = true;
match self.result_receiver.take().unwrap().await {
Ok(task_result) => task_result,
Err(_) => {
Err(anyhow::anyhow!("Critical task monitor was cancelled"))
}
}
}
pub fn detach(mut self) {
self.detached = true;
}
}
impl Drop for CriticalTaskExecutionHandle {
fn drop(&mut self) {
if !self.detached {
panic!("Critical task was not detached prior to drop!");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::time::Duration;
use tokio::time::timeout;
#[tokio::test]
async fn test_successful_task_completion() {
let parent_token = CancellationToken::new();
let completed = Arc::new(AtomicBool::new(false));
let completed_clone = completed.clone();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
completed_clone.store(true, Ordering::SeqCst);
Ok(())
},
parent_token.clone(),
"test-success-task",
)
.unwrap();
let result = handle.join().await;
assert!(result.is_ok());
assert!(completed.load(Ordering::SeqCst));
assert!(!parent_token.is_cancelled());
}
#[tokio::test]
async fn test_task_failure_cancels_parent_token() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
anyhow::bail!("Critical task failed!");
},
parent_token.clone(),
"test-failure-task",
)
.unwrap();
let result = handle.join().await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("Critical task failed!")
|| error_msg.contains("Critical task 'test-failure-task' failed"),
"Error message should contain failure context: {}",
error_msg
);
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(parent_token.is_cancelled());
}
#[tokio::test]
async fn test_task_panic_is_caught_and_reported() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
panic!("Something went terribly wrong!");
},
parent_token.clone(),
"test-panic-task",
)
.unwrap();
let result = handle.join().await;
assert!(result.is_err());
let error_msg = result.unwrap_err().to_string();
assert!(
error_msg.contains("panicked") || error_msg.contains("panic"),
"Error message should indicate a panic occurred: {}",
error_msg
);
assert!(parent_token.is_cancelled());
}
#[tokio::test]
async fn test_graceful_shutdown_via_cancellation_token() {
let parent_token = CancellationToken::new();
let work_done = Arc::new(AtomicU32::new(0));
let work_done_clone = work_done.clone();
let handle = CriticalTaskExecutionHandle::new(
|cancel_token| async move {
for i in 0..100 {
if cancel_token.is_cancelled() {
break;
}
work_done_clone.store(i, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(10)).await;
}
Ok(())
},
parent_token.clone(),
"test-graceful-shutdown",
)
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
handle.cancel();
let result = handle.join().await;
assert!(result.is_ok());
let final_work = work_done.load(Ordering::SeqCst);
assert!(final_work > 0);
assert!(final_work < 99);
assert!(!parent_token.is_cancelled());
}
#[tokio::test]
async fn test_multiple_critical_tasks_one_failure() {
let parent_token = CancellationToken::new();
let task1_completed = Arc::new(AtomicBool::new(false));
let task2_completed = Arc::new(AtomicBool::new(false));
let task1_completed_clone = task1_completed.clone();
let task2_completed_clone = task2_completed.clone();
let handle1 = CriticalTaskExecutionHandle::new(
|cancel_token| async move {
for _ in 0..50 {
if cancel_token.is_cancelled() {
return Ok(());
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
task1_completed_clone.store(true, Ordering::SeqCst);
Ok(())
},
parent_token.clone(),
"long-running-task",
)
.unwrap();
let handle2 = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(100)).await;
task2_completed_clone.store(true, Ordering::SeqCst);
anyhow::bail!("Task 2 failed!");
},
parent_token.clone(),
"failing-task",
)
.unwrap();
let result2 = handle2.join().await;
assert!(result2.is_err());
assert!(parent_token.is_cancelled());
let result1 = handle1.join().await;
assert!(result1.is_ok());
assert!(!task1_completed.load(Ordering::SeqCst)); }
#[tokio::test]
async fn test_status_checking_methods() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|cancel_token| async move {
tokio::time::sleep(Duration::from_millis(100)).await;
if cancel_token.is_cancelled() {
return Ok(());
}
tokio::time::sleep(Duration::from_millis(100)).await;
Ok(())
},
parent_token,
"status-test-task",
)
.unwrap();
assert!(!handle.is_finished());
assert!(!handle.is_cancelled());
handle.cancel();
assert!(handle.is_cancelled());
let result = handle.join().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_task_with_select_pattern() {
let parent_token = CancellationToken::new();
let work_completed = Arc::new(AtomicBool::new(false));
let work_completed_clone = work_completed.clone();
let handle = CriticalTaskExecutionHandle::new(
|cancel_token| async move {
tokio::select! {
_ = tokio::time::sleep(Duration::from_millis(200)) => {
work_completed_clone.store(true, Ordering::SeqCst);
Ok(())
}
_ = cancel_token.cancelled() => {
Ok(())
}
}
},
parent_token,
"select-pattern-task",
)
.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
handle.cancel();
let result = handle.join().await;
assert!(result.is_ok());
assert!(!work_completed.load(Ordering::SeqCst)); }
#[tokio::test]
async fn test_timeout_behavior() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok(())
},
parent_token,
"long-task",
)
.unwrap();
let result = timeout(Duration::from_millis(100), handle.join()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_panic_triggers_immediate_parent_cancellation() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await;
panic!("Critical failure!");
},
parent_token.clone(),
"immediate-panic-task",
)
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(
parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task panics"
);
assert!(handle.join().await.is_err());
}
#[tokio::test]
async fn test_error_triggers_immediate_parent_cancellation() {
let parent_token = CancellationToken::new();
let handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move {
tokio::time::sleep(Duration::from_millis(50)).await;
anyhow::bail!("Critical error!");
},
parent_token.clone(),
"immediate-error-task",
)
.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(
parent_token.is_cancelled(),
"Parent token should be cancelled immediately when critical task errors"
);
assert!(handle.join().await.is_err());
}
#[tokio::test]
#[should_panic]
async fn test_task_detach() {
let parent_token = CancellationToken::new();
let _handle = CriticalTaskExecutionHandle::new(
|_cancel_token| async move { Ok(()) },
parent_token,
"test-detach-task",
)
.unwrap();
}
}