1use asupersync::{Budget, CancelReason};
46use std::future::Future;
47use std::pin::Pin;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
50use std::task::{Context, Poll, Waker};
51use std::time::Duration;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59#[repr(u8)]
60pub enum ShutdownPhase {
61 Running = 0,
63 StopAccepting = 1,
65 ShutdownFlagged = 2,
67 GracePeriod = 3,
69 Cancelling = 4,
71 RunningHooks = 5,
73 Stopped = 6,
75}
76
77impl ShutdownPhase {
78 #[must_use]
80 pub fn should_reject_connections(self) -> bool {
81 self as u8 >= Self::StopAccepting as u8
82 }
83
84 #[must_use]
86 pub fn should_reject_requests(self) -> bool {
87 self as u8 >= Self::ShutdownFlagged as u8
88 }
89
90 #[must_use]
92 pub fn is_shutting_down(self) -> bool {
93 self as u8 >= Self::StopAccepting as u8
94 }
95
96 #[must_use]
98 pub fn is_stopped(self) -> bool {
99 self == Self::Stopped
100 }
101}
102
103impl From<u8> for ShutdownPhase {
104 fn from(value: u8) -> Self {
105 match value {
106 0 => Self::Running,
107 1 => Self::StopAccepting,
108 2 => Self::ShutdownFlagged,
109 3 => Self::GracePeriod,
110 4 => Self::Cancelling,
111 5 => Self::RunningHooks,
112 _ => Self::Stopped,
113 }
114 }
115}
116
117struct ShutdownState {
123 phase: AtomicU8,
125 forced: AtomicBool,
127 wakers: parking_lot::Mutex<Vec<Waker>>,
129 hooks: parking_lot::Mutex<Vec<ShutdownHook>>,
131 in_flight: std::sync::atomic::AtomicUsize,
133}
134
135impl ShutdownState {
136 fn new() -> Self {
137 Self {
138 phase: AtomicU8::new(ShutdownPhase::Running as u8),
139 forced: AtomicBool::new(false),
140 wakers: parking_lot::Mutex::new(Vec::new()),
141 hooks: parking_lot::Mutex::new(Vec::new()),
142 in_flight: std::sync::atomic::AtomicUsize::new(0),
143 }
144 }
145
146 fn phase(&self) -> ShutdownPhase {
147 ShutdownPhase::from(self.phase.load(Ordering::Acquire))
148 }
149
150 fn set_phase(&self, phase: ShutdownPhase) {
151 self.phase.store(phase as u8, Ordering::Release);
152 self.wake_all();
153 }
154
155 fn try_advance_phase(&self, from: ShutdownPhase, to: ShutdownPhase) -> bool {
156 self.phase
157 .compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
158 .is_ok()
159 }
160
161 fn is_forced(&self) -> bool {
162 self.forced.load(Ordering::Acquire)
163 }
164
165 fn set_forced(&self) {
166 self.forced.store(true, Ordering::Release);
167 self.wake_all();
168 }
169
170 fn wake_all(&self) {
171 let wakers = std::mem::take(&mut *self.wakers.lock());
172 for waker in wakers {
173 waker.wake();
174 }
175 }
176
177 fn register_waker(&self, waker: &Waker) {
178 let mut wakers = self.wakers.lock();
179 if !wakers.iter().any(|w| w.will_wake(waker)) {
180 wakers.push(waker.clone());
181 }
182 }
183
184 fn increment_in_flight(&self) -> usize {
185 self.in_flight.fetch_add(1, Ordering::AcqRel) + 1
186 }
187
188 fn decrement_in_flight(&self) -> usize {
189 self.in_flight.fetch_sub(1, Ordering::AcqRel) - 1
190 }
191
192 fn in_flight_count(&self) -> usize {
193 self.in_flight.load(Ordering::Acquire)
194 }
195}
196
197#[derive(Clone)]
205pub struct ShutdownController {
206 state: Arc<ShutdownState>,
207}
208
209impl ShutdownController {
210 #[must_use]
212 pub fn new() -> Self {
213 Self {
214 state: Arc::new(ShutdownState::new()),
215 }
216 }
217
218 #[must_use]
220 pub fn subscribe(&self) -> ShutdownReceiver {
221 ShutdownReceiver {
222 state: Arc::clone(&self.state),
223 }
224 }
225
226 #[must_use]
228 pub fn phase(&self) -> ShutdownPhase {
229 self.state.phase()
230 }
231
232 #[must_use]
234 pub fn is_shutting_down(&self) -> bool {
235 self.state.phase().is_shutting_down()
236 }
237
238 #[must_use]
240 pub fn is_forced(&self) -> bool {
241 self.state.is_forced()
242 }
243
244 pub fn shutdown(&self) {
249 let current = self.state.phase();
250 if current == ShutdownPhase::Running {
251 self.state.set_phase(ShutdownPhase::StopAccepting);
252 } else if !self.state.is_forced() {
253 self.state.set_forced();
255 }
256 }
257
258 pub fn force_shutdown(&self) {
260 self.state.set_forced();
261 self.state.set_phase(ShutdownPhase::Cancelling);
262 }
263
264 pub fn advance_phase(&self) -> bool {
268 let current = self.state.phase();
269 let next = match current {
270 ShutdownPhase::Running => ShutdownPhase::StopAccepting,
271 ShutdownPhase::StopAccepting => ShutdownPhase::ShutdownFlagged,
272 ShutdownPhase::ShutdownFlagged => ShutdownPhase::GracePeriod,
273 ShutdownPhase::GracePeriod => ShutdownPhase::Cancelling,
274 ShutdownPhase::Cancelling => ShutdownPhase::RunningHooks,
275 ShutdownPhase::RunningHooks => ShutdownPhase::Stopped,
276 ShutdownPhase::Stopped => return false,
277 };
278 self.state.try_advance_phase(current, next)
279 }
280
281 pub fn register_hook<F>(&self, hook: F)
285 where
286 F: FnOnce() + Send + 'static,
287 {
288 let mut hooks = self.state.hooks.lock();
289 hooks.push(ShutdownHook::Sync(Box::new(hook)));
290 }
291
292 pub fn register_async_hook<F, Fut>(&self, hook: F)
294 where
295 F: FnOnce() -> Fut + Send + 'static,
296 Fut: Future<Output = ()> + Send + 'static,
297 {
298 let mut hooks = self.state.hooks.lock();
299 hooks.push(ShutdownHook::AsyncFactory(Box::new(move || {
300 Box::pin(hook())
301 })));
302 }
303
304 pub fn pop_hook(&self) -> Option<ShutdownHook> {
308 let mut hooks = self.state.hooks.lock();
309 hooks.pop()
310 }
311
312 #[must_use]
314 pub fn hook_count(&self) -> usize {
315 self.state.hooks.lock().len()
316 }
317
318 #[must_use]
322 pub fn track_request(&self) -> InFlightGuard {
323 self.state.increment_in_flight();
324 InFlightGuard {
325 state: Arc::clone(&self.state),
326 }
327 }
328
329 #[must_use]
331 pub fn in_flight_count(&self) -> usize {
332 self.state.in_flight_count()
333 }
334}
335
336impl Default for ShutdownController {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[derive(Clone)]
350pub struct ShutdownReceiver {
351 state: Arc<ShutdownState>,
352}
353
354impl ShutdownReceiver {
355 pub async fn wait(&self) {
359 ShutdownWaitFuture { state: &self.state }.await
360 }
361
362 #[must_use]
364 pub fn phase(&self) -> ShutdownPhase {
365 self.state.phase()
366 }
367
368 #[must_use]
370 pub fn is_shutting_down(&self) -> bool {
371 self.state.phase().is_shutting_down()
372 }
373
374 #[must_use]
376 pub fn is_forced(&self) -> bool {
377 self.state.is_forced()
378 }
379}
380
381struct ShutdownWaitFuture<'a> {
383 state: &'a ShutdownState,
384}
385
386impl Future for ShutdownWaitFuture<'_> {
387 type Output = ();
388
389 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390 if self.state.phase().is_shutting_down() {
391 Poll::Ready(())
392 } else {
393 self.state.register_waker(cx.waker());
394 if self.state.phase().is_shutting_down() {
396 Poll::Ready(())
397 } else {
398 Poll::Pending
399 }
400 }
401 }
402}
403
404pub struct InFlightGuard {
412 state: Arc<ShutdownState>,
413}
414
415impl Drop for InFlightGuard {
416 fn drop(&mut self) {
417 self.state.decrement_in_flight();
418 }
419}
420
421pub enum ShutdownHook {
427 Sync(Box<dyn FnOnce() + Send>),
429 AsyncFactory(Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>),
431}
432
433impl ShutdownHook {
434 pub fn run(self) -> Option<Pin<Box<dyn Future<Output = ()> + Send>>> {
438 match self {
439 Self::Sync(f) => {
440 f();
441 None
442 }
443 Self::AsyncFactory(f) => Some(f()),
444 }
445 }
446}
447
448#[derive(Clone)]
454pub struct GracefulConfig {
455 pub grace_period: Duration,
457 pub cleanup_budget: Budget,
459 pub log_events: bool,
461}
462
463impl Default for GracefulConfig {
464 fn default() -> Self {
465 Self {
466 grace_period: Duration::from_secs(30),
467 cleanup_budget: Budget::new()
468 .with_poll_quota(500)
469 .with_deadline(asupersync::Time::from_secs(5)),
470 log_events: true,
471 }
472 }
473}
474
475pub struct GracefulShutdown {
477 receiver: ShutdownReceiver,
478 config: GracefulConfig,
479}
480
481impl GracefulShutdown {
482 #[must_use]
484 pub fn new(receiver: ShutdownReceiver) -> Self {
485 Self {
486 receiver,
487 config: GracefulConfig::default(),
488 }
489 }
490
491 #[must_use]
493 pub fn grace_period(mut self, duration: Duration) -> Self {
494 self.config.grace_period = duration;
495 self
496 }
497
498 #[must_use]
500 pub fn cleanup_budget(mut self, budget: Budget) -> Self {
501 self.config.cleanup_budget = budget;
502 self
503 }
504
505 #[must_use]
507 pub fn log_events(mut self, enabled: bool) -> Self {
508 self.config.log_events = enabled;
509 self
510 }
511
512 pub async fn run<F, T>(self, fut: F) -> ShutdownOutcome<T>
519 where
520 F: Future<Output = T>,
521 {
522 use std::pin::pin;
523 use std::task::Poll;
524
525 let mut fut = pin!(fut);
526
527 std::future::poll_fn(|cx| {
528 if let Poll::Ready(v) = fut.as_mut().poll(cx) {
530 return Poll::Ready(ShutdownOutcome::Completed(v));
531 }
532
533 if self.receiver.state.is_forced() {
537 return Poll::Ready(ShutdownOutcome::ForcedShutdown);
538 }
539 if self.receiver.state.phase().is_shutting_down() {
540 return Poll::Ready(ShutdownOutcome::GracefulShutdown);
541 }
542
543 self.receiver.state.register_waker(cx.waker());
545 if self.receiver.state.is_forced() {
547 Poll::Ready(ShutdownOutcome::ForcedShutdown)
548 } else if self.receiver.state.phase().is_shutting_down() {
549 Poll::Ready(ShutdownOutcome::GracefulShutdown)
550 } else {
551 Poll::Pending
552 }
553 })
554 .await
555 }
556
557 #[must_use]
559 pub fn config(&self) -> &GracefulConfig {
560 &self.config
561 }
562}
563
564#[derive(Debug)]
566pub enum ShutdownOutcome<T> {
567 Completed(T),
569 GracefulShutdown,
571 ForcedShutdown,
573}
574
575impl<T> ShutdownOutcome<T> {
576 #[must_use]
578 pub fn is_completed(&self) -> bool {
579 matches!(self, Self::Completed(_))
580 }
581
582 #[must_use]
584 pub fn is_shutdown(&self) -> bool {
585 matches!(self, Self::GracefulShutdown | Self::ForcedShutdown)
586 }
587
588 #[must_use]
590 pub fn is_forced(&self) -> bool {
591 matches!(self, Self::ForcedShutdown)
592 }
593
594 #[must_use]
596 pub fn into_completed(self) -> Option<T> {
597 match self {
598 Self::Completed(v) => Some(v),
599 _ => None,
600 }
601 }
602}
603
604#[must_use]
625pub fn subdivide_grace_budget(
626 grace_remaining: Duration,
627 in_flight_count: usize,
628 original_budget: Option<Budget>,
629) -> Budget {
630 use asupersync::Time;
631
632 let count = in_flight_count.max(1);
633 let per_request = grace_remaining / count as u32;
634
635 let deadline_nanos = per_request.as_nanos() as u64;
637 let grace_budget = Budget::new().with_deadline(Time::from_nanos(deadline_nanos));
638
639 match original_budget {
640 Some(original) => original.meet(grace_budget),
641 None => grace_budget,
642 }
643}
644
645#[must_use]
651pub fn shutdown_cancel_reason() -> CancelReason {
652 CancelReason::shutdown()
653}
654
655#[must_use]
657pub fn grace_expired_cancel_reason() -> CancelReason {
658 CancelReason::timeout()
659}
660
661pub trait ShutdownAware {
667 fn is_shutting_down(&self) -> bool;
669
670 fn shutdown_phase(&self) -> Option<ShutdownPhase>;
672}
673
674#[cfg(test)]
679mod tests {
680 use super::*;
681
682 #[test]
683 fn shutdown_phase_transitions() {
684 assert!(!ShutdownPhase::Running.should_reject_connections());
685 assert!(ShutdownPhase::StopAccepting.should_reject_connections());
686 assert!(ShutdownPhase::ShutdownFlagged.should_reject_requests());
687 assert!(ShutdownPhase::GracePeriod.is_shutting_down());
688 assert!(ShutdownPhase::Stopped.is_stopped());
689 }
690
691 #[test]
692 fn controller_basic() {
693 let controller = ShutdownController::new();
694 assert_eq!(controller.phase(), ShutdownPhase::Running);
695 assert!(!controller.is_shutting_down());
696
697 controller.shutdown();
698 assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
699 assert!(controller.is_shutting_down());
700 }
701
702 #[test]
703 fn controller_double_shutdown_forces() {
704 let controller = ShutdownController::new();
705 controller.shutdown();
706 assert!(!controller.is_forced());
707
708 controller.shutdown();
709 assert!(controller.is_forced());
710 }
711
712 #[test]
713 fn controller_advance_phase() {
714 let controller = ShutdownController::new();
715
716 assert!(controller.advance_phase());
717 assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
718
719 assert!(controller.advance_phase());
720 assert_eq!(controller.phase(), ShutdownPhase::ShutdownFlagged);
721
722 assert!(controller.advance_phase());
723 assert_eq!(controller.phase(), ShutdownPhase::GracePeriod);
724
725 assert!(controller.advance_phase());
726 assert_eq!(controller.phase(), ShutdownPhase::Cancelling);
727
728 assert!(controller.advance_phase());
729 assert_eq!(controller.phase(), ShutdownPhase::RunningHooks);
730
731 assert!(controller.advance_phase());
732 assert_eq!(controller.phase(), ShutdownPhase::Stopped);
733
734 assert!(!controller.advance_phase());
736 }
737
738 #[test]
739 fn in_flight_tracking() {
740 let controller = ShutdownController::new();
741 assert_eq!(controller.in_flight_count(), 0);
742
743 let guard1 = controller.track_request();
744 assert_eq!(controller.in_flight_count(), 1);
745
746 let guard2 = controller.track_request();
747 assert_eq!(controller.in_flight_count(), 2);
748
749 drop(guard1);
750 assert_eq!(controller.in_flight_count(), 1);
751
752 drop(guard2);
753 assert_eq!(controller.in_flight_count(), 0);
754 }
755
756 #[test]
757 fn shutdown_hooks_lifo() {
758 let controller = ShutdownController::new();
759 let order: Arc<parking_lot::Mutex<Vec<i32>>> =
760 Arc::new(parking_lot::Mutex::new(Vec::new()));
761
762 let order1 = Arc::clone(&order);
763 controller.register_hook(move || order1.lock().push(1));
764
765 let order2 = Arc::clone(&order);
766 controller.register_hook(move || order2.lock().push(2));
767
768 let order3 = Arc::clone(&order);
769 controller.register_hook(move || order3.lock().push(3));
770
771 assert_eq!(controller.hook_count(), 3);
772
773 while let Some(hook) = controller.pop_hook() {
775 hook.run();
776 }
777
778 assert_eq!(*order.lock(), vec![3, 2, 1]);
779 }
780
781 #[test]
782 fn subdivide_grace_budget_basic() {
783 let grace = Duration::from_secs(30);
784 let budget = subdivide_grace_budget(grace, 3, None);
785
786 assert!(budget.deadline.is_some());
788 }
789
790 #[test]
791 fn subdivide_grace_budget_respects_original() {
792 use asupersync::Time;
793
794 let grace = Duration::from_secs(30);
795 let original = Budget::new().with_deadline(Time::from_secs(5));
796 let budget = subdivide_grace_budget(grace, 3, Some(original));
797
798 assert!(budget.deadline.is_some());
800 }
801
802 #[test]
803 fn receiver_is_shutting_down() {
804 let controller = ShutdownController::new();
805 let receiver = controller.subscribe();
806
807 assert!(!receiver.is_shutting_down());
808
809 controller.shutdown();
810 assert!(receiver.is_shutting_down());
811 }
812
813 #[test]
814 fn graceful_shutdown_run_completed() {
815 let controller = ShutdownController::new();
816 let shutdown = GracefulShutdown::new(controller.subscribe());
817
818 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
819 .build()
820 .expect("runtime must build");
821 let out = rt.block_on(async { shutdown.run(async { 42i32 }).await });
822
823 assert!(matches!(out, ShutdownOutcome::Completed(42)));
824 }
825
826 #[test]
827 fn graceful_shutdown_run_graceful_shutdown() {
828 let controller = ShutdownController::new();
829 controller.shutdown();
830 let shutdown = GracefulShutdown::new(controller.subscribe());
831
832 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
833 .build()
834 .expect("runtime must build");
835 let out = rt.block_on(async { shutdown.run(std::future::pending::<i32>()).await });
836
837 assert!(matches!(out, ShutdownOutcome::GracefulShutdown));
838 }
839
840 #[test]
841 fn graceful_shutdown_run_forced_shutdown() {
842 let controller = ShutdownController::new();
843 controller.force_shutdown();
844 let shutdown = GracefulShutdown::new(controller.subscribe());
845
846 let rt = asupersync::runtime::RuntimeBuilder::current_thread()
847 .build()
848 .expect("runtime must build");
849 let out = rt.block_on(async { shutdown.run(std::future::pending::<i32>()).await });
850
851 assert!(matches!(out, ShutdownOutcome::ForcedShutdown));
852 }
853
854 #[test]
855 fn shutdown_outcome_accessors() {
856 let completed: ShutdownOutcome<i32> = ShutdownOutcome::Completed(42);
857 assert!(completed.is_completed());
858 assert!(!completed.is_shutdown());
859 assert_eq!(completed.into_completed(), Some(42));
860
861 let graceful: ShutdownOutcome<i32> = ShutdownOutcome::GracefulShutdown;
862 assert!(!graceful.is_completed());
863 assert!(graceful.is_shutdown());
864 assert!(!graceful.is_forced());
865
866 let forced: ShutdownOutcome<i32> = ShutdownOutcome::ForcedShutdown;
867 assert!(forced.is_shutdown());
868 assert!(forced.is_forced());
869 }
870
871 #[test]
872 fn cancel_reasons() {
873 let shutdown = shutdown_cancel_reason();
874 assert!(shutdown.is_shutdown());
875
876 let grace = grace_expired_cancel_reason();
877 assert!(!grace.is_shutdown());
878 }
879}