async_shutdown/
shutdown_complete.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, Mutex};
4use std::task::{Context, Poll};
5
6use crate::waker_list::WakerToken;
7use crate::ShutdownManagerInner;
8
9/// Future to wait for a shutdown to complete.
10pub struct ShutdownComplete<T: Clone> {
11	pub(crate) inner: Arc<Mutex<ShutdownManagerInner<T>>>,
12	pub(crate) waker_token: Option<WakerToken>,
13}
14
15impl<T: Clone> Clone for ShutdownComplete<T> {
16	fn clone(&self) -> Self {
17		// Clone only the reference to the shutdown manager, not the waker token.
18		// The waker token is personal to each future.
19		Self {
20			inner: self.inner.clone(),
21			waker_token: None,
22		}
23	}
24}
25
26impl<T: Clone> Drop for ShutdownComplete<T> {
27	fn drop(&mut self) {
28		if let Some(token) = self.waker_token.take() {
29			let mut inner = self.inner.lock().unwrap();
30			inner.on_shutdown_complete.deregister(token);
31		}
32	}
33}
34
35impl<T: Clone> Future for ShutdownComplete<T> {
36	type Output = T;
37
38	#[inline]
39	fn poll(self: Pin<&mut Self>, context: &mut Context) -> Poll<Self::Output> {
40		let me = self.get_mut();
41		let mut inner = me.inner.lock().unwrap();
42
43		// We're being polled, so we should deregister the waker (if any).
44		if let Some(token) = me.waker_token.take() {
45			inner.on_shutdown_complete.deregister(token);
46		}
47
48		// Check if the shutdown is completed.
49		if inner.delay_tokens == 0 {
50			if let Some(reason) = inner.shutdown_reason.clone() {
51				return Poll::Ready(reason);
52			}
53		}
54
55		// We're not ready, so register the waker to wake us on shutdown completion.
56		me.waker_token = Some(inner.on_shutdown_complete.register(context.waker().clone()));
57
58		Poll::Pending
59	}
60}
61
62#[cfg(test)]
63mod test {
64	use assert2::assert;
65	use std::future::Future;
66	use std::pin::Pin;
67	use std::task::Poll;
68
69	/// Wrapper around a future to poll it only once.
70	struct PollOnce<'a, F>(&'a mut F);
71
72	impl<'a, F: std::marker::Unpin + Future> Future for PollOnce<'a, F> {
73		type Output = Poll<F::Output>;
74
75		fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
76			Poll::Ready(Pin::new(&mut self.get_mut().0).poll(cx))
77		}
78	}
79
80	/// Poll a future once.
81	async fn poll_once<F: Future + Unpin>(future: &mut F) -> Poll<F::Output> {
82		PollOnce(future).await
83	}
84
85	#[tokio::test]
86	async fn waker_list_doesnt_grow_infinitely() {
87		let shutdown = crate::ShutdownManager::<()>::new();
88		for i in 0..100_000 {
89			let mut wait_shutdown_complete = shutdown.wait_shutdown_complete();
90			let task = tokio::spawn(async move {
91				assert!(let Poll::Pending = poll_once(&mut wait_shutdown_complete).await);
92			});
93			assert!(let Ok(()) = task.await, "task = {i}");
94		}
95
96		// Since we wait for each task to complete before spawning another,
97		// the total amount of waker slots used should be only 1.
98		let inner = shutdown.inner.lock().unwrap();
99		assert!(inner.on_shutdown_complete.total_slots() == 1);
100		assert!(inner.on_shutdown_complete.empty_slots() == 1);
101	}
102
103	#[tokio::test]
104	async fn cloning_does_not_clone_waker_token() {
105		let shutdown = crate::ShutdownManager::<()>::new();
106
107		let mut signal = shutdown.wait_shutdown_complete();
108		assert!(let None = &signal.waker_token);
109
110		assert!(let Poll::Pending = poll_once(&mut signal).await);
111		assert!(let Some(_) = &signal.waker_token);
112
113		let mut cloned = signal.clone();
114		assert!(let None = &cloned.waker_token);
115		assert!(let Some(_) = &signal.waker_token);
116
117		assert!(let Poll::Pending = poll_once(&mut cloned).await);
118		assert!(let Some(_) = &cloned.waker_token);
119		assert!(let Some(_) = &signal.waker_token);
120
121		{
122			let inner = shutdown.inner.lock().unwrap();
123			assert!(inner.on_shutdown_complete.total_slots() == 2);
124			assert!(inner.on_shutdown_complete.empty_slots() == 0);
125		}
126
127		{
128			drop(signal);
129			let inner = shutdown.inner.lock().unwrap();
130			assert!(inner.on_shutdown_complete.empty_slots() == 1);
131		}
132
133		{
134			drop(cloned);
135			let inner = shutdown.inner.lock().unwrap();
136			assert!(inner.on_shutdown_complete.empty_slots() == 2);
137		}
138	}
139}