async_shutdown/
shutdown_signal.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use crate::waker_list::WakerToken;
7use crate::{WrapCancel, ShutdownManagerInner};
8
9pub struct ShutdownSignal<T: Clone> {
15 pub(crate) inner: Arc<Mutex<ShutdownManagerInner<T>>>,
16 pub(crate) waker_token: Option<WakerToken>,
17}
18
19impl<T: Clone> Clone for ShutdownSignal<T> {
20 fn clone(&self) -> Self {
21 Self {
24 inner: self.inner.clone(),
25 waker_token: None,
26 }
27 }
28}
29
30impl<T: Clone> Drop for ShutdownSignal<T> {
31 fn drop(&mut self) {
32 if let Some(token) = self.waker_token.take() {
33 let mut inner = self.inner.lock().unwrap();
34 inner.on_shutdown.deregister(token);
35 }
36 }
37}
38
39impl<T: Clone> ShutdownSignal<T> {
40 #[inline]
47 pub fn wrap_cancel<F: Future>(&self, future: F) -> WrapCancel<T, F> {
48 WrapCancel {
49 shutdown_signal: self.clone(),
50 future: Ok(future),
51 }
52 }
53}
54
55impl<T: Clone> Future for ShutdownSignal<T> {
56 type Output = T;
57
58 #[inline]
59 fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
60 let me = self.get_mut();
61 let mut inner = me.inner.lock().unwrap();
62
63 if let Some(token) = me.waker_token.take() {
65 inner.on_shutdown.deregister(token);
66 }
67
68 if let Some(reason) = inner.shutdown_reason.clone() {
69 Poll::Ready(reason)
71 } else {
72 me.waker_token = Some(inner.on_shutdown.register(context.waker().clone()));
74 Poll::Pending
75 }
76 }
77}
78
79#[cfg(test)]
80mod test {
81 use assert2::assert;
82 use std::future::Future;
83 use std::pin::Pin;
84 use std::task::Poll;
85
86 struct PollOnce<'a, F>(&'a mut F);
88
89 impl<'a, F: std::marker::Unpin + Future> Future for PollOnce<'a, F> {
90 type Output = Poll<F::Output>;
91
92 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
93 Poll::Ready(Pin::new(&mut self.get_mut().0).poll(cx))
94 }
95 }
96
97 async fn poll_once<F: Future + Unpin>(future: &mut F) -> Poll<F::Output> {
99 PollOnce(future).await
100 }
101
102 #[tokio::test]
103 async fn waker_list_doesnt_grow_infinitely() {
104 let shutdown = crate::ShutdownManager::<()>::new();
105 for i in 0..100_000 {
106 let task = tokio::spawn(shutdown.wrap_cancel(async move {
107 tokio::task::yield_now().await;
108 }));
109 assert!(let Ok(Ok(())) = task.await, "task = {i}");
110 }
111
112 let inner = shutdown.inner.lock().unwrap();
115 assert!(inner.on_shutdown.total_slots() == 1);
116 assert!(inner.on_shutdown.empty_slots() == 1);
117 }
118
119 #[tokio::test]
120 async fn cloning_does_not_clone_waker_token() {
121 let shutdown = crate::ShutdownManager::<()>::new();
122
123 let mut signal = shutdown.wait_shutdown_triggered();
124 assert!(let None = &signal.waker_token);
125
126 assert!(let Poll::Pending = poll_once(&mut signal).await);
127 assert!(let Some(_) = &signal.waker_token);
128
129 let mut cloned = signal.clone();
130 assert!(let None = &cloned.waker_token);
131 assert!(let Some(_) = &signal.waker_token);
132
133 assert!(let Poll::Pending = poll_once(&mut cloned).await);
134 assert!(let Some(_) = &cloned.waker_token);
135 assert!(let Some(_) = &signal.waker_token);
136
137 {
138 let inner = shutdown.inner.lock().unwrap();
139 assert!(inner.on_shutdown.total_slots() == 2);
140 assert!(inner.on_shutdown.empty_slots() == 0);
141 }
142
143 {
144 drop(signal);
145 let inner = shutdown.inner.lock().unwrap();
146 assert!(inner.on_shutdown.empty_slots() == 1);
147 }
148
149 {
150 drop(cloned);
151 let inner = shutdown.inner.lock().unwrap();
152 assert!(inner.on_shutdown.empty_slots() == 2);
153 }
154 }
155}