1#![no_std]
4
5use core::{
6 cell::{Cell, UnsafeCell},
7 pin::Pin,
8 ptr::NonNull,
9 task::{Context, Poll, Waker},
10};
11
12pub trait IFulfillment {
13 fn take_one(&mut self) -> Self;
14
15 fn append(&mut self, other: Self, other_count: usize);
16}
17
18impl IFulfillment for () {
19 fn take_one(&mut self) -> Self {
20 ()
21 }
22
23 fn append(&mut self, _other: Self, _other_count: usize) {}
24}
25
26pub struct Fulfillment<T> {
27 pub count: usize,
28 pub inner: T,
29}
30
31impl<T: IFulfillment> Fulfillment<T> {
32 pub fn take_one(&mut self) -> T {
33 self.count -= 1;
34 self.inner.take_one()
35 }
36
37 pub fn append(&mut self, other: Self) {
38 self.inner.append(other.inner, other.count);
39 self.count += other.count;
40 }
41}
42
43pub struct WaiterQueue<T> {
44 state: spin::Mutex<WaiterQueueState<T>>,
45 local: thid::ThreadLocal<Local<T>>,
46}
47
48struct WaiterQueueState<T> {
49 front: Option<NonNull<WaiterNode<T>>>,
50 back: Option<NonNull<WaiterNode<T>>>,
51 count: usize,
52}
53
54struct WaiterNode<T> {
55 state: spin::Mutex<WaiterNodeState<T>>,
56
57 previous: UnsafeCell<Option<NonNull<Self>>>,
59 next: UnsafeCell<Option<NonNull<Self>>>,
60
61 local_lifecycle: Cell<WaiterLifecycle>,
64 local_state: UnsafeCell<WaiterNodeState<T>>,
65 local_next: Cell<Option<NonNull<Self>>>,
66 local_prev: Cell<Option<NonNull<Self>>>,
67}
68
69enum WaiterNodeState<T> {
70 Pending,
71 Polled { waker: Waker },
72 Notified { fulfillment: Fulfillment<T> },
73 Releasing,
74}
75
76unsafe impl<T> Send for WaiterQueue<T> {}
77unsafe impl<T> Sync for WaiterQueue<T> {}
78
79impl<T: IFulfillment> WaiterNode<T> {
80 pub fn new() -> Self {
81 Self {
82 previous: UnsafeCell::new(None),
83 next: UnsafeCell::new(None),
84 state: spin::Mutex::new(WaiterNodeState::Pending),
85 local_lifecycle: Cell::new(WaiterLifecycle::Unregistered),
86 local_state: UnsafeCell::new(WaiterNodeState::Pending),
87 local_next: Cell::new(None),
88 local_prev: Cell::new(None),
89 }
90 }
91
92 #[inline]
93 fn with_state<R>(&self, f: impl FnOnce(&mut WaiterNodeState<T>) -> R) -> R {
94 f(&mut self.state.lock())
95 }
96
97 #[inline]
98 fn fulfill(&self, fulfillment: Fulfillment<T>) -> Option<Waker> {
99 self.with_state(|state| Self::fulfill_common(state, fulfillment))
100 }
101
102 #[inline]
103 fn fulfill_local(&self, fulfillment: Fulfillment<T>) -> Option<Waker> {
104 let state = unsafe { &mut *self.local_state.get() };
105 Self::fulfill_common(state, fulfillment)
106 }
107
108 #[inline]
109 fn fulfill_common(
110 state: &mut WaiterNodeState<T>,
111 fulfillment: Fulfillment<T>,
112 ) -> Option<Waker> {
113 match state {
114 WaiterNodeState::Pending => {
115 *state = WaiterNodeState::Notified { fulfillment };
116 None
117 }
118 WaiterNodeState::Polled { .. } => {
119 let WaiterNodeState::Polled { waker } =
120 core::mem::replace(&mut *state, WaiterNodeState::Notified { fulfillment })
121 else {
122 unreachable!();
123 };
124 Some(waker)
125 }
126 WaiterNodeState::Notified {
128 fulfillment: existing_fulfillment,
129 } => {
130 existing_fulfillment.append(fulfillment);
131 None
132 }
133 WaiterNodeState::Releasing => unreachable!(),
135 }
136 }
137}
138
139struct Local<T> {
140 nodes: Cell<Option<(NonNull<WaiterNode<T>>, NonNull<WaiterNode<T>>)>>,
141 count: Cell<usize>,
142}
143
144unsafe impl<T> Send for Local<T> {}
145
146impl<T> Default for Local<T> {
147 fn default() -> Self {
148 Self {
149 nodes: Cell::new(None),
150 count: Cell::new(0),
151 }
152 }
153}
154
155impl<T> Local<T> {
156 #[inline]
157 fn add_node(&self, new_node: NonNull<WaiterNode<T>>) {
158 self.count.set(self.count.get() + 1);
159 if let Some((head, tail)) = self.nodes.get() {
160 unsafe { new_node.as_ref() }.local_prev.set(Some(tail));
161 unsafe { tail.as_ref() }.local_next.set(Some(new_node));
162 self.nodes.set(Some((head, new_node)));
163 } else {
164 self.nodes.set(Some((new_node, new_node)));
166 debug_assert_eq!(self.count.get(), 1);
167 }
168 }
169
170 #[inline]
171 fn remove_node(&self, to_remove: &WaiterNode<T>) {
172 let Some((head, tail)) = self.nodes.get() else {
173 return;
175 };
176
177 let prev = to_remove.local_prev.replace(None);
178 let next = to_remove.local_next.replace(None);
179
180 if prev.is_none() && next.is_none() && head != NonNull::from(to_remove) {
181 return;
183 }
184 self.count.set(self.count.get() - 1);
185
186 if let Some(next) = next {
187 unsafe { next.as_ref() }.local_prev.set(prev);
188 } else {
189 debug_assert_eq!(NonNull::from(to_remove), tail);
191 if let Some(prev) = prev {
192 unsafe { prev.as_ref() }.local_next.set(None);
193 self.nodes.set(Some((head, prev)));
194 } else {
195 debug_assert_eq!(head, tail);
197 debug_assert_eq!(self.count.get(), 0);
198 self.nodes.set(None);
199 }
200 return;
201 }
202
203 if let Some(prev) = prev {
204 unsafe { prev.as_ref() }.local_next.set(next);
205 } else {
206 debug_assert_eq!(NonNull::from(to_remove), head);
208 if let Some(next) = next {
209 self.nodes.set(Some((next, tail)));
210 } else {
211 debug_assert_eq!(head, tail);
213 debug_assert_eq!(self.count.get(), 0);
214 self.nodes.set(None);
215 return;
216 }
217 }
218 }
219
220 #[inline]
221 fn pop_node(&self) -> Option<NonNull<WaiterNode<T>>> {
222 let (head, tail) = self.nodes.take()?;
223 self.count.set(self.count.get() - 1);
224 if head != tail {
225 let new_head = unsafe { head.as_ref() }.local_next.take().unwrap();
226 unsafe { new_head.as_ref() }.local_prev.set(None);
227 self.nodes.set(Some((new_head, tail)));
228 } else {
229 debug_assert_eq!(self.count.get(), 0);
230 }
231
232 Some(head)
233 }
234}
235
236pub struct WaiterQueueGuard<'a, T> {
237 state: spin::MutexGuard<'a, WaiterQueueState<T>>,
238}
239
240impl<T> WaiterQueueGuard<'_, T> {
241 pub fn waiter_count(&self) -> usize {
242 self.state.count
243 }
244}
245
246impl<T: IFulfillment> WaiterQueue<T> {
247 pub fn new() -> Self {
248 Self {
249 state: spin::Mutex::new(WaiterQueueState {
250 front: None,
251 back: None,
252 count: 0,
253 }),
254 local: thid::ThreadLocal::new(),
255 }
256 }
257
258 pub fn lock(&self) -> WaiterQueueGuard<'_, T> {
259 WaiterQueueGuard {
260 state: self.state.lock(),
261 }
262 }
263
264 pub fn notify_one_local(&self, fulfillment: T) -> Option<T> {
265 let local = self.local.get_or_default();
266 let Some((local_head, _)) = local.nodes.get() else {
267 return Some(fulfillment);
268 };
269
270 debug_assert!(unsafe { local_head.as_ref() }.local_prev.get().is_none());
271 debug_assert_eq!(
272 unsafe { local_head.as_ref() }.local_lifecycle.get(),
273 WaiterLifecycle::Registered,
274 );
275
276 let fulfillment = Fulfillment {
277 inner: fulfillment,
278 count: 1,
279 };
280
281 let mut guard = self.lock();
282 if guard.remove_waiter(local_head) {
283 local.pop_node();
287
288 if let Some((new_head, _)) = local.nodes.get() {
289 Self::upgrade_local_waiter(&mut guard, new_head);
290 }
291 drop(guard);
292
293 unsafe { local_head.as_ref() }
294 .local_lifecycle
295 .set(WaiterLifecycle::RegisteredLocal);
296 if let Some(waker) = unsafe { local_head.as_ref() }.fulfill_local(fulfillment) {
297 waker.wake();
298 }
299 } else {
300 drop(guard);
306
307 if let Some(waker) = unsafe { local_head.as_ref() }.fulfill(fulfillment) {
308 waker.wake();
309 }
310 }
311
312 None
313 }
314
315 fn remove_local_waiter(&self, to_remove: &WaiterNode<T>) {
316 let local = self.local.get_or_default();
317 local.remove_node(to_remove);
318 }
319
320 fn upgrade_local_waiter(guard: &mut WaiterQueueGuard<'_, T>, waiter: NonNull<WaiterNode<T>>) {
321 debug_assert_eq!(
322 unsafe { waiter.as_ref() }.local_lifecycle.get(),
323 WaiterLifecycle::RegisteredLocal,
324 );
325
326 let waiter_ref = unsafe { waiter.as_ref() };
329 *waiter_ref.state.lock() = match unsafe { &*waiter_ref.local_state.get() } {
330 WaiterNodeState::Pending => WaiterNodeState::Pending,
331 WaiterNodeState::Polled { waker } => WaiterNodeState::Polled {
332 waker: waker.clone(),
333 },
334 WaiterNodeState::Notified { .. } => unreachable!(),
335 WaiterNodeState::Releasing => unreachable!(),
336 };
337
338 guard.add_waiter(waiter);
339 waiter_ref.local_lifecycle.set(WaiterLifecycle::Registered);
340 }
341}
342
343impl WaiterQueue<()> {
344 #[inline]
346 pub fn notify_one(&self) -> bool {
347 self.lock().notify((), 1).is_none()
348 }
349
350 #[inline]
351 pub fn notify_all(&self) -> usize {
352 self.lock().notify_all(())
353 }
354
355 #[inline]
359 pub unsafe fn wait(&self) -> Waiter<'_, ()> {
360 Waiter::new(&self)
361 }
362
363 #[inline]
364 pub async fn wait_for<R>(&self, mut condition: impl FnMut() -> Option<R>) -> R {
365 if let Some(r) = condition() {
366 return r;
367 }
368
369 let result = Cell::new(None);
370 loop {
371 let wait_until = core::pin::pin!(WaitUntil {
372 waiter: unsafe { self.wait() },
374 condition: UnsafeCell::new(|| {
375 if let Some(r) = condition() {
376 result.set(Some(r));
377 true
378 } else {
379 false
380 }
381 }),
382 });
383 core::future::poll_fn(|cx| wait_until.as_ref().poll(cx)).await;
384
385 if let Some(r) = result.take() {
386 return r;
387 }
388 }
389 }
390
391 #[inline]
392 pub async fn wait_until(&self, condition: impl Fn() -> bool) {
393 let condition = &condition;
394 loop {
395 if condition() {
396 return;
397 }
398
399 let wait_until = core::pin::pin!(WaitUntil {
400 waiter: unsafe { self.wait() },
402 condition: UnsafeCell::new(condition),
403 });
404 core::future::poll_fn(|cx| wait_until.as_ref().poll(cx)).await;
405 }
406 }
407}
408
409impl<T: IFulfillment> WaiterQueueGuard<'_, T> {
410 pub fn notify(mut self, fulfillment: T, count: usize) -> Option<Fulfillment<T>> {
411 let Some(front_ptr) = self.state.front else {
412 return Some(Fulfillment {
414 count,
415 inner: fulfillment,
416 });
417 };
418
419 self.state.count -= 1;
420
421 let next_ptr = core::mem::replace(unsafe { &mut *front_ptr.as_ref().next.get() }, None);
423 self.state.front = next_ptr;
424
425 if let Some(new_front_ptr) = self.state.front {
426 unsafe { *new_front_ptr.as_ref().previous.get() = None };
427 } else {
428 debug_assert_eq!(Some(front_ptr), self.state.back);
429 debug_assert!(unsafe { *front_ptr.as_ref().previous.get() }.is_none());
430
431 self.state.back = None;
433 }
434 drop(self);
436
437 let maybe_waker = unsafe { front_ptr.as_ref() }.fulfill(Fulfillment {
438 inner: fulfillment,
439 count,
440 });
441 if let Some(waker) = maybe_waker {
442 waker.wake();
443 }
444
445 None
446 }
447
448 fn remove_waiter(&mut self, node: NonNull<WaiterNode<T>>) -> bool {
451 let prev = unsafe { *node.as_ref().previous.get() };
452 let next = unsafe { *node.as_ref().next.get() };
453
454 if prev.is_none() && next.is_none() && self.state.front != Some(node) {
455 return false;
457 }
458
459 self.state.count -= 1;
460
461 unsafe {
462 *node.as_ref().next.get() = None;
463 }
464 unsafe {
465 *node.as_ref().previous.get() = None;
466 }
467
468 if Some(node) == self.state.back {
470 self.state.back = prev;
471 debug_assert!(next.is_none());
472 }
473
474 if Some(node) == self.state.front {
475 self.state.front = next;
477 if let Some(next) = next {
478 unsafe {
480 *next.as_ref().previous.get() = None;
481 }
482 } else {
483 debug_assert!(self.state.back.is_none());
484 }
485 } else if let Some(prev) = prev {
486 unsafe { *prev.as_ref().next.get() = next };
489 if let Some(next) = next {
490 unsafe {
492 *next.as_ref().previous.get() = Some(prev);
493 }
494 }
495 }
496
497 true
498 }
499
500 fn add_waiter(&mut self, new_node: NonNull<WaiterNode<T>>) {
501 let state = &mut self.state;
502 state.count += 1;
503
504 debug_assert!(unsafe { (*new_node.as_ref().next.get()).is_none() });
505 debug_assert!(unsafe { (*new_node.as_ref().previous.get()).is_none() });
506
507 let prev_back = core::mem::replace(&mut state.back, Some(new_node));
508 if let Some(prev_back) = prev_back {
509 unsafe {
510 *new_node.as_ref().previous.get() = Some(prev_back);
512 *prev_back.as_ref().next.get() = Some(new_node);
514 }
515 } else {
516 state.front = Some(new_node);
518 debug_assert!(unsafe { &*new_node.as_ref().next.get() }.is_none());
519 debug_assert!(unsafe { &*new_node.as_ref().previous.get() }.is_none());
520 }
521 }
522}
523
524impl<T: IFulfillment + Copy> WaiterQueueGuard<'_, T> {
525 pub fn notify_all(&mut self, fulfillment: T) -> usize {
526 let mut notified_count = 0;
527
528 while let Some(front_ptr) = self.state.front {
529 notified_count += 1;
530 self.state.count -= 1;
531
532 let next_ptr = core::mem::replace(unsafe { &mut *front_ptr.as_ref().next.get() }, None);
534 self.state.front = next_ptr;
535
536 if let Some(new_front_ptr) = self.state.front {
537 unsafe { *new_front_ptr.as_ref().previous.get() = None };
538 } else {
539 debug_assert_eq!(Some(front_ptr), self.state.back);
540 debug_assert!(unsafe { *front_ptr.as_ref().previous.get() }.is_none());
541
542 self.state.back = None;
544 }
545
546 let maybe_waker = unsafe { front_ptr.as_ref() }.fulfill(Fulfillment {
547 inner: fulfillment,
548 count: usize::MAX,
549 });
550 if let Some(waker) = maybe_waker {
551 waker.wake();
552 }
553 }
554
555 notified_count
556 }
557}
558
559#[derive(Copy, Clone, Debug, Eq, PartialEq)]
560enum WaiterLifecycle {
561 Unregistered,
562 Registered,
563 RegisteredLocal,
564 Releasing,
565}
566
567pub struct Waiter<'a, T: IFulfillment> {
568 waiter_queue: &'a WaiterQueue<T>,
569 waiter_node: UnsafeCell<WaiterNode<T>>,
570}
571
572impl<'a, T: IFulfillment> Waiter<'a, T> {
573 pub fn new(waiter_queue: &'a WaiterQueue<T>) -> Self {
574 Self {
575 waiter_queue,
576 waiter_node: UnsafeCell::new(WaiterNode::new()),
577 }
578 }
579
580 #[inline]
581 fn lifecycle(&self) -> WaiterLifecycle {
582 unsafe { &*self.waiter_node.get() }.local_lifecycle.get()
583 }
584
585 #[inline]
586 fn set_lifecycle(&self, new_value: WaiterLifecycle) {
587 unsafe { &*self.waiter_node.get() }
588 .local_lifecycle
589 .set(new_value);
590 }
591
592 #[inline]
593 fn register(
594 self: Pin<&Self>,
595 mut try_fulfill: impl FnMut() -> Option<Fulfillment<T>>,
596 ) -> Option<Fulfillment<T>> {
597 if self.lifecycle() != WaiterLifecycle::Unregistered {
598 return None;
599 }
600
601 let local = self.waiter_queue.local.get_or_default();
602 let waiter_node_ptr = NonNull::from(unsafe { &*self.waiter_node.get() });
603
604 if local.nodes.get().is_some() {
605 self.set_lifecycle(WaiterLifecycle::RegisteredLocal);
606 local.add_node(waiter_node_ptr);
607 None
608 } else {
609 let mut guard = self.waiter_queue.lock();
611 if let Some(fulfillment) = try_fulfill() {
612 drop(guard);
613 Some(fulfillment)
614 } else {
615 guard.add_waiter(waiter_node_ptr);
616 self.set_lifecycle(WaiterLifecycle::Registered);
617 local.add_node(waiter_node_ptr);
618 None
619 }
620 }
621 }
622
623 pub fn cancel(&self) -> Option<Fulfillment<T>> {
624 match self.lifecycle() {
625 WaiterLifecycle::Registered => {
626 self.set_lifecycle(WaiterLifecycle::Releasing);
627
628 let mut waiter_queue_guard = self.waiter_queue.lock();
630
631 let waiter_node = unsafe { &*self.waiter_node.get() };
632 let mut state = waiter_node.state.lock();
633 match core::mem::replace(&mut *state, WaiterNodeState::Releasing) {
634 WaiterNodeState::Notified { fulfillment } => {
635 self.waiter_queue.remove_local_waiter(waiter_node);
636 Some(fulfillment)
637 }
638 WaiterNodeState::Releasing => None,
640 _ => {
641 waiter_queue_guard.remove_waiter(NonNull::from(waiter_node));
643 self.waiter_queue.remove_local_waiter(waiter_node);
644 None
645 }
646 }
647 }
648 WaiterLifecycle::RegisteredLocal => {
649 self.set_lifecycle(WaiterLifecycle::Releasing);
650
651 let waiter_node = unsafe { &*self.waiter_node.get() };
652 let state = unsafe { &mut *waiter_node.local_state.get() };
653 match core::mem::replace(&mut *state, WaiterNodeState::Releasing) {
654 WaiterNodeState::Notified { fulfillment } => {
655 Some(fulfillment)
657 }
658 WaiterNodeState::Releasing => None,
660 _ => {
661 self.waiter_queue.remove_local_waiter(waiter_node);
663 None
664 }
665 }
666 }
667 _ => None,
668 }
669 }
670
671 pub fn poll_fulfillment(
672 self: Pin<&'_ Self>,
673 context: &'_ mut Context<'_>,
674 mut try_fulfill: impl FnMut() -> Option<Fulfillment<T>>,
675 ) -> Poll<Fulfillment<T>> {
676 if let Some(fulfillment) = self.as_ref().register(&mut try_fulfill) {
677 return Poll::Ready(fulfillment);
678 }
679
680 let waiter_node = unsafe { &*self.waiter_node.get() };
681
682 let update_state = |state: &mut WaiterNodeState<T>| {
683 let mut maybe_fulfillment = None;
684 let state_ptr = &mut *state as *mut WaiterNodeState<T>;
685 let taken_state = unsafe { core::ptr::read(state_ptr) };
686
687 let new_state = match taken_state {
689 WaiterNodeState::Pending => WaiterNodeState::Polled {
690 waker: context.waker().clone(),
691 },
692 WaiterNodeState::Polled { waker } => {
693 let new_waker = context.waker();
694 if !waker.will_wake(new_waker) {
695 WaiterNodeState::Polled {
696 waker: new_waker.clone(),
697 }
698 } else {
699 WaiterNodeState::Polled { waker }
700 }
701 }
702 WaiterNodeState::Notified { fulfillment } => {
703 maybe_fulfillment = Some(fulfillment);
704 WaiterNodeState::Releasing
705 }
706 WaiterNodeState::Releasing => unreachable!(),
707 };
708
709 unsafe {
710 state_ptr.write(new_state);
711 }
712
713 maybe_fulfillment
714 };
715
716 let local_state = unsafe { &mut *waiter_node.local_state.get() };
719 if let Some(fulfillment) = update_state(local_state) {
720 debug_assert_eq!(self.lifecycle(), WaiterLifecycle::RegisteredLocal);
722 self.set_lifecycle(WaiterLifecycle::Releasing);
723 return Poll::Ready(fulfillment);
724 }
725
726 if self.as_ref().lifecycle() == WaiterLifecycle::Registered {
727 if let Some(mut fulfillment) = waiter_node.with_state(update_state) {
730 let waiter_queue_local = self.as_ref().waiter_queue.local.get_or_default();
731 let popped_head = waiter_queue_local.pop_node();
732 debug_assert_eq!(popped_head, Some(NonNull::from(waiter_node)));
733
734 self.set_lifecycle(WaiterLifecycle::Releasing);
735
736 while fulfillment.count > 1 {
738 let Some(local_next) = waiter_queue_local.pop_node() else {
739 break;
740 };
741 debug_assert!(unsafe { local_next.as_ref() }.local_prev.get().is_none());
742 debug_assert_eq!(
743 unsafe { local_next.as_ref() }.local_lifecycle.get(),
744 WaiterLifecycle::RegisteredLocal,
745 );
746
747 if let Some(waker) = unsafe { local_next.as_ref() }.fulfill_local(Fulfillment {
748 inner: fulfillment.take_one(),
749 count: 1,
750 }) {
751 waker.wake();
752 }
753 }
754
755 if let Some((local_head, local_tail)) = waiter_queue_local.nodes.get() {
757 let mut guard = self.as_ref().waiter_queue.lock();
758
759 while let Some(new_fulfillment) = try_fulfill() {
761 fulfillment.append(new_fulfillment);
762 if fulfillment.count > waiter_queue_local.count.get() {
763 break;
764 }
765 }
766
767 if fulfillment.count == 1 {
768 WaiterQueue::<T>::upgrade_local_waiter(&mut guard, local_head);
771 drop(guard);
772 } else if fulfillment.count > waiter_queue_local.count.get() {
773 drop(guard);
776 while let Some(next_local) = waiter_queue_local.pop_node() {
777 let local_fulfillment = Fulfillment {
778 inner: fulfillment.take_one(),
779 count: 1,
780 };
781 if let Some(waker) =
782 unsafe { next_local.as_ref() }.fulfill_local(local_fulfillment)
783 {
784 waker.wake();
785 }
786 }
787 } else {
788 let notify_count = fulfillment.count - 1;
792 let mut cursor = local_head;
793 for _ in 0..notify_count - 1 {
794 cursor = unsafe { cursor.as_ref() }
795 .local_next
796 .get()
797 .expect("bug: missing local waiter");
798 }
799
800 let new_head = unsafe { cursor.as_ref() }
801 .local_next
802 .replace(None)
803 .expect("bug: missing local waiter");
804 unsafe { new_head.as_ref() }.local_prev.set(None);
805 waiter_queue_local.nodes.set(Some((new_head, local_tail)));
806 waiter_queue_local
807 .count
808 .set(waiter_queue_local.count.get() - notify_count);
809
810 WaiterQueue::<T>::upgrade_local_waiter(&mut guard, new_head);
813 drop(guard);
814
815 let mut wake_cursor = Some(local_head);
816 while let Some(next) = wake_cursor {
817 let local_fulfillment = Fulfillment {
818 inner: fulfillment.take_one(),
819 count: 1,
820 };
821 if let Some(waker) =
822 unsafe { next.as_ref() }.fulfill_local(local_fulfillment)
823 {
824 waker.wake();
825 }
826 unsafe { next.as_ref() }.local_prev.set(None);
827 wake_cursor = unsafe { next.as_ref() }.local_next.replace(None);
828 }
829 }
830 }
831
832 return Poll::Ready(fulfillment);
833 }
834 }
835
836 Poll::Pending
837 }
838}
839
840pub struct WaitUntil<'a, F> {
841 waiter: Waiter<'a, ()>,
842 condition: UnsafeCell<F>,
843}
844
845impl<F> WaitUntil<'_, F>
846where
847 F: FnMut() -> bool,
848{
849 fn poll(self: Pin<&Self>, context: &mut Context<'_>) -> Poll<()> {
852 let unpinned_self = unsafe { Pin::into_inner_unchecked(self) };
855 let waiter = unsafe { Pin::new_unchecked(&unpinned_self.waiter) };
856 let condition = unsafe { &mut *unpinned_self.condition.get() };
858
859 let Poll::Ready(fulfillment) = waiter.poll_fulfillment(context, || {
860 if condition() {
861 Some(Fulfillment {
862 inner: (),
863 count: usize::MAX,
864 })
865 } else {
866 None
867 }
868 }) else {
869 return Poll::Pending;
870 };
871
872 Poll::Ready(fulfillment.inner)
873 }
874}
875
876impl<F> Drop for WaitUntil<'_, F> {
877 fn drop(&mut self) {
878 let _ = self.waiter.cancel();
880 }
881}
882
883#[cfg(test)]
884mod test {
885 use super::*;
886
887 #[test]
888 fn test_add_remove_local_node() {
889 let a = WaiterNode::new();
890 let b = WaiterNode::new();
891 let c = WaiterNode::new();
892
893 let a_ptr = NonNull::from(&a);
894 let b_ptr = NonNull::from(&b);
895 let c_ptr = NonNull::from(&c);
896
897 let local = Local::<()>::default();
898
899 local.add_node(a_ptr);
900 local.add_node(b_ptr);
901 local.add_node(c_ptr);
902
903 assert_eq!(local.nodes.get(), Some((a_ptr, c_ptr)));
904 assert_eq!(a.local_prev.get(), None);
905 assert_eq!(a.local_next.get(), Some(b_ptr));
906 assert_eq!(b.local_prev.get(), Some(a_ptr));
907 assert_eq!(b.local_next.get(), Some(c_ptr));
908 assert_eq!(c.local_prev.get(), Some(b_ptr));
909 assert_eq!(c.local_next.get(), None);
910
911 local.remove_node(&b);
912
913 assert_eq!(local.nodes.get(), Some((a_ptr, c_ptr)));
914 assert_eq!(a.local_prev.get(), None);
915 assert_eq!(a.local_next.get(), Some(c_ptr));
916 assert_eq!(c.local_prev.get(), Some(a_ptr));
917 assert_eq!(c.local_next.get(), None);
918
919 local.remove_node(&a);
920
921 assert_eq!(local.nodes.get(), Some((c_ptr, c_ptr)));
922 assert_eq!(c.local_prev.get(), None);
923 assert_eq!(c.local_next.get(), None);
924
925 local.remove_node(&c);
926
927 assert_eq!(local.nodes.get(), None);
928 }
929
930 #[test]
931 fn test_add_waiter() {
932 let waiter_queue = WaiterQueue::<()>::new();
933
934 let a = WaiterNode::new();
935 let b = WaiterNode::new();
936 let c = WaiterNode::new();
937
938 let a_ptr = NonNull::from(&a);
939 let b_ptr = NonNull::from(&b);
940 let c_ptr = NonNull::from(&c);
941
942 let mut guard = waiter_queue.lock();
943
944 guard.add_waiter(a_ptr);
945 guard.add_waiter(b_ptr);
946 guard.add_waiter(c_ptr);
947
948 assert!(guard.remove_waiter(b_ptr));
949 assert!(guard.remove_waiter(a_ptr));
950 assert!(guard.remove_waiter(c_ptr));
951
952 assert!(!guard.remove_waiter(a_ptr));
953 assert!(!guard.remove_waiter(b_ptr));
954 assert!(!guard.remove_waiter(c_ptr));
955 }
956
957 #[test]
958 fn test_register_waiter() {
959 let waiter_queue = WaiterQueue::<()>::new();
960
961 let a = core::pin::pin!(Waiter::new(&waiter_queue));
962 let b = core::pin::pin!(Waiter::new(&waiter_queue));
963 let c = core::pin::pin!(Waiter::new(&waiter_queue));
964
965 a.as_ref().register(|| None);
966 b.as_ref().register(|| None);
967 c.as_ref().register(|| None);
968
969 assert!(b.cancel().is_none());
970 assert!(a.cancel().is_none());
971 assert!(c.cancel().is_none());
972 }
973}