async_shutdown/
shutdown_signal.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use crate::waker_list::WakerToken;
7use crate::{WrapCancel, ShutdownManagerInner};
8
9/// A future to wait for a shutdown signal.
10///
11/// The future completes when the associated [`ShutdownManager`][crate::ShutdownManager] triggers a shutdown.
12///
13/// The shutdown signal can be cloned and sent between threads freely.
14pub struct ShutdownSignal<T: Clone> {
15	pub(crate) inner: Arc<Mutex<ShutdownManagerInner<T>>>,
16	pub(crate) waker_token: Option<WakerToken>,
17}
18
19impl<T: Clone> Clone for ShutdownSignal<T> {
20	fn clone(&self) -> Self {
21		// Clone only the reference to the shutdown manager, not the waker token.
22		// The waker token is personal to each future.
23		Self {
24			inner: self.inner.clone(),
25			waker_token: None,
26		}
27	}
28}
29
30impl<T: Clone> Drop for ShutdownSignal<T> {
31	fn drop(&mut self) {
32		if let Some(token) = self.waker_token.take() {
33			let mut inner = self.inner.lock().unwrap();
34			inner.on_shutdown.deregister(token);
35		}
36	}
37}
38
39impl<T: Clone> ShutdownSignal<T> {
40	/// Wrap a future so that it is cancelled when a shutdown is triggered.
41	///
42	/// The returned future completes with `Err(reason)` containing the shutdown reason if a shutdown is triggered,
43	/// and with `Ok(x)` when the wrapped future completes.
44	///
45	/// The wrapped future is dropped if the shutdown starts before the wrapped future completes.
46	#[inline]
47	pub fn wrap_cancel<F: Future>(&self, future: F) -> WrapCancel<T, F> {
48		WrapCancel {
49			shutdown_signal: self.clone(),
50			future: Ok(future),
51		}
52	}
53}
54
55impl<T: Clone> Future for ShutdownSignal<T> {
56	type Output = T;
57
58	#[inline]
59	fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
60		let me = self.get_mut();
61		let mut inner = me.inner.lock().unwrap();
62
63		// We're being polled, so we should deregister the waker (if any).
64		if let Some(token) = me.waker_token.take() {
65			inner.on_shutdown.deregister(token);
66		}
67
68		if let Some(reason) = inner.shutdown_reason.clone() {
69			// Shutdown started, so we're ready.
70			Poll::Ready(reason)
71		} else {
72			// We're not ready, so register the waker to wake us on shutdown start.
73			me.waker_token = Some(inner.on_shutdown.register(context.waker().clone()));
74			Poll::Pending
75		}
76	}
77}
78
79#[cfg(test)]
80mod test {
81	use assert2::assert;
82	use std::future::Future;
83	use std::pin::Pin;
84	use std::task::Poll;
85
86	/// Wrapper around a future to poll it only once.
87	struct PollOnce<'a, F>(&'a mut F);
88
89	impl<'a, F: std::marker::Unpin + Future> Future for PollOnce<'a, F> {
90		type Output = Poll<F::Output>;
91
92		fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
93			Poll::Ready(Pin::new(&mut self.get_mut().0).poll(cx))
94		}
95	}
96
97	/// Poll a future once.
98	async fn poll_once<F: Future + Unpin>(future: &mut F) -> Poll<F::Output> {
99		PollOnce(future).await
100	}
101
102	#[tokio::test]
103	async fn waker_list_doesnt_grow_infinitely() {
104		let shutdown = crate::ShutdownManager::<()>::new();
105		for i in 0..100_000 {
106			let task = tokio::spawn(shutdown.wrap_cancel(async move {
107				tokio::task::yield_now().await;
108			}));
109			assert!(let Ok(Ok(())) = task.await, "task = {i}");
110		}
111
112		// Since we wait for each task to complete before spawning another,
113		// the total amount of waker slots used should be only 1.
114		let inner = shutdown.inner.lock().unwrap();
115		assert!(inner.on_shutdown.total_slots() == 1);
116		assert!(inner.on_shutdown.empty_slots() == 1);
117	}
118
119	#[tokio::test]
120	async fn cloning_does_not_clone_waker_token() {
121		let shutdown = crate::ShutdownManager::<()>::new();
122
123		let mut signal = shutdown.wait_shutdown_triggered();
124		assert!(let None = &signal.waker_token);
125
126		assert!(let Poll::Pending = poll_once(&mut signal).await);
127		assert!(let Some(_) = &signal.waker_token);
128
129		let mut cloned = signal.clone();
130		assert!(let None = &cloned.waker_token);
131		assert!(let Some(_) = &signal.waker_token);
132
133		assert!(let Poll::Pending = poll_once(&mut cloned).await);
134		assert!(let Some(_) = &cloned.waker_token);
135		assert!(let Some(_) = &signal.waker_token);
136
137		{
138			let inner = shutdown.inner.lock().unwrap();
139			assert!(inner.on_shutdown.total_slots() == 2);
140			assert!(inner.on_shutdown.empty_slots() == 0);
141		}
142
143		{
144			drop(signal);
145			let inner = shutdown.inner.lock().unwrap();
146			assert!(inner.on_shutdown.empty_slots() == 1);
147		}
148
149		{
150			drop(cloned);
151			let inner = shutdown.inner.lock().unwrap();
152			assert!(inner.on_shutdown.empty_slots() == 2);
153		}
154	}
155}