waitq/
lib.rs

1//! `no_std`, no-alloc implementation of an async wait queue using an intrusive linked list.
2
3#![no_std]
4
5use core::{
6    cell::{Cell, UnsafeCell},
7    pin::Pin,
8    ptr::NonNull,
9    task::{Context, Poll, Waker},
10};
11
12pub trait IFulfillment {
13    fn take_one(&mut self) -> Self;
14
15    fn append(&mut self, other: Self, other_count: usize);
16}
17
18impl IFulfillment for () {
19    fn take_one(&mut self) -> Self {
20        ()
21    }
22
23    fn append(&mut self, _other: Self, _other_count: usize) {}
24}
25
26pub struct Fulfillment<T> {
27    pub count: usize,
28    pub inner: T,
29}
30
31impl<T: IFulfillment> Fulfillment<T> {
32    pub fn take_one(&mut self) -> T {
33        self.count -= 1;
34        self.inner.take_one()
35    }
36
37    pub fn append(&mut self, other: Self) {
38        self.inner.append(other.inner, other.count);
39        self.count += other.count;
40    }
41}
42
43pub struct WaiterQueue<T> {
44    state: spin::Mutex<WaiterQueueState<T>>,
45    local: thid::ThreadLocal<Local<T>>,
46}
47
48struct WaiterQueueState<T> {
49    front: Option<NonNull<WaiterNode<T>>>,
50    back: Option<NonNull<WaiterNode<T>>>,
51    count: usize,
52}
53
54struct WaiterNode<T> {
55    state: spin::Mutex<WaiterNodeState<T>>,
56
57    // SAFETY: These can only be accessed while the WaiterQueue lock is held.
58    previous: UnsafeCell<Option<NonNull<Self>>>,
59    next: UnsafeCell<Option<NonNull<Self>>>,
60
61    // SAFETY: These can only be accessed from the awaiting task's thread, and that task cannot be
62    // `Send`.
63    local_lifecycle: Cell<WaiterLifecycle>,
64    local_state: UnsafeCell<WaiterNodeState<T>>,
65    local_next: Cell<Option<NonNull<Self>>>,
66    local_prev: Cell<Option<NonNull<Self>>>,
67}
68
69enum WaiterNodeState<T> {
70    Pending,
71    Polled { waker: Waker },
72    Notified { fulfillment: Fulfillment<T> },
73    Releasing,
74}
75
76unsafe impl<T> Send for WaiterQueue<T> {}
77unsafe impl<T> Sync for WaiterQueue<T> {}
78
79impl<T: IFulfillment> WaiterNode<T> {
80    pub fn new() -> Self {
81        Self {
82            previous: UnsafeCell::new(None),
83            next: UnsafeCell::new(None),
84            state: spin::Mutex::new(WaiterNodeState::Pending),
85            local_lifecycle: Cell::new(WaiterLifecycle::Unregistered),
86            local_state: UnsafeCell::new(WaiterNodeState::Pending),
87            local_next: Cell::new(None),
88            local_prev: Cell::new(None),
89        }
90    }
91
92    #[inline]
93    fn with_state<R>(&self, f: impl FnOnce(&mut WaiterNodeState<T>) -> R) -> R {
94        f(&mut self.state.lock())
95    }
96
97    #[inline]
98    fn fulfill(&self, fulfillment: Fulfillment<T>) -> Option<Waker> {
99        self.with_state(|state| Self::fulfill_common(state, fulfillment))
100    }
101
102    #[inline]
103    fn fulfill_local(&self, fulfillment: Fulfillment<T>) -> Option<Waker> {
104        let state = unsafe { &mut *self.local_state.get() };
105        Self::fulfill_common(state, fulfillment)
106    }
107
108    #[inline]
109    fn fulfill_common(
110        state: &mut WaiterNodeState<T>,
111        fulfillment: Fulfillment<T>,
112    ) -> Option<Waker> {
113        match state {
114            WaiterNodeState::Pending => {
115                *state = WaiterNodeState::Notified { fulfillment };
116                None
117            }
118            WaiterNodeState::Polled { .. } => {
119                let WaiterNodeState::Polled { waker } =
120                    core::mem::replace(&mut *state, WaiterNodeState::Notified { fulfillment })
121                else {
122                    unreachable!();
123                };
124                Some(waker)
125            }
126            // A WaiterNode can be notified exactly once.
127            WaiterNodeState::Notified {
128                fulfillment: existing_fulfillment,
129            } => {
130                existing_fulfillment.append(fulfillment);
131                None
132            }
133            // A WaiterNode shouldn't be reachable after it's been dropped.
134            WaiterNodeState::Releasing => unreachable!(),
135        }
136    }
137}
138
139struct Local<T> {
140    nodes: Cell<Option<(NonNull<WaiterNode<T>>, NonNull<WaiterNode<T>>)>>,
141    count: Cell<usize>,
142}
143
144unsafe impl<T> Send for Local<T> {}
145
146impl<T> Default for Local<T> {
147    fn default() -> Self {
148        Self {
149            nodes: Cell::new(None),
150            count: Cell::new(0),
151        }
152    }
153}
154
155impl<T> Local<T> {
156    #[inline]
157    fn add_node(&self, new_node: NonNull<WaiterNode<T>>) {
158        self.count.set(self.count.get() + 1);
159        if let Some((head, tail)) = self.nodes.get() {
160            unsafe { new_node.as_ref() }.local_prev.set(Some(tail));
161            unsafe { tail.as_ref() }.local_next.set(Some(new_node));
162            self.nodes.set(Some((head, new_node)));
163        } else {
164            // List is empty
165            self.nodes.set(Some((new_node, new_node)));
166            debug_assert_eq!(self.count.get(), 1);
167        }
168    }
169
170    #[inline]
171    fn remove_node(&self, to_remove: &WaiterNode<T>) {
172        let Some((head, tail)) = self.nodes.get() else {
173            // List is empty, so to_remove is definitely not in the list.
174            return;
175        };
176
177        let prev = to_remove.local_prev.replace(None);
178        let next = to_remove.local_next.replace(None);
179
180        if prev.is_none() && next.is_none() && head != NonNull::from(to_remove) {
181            // This node is not in the local list
182            return;
183        }
184        self.count.set(self.count.get() - 1);
185
186        if let Some(next) = next {
187            unsafe { next.as_ref() }.local_prev.set(prev);
188        } else {
189            // This is the tail node.
190            debug_assert_eq!(NonNull::from(to_remove), tail);
191            if let Some(prev) = prev {
192                unsafe { prev.as_ref() }.local_next.set(None);
193                self.nodes.set(Some((head, prev)));
194            } else {
195                // This is the only node.
196                debug_assert_eq!(head, tail);
197                debug_assert_eq!(self.count.get(), 0);
198                self.nodes.set(None);
199            }
200            return;
201        }
202
203        if let Some(prev) = prev {
204            unsafe { prev.as_ref() }.local_next.set(next);
205        } else {
206            // This is the head node.
207            debug_assert_eq!(NonNull::from(to_remove), head);
208            if let Some(next) = next {
209                self.nodes.set(Some((next, tail)));
210            } else {
211                // This is the only node.
212                debug_assert_eq!(head, tail);
213                debug_assert_eq!(self.count.get(), 0);
214                self.nodes.set(None);
215                return;
216            }
217        }
218    }
219
220    #[inline]
221    fn pop_node(&self) -> Option<NonNull<WaiterNode<T>>> {
222        let (head, tail) = self.nodes.take()?;
223        self.count.set(self.count.get() - 1);
224        if head != tail {
225            let new_head = unsafe { head.as_ref() }.local_next.take().unwrap();
226            unsafe { new_head.as_ref() }.local_prev.set(None);
227            self.nodes.set(Some((new_head, tail)));
228        } else {
229            debug_assert_eq!(self.count.get(), 0);
230        }
231
232        Some(head)
233    }
234}
235
236pub struct WaiterQueueGuard<'a, T> {
237    state: spin::MutexGuard<'a, WaiterQueueState<T>>,
238}
239
240impl<T> WaiterQueueGuard<'_, T> {
241    pub fn waiter_count(&self) -> usize {
242        self.state.count
243    }
244}
245
246impl<T: IFulfillment> WaiterQueue<T> {
247    pub fn new() -> Self {
248        Self {
249            state: spin::Mutex::new(WaiterQueueState {
250                front: None,
251                back: None,
252                count: 0,
253            }),
254            local: thid::ThreadLocal::new(),
255        }
256    }
257
258    pub fn lock(&self) -> WaiterQueueGuard<'_, T> {
259        WaiterQueueGuard {
260            state: self.state.lock(),
261        }
262    }
263
264    pub fn notify_one_local(&self, fulfillment: T) -> Option<T> {
265        let local = self.local.get_or_default();
266        let Some((local_head, _)) = local.nodes.get() else {
267            return Some(fulfillment);
268        };
269
270        debug_assert!(unsafe { local_head.as_ref() }.local_prev.get().is_none());
271        debug_assert_eq!(
272            unsafe { local_head.as_ref() }.local_lifecycle.get(),
273            WaiterLifecycle::Registered,
274        );
275
276        let fulfillment = Fulfillment {
277            inner: fulfillment,
278            count: 1,
279        };
280
281        let mut guard = self.lock();
282        if guard.remove_waiter(local_head) {
283            // This waiter hasn't been notified yet. Convert it to a local registration and
284            // fulfill it. Also upgrade the next waiter to be in the shared queue.
285
286            local.pop_node();
287
288            if let Some((new_head, _)) = local.nodes.get() {
289                Self::upgrade_local_waiter(&mut guard, new_head);
290            }
291            drop(guard);
292
293            unsafe { local_head.as_ref() }
294                .local_lifecycle
295                .set(WaiterLifecycle::RegisteredLocal);
296            if let Some(waker) = unsafe { local_head.as_ref() }.fulfill_local(fulfillment) {
297                waker.wake();
298            }
299        } else {
300            // This waiter was already notified by another thread but hasn't been polled yet. We
301            // can tack this fulfillment on - when the waiter is polled, the extra fulfillments
302            // will be used to notify any local waiters, and the next local waiter in line will be
303            // upgraded to the shared queue.
304
305            drop(guard);
306
307            if let Some(waker) = unsafe { local_head.as_ref() }.fulfill(fulfillment) {
308                waker.wake();
309            }
310        }
311
312        None
313    }
314
315    fn remove_local_waiter(&self, to_remove: &WaiterNode<T>) {
316        let local = self.local.get_or_default();
317        local.remove_node(to_remove);
318    }
319
320    fn upgrade_local_waiter(guard: &mut WaiterQueueGuard<'_, T>, waiter: NonNull<WaiterNode<T>>) {
321        debug_assert_eq!(
322            unsafe { waiter.as_ref() }.local_lifecycle.get(),
323            WaiterLifecycle::RegisteredLocal,
324        );
325
326        // Before registering the waiter in the shared queue, set its shared state from
327        // its local state in case it was already polled (so that the waker is properly set).
328        let waiter_ref = unsafe { waiter.as_ref() };
329        *waiter_ref.state.lock() = match unsafe { &*waiter_ref.local_state.get() } {
330            WaiterNodeState::Pending => WaiterNodeState::Pending,
331            WaiterNodeState::Polled { waker } => WaiterNodeState::Polled {
332                waker: waker.clone(),
333            },
334            WaiterNodeState::Notified { .. } => unreachable!(),
335            WaiterNodeState::Releasing => unreachable!(),
336        };
337
338        guard.add_waiter(waiter);
339        waiter_ref.local_lifecycle.set(WaiterLifecycle::Registered);
340    }
341}
342
343impl WaiterQueue<()> {
344    /// Returns true if a waiter was notified.
345    #[inline]
346    pub fn notify_one(&self) -> bool {
347        self.lock().notify((), 1).is_none()
348    }
349
350    #[inline]
351    pub fn notify_all(&self) -> usize {
352        self.lock().notify_all(())
353    }
354
355    /// # Safety
356    ///
357    /// You must cancel the returned waiter before it is dropped.
358    #[inline]
359    pub unsafe fn wait(&self) -> Waiter<'_, ()> {
360        Waiter::new(&self)
361    }
362
363    #[inline]
364    pub async fn wait_for<R>(&self, mut condition: impl FnMut() -> Option<R>) -> R {
365        if let Some(r) = condition() {
366            return r;
367        }
368
369        let result = Cell::new(None);
370        loop {
371            let wait_until = core::pin::pin!(WaitUntil {
372                // SAFETY: WaitUntil::drop cancels the Waiter if necessary.
373                waiter: unsafe { self.wait() },
374                condition: UnsafeCell::new(|| {
375                    if let Some(r) = condition() {
376                        result.set(Some(r));
377                        true
378                    } else {
379                        false
380                    }
381                }),
382            });
383            core::future::poll_fn(|cx| wait_until.as_ref().poll(cx)).await;
384
385            if let Some(r) = result.take() {
386                return r;
387            }
388        }
389    }
390
391    #[inline]
392    pub async fn wait_until(&self, condition: impl Fn() -> bool) {
393        let condition = &condition;
394        loop {
395            if condition() {
396                return;
397            }
398
399            let wait_until = core::pin::pin!(WaitUntil {
400                // SAFETY: WaitUntil::drop cancels the Waiter if necessary.
401                waiter: unsafe { self.wait() },
402                condition: UnsafeCell::new(condition),
403            });
404            core::future::poll_fn(|cx| wait_until.as_ref().poll(cx)).await;
405        }
406    }
407}
408
409impl<T: IFulfillment> WaiterQueueGuard<'_, T> {
410    pub fn notify(mut self, fulfillment: T, count: usize) -> Option<Fulfillment<T>> {
411        let Some(front_ptr) = self.state.front else {
412            // There are currently no waiters.
413            return Some(Fulfillment {
414                count,
415                inner: fulfillment,
416            });
417        };
418
419        self.state.count -= 1;
420
421        // Advance the front cursor
422        let next_ptr = core::mem::replace(unsafe { &mut *front_ptr.as_ref().next.get() }, None);
423        self.state.front = next_ptr;
424
425        if let Some(new_front_ptr) = self.state.front {
426            unsafe { *new_front_ptr.as_ref().previous.get() = None };
427        } else {
428            debug_assert_eq!(Some(front_ptr), self.state.back);
429            debug_assert!(unsafe { *front_ptr.as_ref().previous.get() }.is_none());
430
431            // We've reached the end of the waiter list - clear the `back` pointer.
432            self.state.back = None;
433        }
434        // Release the waiter queue lock before waking.
435        drop(self);
436
437        let maybe_waker = unsafe { front_ptr.as_ref() }.fulfill(Fulfillment {
438            inner: fulfillment,
439            count,
440        });
441        if let Some(waker) = maybe_waker {
442            waker.wake();
443        }
444
445        None
446    }
447
448    /// Returns true if the waiter was removed.
449    /// Returns false if the waiter had already been removed before this call.
450    fn remove_waiter(&mut self, node: NonNull<WaiterNode<T>>) -> bool {
451        let prev = unsafe { *node.as_ref().previous.get() };
452        let next = unsafe { *node.as_ref().next.get() };
453
454        if prev.is_none() && next.is_none() && self.state.front != Some(node) {
455            // This waiter has already been removed.
456            return false;
457        }
458
459        self.state.count -= 1;
460
461        unsafe {
462            *node.as_ref().next.get() = None;
463        }
464        unsafe {
465            *node.as_ref().previous.get() = None;
466        }
467
468        // Check if we are removing the back node, move `back` pointer to earlier waiter in line.
469        if Some(node) == self.state.back {
470            self.state.back = prev;
471            debug_assert!(next.is_none());
472        }
473
474        if Some(node) == self.state.front {
475            // We are removing the front node
476            self.state.front = next;
477            if let Some(next) = next {
478                // SAFETY: the shared lock protects access to all `previous` values.
479                unsafe {
480                    *next.as_ref().previous.get() = None;
481                }
482            } else {
483                debug_assert!(self.state.back.is_none());
484            }
485        } else if let Some(prev) = prev {
486            // Previous node is guaranteed not to be the back node, and we have the shared lock, so
487            // we have exclusive access to `next`.
488            unsafe { *prev.as_ref().next.get() = next };
489            if let Some(next) = next {
490                // SAFETY: the shared lock protects access to all `previous` values.
491                unsafe {
492                    *next.as_ref().previous.get() = Some(prev);
493                }
494            }
495        }
496
497        true
498    }
499
500    fn add_waiter(&mut self, new_node: NonNull<WaiterNode<T>>) {
501        let state = &mut self.state;
502        state.count += 1;
503
504        debug_assert!(unsafe { (*new_node.as_ref().next.get()).is_none() });
505        debug_assert!(unsafe { (*new_node.as_ref().previous.get()).is_none() });
506
507        let prev_back = core::mem::replace(&mut state.back, Some(new_node));
508        if let Some(prev_back) = prev_back {
509            unsafe {
510                // Set my node's previous node.
511                *new_node.as_ref().previous.get() = Some(prev_back);
512                // Link my node as next after the previous `back`.
513                *prev_back.as_ref().next.get() = Some(new_node);
514            }
515        } else {
516            // We are the first in line - set the queue's front.
517            state.front = Some(new_node);
518            debug_assert!(unsafe { &*new_node.as_ref().next.get() }.is_none());
519            debug_assert!(unsafe { &*new_node.as_ref().previous.get() }.is_none());
520        }
521    }
522}
523
524impl<T: IFulfillment + Copy> WaiterQueueGuard<'_, T> {
525    pub fn notify_all(&mut self, fulfillment: T) -> usize {
526        let mut notified_count = 0;
527
528        while let Some(front_ptr) = self.state.front {
529            notified_count += 1;
530            self.state.count -= 1;
531
532            // Advance the front cursor
533            let next_ptr = core::mem::replace(unsafe { &mut *front_ptr.as_ref().next.get() }, None);
534            self.state.front = next_ptr;
535
536            if let Some(new_front_ptr) = self.state.front {
537                unsafe { *new_front_ptr.as_ref().previous.get() = None };
538            } else {
539                debug_assert_eq!(Some(front_ptr), self.state.back);
540                debug_assert!(unsafe { *front_ptr.as_ref().previous.get() }.is_none());
541
542                // We've reached the end of the waiter list - clear the `back` pointer.
543                self.state.back = None;
544            }
545
546            let maybe_waker = unsafe { front_ptr.as_ref() }.fulfill(Fulfillment {
547                inner: fulfillment,
548                count: usize::MAX,
549            });
550            if let Some(waker) = maybe_waker {
551                waker.wake();
552            }
553        }
554
555        notified_count
556    }
557}
558
559#[derive(Copy, Clone, Debug, Eq, PartialEq)]
560enum WaiterLifecycle {
561    Unregistered,
562    Registered,
563    RegisteredLocal,
564    Releasing,
565}
566
567pub struct Waiter<'a, T: IFulfillment> {
568    waiter_queue: &'a WaiterQueue<T>,
569    waiter_node: UnsafeCell<WaiterNode<T>>,
570}
571
572impl<'a, T: IFulfillment> Waiter<'a, T> {
573    pub fn new(waiter_queue: &'a WaiterQueue<T>) -> Self {
574        Self {
575            waiter_queue,
576            waiter_node: UnsafeCell::new(WaiterNode::new()),
577        }
578    }
579
580    #[inline]
581    fn lifecycle(&self) -> WaiterLifecycle {
582        unsafe { &*self.waiter_node.get() }.local_lifecycle.get()
583    }
584
585    #[inline]
586    fn set_lifecycle(&self, new_value: WaiterLifecycle) {
587        unsafe { &*self.waiter_node.get() }
588            .local_lifecycle
589            .set(new_value);
590    }
591
592    #[inline]
593    fn register(
594        self: Pin<&Self>,
595        mut try_fulfill: impl FnMut() -> Option<Fulfillment<T>>,
596    ) -> Option<Fulfillment<T>> {
597        if self.lifecycle() != WaiterLifecycle::Unregistered {
598            return None;
599        }
600
601        let local = self.waiter_queue.local.get_or_default();
602        let waiter_node_ptr = NonNull::from(unsafe { &*self.waiter_node.get() });
603
604        if local.nodes.get().is_some() {
605            self.set_lifecycle(WaiterLifecycle::RegisteredLocal);
606            local.add_node(waiter_node_ptr);
607            None
608        } else {
609            // Try to fulfill the waiter with the waiter queue locked before registering it.
610            let mut guard = self.waiter_queue.lock();
611            if let Some(fulfillment) = try_fulfill() {
612                drop(guard);
613                Some(fulfillment)
614            } else {
615                guard.add_waiter(waiter_node_ptr);
616                self.set_lifecycle(WaiterLifecycle::Registered);
617                local.add_node(waiter_node_ptr);
618                None
619            }
620        }
621    }
622
623    pub fn cancel(&self) -> Option<Fulfillment<T>> {
624        match self.lifecycle() {
625            WaiterLifecycle::Registered => {
626                self.set_lifecycle(WaiterLifecycle::Releasing);
627
628                // Lock waker queue to prevent this node from being notified if it wasn't already notified.
629                let mut waiter_queue_guard = self.waiter_queue.lock();
630
631                let waiter_node = unsafe { &*self.waiter_node.get() };
632                let mut state = waiter_node.state.lock();
633                match core::mem::replace(&mut *state, WaiterNodeState::Releasing) {
634                    WaiterNodeState::Notified { fulfillment } => {
635                        self.waiter_queue.remove_local_waiter(waiter_node);
636                        Some(fulfillment)
637                    }
638                    // Fulfillment was already processed, so this waiter has already been deregistered.
639                    WaiterNodeState::Releasing => None,
640                    _ => {
641                        // Deregister the waiter.
642                        waiter_queue_guard.remove_waiter(NonNull::from(waiter_node));
643                        self.waiter_queue.remove_local_waiter(waiter_node);
644                        None
645                    }
646                }
647            }
648            WaiterLifecycle::RegisteredLocal => {
649                self.set_lifecycle(WaiterLifecycle::Releasing);
650
651                let waiter_node = unsafe { &*self.waiter_node.get() };
652                let state = unsafe { &mut *waiter_node.local_state.get() };
653                match core::mem::replace(&mut *state, WaiterNodeState::Releasing) {
654                    WaiterNodeState::Notified { fulfillment } => {
655                        // Local waiters are deregistered immediately when fulfilled, so don't need to deregister here.
656                        Some(fulfillment)
657                    }
658                    // Fulfillment was already processed, so this waiter has already been deregistered.
659                    WaiterNodeState::Releasing => None,
660                    _ => {
661                        // Deregister the waiter.
662                        self.waiter_queue.remove_local_waiter(waiter_node);
663                        None
664                    }
665                }
666            }
667            _ => None,
668        }
669    }
670
671    pub fn poll_fulfillment(
672        self: Pin<&'_ Self>,
673        context: &'_ mut Context<'_>,
674        mut try_fulfill: impl FnMut() -> Option<Fulfillment<T>>,
675    ) -> Poll<Fulfillment<T>> {
676        if let Some(fulfillment) = self.as_ref().register(&mut try_fulfill) {
677            return Poll::Ready(fulfillment);
678        }
679
680        let waiter_node = unsafe { &*self.waiter_node.get() };
681
682        let update_state = |state: &mut WaiterNodeState<T>| {
683            let mut maybe_fulfillment = None;
684            let state_ptr = &mut *state as *mut WaiterNodeState<T>;
685            let taken_state = unsafe { core::ptr::read(state_ptr) };
686
687            // SAFETY: the match block below must not panic.
688            let new_state = match taken_state {
689                WaiterNodeState::Pending => WaiterNodeState::Polled {
690                    waker: context.waker().clone(),
691                },
692                WaiterNodeState::Polled { waker } => {
693                    let new_waker = context.waker();
694                    if !waker.will_wake(new_waker) {
695                        WaiterNodeState::Polled {
696                            waker: new_waker.clone(),
697                        }
698                    } else {
699                        WaiterNodeState::Polled { waker }
700                    }
701                }
702                WaiterNodeState::Notified { fulfillment } => {
703                    maybe_fulfillment = Some(fulfillment);
704                    WaiterNodeState::Releasing
705                }
706                WaiterNodeState::Releasing => unreachable!(),
707            };
708
709            unsafe {
710                state_ptr.write(new_state);
711            }
712
713            maybe_fulfillment
714        };
715
716        // Always poll the local state - a waiter can be notified locally even if it is registered
717        // in the shared queue.
718        let local_state = unsafe { &mut *waiter_node.local_state.get() };
719        if let Some(fulfillment) = update_state(local_state) {
720            //debug_assert_eq!(fulfillment.count, 1);
721            debug_assert_eq!(self.lifecycle(), WaiterLifecycle::RegisteredLocal);
722            self.set_lifecycle(WaiterLifecycle::Releasing);
723            return Poll::Ready(fulfillment);
724        }
725
726        if self.as_ref().lifecycle() == WaiterLifecycle::Registered {
727            // This waiter is registered in the shared queue.
728
729            if let Some(mut fulfillment) = waiter_node.with_state(update_state) {
730                let waiter_queue_local = self.as_ref().waiter_queue.local.get_or_default();
731                let popped_head = waiter_queue_local.pop_node();
732                debug_assert_eq!(popped_head, Some(NonNull::from(waiter_node)));
733
734                self.set_lifecycle(WaiterLifecycle::Releasing);
735
736                // Use extra fulfillments to notify any local waiters.
737                while fulfillment.count > 1 {
738                    let Some(local_next) = waiter_queue_local.pop_node() else {
739                        break;
740                    };
741                    debug_assert!(unsafe { local_next.as_ref() }.local_prev.get().is_none());
742                    debug_assert_eq!(
743                        unsafe { local_next.as_ref() }.local_lifecycle.get(),
744                        WaiterLifecycle::RegisteredLocal,
745                    );
746
747                    if let Some(waker) = unsafe { local_next.as_ref() }.fulfill_local(Fulfillment {
748                        inner: fulfillment.take_one(),
749                        count: 1,
750                    }) {
751                        waker.wake();
752                    }
753                }
754
755                // Upgrade the next local waiter to be in the shared queue.
756                if let Some((local_head, local_tail)) = waiter_queue_local.nodes.get() {
757                    let mut guard = self.as_ref().waiter_queue.lock();
758
759                    // Acquire as many fulfillments as possible with the shared queue locked.
760                    while let Some(new_fulfillment) = try_fulfill() {
761                        fulfillment.append(new_fulfillment);
762                        if fulfillment.count > waiter_queue_local.count.get() {
763                            break;
764                        }
765                    }
766
767                    if fulfillment.count == 1 {
768                        // Can't notify any additional local waiters - upgrade the next local
769                        // waiter.
770                        WaiterQueue::<T>::upgrade_local_waiter(&mut guard, local_head);
771                        drop(guard);
772                    } else if fulfillment.count > waiter_queue_local.count.get() {
773                        // Notify all local waiters
774
775                        drop(guard);
776                        while let Some(next_local) = waiter_queue_local.pop_node() {
777                            let local_fulfillment = Fulfillment {
778                                inner: fulfillment.take_one(),
779                                count: 1,
780                            };
781                            if let Some(waker) =
782                                unsafe { next_local.as_ref() }.fulfill_local(local_fulfillment)
783                            {
784                                waker.wake();
785                            }
786                        }
787                    } else {
788                        // Upgrade as many local waiters as possible and upgrade the new local head
789                        // to be in the shared queue.
790
791                        let notify_count = fulfillment.count - 1;
792                        let mut cursor = local_head;
793                        for _ in 0..notify_count - 1 {
794                            cursor = unsafe { cursor.as_ref() }
795                                .local_next
796                                .get()
797                                .expect("bug: missing local waiter");
798                        }
799
800                        let new_head = unsafe { cursor.as_ref() }
801                            .local_next
802                            .replace(None)
803                            .expect("bug: missing local waiter");
804                        unsafe { new_head.as_ref() }.local_prev.set(None);
805                        waiter_queue_local.nodes.set(Some((new_head, local_tail)));
806                        waiter_queue_local
807                            .count
808                            .set(waiter_queue_local.count.get() - notify_count);
809
810                        // Upgrade the new local head to be in the shared queue and release the
811                        // shared queue lock.
812                        WaiterQueue::<T>::upgrade_local_waiter(&mut guard, new_head);
813                        drop(guard);
814
815                        let mut wake_cursor = Some(local_head);
816                        while let Some(next) = wake_cursor {
817                            let local_fulfillment = Fulfillment {
818                                inner: fulfillment.take_one(),
819                                count: 1,
820                            };
821                            if let Some(waker) =
822                                unsafe { next.as_ref() }.fulfill_local(local_fulfillment)
823                            {
824                                waker.wake();
825                            }
826                            unsafe { next.as_ref() }.local_prev.set(None);
827                            wake_cursor = unsafe { next.as_ref() }.local_next.replace(None);
828                        }
829                    }
830                }
831
832                return Poll::Ready(fulfillment);
833            }
834        }
835
836        Poll::Pending
837    }
838}
839
840pub struct WaitUntil<'a, F> {
841    waiter: Waiter<'a, ()>,
842    condition: UnsafeCell<F>,
843}
844
845impl<F> WaitUntil<'_, F>
846where
847    F: FnMut() -> bool,
848{
849    // Miri is unhappy with `self: Pin<&mut Self>` even if we do `Pin::as_ref(self)` first thing. So
850    // we can't impl Future. So we do a custom poll method and use poll_fn instead.
851    fn poll(self: Pin<&Self>, context: &mut Context<'_>) -> Poll<()> {
852        // SAFETY: we continue to treat `self.waiter` as pinned, and `self.condition` is never
853        // considered pinned.
854        let unpinned_self = unsafe { Pin::into_inner_unchecked(self) };
855        let waiter = unsafe { Pin::new_unchecked(&unpinned_self.waiter) };
856        // SAFETY: This is the only place that dereferences `self.condition`.
857        let condition = unsafe { &mut *unpinned_self.condition.get() };
858
859        let Poll::Ready(fulfillment) = waiter.poll_fulfillment(context, || {
860            if condition() {
861                Some(Fulfillment {
862                    inner: (),
863                    count: usize::MAX,
864                })
865            } else {
866                None
867            }
868        }) else {
869            return Poll::Pending;
870        };
871
872        Poll::Ready(fulfillment.inner)
873    }
874}
875
876impl<F> Drop for WaitUntil<'_, F> {
877    fn drop(&mut self) {
878        // TODO we can't just drop already sent fulfillments on the floor.
879        let _ = self.waiter.cancel();
880    }
881}
882
883#[cfg(test)]
884mod test {
885    use super::*;
886
887    #[test]
888    fn test_add_remove_local_node() {
889        let a = WaiterNode::new();
890        let b = WaiterNode::new();
891        let c = WaiterNode::new();
892
893        let a_ptr = NonNull::from(&a);
894        let b_ptr = NonNull::from(&b);
895        let c_ptr = NonNull::from(&c);
896
897        let local = Local::<()>::default();
898
899        local.add_node(a_ptr);
900        local.add_node(b_ptr);
901        local.add_node(c_ptr);
902
903        assert_eq!(local.nodes.get(), Some((a_ptr, c_ptr)));
904        assert_eq!(a.local_prev.get(), None);
905        assert_eq!(a.local_next.get(), Some(b_ptr));
906        assert_eq!(b.local_prev.get(), Some(a_ptr));
907        assert_eq!(b.local_next.get(), Some(c_ptr));
908        assert_eq!(c.local_prev.get(), Some(b_ptr));
909        assert_eq!(c.local_next.get(), None);
910
911        local.remove_node(&b);
912
913        assert_eq!(local.nodes.get(), Some((a_ptr, c_ptr)));
914        assert_eq!(a.local_prev.get(), None);
915        assert_eq!(a.local_next.get(), Some(c_ptr));
916        assert_eq!(c.local_prev.get(), Some(a_ptr));
917        assert_eq!(c.local_next.get(), None);
918
919        local.remove_node(&a);
920
921        assert_eq!(local.nodes.get(), Some((c_ptr, c_ptr)));
922        assert_eq!(c.local_prev.get(), None);
923        assert_eq!(c.local_next.get(), None);
924
925        local.remove_node(&c);
926
927        assert_eq!(local.nodes.get(), None);
928    }
929
930    #[test]
931    fn test_add_waiter() {
932        let waiter_queue = WaiterQueue::<()>::new();
933
934        let a = WaiterNode::new();
935        let b = WaiterNode::new();
936        let c = WaiterNode::new();
937
938        let a_ptr = NonNull::from(&a);
939        let b_ptr = NonNull::from(&b);
940        let c_ptr = NonNull::from(&c);
941
942        let mut guard = waiter_queue.lock();
943
944        guard.add_waiter(a_ptr);
945        guard.add_waiter(b_ptr);
946        guard.add_waiter(c_ptr);
947
948        assert!(guard.remove_waiter(b_ptr));
949        assert!(guard.remove_waiter(a_ptr));
950        assert!(guard.remove_waiter(c_ptr));
951
952        assert!(!guard.remove_waiter(a_ptr));
953        assert!(!guard.remove_waiter(b_ptr));
954        assert!(!guard.remove_waiter(c_ptr));
955    }
956
957    #[test]
958    fn test_register_waiter() {
959        let waiter_queue = WaiterQueue::<()>::new();
960
961        let a = core::pin::pin!(Waiter::new(&waiter_queue));
962        let b = core::pin::pin!(Waiter::new(&waiter_queue));
963        let c = core::pin::pin!(Waiter::new(&waiter_queue));
964
965        a.as_ref().register(|| None);
966        b.as_ref().register(|| None);
967        c.as_ref().register(|| None);
968
969        assert!(b.cancel().is_none());
970        assert!(a.cancel().is_none());
971        assert!(c.cancel().is_none());
972    }
973}