use std::time::Duration;
use tokio::signal;
use tokio::sync::{broadcast, mpsc};
use tokio::time::timeout;
use tracing::debug;
pub struct Task {
tx: mpsc::Sender<()>,
btx: broadcast::Sender<()>,
rx: broadcast::Receiver<()>,
}
impl Task {
pub async fn wait(&mut self) {
_ = self.rx.recv().await;
debug!("task received shutdown signal");
}
}
impl Clone for Task {
fn clone(&self) -> Self {
Self {
tx: self.tx.clone(),
btx: self.btx.clone(),
rx: self.btx.subscribe(),
}
}
}
pub struct TaskManager {
wait_timeout: Duration,
btx: broadcast::Sender<()>,
rtx: broadcast::Receiver<()>,
tx: mpsc::Sender<()>,
rx: mpsc::Receiver<()>,
}
impl TaskManager {
pub fn new(wait_timeout: Duration) -> Self {
let (btx, rtx) = broadcast::channel(1);
let (tx, rx) = mpsc::channel(1);
Self {
wait_timeout,
btx,
rtx,
tx,
rx,
}
}
pub fn task(&self) -> Task {
Task {
tx: self.tx.clone(),
btx: self.btx.clone(),
rx: self.btx.subscribe(),
}
}
pub async fn wait(self) -> bool {
self.shutdown(false).await
}
pub async fn shutdown_gracefully_on_ctrl_c(self) -> bool {
self.shutdown(true).await
}
async fn shutdown(mut self, block_until_signal: bool) -> bool {
drop(self.tx);
if block_until_signal {
tokio::select! {
_ = self.rx.recv() => {
debug!("task manager has been shut down due to no active tasks");
return true;
},
_ = signal::ctrl_c() => {
debug!("ctrl+c signal received: starting graceful shutdown of task manager");
}
};
}
if let Err(err) = self.btx.send(()) {
debug!(
"task manager received error while sending broadcast shutdown signal: {}",
err
);
}
drop(self.rtx);
if let Err(err) = timeout(self.wait_timeout, async move { _ = self.rx.recv().await }).await
{
debug!("task manager received error while waiting: {}", err);
return false;
}
true
}
}
#[doc = include_str!("../README.md")]
#[cfg(doctest)]
pub struct ReadmeDocTests;
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn task_manager_zero_task_wait() {
let tm = TaskManager::new(Duration::from_secs(1));
assert!(tm.wait().await);
}
#[tokio::test]
async fn task_manager_zero_task_shutdown_gracefully_on_ctrl_c_() {
let tm = TaskManager::new(Duration::from_secs(1));
assert!(tm.shutdown_gracefully_on_ctrl_c().await);
}
#[tokio::test]
async fn task_manager_graceful_shutdown() {
let tm = TaskManager::new(Duration::from_millis(200));
let (tx, mut rx) = tokio::sync::mpsc::channel(20);
for i in 0..10 {
let tx = tx.clone();
let n = i;
let mut task = tm.task();
tokio::spawn(async move {
let mut child_task = task.clone();
let child_tx = tx.clone();
let m = n;
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(m * 10)).await;
tokio::select! {
result = child_tx.send((m+1)*10) => assert!(result.is_ok()),
_ = child_task.wait() => (),
}
});
tokio::time::sleep(Duration::from_millis(n * 10)).await;
tokio::select! {
result = tx.send(n) => assert!(result.is_ok()),
_ = task.wait() => (),
}
});
}
let mut task = tm.task();
tokio::spawn(async move {
let mut child_task = task.clone();
tokio::spawn(async move {
tokio::select! {
_ = child_task.wait() => (),
_ = tokio::time::sleep(Duration::from_secs(60)) => (),
}
});
tokio::select! {
_ = task.wait() => (),
_ = tokio::time::sleep(Duration::from_secs(60)) => (),
}
});
tokio::time::sleep(Duration::from_millis(100)).await;
drop(tx);
assert!(tm.wait().await);
let mut results = Vec::with_capacity(20);
while let Some(n) = rx.recv().await {
results.push(n);
}
results.sort_unstable();
assert_eq!(
&results,
&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
);
}
#[tokio::test]
async fn task_manager_shutdown_timeout() {
let tm = TaskManager::new(Duration::from_millis(10));
let mut task = tm.task();
tokio::spawn(async move {
let _ = task.wait();
tokio::time::sleep(Duration::from_secs(60)).await;
panic!("never should reach here");
});
assert!(!tm.wait().await);
}
}