use tokio::sync::oneshot;
#[derive(Debug)]
pub(crate) struct ShutdownSignal {
ack_tx: oneshot::Sender<()>,
}
impl ShutdownSignal {
pub(crate) fn new() -> (Self, oneshot::Receiver<()>) {
let (ack_tx, ack_rx) = oneshot::channel();
(Self { ack_tx }, ack_rx)
}
pub(crate) fn ack(self) {
let _ = self.ack_tx.send(());
}
}
pub struct CloudWatchWorkerGuard {
shutdown_tx: Option<oneshot::Sender<ShutdownSignal>>,
}
impl CloudWatchWorkerGuard {
pub(crate) fn new(shutdown_tx: oneshot::Sender<ShutdownSignal>) -> Self {
Self {
shutdown_tx: Some(shutdown_tx),
}
}
fn take_shutdown_tx(&mut self) -> Option<oneshot::Sender<ShutdownSignal>> {
self.shutdown_tx.take()
}
pub async fn shutdown(mut self) {
let shutdown_tx = match self.take_shutdown_tx() {
Some(value) => value,
None => return,
};
let (shutdown_signal, ack_rx) = ShutdownSignal::new();
if shutdown_tx.send(shutdown_signal).is_err() {
return;
}
_ = ack_rx.await;
}
}
impl Drop for CloudWatchWorkerGuard {
fn drop(&mut self) {
let shutdown_tx = match self.take_shutdown_tx() {
Some(value) => value,
None => return,
};
let (shutdown_signal, _ack_rx) = ShutdownSignal::new();
let _ = shutdown_tx.send(shutdown_signal);
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::time::{Duration, sleep};
#[tokio::test(flavor = "current_thread")]
async fn shutdown_waits_for_ack() {
let (shutdown_tx, shutdown_rx) = oneshot::channel::<ShutdownSignal>();
let guard = CloudWatchWorkerGuard::new(shutdown_tx);
let worker = tokio::spawn(async move {
let signal = shutdown_rx.await.unwrap();
sleep(Duration::from_millis(20)).await;
signal.ack();
});
guard.shutdown().await;
worker.await.unwrap();
}
}