1use asupersync::{Budget, CancelReason};
46use std::future::Future;
47use std::pin::Pin;
48use std::sync::Arc;
49use std::sync::atomic::{AtomicBool, AtomicU8, Ordering};
50use std::task::{Context, Poll, Waker};
51use std::time::Duration;
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
59#[repr(u8)]
60pub enum ShutdownPhase {
61 Running = 0,
63 StopAccepting = 1,
65 ShutdownFlagged = 2,
67 GracePeriod = 3,
69 Cancelling = 4,
71 RunningHooks = 5,
73 Stopped = 6,
75}
76
77impl ShutdownPhase {
78 #[must_use]
80 pub fn should_reject_connections(self) -> bool {
81 self as u8 >= Self::StopAccepting as u8
82 }
83
84 #[must_use]
86 pub fn should_reject_requests(self) -> bool {
87 self as u8 >= Self::ShutdownFlagged as u8
88 }
89
90 #[must_use]
92 pub fn is_shutting_down(self) -> bool {
93 self as u8 >= Self::StopAccepting as u8
94 }
95
96 #[must_use]
98 pub fn is_stopped(self) -> bool {
99 self == Self::Stopped
100 }
101}
102
103impl From<u8> for ShutdownPhase {
104 fn from(value: u8) -> Self {
105 match value {
106 0 => Self::Running,
107 1 => Self::StopAccepting,
108 2 => Self::ShutdownFlagged,
109 3 => Self::GracePeriod,
110 4 => Self::Cancelling,
111 5 => Self::RunningHooks,
112 _ => Self::Stopped,
113 }
114 }
115}
116
117struct ShutdownState {
123 phase: AtomicU8,
125 forced: AtomicBool,
127 wakers: parking_lot::Mutex<Vec<Waker>>,
129 hooks: parking_lot::Mutex<Vec<ShutdownHook>>,
131 in_flight: std::sync::atomic::AtomicUsize,
133}
134
135impl ShutdownState {
136 fn new() -> Self {
137 Self {
138 phase: AtomicU8::new(ShutdownPhase::Running as u8),
139 forced: AtomicBool::new(false),
140 wakers: parking_lot::Mutex::new(Vec::new()),
141 hooks: parking_lot::Mutex::new(Vec::new()),
142 in_flight: std::sync::atomic::AtomicUsize::new(0),
143 }
144 }
145
146 fn phase(&self) -> ShutdownPhase {
147 ShutdownPhase::from(self.phase.load(Ordering::Acquire))
148 }
149
150 fn set_phase(&self, phase: ShutdownPhase) {
151 self.phase.store(phase as u8, Ordering::Release);
152 self.wake_all();
153 }
154
155 fn try_advance_phase(&self, from: ShutdownPhase, to: ShutdownPhase) -> bool {
156 self.phase
157 .compare_exchange(from as u8, to as u8, Ordering::AcqRel, Ordering::Acquire)
158 .is_ok()
159 }
160
161 fn is_forced(&self) -> bool {
162 self.forced.load(Ordering::Acquire)
163 }
164
165 fn set_forced(&self) {
166 self.forced.store(true, Ordering::Release);
167 self.wake_all();
168 }
169
170 fn wake_all(&self) {
171 let wakers = std::mem::take(&mut *self.wakers.lock());
172 for waker in wakers {
173 waker.wake();
174 }
175 }
176
177 fn register_waker(&self, waker: &Waker) {
178 let mut wakers = self.wakers.lock();
179 if !wakers.iter().any(|w| w.will_wake(waker)) {
180 wakers.push(waker.clone());
181 }
182 }
183
184 fn increment_in_flight(&self) -> usize {
185 self.in_flight.fetch_add(1, Ordering::AcqRel) + 1
186 }
187
188 fn decrement_in_flight(&self) -> usize {
189 self.in_flight.fetch_sub(1, Ordering::AcqRel) - 1
190 }
191
192 fn in_flight_count(&self) -> usize {
193 self.in_flight.load(Ordering::Acquire)
194 }
195}
196
197#[derive(Clone)]
205pub struct ShutdownController {
206 state: Arc<ShutdownState>,
207}
208
209impl ShutdownController {
210 #[must_use]
212 pub fn new() -> Self {
213 Self {
214 state: Arc::new(ShutdownState::new()),
215 }
216 }
217
218 #[must_use]
220 pub fn subscribe(&self) -> ShutdownReceiver {
221 ShutdownReceiver {
222 state: Arc::clone(&self.state),
223 }
224 }
225
226 #[must_use]
228 pub fn phase(&self) -> ShutdownPhase {
229 self.state.phase()
230 }
231
232 #[must_use]
234 pub fn is_shutting_down(&self) -> bool {
235 self.state.phase().is_shutting_down()
236 }
237
238 #[must_use]
240 pub fn is_forced(&self) -> bool {
241 self.state.is_forced()
242 }
243
244 pub fn shutdown(&self) {
249 let current = self.state.phase();
250 if current == ShutdownPhase::Running {
251 self.state.set_phase(ShutdownPhase::StopAccepting);
252 } else if !self.state.is_forced() {
253 self.state.set_forced();
255 }
256 }
257
258 pub fn force_shutdown(&self) {
260 self.state.set_forced();
261 self.state.set_phase(ShutdownPhase::Cancelling);
262 }
263
264 pub fn advance_phase(&self) -> bool {
268 let current = self.state.phase();
269 let next = match current {
270 ShutdownPhase::Running => ShutdownPhase::StopAccepting,
271 ShutdownPhase::StopAccepting => ShutdownPhase::ShutdownFlagged,
272 ShutdownPhase::ShutdownFlagged => ShutdownPhase::GracePeriod,
273 ShutdownPhase::GracePeriod => ShutdownPhase::Cancelling,
274 ShutdownPhase::Cancelling => ShutdownPhase::RunningHooks,
275 ShutdownPhase::RunningHooks => ShutdownPhase::Stopped,
276 ShutdownPhase::Stopped => return false,
277 };
278 self.state.try_advance_phase(current, next)
279 }
280
281 pub fn register_hook<F>(&self, hook: F)
285 where
286 F: FnOnce() + Send + 'static,
287 {
288 let mut hooks = self.state.hooks.lock();
289 hooks.push(ShutdownHook::Sync(Box::new(hook)));
290 }
291
292 pub fn register_async_hook<F, Fut>(&self, hook: F)
294 where
295 F: FnOnce() -> Fut + Send + 'static,
296 Fut: Future<Output = ()> + Send + 'static,
297 {
298 let mut hooks = self.state.hooks.lock();
299 hooks.push(ShutdownHook::AsyncFactory(Box::new(move || {
300 Box::pin(hook())
301 })));
302 }
303
304 pub fn pop_hook(&self) -> Option<ShutdownHook> {
308 let mut hooks = self.state.hooks.lock();
309 hooks.pop()
310 }
311
312 #[must_use]
314 pub fn hook_count(&self) -> usize {
315 self.state.hooks.lock().len()
316 }
317
318 #[must_use]
322 pub fn track_request(&self) -> InFlightGuard {
323 self.state.increment_in_flight();
324 InFlightGuard {
325 state: Arc::clone(&self.state),
326 }
327 }
328
329 #[must_use]
331 pub fn in_flight_count(&self) -> usize {
332 self.state.in_flight_count()
333 }
334}
335
336impl Default for ShutdownController {
337 fn default() -> Self {
338 Self::new()
339 }
340}
341
342#[derive(Clone)]
350pub struct ShutdownReceiver {
351 state: Arc<ShutdownState>,
352}
353
354impl ShutdownReceiver {
355 pub async fn wait(&self) {
359 ShutdownWaitFuture { state: &self.state }.await
360 }
361
362 #[must_use]
364 pub fn phase(&self) -> ShutdownPhase {
365 self.state.phase()
366 }
367
368 #[must_use]
370 pub fn is_shutting_down(&self) -> bool {
371 self.state.phase().is_shutting_down()
372 }
373
374 #[must_use]
376 pub fn is_forced(&self) -> bool {
377 self.state.is_forced()
378 }
379}
380
381struct ShutdownWaitFuture<'a> {
383 state: &'a ShutdownState,
384}
385
386impl Future for ShutdownWaitFuture<'_> {
387 type Output = ();
388
389 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
390 if self.state.phase().is_shutting_down() {
391 Poll::Ready(())
392 } else {
393 self.state.register_waker(cx.waker());
394 if self.state.phase().is_shutting_down() {
396 Poll::Ready(())
397 } else {
398 Poll::Pending
399 }
400 }
401 }
402}
403
404pub struct InFlightGuard {
412 state: Arc<ShutdownState>,
413}
414
415impl Drop for InFlightGuard {
416 fn drop(&mut self) {
417 self.state.decrement_in_flight();
418 }
419}
420
421pub enum ShutdownHook {
427 Sync(Box<dyn FnOnce() + Send>),
429 AsyncFactory(Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = ()> + Send>> + Send>),
431}
432
433impl ShutdownHook {
434 pub fn run(self) -> Option<Pin<Box<dyn Future<Output = ()> + Send>>> {
438 match self {
439 Self::Sync(f) => {
440 f();
441 None
442 }
443 Self::AsyncFactory(f) => Some(f()),
444 }
445 }
446}
447
448#[derive(Clone)]
454pub struct GracefulConfig {
455 pub grace_period: Duration,
457 pub cleanup_budget: Budget,
459 pub log_events: bool,
461}
462
463impl Default for GracefulConfig {
464 fn default() -> Self {
465 Self {
466 grace_period: Duration::from_secs(30),
467 cleanup_budget: Budget::new()
468 .with_poll_quota(500)
469 .with_deadline(asupersync::Time::from_secs(5)),
470 log_events: true,
471 }
472 }
473}
474
475pub struct GracefulShutdown {
477 receiver: ShutdownReceiver,
478 config: GracefulConfig,
479}
480
481impl GracefulShutdown {
482 #[must_use]
484 pub fn new(receiver: ShutdownReceiver) -> Self {
485 Self {
486 receiver,
487 config: GracefulConfig::default(),
488 }
489 }
490
491 #[must_use]
493 pub fn grace_period(mut self, duration: Duration) -> Self {
494 self.config.grace_period = duration;
495 self
496 }
497
498 #[must_use]
500 pub fn cleanup_budget(mut self, budget: Budget) -> Self {
501 self.config.cleanup_budget = budget;
502 self
503 }
504
505 #[must_use]
507 pub fn log_events(mut self, enabled: bool) -> Self {
508 self.config.log_events = enabled;
509 self
510 }
511
512 pub async fn run<F, T>(self, fut: F) -> ShutdownOutcome<T>
523 where
524 F: Future<Output = T>,
525 {
526 use std::pin::pin;
527
528 let mut fut = pin!(fut);
529
530 futures_executor::block_on(async {
533 loop {
534 if self.receiver.is_shutting_down() {
536 if self.receiver.is_forced() {
537 return ShutdownOutcome::ForcedShutdown;
538 }
539 return ShutdownOutcome::GracefulShutdown;
540 }
541
542 break ShutdownOutcome::Completed(fut.as_mut().await);
546 }
547 })
548 }
549
550 #[must_use]
552 pub fn config(&self) -> &GracefulConfig {
553 &self.config
554 }
555}
556
557#[derive(Debug)]
559pub enum ShutdownOutcome<T> {
560 Completed(T),
562 GracefulShutdown,
564 ForcedShutdown,
566}
567
568impl<T> ShutdownOutcome<T> {
569 #[must_use]
571 pub fn is_completed(&self) -> bool {
572 matches!(self, Self::Completed(_))
573 }
574
575 #[must_use]
577 pub fn is_shutdown(&self) -> bool {
578 matches!(self, Self::GracefulShutdown | Self::ForcedShutdown)
579 }
580
581 #[must_use]
583 pub fn is_forced(&self) -> bool {
584 matches!(self, Self::ForcedShutdown)
585 }
586
587 #[must_use]
589 pub fn into_completed(self) -> Option<T> {
590 match self {
591 Self::Completed(v) => Some(v),
592 _ => None,
593 }
594 }
595}
596
597#[must_use]
618pub fn subdivide_grace_budget(
619 grace_remaining: Duration,
620 in_flight_count: usize,
621 original_budget: Option<Budget>,
622) -> Budget {
623 use asupersync::Time;
624
625 let count = in_flight_count.max(1);
626 let per_request = grace_remaining / count as u32;
627
628 let deadline_nanos = per_request.as_nanos() as u64;
630 let grace_budget = Budget::new().with_deadline(Time::from_nanos(deadline_nanos));
631
632 match original_budget {
633 Some(original) => original.meet(grace_budget),
634 None => grace_budget,
635 }
636}
637
638#[must_use]
644pub fn shutdown_cancel_reason() -> CancelReason {
645 CancelReason::shutdown()
646}
647
648#[must_use]
650pub fn grace_expired_cancel_reason() -> CancelReason {
651 CancelReason::timeout()
652}
653
654pub trait ShutdownAware {
660 fn is_shutting_down(&self) -> bool;
662
663 fn shutdown_phase(&self) -> Option<ShutdownPhase>;
665}
666
667#[cfg(test)]
672mod tests {
673 use super::*;
674
675 #[test]
676 fn shutdown_phase_transitions() {
677 assert!(!ShutdownPhase::Running.should_reject_connections());
678 assert!(ShutdownPhase::StopAccepting.should_reject_connections());
679 assert!(ShutdownPhase::ShutdownFlagged.should_reject_requests());
680 assert!(ShutdownPhase::GracePeriod.is_shutting_down());
681 assert!(ShutdownPhase::Stopped.is_stopped());
682 }
683
684 #[test]
685 fn controller_basic() {
686 let controller = ShutdownController::new();
687 assert_eq!(controller.phase(), ShutdownPhase::Running);
688 assert!(!controller.is_shutting_down());
689
690 controller.shutdown();
691 assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
692 assert!(controller.is_shutting_down());
693 }
694
695 #[test]
696 fn controller_double_shutdown_forces() {
697 let controller = ShutdownController::new();
698 controller.shutdown();
699 assert!(!controller.is_forced());
700
701 controller.shutdown();
702 assert!(controller.is_forced());
703 }
704
705 #[test]
706 fn controller_advance_phase() {
707 let controller = ShutdownController::new();
708
709 assert!(controller.advance_phase());
710 assert_eq!(controller.phase(), ShutdownPhase::StopAccepting);
711
712 assert!(controller.advance_phase());
713 assert_eq!(controller.phase(), ShutdownPhase::ShutdownFlagged);
714
715 assert!(controller.advance_phase());
716 assert_eq!(controller.phase(), ShutdownPhase::GracePeriod);
717
718 assert!(controller.advance_phase());
719 assert_eq!(controller.phase(), ShutdownPhase::Cancelling);
720
721 assert!(controller.advance_phase());
722 assert_eq!(controller.phase(), ShutdownPhase::RunningHooks);
723
724 assert!(controller.advance_phase());
725 assert_eq!(controller.phase(), ShutdownPhase::Stopped);
726
727 assert!(!controller.advance_phase());
729 }
730
731 #[test]
732 fn in_flight_tracking() {
733 let controller = ShutdownController::new();
734 assert_eq!(controller.in_flight_count(), 0);
735
736 let guard1 = controller.track_request();
737 assert_eq!(controller.in_flight_count(), 1);
738
739 let guard2 = controller.track_request();
740 assert_eq!(controller.in_flight_count(), 2);
741
742 drop(guard1);
743 assert_eq!(controller.in_flight_count(), 1);
744
745 drop(guard2);
746 assert_eq!(controller.in_flight_count(), 0);
747 }
748
749 #[test]
750 fn shutdown_hooks_lifo() {
751 let controller = ShutdownController::new();
752 let order: Arc<parking_lot::Mutex<Vec<i32>>> =
753 Arc::new(parking_lot::Mutex::new(Vec::new()));
754
755 let order1 = Arc::clone(&order);
756 controller.register_hook(move || order1.lock().push(1));
757
758 let order2 = Arc::clone(&order);
759 controller.register_hook(move || order2.lock().push(2));
760
761 let order3 = Arc::clone(&order);
762 controller.register_hook(move || order3.lock().push(3));
763
764 assert_eq!(controller.hook_count(), 3);
765
766 while let Some(hook) = controller.pop_hook() {
768 hook.run();
769 }
770
771 assert_eq!(*order.lock(), vec![3, 2, 1]);
772 }
773
774 #[test]
775 fn subdivide_grace_budget_basic() {
776 let grace = Duration::from_secs(30);
777 let budget = subdivide_grace_budget(grace, 3, None);
778
779 assert!(budget.deadline.is_some());
781 }
782
783 #[test]
784 fn subdivide_grace_budget_respects_original() {
785 use asupersync::Time;
786
787 let grace = Duration::from_secs(30);
788 let original = Budget::new().with_deadline(Time::from_secs(5));
789 let budget = subdivide_grace_budget(grace, 3, Some(original));
790
791 assert!(budget.deadline.is_some());
793 }
794
795 #[test]
796 fn receiver_is_shutting_down() {
797 let controller = ShutdownController::new();
798 let receiver = controller.subscribe();
799
800 assert!(!receiver.is_shutting_down());
801
802 controller.shutdown();
803 assert!(receiver.is_shutting_down());
804 }
805
806 #[test]
807 fn shutdown_outcome_accessors() {
808 let completed: ShutdownOutcome<i32> = ShutdownOutcome::Completed(42);
809 assert!(completed.is_completed());
810 assert!(!completed.is_shutdown());
811 assert_eq!(completed.into_completed(), Some(42));
812
813 let graceful: ShutdownOutcome<i32> = ShutdownOutcome::GracefulShutdown;
814 assert!(!graceful.is_completed());
815 assert!(graceful.is_shutdown());
816 assert!(!graceful.is_forced());
817
818 let forced: ShutdownOutcome<i32> = ShutdownOutcome::ForcedShutdown;
819 assert!(forced.is_shutdown());
820 assert!(forced.is_forced());
821 }
822
823 #[test]
824 fn cancel_reasons() {
825 let shutdown = shutdown_cancel_reason();
826 assert!(shutdown.is_shutdown());
827
828 let grace = grace_expired_cancel_reason();
829 assert!(!grace.is_shutdown());
830 }
831}