use std::sync::Arc;
use tokio::{
sync::broadcast,
time::{interval, Duration},
};
use tracing::{error, info};
#[derive(Debug)]
pub enum Error {
Tokio(broadcast::error::SendError<()>),
}
#[derive(Debug)]
pub struct Shutdown {
sender: Arc<broadcast::Sender<()>>,
notify: broadcast::Receiver<()>,
shutdown: bool,
}
impl Default for Shutdown {
fn default() -> Self {
Self::new()
}
}
impl Shutdown {
#[must_use]
pub fn new() -> Self {
let (shutdown_snd, shutdown_rcv) = broadcast::channel(1);
Self {
sender: Arc::new(shutdown_snd),
notify: shutdown_rcv,
shutdown: false,
}
}
pub async fn recv(&mut self) {
if self.shutdown {
return;
}
let _ = self.notify.recv().await;
self.shutdown = true;
}
pub fn signal(&self) -> Result<usize, Error> {
self.sender.send(()).map_err(Error::Tokio)
}
pub async fn wait(self, max_delay: Duration) {
drop(self.notify);
let mut check_pulse = interval(Duration::from_secs(1));
let mut max_delay = interval(max_delay);
max_delay.tick().await;
loop {
tokio::select! {
_ = check_pulse.tick() => {
let remaining: usize = self.sender.receiver_count();
if remaining == 0 {
info!("all tasks shut down");
return;
}
info!("waiting for {} tasks to shutdown", remaining);
}
_ = max_delay.tick() => {
let remaining: usize = self.sender.receiver_count();
error!("shutdown wait completing with {} remaining tasks", remaining);
return;
}
}
}
}
}
impl Clone for Shutdown {
fn clone(&self) -> Self {
let notify = self.sender.subscribe();
Self {
shutdown: self.shutdown,
notify,
sender: Arc::clone(&self.sender),
}
}
}