use std::time::Duration;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone, Copy, PartialEq)]
enum TaskExit {
Normal,
Cancelled,
Aborted,
}
async fn spawn_task_dual_channel<Task, Fut>(
task: Task,
cancel_token: CancellationToken,
cancel_task_timeout: Duration,
) -> TaskExit
where
Task: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let task_token = CancellationToken::new();
let (tx, rx) = oneshot::channel();
let (tx2, rx2) = oneshot::channel();
let task_handle = tokio::spawn({
let task_token = task_token.clone();
async move {
task(task_token).await;
let _ = tx.send(());
let _ = tx2.send(());
}
});
tokio::select! {
_ = rx => TaskExit::Normal,
_ = cancel_token.cancelled() => {
task_token.cancel();
tokio::select! {
_ = tokio::time::sleep(cancel_task_timeout) => {
task_handle.abort();
TaskExit::Aborted
}
_ = rx2 => {
TaskExit::Cancelled
}
}
}
}
}
async fn spawn_task_single_channel<Task, Fut>(
task: Task,
cancel_token: CancellationToken,
cancel_task_timeout: Duration,
) -> TaskExit
where
Task: FnOnce(CancellationToken) -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let task_token = CancellationToken::new();
let task_handle = tokio::spawn({
let task_token = task_token.clone();
async move {
task(task_token).await;
}
});
tokio::pin!(task_handle);
tokio::select! {
result = &mut task_handle => {
match result {
Ok(()) => TaskExit::Normal,
Err(_) => TaskExit::Aborted,
}
}
_ = cancel_token.cancelled() => {
task_token.cancel();
tokio::select! {
_ = tokio::time::sleep(cancel_task_timeout) => {
task_handle.abort();
TaskExit::Aborted
}
result = &mut task_handle => {
match result {
Ok(()) => TaskExit::Cancelled,
Err(_) => TaskExit::Aborted,
}
}
}
}
}
}
#[tokio::main]
async fn main() {
println!("Testing dual channel vs single channel implementations\n");
println!("=== Test 1: Normal completion ===");
{
let cancel_token = CancellationToken::new();
let result = spawn_task_dual_channel(
|_token| async { tokio::time::sleep(Duration::from_millis(10)).await },
cancel_token,
Duration::from_millis(100),
)
.await;
println!("Dual channel: {:?}", result);
assert_eq!(result, TaskExit::Normal);
let cancel_token = CancellationToken::new();
let result = spawn_task_single_channel(
|_token| async { tokio::time::sleep(Duration::from_millis(10)).await },
cancel_token,
Duration::from_millis(100),
)
.await;
println!("Single channel: {:?}", result);
assert_eq!(result, TaskExit::Normal);
}
println!("\n=== Test 2: Graceful cancellation ===");
{
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_token_clone.cancel();
});
let result = spawn_task_dual_channel(
|token| async move {
tokio::select! {
_ = token.cancelled() => {}
_ = tokio::time::sleep(Duration::from_secs(10)) => {}
}
},
cancel_token,
Duration::from_millis(100),
)
.await;
println!("Dual channel: {:?}", result);
assert_eq!(result, TaskExit::Cancelled);
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_token_clone.cancel();
});
let result = spawn_task_single_channel(
|token| async move {
tokio::select! {
_ = token.cancelled() => {}
_ = tokio::time::sleep(Duration::from_secs(10)) => {}
}
},
cancel_token,
Duration::from_millis(100),
)
.await;
println!("Single channel: {:?}", result);
assert_eq!(result, TaskExit::Cancelled);
}
println!("\n=== Test 3: Zombie task (abort) ===");
{
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
cancel_token_clone.cancel();
});
let result = spawn_task_dual_channel(
|token| async move {
tokio::select! {
_ = token.cancelled() => {
loop { tokio::time::sleep(Duration::from_secs(1)).await; }
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {}
}
},
cancel_token,
Duration::from_millis(50),
)
.await;
println!("Dual channel: {:?}", result);
assert_eq!(result, TaskExit::Aborted);
let cancel_token = CancellationToken::new();
let cancel_token_clone = cancel_token.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
cancel_token_clone.cancel();
});
let result = spawn_task_single_channel(
|token| async move {
tokio::select! {
_ = token.cancelled() => {
loop { tokio::time::sleep(Duration::from_secs(1)).await; }
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {}
}
},
cancel_token,
Duration::from_millis(50),
)
.await;
println!("Single channel: {:?}", result);
assert_eq!(result, TaskExit::Aborted);
}
println!("\n=== All tests passed! ===");
}