1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
41use std::sync::Arc;
42use std::time::{Duration, Instant};
43use tokio::sync::{broadcast, watch, Notify};
44use tracing::{debug, info, warn};
45
46pub const DEFAULT_SHUTDOWN_TIMEOUT_SECS: u64 = 30;
48
49pub const DEFAULT_DRAIN_TIMEOUT_SECS: u64 = 30;
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum ShutdownState {
55 Running,
57 Draining,
59 Stopped,
61}
62
63impl std::fmt::Display for ShutdownState {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 match self {
66 ShutdownState::Running => write!(f, "running"),
67 ShutdownState::Draining => write!(f, "draining"),
68 ShutdownState::Stopped => write!(f, "stopped"),
69 }
70 }
71}
72
73#[derive(Clone)]
78pub struct ShutdownController {
79 inner: Arc<ShutdownControllerInner>,
81}
82
83struct ShutdownControllerInner {
84 is_shutting_down: AtomicBool,
86
87 state: std::sync::RwLock<ShutdownState>,
89
90 shutdown_notify: Notify,
92
93 state_tx: watch::Sender<ShutdownState>,
95 state_rx: watch::Receiver<ShutdownState>,
96
97 shutdown_tx: broadcast::Sender<()>,
99
100 active_connections: AtomicU64,
102
103 drain_timeout: Duration,
105
106 shutdown_started: std::sync::RwLock<Option<Instant>>,
108}
109
110impl ShutdownController {
111 pub fn new() -> Self {
113 Self::with_timeout(Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS))
114 }
115
116 pub fn with_timeout(drain_timeout: Duration) -> Self {
118 let (state_tx, state_rx) = watch::channel(ShutdownState::Running);
119 let (shutdown_tx, _) = broadcast::channel(16);
120
121 Self {
122 inner: Arc::new(ShutdownControllerInner {
123 is_shutting_down: AtomicBool::new(false),
124 state: std::sync::RwLock::new(ShutdownState::Running),
125 shutdown_notify: Notify::new(),
126 state_tx,
127 state_rx,
128 shutdown_tx,
129 active_connections: AtomicU64::new(0),
130 drain_timeout,
131 shutdown_started: std::sync::RwLock::new(None),
132 }),
133 }
134 }
135
136 pub fn is_shutting_down(&self) -> bool {
138 self.inner.is_shutting_down.load(Ordering::SeqCst)
139 }
140
141 pub fn state(&self) -> ShutdownState {
143 *self.inner.state.read().unwrap()
144 }
145
146 pub fn state_receiver(&self) -> watch::Receiver<ShutdownState> {
148 self.inner.state_rx.clone()
149 }
150
151 pub fn subscribe(&self) -> broadcast::Receiver<()> {
153 self.inner.shutdown_tx.subscribe()
154 }
155
156 pub async fn initiate_shutdown(&self) {
165 if self
167 .inner
168 .is_shutting_down
169 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
170 .is_err()
171 {
172 debug!("Shutdown already in progress");
173 return;
174 }
175
176 info!("Initiating graceful shutdown");
177
178 *self.inner.shutdown_started.write().unwrap() = Some(Instant::now());
180
181 self.set_state(ShutdownState::Draining);
183
184 #[cfg(target_os = "linux")]
186 systemd_notify_stopping();
187
188 let _ = self.inner.shutdown_tx.send(());
190 self.inner.shutdown_notify.notify_waiters();
191
192 self.wait_for_drain().await;
194
195 self.flush_logs();
197
198 self.set_state(ShutdownState::Stopped);
200
201 info!("Graceful shutdown complete");
202 }
203
204 pub async fn wait_for_shutdown(&self) {
206 if self.is_shutting_down() {
207 return;
208 }
209 self.inner.shutdown_notify.notified().await;
210 }
211
212 pub fn connection_start(&self) {
214 self.inner.active_connections.fetch_add(1, Ordering::SeqCst);
215 }
216
217 pub fn connection_end(&self) {
219 self.inner.active_connections.fetch_sub(1, Ordering::SeqCst);
220 }
221
222 pub fn active_connections(&self) -> u64 {
224 self.inner.active_connections.load(Ordering::SeqCst)
225 }
226
227 pub fn connection_guard(&self) -> ConnectionGuard {
229 self.connection_start();
230 ConnectionGuard {
231 controller: self.clone(),
232 }
233 }
234
235 pub fn drain_timeout(&self) -> Duration {
237 self.inner.drain_timeout
238 }
239
240 pub fn shutdown_elapsed(&self) -> Option<Duration> {
242 self.inner
243 .shutdown_started
244 .read()
245 .unwrap()
246 .map(|started| started.elapsed())
247 }
248
249 pub fn retry_after_secs(&self) -> u64 {
251 match self.shutdown_elapsed() {
252 Some(elapsed) => {
253 let remaining = self.inner.drain_timeout.saturating_sub(elapsed);
254 remaining.as_secs().saturating_add(5) }
256 None => DEFAULT_DRAIN_TIMEOUT_SECS + 5,
257 }
258 }
259
260 fn set_state(&self, state: ShutdownState) {
263 *self.inner.state.write().unwrap() = state;
264 let _ = self.inner.state_tx.send(state);
265 info!("Shutdown state changed to: {}", state);
266 }
267
268 async fn wait_for_drain(&self) {
269 let timeout = self.inner.drain_timeout;
270 let start = Instant::now();
271
272 info!(
273 "Waiting for {} active connections to drain (timeout: {:?})",
274 self.active_connections(),
275 timeout
276 );
277
278 loop {
279 let active = self.active_connections();
280
281 if active == 0 {
282 info!("All connections drained successfully");
283 return;
284 }
285
286 if start.elapsed() >= timeout {
287 warn!(
288 "Drain timeout reached with {} active connections remaining",
289 active
290 );
291 return;
292 }
293
294 tokio::time::sleep(Duration::from_millis(100)).await;
296 }
297 }
298
299 fn flush_logs(&self) {
300 debug!("Flushing logs before shutdown");
303
304 std::thread::sleep(Duration::from_millis(50));
306 }
307}
308
309impl Default for ShutdownController {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315pub struct ConnectionGuard {
319 controller: ShutdownController,
320}
321
322impl Drop for ConnectionGuard {
323 fn drop(&mut self) {
324 self.controller.connection_end();
325 }
326}
327
328pub async fn shutdown_signal() {
349 let ctrl_c = async {
350 tokio::signal::ctrl_c()
351 .await
352 .expect("Failed to install Ctrl+C handler");
353 };
354
355 #[cfg(unix)]
356 let terminate = async {
357 use tokio::signal::unix::{signal, SignalKind};
358
359 let mut sigterm =
360 signal(SignalKind::terminate()).expect("Failed to install SIGTERM handler");
361
362 let mut sigint = signal(SignalKind::interrupt()).expect("Failed to install SIGINT handler");
363
364 let mut sighup = signal(SignalKind::hangup()).expect("Failed to install SIGHUP handler");
365
366 tokio::select! {
367 _ = sigterm.recv() => {
368 info!("Received SIGTERM");
369 }
370 _ = sigint.recv() => {
371 info!("Received SIGINT");
372 }
373 _ = sighup.recv() => {
374 info!("Received SIGHUP");
375 }
376 }
377 };
378
379 #[cfg(not(unix))]
380 let terminate = std::future::pending::<()>();
381
382 tokio::select! {
383 _ = ctrl_c => {
384 info!("Received Ctrl+C");
385 }
386 _ = terminate => {}
387 }
388}
389
390pub async fn shutdown_signal_with_controller(controller: ShutdownController) {
414 shutdown_signal().await;
415 controller.initiate_shutdown().await;
416}
417
418#[cfg(target_os = "linux")]
425pub fn systemd_notify_ready() {
426 if let Err(e) = sd_notify("READY=1") {
427 debug!(
428 "Failed to notify systemd ready (may not be running under systemd): {}",
429 e
430 );
431 } else {
432 info!("Notified systemd: READY");
433 }
434}
435
436#[cfg(target_os = "linux")]
440pub fn systemd_notify_stopping() {
441 if let Err(e) = sd_notify("STOPPING=1") {
442 debug!("Failed to notify systemd stopping: {}", e);
443 } else {
444 info!("Notified systemd: STOPPING");
445 }
446}
447
448#[cfg(target_os = "linux")]
452pub fn systemd_notify_status(status: &str) {
453 if let Err(e) = sd_notify(&format!("STATUS={}", status)) {
454 debug!("Failed to notify systemd status: {}", e);
455 }
456}
457
458#[cfg(target_os = "linux")]
463pub fn systemd_watchdog_ping() {
464 if let Err(e) = sd_notify("WATCHDOG=1") {
465 debug!("Failed to send watchdog ping: {}", e);
466 }
467}
468
469#[cfg(target_os = "linux")]
474fn sd_notify(state: &str) -> std::io::Result<()> {
475 use std::os::unix::net::UnixDatagram;
476
477 let socket_path = match std::env::var("NOTIFY_SOCKET") {
478 Ok(path) => path,
479 Err(_) => {
480 return Ok(());
482 }
483 };
484
485 let socket_path = if let Some(rest) = socket_path.strip_prefix('@') {
487 format!("\0{rest}")
488 } else {
489 socket_path
490 };
491
492 let socket = UnixDatagram::unbound()?;
493
494 if let Some(rest) = socket_path.strip_prefix('\0') {
496 use std::os::unix::net::SocketAddr;
498 let addr = SocketAddr::from_pathname(rest)?;
499 socket.send_to(state.as_bytes(), addr.as_pathname().unwrap())?;
500 } else {
501 socket.send_to(state.as_bytes(), &socket_path)?;
502 }
503
504 Ok(())
505}
506
507#[cfg(not(target_os = "linux"))]
511pub fn systemd_notify_ready() {
512 debug!("systemd_notify_ready: not on Linux, skipping");
513}
514
515#[cfg(not(target_os = "linux"))]
517pub fn systemd_notify_stopping() {
518 debug!("systemd_notify_stopping: not on Linux, skipping");
519}
520
521#[cfg(not(target_os = "linux"))]
523pub fn systemd_notify_status(_status: &str) {
524 debug!("systemd_notify_status: not on Linux, skipping");
525}
526
527#[cfg(not(target_os = "linux"))]
529pub fn systemd_watchdog_ping() {
530 debug!("systemd_watchdog_ping: not on Linux, skipping");
531}
532
533pub async fn watchdog_task(interval: Duration, mut shutdown_rx: broadcast::Receiver<()>) {
538 info!(
539 "Starting systemd watchdog task with {:?} interval",
540 interval
541 );
542
543 loop {
544 tokio::select! {
545 _ = tokio::time::sleep(interval) => {
546 systemd_watchdog_ping();
547 }
548 _ = shutdown_rx.recv() => {
549 info!("Watchdog task stopping due to shutdown");
550 break;
551 }
552 }
553 }
554}
555
556#[derive(Debug, Clone, serde::Serialize)]
560pub struct HealthStatus {
561 pub status: String,
563 pub healthy: bool,
565 pub shutdown_state: String,
567 pub active_connections: u64,
569 #[serde(skip_serializing_if = "Option::is_none")]
571 pub drain_remaining_secs: Option<u64>,
572 #[serde(skip_serializing_if = "Option::is_none")]
574 pub retry_after_secs: Option<u64>,
575}
576
577impl ShutdownController {
578 pub fn health_status(&self) -> HealthStatus {
582 let state = self.state();
583 let active = self.active_connections();
584
585 match state {
586 ShutdownState::Running => HealthStatus {
587 status: "ok".to_string(),
588 healthy: true,
589 shutdown_state: state.to_string(),
590 active_connections: active,
591 drain_remaining_secs: None,
592 retry_after_secs: None,
593 },
594 ShutdownState::Draining => {
595 let drain_remaining = self
596 .shutdown_elapsed()
597 .map(|elapsed| self.drain_timeout().saturating_sub(elapsed).as_secs());
598
599 HealthStatus {
600 status: "draining".to_string(),
601 healthy: false,
602 shutdown_state: state.to_string(),
603 active_connections: active,
604 drain_remaining_secs: drain_remaining,
605 retry_after_secs: Some(self.retry_after_secs()),
606 }
607 }
608 ShutdownState::Stopped => HealthStatus {
609 status: "stopped".to_string(),
610 healthy: false,
611 shutdown_state: state.to_string(),
612 active_connections: active,
613 drain_remaining_secs: Some(0),
614 retry_after_secs: Some(self.retry_after_secs()),
615 },
616 }
617 }
618}
619
620pub mod axum_integration {
627 use super::*;
628 use axum::{
629 body::Body,
630 http::{header, Request, Response, StatusCode},
631 };
632 use std::task::{Context, Poll};
633 use tower::{Layer, Service};
634
635 #[derive(Clone)]
637 pub struct ShutdownLayer {
638 controller: ShutdownController,
639 }
640
641 impl ShutdownLayer {
642 pub fn new(controller: ShutdownController) -> Self {
644 Self { controller }
645 }
646 }
647
648 impl<S> Layer<S> for ShutdownLayer {
649 type Service = ShutdownService<S>;
650
651 fn layer(&self, inner: S) -> Self::Service {
652 ShutdownService {
653 inner,
654 controller: self.controller.clone(),
655 }
656 }
657 }
658
659 #[derive(Clone)]
661 pub struct ShutdownService<S> {
662 inner: S,
663 controller: ShutdownController,
664 }
665
666 impl<S, ReqBody> Service<Request<ReqBody>> for ShutdownService<S>
667 where
668 S: Service<Request<ReqBody>, Response = Response<Body>> + Clone + Send + 'static,
669 S::Future: Send,
670 ReqBody: Send + 'static,
671 {
672 type Response = Response<Body>;
673 type Error = S::Error;
674 type Future = std::pin::Pin<
675 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
676 >;
677
678 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
679 self.inner.poll_ready(cx)
680 }
681
682 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
683 let controller = self.controller.clone();
684 let mut inner = self.inner.clone();
685
686 Box::pin(async move {
687 if controller.is_shutting_down() {
689 let retry_after = controller.retry_after_secs().to_string();
690 let health = controller.health_status();
691 let body = serde_json::to_string(&health).unwrap_or_else(|_| {
692 r#"{"status":"unavailable","healthy":false}"#.to_string()
693 });
694
695 let response = Response::builder()
696 .status(StatusCode::SERVICE_UNAVAILABLE)
697 .header(header::RETRY_AFTER, retry_after)
698 .header(header::CONTENT_TYPE, "application/json")
699 .body(Body::from(body))
700 .unwrap();
701
702 return Ok(response);
703 }
704
705 let _guard = controller.connection_guard();
707
708 inner.call(req).await
710 })
711 }
712 }
713}
714
715#[cfg(test)]
716mod tests {
717 use super::*;
718
719 #[tokio::test]
720 async fn test_shutdown_controller_new() {
721 let controller = ShutdownController::new();
722 assert!(!controller.is_shutting_down());
723 assert_eq!(controller.state(), ShutdownState::Running);
724 assert_eq!(controller.active_connections(), 0);
725 }
726
727 #[tokio::test]
728 async fn test_shutdown_controller_with_timeout() {
729 let controller = ShutdownController::with_timeout(Duration::from_secs(60));
730 assert_eq!(controller.drain_timeout(), Duration::from_secs(60));
731 }
732
733 #[tokio::test]
734 async fn test_connection_tracking() {
735 let controller = ShutdownController::new();
736
737 controller.connection_start();
738 assert_eq!(controller.active_connections(), 1);
739
740 controller.connection_start();
741 assert_eq!(controller.active_connections(), 2);
742
743 controller.connection_end();
744 assert_eq!(controller.active_connections(), 1);
745
746 controller.connection_end();
747 assert_eq!(controller.active_connections(), 0);
748 }
749
750 #[tokio::test]
751 async fn test_connection_guard() {
752 let controller = ShutdownController::new();
753
754 {
755 let _guard = controller.connection_guard();
756 assert_eq!(controller.active_connections(), 1);
757
758 {
759 let _guard2 = controller.connection_guard();
760 assert_eq!(controller.active_connections(), 2);
761 }
762
763 assert_eq!(controller.active_connections(), 1);
764 }
765
766 assert_eq!(controller.active_connections(), 0);
767 }
768
769 #[tokio::test]
770 async fn test_shutdown_initiation() {
771 let controller = ShutdownController::with_timeout(Duration::from_millis(100));
772
773 assert!(!controller.is_shutting_down());
775 assert_eq!(controller.state(), ShutdownState::Running);
776
777 controller.initiate_shutdown().await;
779
780 assert!(controller.is_shutting_down());
782 assert_eq!(controller.state(), ShutdownState::Stopped);
783 }
784
785 #[tokio::test]
786 async fn test_shutdown_only_once() {
787 let controller = ShutdownController::with_timeout(Duration::from_millis(100));
788
789 let controller2 = controller.clone();
790
791 let handle1 = tokio::spawn(async move {
793 controller.initiate_shutdown().await;
794 });
795
796 let handle2 = tokio::spawn(async move {
797 controller2.initiate_shutdown().await;
798 });
799
800 let (r1, r2) = tokio::join!(handle1, handle2);
802 r1.unwrap();
803 r2.unwrap();
804 }
805
806 #[tokio::test]
807 async fn test_health_status_running() {
808 let controller = ShutdownController::new();
809 let health = controller.health_status();
810
811 assert!(health.healthy);
812 assert_eq!(health.status, "ok");
813 assert_eq!(health.shutdown_state, "running");
814 assert!(health.retry_after_secs.is_none());
815 }
816
817 #[tokio::test]
818 async fn test_subscribe_and_notify() {
819 let controller = ShutdownController::with_timeout(Duration::from_millis(100));
820 let mut rx = controller.subscribe();
821
822 let controller2 = controller.clone();
824 tokio::spawn(async move {
825 tokio::time::sleep(Duration::from_millis(10)).await;
826 controller2.initiate_shutdown().await;
827 });
828
829 let result = tokio::time::timeout(Duration::from_secs(1), rx.recv()).await;
831 assert!(result.is_ok());
832 }
833
834 #[tokio::test]
835 async fn test_state_receiver() {
836 let controller = ShutdownController::with_timeout(Duration::from_millis(100));
837 let mut rx = controller.state_receiver();
838
839 assert_eq!(*rx.borrow(), ShutdownState::Running);
841
842 let controller2 = controller.clone();
844 tokio::spawn(async move {
845 controller2.initiate_shutdown().await;
846 });
847
848 rx.changed().await.unwrap();
850
851 let state = *rx.borrow();
853 assert!(state == ShutdownState::Draining || state == ShutdownState::Stopped);
854 }
855
856 #[tokio::test]
857 async fn test_drain_with_active_connections() {
858 let controller = ShutdownController::with_timeout(Duration::from_millis(500));
859
860 let guard = controller.connection_guard();
862
863 let controller2 = controller.clone();
865 let shutdown_handle = tokio::spawn(async move {
866 controller2.initiate_shutdown().await;
867 });
868
869 tokio::time::sleep(Duration::from_millis(100)).await;
871 drop(guard);
872
873 tokio::time::timeout(Duration::from_secs(1), shutdown_handle)
875 .await
876 .unwrap()
877 .unwrap();
878
879 assert_eq!(controller.state(), ShutdownState::Stopped);
880 }
881
882 #[test]
883 fn test_shutdown_state_display() {
884 assert_eq!(ShutdownState::Running.to_string(), "running");
885 assert_eq!(ShutdownState::Draining.to_string(), "draining");
886 assert_eq!(ShutdownState::Stopped.to_string(), "stopped");
887 }
888
889 #[test]
890 fn test_retry_after_secs() {
891 let controller = ShutdownController::with_timeout(Duration::from_secs(30));
892 assert_eq!(controller.retry_after_secs(), 35);
894 }
895
896 #[test]
897 fn test_health_status_serialization() {
898 let status = HealthStatus {
899 status: "ok".to_string(),
900 healthy: true,
901 shutdown_state: "running".to_string(),
902 active_connections: 5,
903 drain_remaining_secs: None,
904 retry_after_secs: None,
905 };
906
907 let json = serde_json::to_string(&status).unwrap();
908 assert!(json.contains("\"status\":\"ok\""));
909 assert!(json.contains("\"healthy\":true"));
910 assert!(!json.contains("drain_remaining_secs"));
912 assert!(!json.contains("retry_after_secs"));
913 }
914
915 #[tokio::test]
916 async fn test_default_trait() {
917 let controller = ShutdownController::default();
918 assert!(!controller.is_shutting_down());
919 assert_eq!(
920 controller.drain_timeout(),
921 Duration::from_secs(DEFAULT_DRAIN_TIMEOUT_SECS)
922 );
923 }
924}