Skip to main content

conducer/
waiter.rs

1use std::{
2	collections::VecDeque,
3	fmt,
4	future::Future,
5	marker::PhantomData,
6	pin::Pin,
7	sync::{Arc, Weak},
8	task::{Context, Poll, Waker},
9};
10
11/// Handle passed to poll functions for registering with WaiterLists
12pub struct Waiter {
13	waker: Arc<Waker>,
14}
15
16impl Waiter {
17	/// Create a new waiter from an async [`Waker`].
18	pub fn new(waker: Waker) -> Self {
19		Self { waker: Arc::new(waker) }
20	}
21
22	/// Create a no-op waiter that discards registrations.
23	///
24	/// Registrations are stored as `Weak<Waker>` refs, so a noop waiter's
25	/// weak ref will just be cleaned up on the next register call.
26	pub fn noop() -> Self {
27		Self {
28			waker: Arc::new(std::task::Waker::noop().clone()),
29		}
30	}
31
32	/// Register this waiter with a [`WaiterList`] for future notification.
33	pub fn register(&self, list: &mut WaiterList) {
34		list.register(self);
35	}
36}
37
38/// A list of weak wakers waiting for notification
39///
40/// Uses a ring buffer that self-cleans dead entries on register.
41pub struct WaiterList {
42	entries: VecDeque<Weak<Waker>>,
43}
44
45impl WaiterList {
46	pub fn new() -> Self {
47		Self {
48			entries: VecDeque::new(),
49		}
50	}
51
52	/// Register a waiter. Cleans up dead entries from the front first.
53	pub fn register(&mut self, waiter: &Waiter) {
54		// Clean up dead entries at front that fail to upgrade
55		while let Some(front) = self.entries.pop_front() {
56			if front.strong_count() == 0 {
57				// Dead entry, skip
58				continue;
59			}
60
61			// Add it to the back so we'll start at a different entry next time.
62			self.entries.push_back(front);
63			break;
64		}
65
66		self.entries.push_back(Arc::downgrade(&waiter.waker));
67	}
68
69	/// Drain all entries into a new [`WaiterList`], leaving this one empty.
70	pub fn take(&mut self) -> Self {
71		Self {
72			entries: std::mem::take(&mut self.entries),
73		}
74	}
75
76	/// Wake all live waiters, consuming the list.
77	pub fn wake(mut self) {
78		for waker in self.entries.drain(..).filter_map(|w| w.upgrade()) {
79			waker.wake_by_ref();
80		}
81	}
82}
83
84impl Default for WaiterList {
85	fn default() -> Self {
86		Self::new()
87	}
88}
89
90impl fmt::Debug for WaiterList {
91	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
92		f.debug_struct("WaiterList").field("len", &self.entries.len()).finish()
93	}
94}
95
96/// Future that drives a poll function, managing waiter lifetime across polls.
97struct WaiterFn<F, R> {
98	poll: F,
99	waiter: Option<Waiter>, // Store the previous waiter to avoid dropping it.
100	_marker: PhantomData<R>,
101}
102
103/// Create a [`Future`] from a poll function that receives a [`Waiter`].
104///
105/// The waiter is kept alive between polls so its registration in a
106/// [`WaiterList`] remains valid until the next poll replaces it.
107pub fn wait<F, R>(poll: F) -> impl Future<Output = R>
108where
109	F: FnMut(&Waiter) -> Poll<R> + Unpin,
110	R: Unpin,
111{
112	WaiterFn {
113		poll,
114		waiter: None,
115		_marker: PhantomData,
116	}
117}
118
119impl<F, R> Future for WaiterFn<F, R>
120where
121	F: FnMut(&Waiter) -> Poll<R> + Unpin,
122	R: Unpin,
123{
124	type Output = R;
125
126	fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<R> {
127		let this = &mut *self;
128		this.waiter = Some(Waiter::new(cx.waker().clone()));
129		(this.poll)(this.waiter.as_ref().unwrap())
130	}
131}