use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
use tokio::sync::{broadcast, watch};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShutdownSignal {
Interrupt,
Terminate,
Manual,
}
impl std::fmt::Display for ShutdownSignal {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShutdownSignal::Interrupt => write!(f, "SIGINT"),
ShutdownSignal::Terminate => write!(f, "SIGTERM"),
ShutdownSignal::Manual => write!(f, "Manual"),
}
}
}
#[derive(Clone)]
pub struct ShutdownToken {
receiver: watch::Receiver<bool>,
}
impl ShutdownToken {
pub fn is_shutdown(&self) -> bool {
*self.receiver.borrow()
}
pub async fn cancelled(&mut self) {
let _ = self.receiver.wait_for(|v| *v).await;
}
}
pub struct GracefulShutdownBuilder {
timeout: Duration,
on_signal: Option<Box<dyn Fn(ShutdownSignal) + Send + Sync>>,
}
impl Default for GracefulShutdownBuilder {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
on_signal: None,
}
}
}
impl GracefulShutdownBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn on_signal<F>(mut self, callback: F) -> Self
where
F: Fn(ShutdownSignal) + Send + Sync + 'static,
{
self.on_signal = Some(Box::new(callback));
self
}
pub fn build(self) -> GracefulShutdown {
let on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>> = self
.on_signal
.map(|f| Arc::from(f) as Arc<dyn Fn(ShutdownSignal) + Send + Sync>);
GracefulShutdown {
timeout: self.timeout,
on_signal,
shutdown_tx: watch::channel(false).0,
signal_tx: broadcast::channel(1).0,
}
}
}
pub struct GracefulShutdown {
timeout: Duration,
on_signal: Option<Arc<dyn Fn(ShutdownSignal) + Send + Sync>>,
shutdown_tx: watch::Sender<bool>,
signal_tx: broadcast::Sender<ShutdownSignal>,
}
impl GracefulShutdown {
pub fn new() -> Self {
GracefulShutdownBuilder::new().build()
}
pub fn builder() -> GracefulShutdownBuilder {
GracefulShutdownBuilder::new()
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn token(&self) -> ShutdownToken {
ShutdownToken {
receiver: self.shutdown_tx.subscribe(),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<ShutdownSignal> {
self.signal_tx.subscribe()
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
let _ = self.signal_tx.send(ShutdownSignal::Manual);
if let Some(ref callback) = self.on_signal {
callback(ShutdownSignal::Manual);
}
}
pub async fn wait(&self) -> ShutdownSignal {
let signal = wait_for_signal().await;
let _ = self.shutdown_tx.send(true);
let _ = self.signal_tx.send(signal);
if let Some(ref callback) = self.on_signal {
callback(signal);
}
signal
}
pub async fn wait_with_timeout(&self) -> Option<ShutdownSignal> {
tokio::select! {
signal = self.wait() => Some(signal),
_ = tokio::time::sleep(self.timeout) => None,
}
}
pub async fn run_until_shutdown<F, T>(&self, future: F) -> Option<T>
where
F: Future<Output = T>,
{
let mut token = self.token();
tokio::select! {
result = future => Some(result),
_ = token.cancelled() => None,
}
}
pub fn spawn<F>(&self, name: &str, future: F) -> tokio::task::JoinHandle<Option<()>>
where
F: Future<Output = ()> + Send + 'static,
{
let mut token = self.token();
let name = name.to_string();
tokio::spawn(async move {
tokio::select! {
_ = future => {
Some(())
}
_ = token.cancelled() => {
#[cfg(feature = "otel")]
tracing::info!(task = %name, "Task cancelled due to shutdown");
#[cfg(not(feature = "otel"))]
let _ = name;
None
}
}
})
}
}
impl Default for GracefulShutdown {
fn default() -> Self {
Self::new()
}
}
async fn wait_for_signal() -> ShutdownSignal {
#[cfg(unix)]
{
use tokio::signal::unix::{signal, SignalKind};
let mut sigint =
signal(SignalKind::interrupt()).expect("Failed to register SIGINT handler");
let mut sigterm =
signal(SignalKind::terminate()).expect("Failed to register SIGTERM handler");
tokio::select! {
_ = sigint.recv() => ShutdownSignal::Interrupt,
_ = sigterm.recv() => ShutdownSignal::Terminate,
}
}
#[cfg(not(unix))]
{
tokio::signal::ctrl_c()
.await
.expect("Failed to register Ctrl+C handler");
ShutdownSignal::Interrupt
}
}
pub struct ShutdownGuard {
shutdown: Arc<GracefulShutdown>,
}
impl ShutdownGuard {
pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
Self { shutdown }
}
}
impl Drop for ShutdownGuard {
fn drop(&mut self) {
self.shutdown.shutdown();
}
}
pub trait ShutdownExt: Future + Sized {
fn with_shutdown(
self,
shutdown: &GracefulShutdown,
) -> Pin<Box<dyn Future<Output = Option<Self::Output>> + Send + '_>>
where
Self: Send + 'static,
Self::Output: Send,
{
let future = self;
let mut token = shutdown.token();
Box::pin(async move {
tokio::select! {
result = future => Some(result),
_ = token.cancelled() => None,
}
})
}
}
impl<F: Future> ShutdownExt for F {}
pub struct ShutdownAwareTaskSpawner {
shutdown: Arc<GracefulShutdown>,
}
impl ShutdownAwareTaskSpawner {
pub fn new(shutdown: Arc<GracefulShutdown>) -> Self {
Self { shutdown }
}
pub fn shutdown(&self) -> &Arc<GracefulShutdown> {
&self.shutdown
}
pub fn spawn<F, Fut>(&self, task_name: &str, future: F) -> tokio::task::JoinHandle<()>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
let mut token = self.shutdown.token();
let task_name = task_name.to_string();
tokio::spawn(async move {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Starting task");
let task_future = future();
tokio::select! {
_ = task_future => {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task completed normally");
}
_ = token.cancelled() => {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task cancelled due to shutdown");
}
}
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task finished");
#[cfg(not(feature = "otel"))]
let _ = task_name;
})
}
pub fn spawn_background<F, Fut>(
&self,
task_name: &str,
future: F,
) -> tokio::task::JoinHandle<()>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = ()> + Send,
{
self.spawn(task_name, future)
}
pub fn spawn_with_result<F, Fut, T>(
&self,
task_name: &str,
future: F,
) -> tokio::task::JoinHandle<Option<T>>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: Future<Output = T> + Send,
T: Send + 'static,
{
let mut token = self.shutdown.token();
let task_name = task_name.to_string();
tokio::spawn(async move {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Starting task");
let task_future = future();
let result = tokio::select! {
result = task_future => {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task completed normally");
Some(result)
}
_ = token.cancelled() => {
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task cancelled due to shutdown");
None
}
};
#[cfg(feature = "otel")]
tracing::info!(task = %task_name, "Task finished");
#[cfg(not(feature = "otel"))]
let _ = task_name;
result
})
}
}
impl Clone for ShutdownAwareTaskSpawner {
fn clone(&self) -> Self {
Self {
shutdown: self.shutdown.clone(),
}
}
}
pub trait GracefulShutdownExt {
fn perform_shutdown<F, Fut, E>(
&self,
cleanup_fn: F,
) -> impl Future<Output = Result<(), E>> + Send
where
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display + Send;
}
impl GracefulShutdownExt for GracefulShutdown {
async fn perform_shutdown<F, Fut, E>(&self, cleanup_fn: F) -> Result<(), E>
where
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display + Send,
{
#[cfg(feature = "otel")]
tracing::info!("Starting graceful shutdown sequence");
#[cfg(feature = "otel")]
tracing::info!("Running cleanup functions");
let result = cleanup_fn().await;
if let Err(ref e) = result {
#[cfg(feature = "otel")]
tracing::error!(error = %e, "Cleanup function failed");
#[cfg(not(feature = "otel"))]
let _ = e;
}
#[cfg(feature = "otel")]
tracing::info!("Graceful shutdown completed");
result
}
}
impl GracefulShutdownExt for Arc<GracefulShutdown> {
async fn perform_shutdown<F, Fut, E>(&self, cleanup_fn: F) -> Result<(), E>
where
F: FnOnce() -> Fut + Send,
Fut: Future<Output = Result<(), E>> + Send,
E: std::fmt::Display + Send,
{
self.as_ref().perform_shutdown(cleanup_fn).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shutdown_signal_display() {
assert_eq!(ShutdownSignal::Interrupt.to_string(), "SIGINT");
assert_eq!(ShutdownSignal::Terminate.to_string(), "SIGTERM");
assert_eq!(ShutdownSignal::Manual.to_string(), "Manual");
}
#[tokio::test]
async fn test_shutdown_token() {
let shutdown = GracefulShutdown::new();
let token = shutdown.token();
assert!(!token.is_shutdown());
shutdown.shutdown();
assert!(token.is_shutdown());
}
#[tokio::test]
async fn test_manual_shutdown() {
let shutdown = GracefulShutdown::new();
let mut rx = shutdown.subscribe();
shutdown.shutdown();
let signal = rx.recv().await.unwrap();
assert_eq!(signal, ShutdownSignal::Manual);
}
#[tokio::test]
async fn test_shutdown_callback() {
use std::sync::atomic::{AtomicBool, Ordering};
let called = Arc::new(AtomicBool::new(false));
let called_clone = called.clone();
let shutdown = GracefulShutdown::builder()
.on_signal(move |_| {
called_clone.store(true, Ordering::SeqCst);
})
.build();
shutdown.shutdown();
assert!(called.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_run_until_shutdown() {
let shutdown = GracefulShutdown::new();
let result = shutdown.run_until_shutdown(async { 42 }).await;
assert_eq!(result, Some(42));
}
#[tokio::test]
async fn test_run_until_shutdown_cancelled() {
let shutdown = GracefulShutdown::new();
let token = shutdown.token();
shutdown.shutdown();
assert!(token.is_shutdown());
}
#[tokio::test]
async fn test_builder_timeout() {
let shutdown = GracefulShutdown::builder()
.timeout(Duration::from_secs(60))
.build();
assert_eq!(shutdown.timeout(), Duration::from_secs(60));
}
#[tokio::test]
async fn test_spawn_task() {
let shutdown = GracefulShutdown::new();
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = counter.clone();
let handle = shutdown.spawn("test_task", async move {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
let result = handle.await.unwrap();
assert_eq!(result, Some(()));
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_shutdown_aware_spawner_task_completes() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = counter.clone();
let handle = spawner.spawn("test_task", move || async move {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
handle.await.unwrap();
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_shutdown_aware_spawner_task_cancelled() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = counter.clone();
let handle = spawner.spawn("long_task", move || async move {
tokio::time::sleep(Duration::from_secs(60)).await;
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
shutdown.shutdown();
handle.await.unwrap();
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 0);
}
#[tokio::test]
async fn test_shutdown_aware_spawner_with_result() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let handle = spawner.spawn_with_result("compute_task", || async { 42 });
let result = handle.await.unwrap();
assert_eq!(result, Some(42));
}
#[tokio::test]
async fn test_shutdown_aware_spawner_with_result_cancelled() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let handle = spawner.spawn_with_result("long_compute", || async {
tokio::time::sleep(Duration::from_secs(60)).await;
42
});
shutdown.shutdown();
let result = handle.await.unwrap();
assert_eq!(result, None); }
#[tokio::test]
async fn test_shutdown_aware_spawner_clone() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let spawner2 = spawner.clone();
assert!(Arc::ptr_eq(spawner.shutdown(), spawner2.shutdown()));
}
#[tokio::test]
async fn test_graceful_shutdown_ext_success() {
let shutdown = GracefulShutdown::new();
let result: Result<(), &str> = shutdown.perform_shutdown(|| async { Ok(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_graceful_shutdown_ext_error() {
let shutdown = GracefulShutdown::new();
let result: Result<(), &str> = shutdown
.perform_shutdown(|| async { Err("cleanup failed") })
.await;
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "cleanup failed");
}
#[tokio::test]
async fn test_graceful_shutdown_ext_with_arc() {
let shutdown = Arc::new(GracefulShutdown::new());
let result: Result<(), &str> = shutdown.perform_shutdown(|| async { Ok(()) }).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_shutdown_aware_spawner_background() {
let shutdown = Arc::new(GracefulShutdown::new());
let spawner = ShutdownAwareTaskSpawner::new(shutdown.clone());
let counter = Arc::new(std::sync::atomic::AtomicU32::new(0));
let counter_clone = counter.clone();
let handle = spawner.spawn_background("bg_task", move || async move {
counter_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
});
handle.await.unwrap();
assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 1);
}
}