spmc_waker/lib.rs
1//! A synchronization primitive for task wakeup.
2//!
3//! This crate provides [`SpmcWaker`], a single-producer, multiple-consumer (SPMC)
4//! atomic waker.
5//!
6//! # Features
7//!
8//! - `portable-atomic`: use `portable-atomic` crate to provide functionality to
9//! targets without atomics.
10#![no_std]
11#[cfg(doc)]
12extern crate std;
13use core::{hint::assert_unchecked, mem::ManuallyDrop, task::Waker};
14
15use crate::{
16 loom::{
17 AtomicUsizeExt,
18 sync::atomic::{
19 AtomicUsize,
20 Ordering::{Relaxed, SeqCst},
21 },
22 },
23 waker_cell::WakerCell,
24};
25
26#[cfg(all(debug_assertions, not(loom)))]
27mod exclusive;
28mod loom;
29mod waker_cell;
30
31const EMPTY: usize = 2;
32const WAKING: usize = 4;
33
34/// A synchronization primitive for task wakeup.
35///
36/// Sometimes the task interested in a given event will change over time.
37/// A `SpmcWaker` can coordinate concurrent notifications with the consumer
38/// potentially "updating" the underlying task to wake up. This is useful in
39/// scenarios where a computation completes in another thread and wants to
40/// notify the consumer, but the consumer is in the process of being migrated to
41/// a new logical task.
42///
43/// Consumers should call `register` before checking the result of a computation
44/// and producers should call `wake` after producing the computation (this
45/// differs from the usual `thread::park` pattern). It is also permitted for
46/// `wake` to be called **before** `register`. This results in a no-op.
47///
48/// A single `SpmcWaker` may be reused for any number of calls to `register` or
49/// `wake`.
50///
51/// # Single-producer, multiple-consumer (SPMC)
52///
53/// `SpmcWaker` algorithm assumes a single thread calling `register`/`unregister`
54/// at a time. It is enforced by the methods' safety condition.
55///
56/// This assumption allows significant optimizations compared to an MPMC algorithm
57/// like [`AtomicWaker`].
58///
59/// # Memory ordering
60///
61/// `SpmcWaker` atomic operations use `SeqCst` ordering, and it has a generic
62/// `SYNC` parameter which determines the synchronization guarantees.
63///
64/// ### `SYNC=false` (the default)
65///
66/// There is no acquire-release synchronization between `register` and `wake`.
67///
68/// Because a `wake` call may not see the waker registered by a concurrent
69/// `register`, the waking condition should use a total order, i.e. `SeqCst`
70/// or RMW operations. It ensures that checking the waking condition after
71/// `register` succeeds even when a concurrent `wake` misses the registered
72/// waker.
73///
74/// When no waker is registered, `wake` is reduced to a single atomic load.
75///
76/// ### `SYNC=true`
77///
78/// Calling `register` "acquires" all memory "released" by calls to `wake`
79/// before the call to `register`.
80///
81/// It allows setting the waking condition and checking it with a relaxed
82/// ordering after the registration, at the cost of having a mandatory
83/// atomic RMW operation in `wake`.
84///
85/// If the waking condition is already set through an atomic RMW operation,
86/// adding `SeqCst` ordering to it and to the waking condition check
87/// comes at a minimal cost, and allows to save an atomic RMW operation
88/// in `wake` by switching to `SYNC=false`. As a matter of fact `SYNC=true`
89/// should only be considered when the waking condition has no RMW involved.
90///
91/// # Waker caching
92///
93/// Most of the time, `SpmcWaker` is used in a single task, so the waker
94/// registered is always the same. That's why it provides a second generic
95/// parameter `CACHED`.
96///
97/// ### `CACHED=true` (the default)
98///
99/// The last waker registered is kept cached to avoid cloning it at the next
100/// registration. As a consequence, waking is done with [`Waker::wake_by_ref`].
101/// As wakers are often `Arc`s, caching avoids atomic RMW operations updating
102/// the reference counter.
103///
104/// ### `CACHED=false`
105///
106/// Waker is cloned when registered by reference, and the tasks are woken with
107/// [`Waker::wake`].
108///
109/// # Examples
110///
111/// Here is a simple example providing a `Flag` that can be signaled manually
112/// when it is ready.
113///
114/// ```rust
115/// use std::{
116/// pin::Pin,
117/// sync::{
118/// Arc,
119/// atomic::{
120/// AtomicBool,
121/// Ordering::{Relaxed, SeqCst},
122/// },
123/// },
124/// task::{Context, Poll},
125/// };
126///
127/// use spmc_waker::SpmcWaker;
128///
129/// #[derive(Default)]
130/// struct Inner {
131/// notified: AtomicBool,
132/// waker: SpmcWaker,
133/// }
134///
135/// #[derive(Clone)]
136/// struct Notifier(Arc<Inner>);
137///
138/// impl Notifier {
139/// pub fn new() -> Self {
140/// Self(Arc::new(Inner {
141/// waker: SpmcWaker::new(),
142/// notified: AtomicBool::new(false),
143/// }))
144/// }
145///
146/// pub fn signal(&self) {
147/// // Use seqcst ordering to synchronize with the load after `register`
148/// self.0.notified.store(true, SeqCst);
149/// self.0.waker.wake();
150/// }
151/// }
152///
153/// #[derive(Default)]
154/// struct Waiter(Arc<Inner>);
155///
156/// impl Waiter {
157/// fn notifier(&self) -> Notifier {
158/// Notifier(self.0.clone())
159/// }
160/// }
161///
162/// impl Future for Waiter {
163/// type Output = ();
164///
165/// fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
166/// // quick check to avoid registration if already done.
167/// if self.0.notified.load(Relaxed) {
168/// return Poll::Ready(());
169/// }
170///
171/// // SAFETY: mutable reference on non-cloneable `Waiter` ensures no concurrent call
172/// unsafe { self.0.waker.register(cx.waker()) };
173///
174/// // Need to check condition **after** `register` to avoid a race
175/// // condition that would result in lost notifications.
176/// // Use seqcst ordering so it synchronizes with the store before wake.
177/// if self.0.notified.load(SeqCst) {
178/// // Unregister the waker to avoid spurious wakeups.
179/// // SAFETY: mutable reference on non-cloneable `Waiter` ensures no concurrent call
180/// unsafe { self.0.waker.unregister() };
181/// Poll::Ready(())
182/// } else {
183/// Poll::Pending
184/// }
185/// }
186/// }
187///
188/// fn event() -> (Notifier, Waiter) {
189/// let waiter = Waiter::default();
190/// (waiter.notifier(), waiter)
191/// }
192/// ```
193///
194/// [`AtomicWaker`]: https://docs.rs/futures/latest/futures/task/struct.AtomicWaker.html
195#[derive(Debug)]
196pub struct SpmcWaker<const SYNC: bool = false, const CACHED: bool = true> {
197 wakers: [WakerCell; 2],
198 /// State possible values are:
199 /// - 0 or 1: A waker is registered in `wakers[state]`
200 /// - EMPTY: there is no waker registered
201 /// with CACHED=true, it becomes a bit-flag and the state's LSB gives
202 /// the cached waker index (cells are initialized with dummy wakers)
203 /// - WAKING: a `wake` operation is ongoing;
204 /// with SYNC=true, it becomes a bit-flag
205 state: AtomicUsize,
206 #[cfg(all(debug_assertions, not(loom)))]
207 exclusive: exclusive::Exclusive,
208}
209
210unsafe impl<const SYNC: bool, const CACHED: bool> Send for SpmcWaker<SYNC, CACHED> {}
211unsafe impl<const SYNC: bool, const CACHED: bool> Sync for SpmcWaker<SYNC, CACHED> {}
212
213impl<const SYNC: bool, const CACHED: bool> Drop for SpmcWaker<SYNC, CACHED> {
214 #[inline]
215 fn drop(&mut self) {
216 let state = self.state.load_mut();
217 if CACHED || state < 2 {
218 // SAFETY: state is the index of a waker currently registered
219 // that must be taken back, and access is safe in destructor
220 unsafe { self.wakers[state % 2].drop() };
221 }
222 }
223}
224
225impl<const SYNC: bool, const CACHED: bool> SpmcWaker<SYNC, CACHED> {
226 /// Creates a new `SpmcWaker`.
227 #[cfg_attr(loom, const_fn::const_fn(cfg(false)))]
228 #[inline]
229 pub const fn new() -> Self {
230 Self {
231 wakers: [WakerCell::new(), WakerCell::new()],
232 state: AtomicUsize::new(EMPTY),
233 #[cfg(all(debug_assertions, not(loom)))]
234 exclusive: exclusive::Exclusive::new(),
235 }
236 }
237
238 /// Registers the waker to be notified on calls to `wake`.
239 ///
240 /// The new task will take place of any previous tasks that were registered
241 /// by previous calls to `register`. Any calls to `wake` that happen after
242 /// a call to `register` (as defined by the memory ordering rules), will
243 /// notify the `register` caller's task and deregister the waker from future
244 /// notifications. Because of this, callers should ensure `register` gets
245 /// invoked with a new `Waker` **each** time they require a wakeup.
246 ///
247 /// It is safe to call `register` with multiple other threads concurrently
248 /// calling `wake`. This will result in the `register` caller's current
249 /// task being notified once. A concurrent `wake` may prevent `register`
250 /// to succeed, in which case it will return `false`. If despite the
251 /// concurrent `wake`, the wakeup condition is still not fulfilled, then
252 /// `Waker::wake` might be called to reschedule the task and give it
253 /// another opportunity to register is waker — this would be equivalent
254 /// to [`std::thread::yield_now`]. It is also possible to call `register`
255 /// in small [spin-loop](std::hint::spin_loop), before falling back to
256 /// calling `Waker::wake`.
257 ///
258 /// # Safety
259 ///
260 /// `register` and `unregister` methods must not be called concurrently
261 /// from multiple threads.
262 #[inline]
263 pub unsafe fn register(&self, waker: &Waker) -> bool {
264 #[cfg(all(debug_assertions, not(loom)))]
265 let _guard = self.exclusive.check();
266 // State is loaded and expected to be EMPTY. Otherwise, it means
267 // there already is a registered waker that needs to be overwritten.
268 let state = self.state.load(SeqCst);
269 // The case `CACHED && state == EMPTY | 1` is handled in `overwrite`.
270 if state == EMPTY {
271 // SAFETY: SeqCst protect against outdated read, and `register`
272 // cannot be called concurrently. It means that reading EMPTY
273 // ensures there cannot be any registered waker at this point.
274 // A concurrent `wake` will thus not attempt any read, so it's
275 // safe to access both cells mutably.
276 unsafe {
277 if !CACHED {
278 self.wakers[0].set(waker.clone());
279 } else if !self.wakers[0].will_wake(waker) {
280 return self.overwrite(waker, state);
281 }
282 }
283 // SYNC=true uses swap, as `wake` must synchronize with `register`
284 if SYNC {
285 self.state.swap(0, SeqCst);
286 } else {
287 self.state.store(0, SeqCst);
288 }
289 true
290 } else {
291 self.overwrite(waker, state)
292 }
293 }
294
295 // Overwriting a registered waker is expected to be rare, hence the `#[cold]` attribute.
296 #[cold]
297 fn overwrite(&self, waker: &Waker, state: usize) -> bool {
298 // A concurrent `wake` may be happening.
299 if (SYNC && state & WAKING != 0) || (!SYNC && state == WAKING) {
300 // A thread is currently waking the registered waker, so we can
301 // assume we should not wait and return immediately.
302 // If a waking thread is preempted before resetting the state,
303 // the task could loop infinitely on this state. This
304 // is caught by loom and requires `spin_loop` to escape the
305 // infinite loop. In practice, `spin_loop` or `Waker::wake`
306 // are already expected to be called in between.
307 #[cfg(loom)]
308 ::loom::hint::spin_loop();
309 return false;
310 }
311 // We voluntarily don't handle `state & EMPTY != 0` in `register` and
312 // only handle index 0 instead to avoid dependency on the state when
313 // computing `self.wakers[0].will_wake(&waker)`, allowing speculative
314 // execution.
315 if CACHED && state & EMPTY != 0 {
316 // SAFETY: same as in `register`
317 unsafe {
318 if state == EMPTY {
319 // State is `EMPTY | 0`, but the cached waker needs to be overwritten.
320 self.wakers[0].drop();
321 self.wakers[0].set(waker.clone());
322 } else if self.wakers[1].will_wake(waker) {
323 // If the cached waker at index 1 matches, it is moved to
324 // index 0 to optimize future `register`.
325 self.wakers[0].set(ManuallyDrop::into_inner(self.wakers[1].get()));
326 } else {
327 // Otherwise, overwrite the cached waker, writing the new
328 // one at index 0 to optimize future `register`.
329 self.wakers[1].drop();
330 self.wakers[0].set(waker.clone());
331 }
332 }
333 // same as in `register`
334 if SYNC {
335 self.state.swap(0, SeqCst);
336 } else {
337 self.state.store(0, SeqCst);
338 }
339 return true;
340 }
341 let cur_idx = state;
342 // SAFETY: state is not EMPTY nor WAKING, so it must be the cell index
343 // of a registered waker.
344 unsafe { assert_unchecked(cur_idx < 2) };
345 // If the new waker wakes the same task, there is no need to replace it.
346 // Crucially, no state update is needed even for `SYNC=true`: the `SeqCst`
347 // load at the top of `register` already participates in the total SeqCst
348 // order, so any release from a preceding `wake` is already visible to
349 // the caller — the synchronization guarantee is satisfied regardless.
350 // SAFETY: `overwrite` cannot be called concurrently, but `wake` could. However,
351 // both access the cell immutably, so it is safe.
352 if unsafe { self.wakers[cur_idx].will_wake(waker) } {
353 return true;
354 }
355 let new_idx = (cur_idx + 1) % 2;
356 // SAFETY: SeqCst protect against outdated read, and `overwrite` cannot be called
357 // concurrently. It means that `wake` can only access the cell at `cur_idx`, so
358 // the cell at `new_idx` is safe to access mutably.
359 unsafe { self.wakers[new_idx].set(waker.clone()) };
360 // The cell index is attempted to be swapped with the new one just initialized.
361 if let Err(state) = (self.state).compare_exchange(cur_idx, new_idx, SeqCst, SeqCst) {
362 // State update failed, which means a concurrent `wake` was happening.
363 // The registered waker should be dropped.
364 debug_assert!(state >= 2);
365 // SAFETY: state has not been updated, so `new_idx` cell is still safe
366 // to access, and the waker previously set can be taken back.
367 unsafe { ManuallyDrop::drop(&mut self.wakers[new_idx].get()) }
368 false
369 } else {
370 // SAFETY: cell index has been successfully swapped, so the cell
371 // at `cur_idx` is now safe to access to drop its waker.
372 unsafe { self.wakers[cur_idx].drop() };
373 true
374 }
375 }
376
377 /// Removes the registered waker if there is one, returning `true` in this case.
378 ///
379 /// It allows avoiding spurious wakeups when a waker has been registered,
380 /// but the wake condition is already met.
381 ///
382 /// # Safety
383 ///
384 /// `register` and `unregister` methods must not be called concurrently
385 /// from multiple threads.
386 #[inline]
387 pub unsafe fn unregister(&self) -> bool {
388 #[cfg(all(debug_assertions, not(loom)))]
389 let _guard = self.exclusive.check();
390 let state = self.state.load(Relaxed);
391 let Some(waker_cell) = self.wakers.get(state) else {
392 return false;
393 };
394 let empty = if CACHED { state | EMPTY } else { EMPTY };
395 // Relaxed order is ok here, as `unregister` and `register` are called in the same
396 // thread, i.e. sequenced-before, so there is no risk that this CAS make possible a
397 // stale load of an empty state instead of inhabited state. It may provoke a stale load
398 // of inhabited state while empty, but wake deals with it.
399 let res = self.state.compare_exchange(state, empty, Relaxed, Relaxed);
400 match res {
401 // SAFETY: state has been swapped to EMPTY, so the cell can
402 // no longer be accessed by `wake`, and its waker can be taken
403 Ok(_) if !CACHED => unsafe { waker_cell.drop() },
404 Ok(_) => {}
405 Err(s) => debug_assert!(s >= 2),
406 }
407 res.is_ok()
408 }
409
410 /// Returns `true` if a waker is currently registered.
411 ///
412 /// This provides a best-effort snapshot: a concurrent [`wake`] call may
413 /// consume the waker right after this returns `true`, and a concurrent
414 /// [`register`] call may store one right after this returns `false`.
415 ///
416 /// Calling `has_waker_registered` then `wake` if it is returned `true`
417 /// is guaranteed to provide the same synchronization as calling `wake`
418 /// alone.
419 ///
420 /// [`register`]: Self::register
421 /// [`wake`]: Self::wake
422 #[inline]
423 pub fn has_waker_registered(&self) -> bool {
424 if SYNC {
425 // See `check_before_wake` about `fetch_add(0)`
426 self.state.load(Relaxed) < 2 || self.state.fetch_add(0, SeqCst) < 2
427 } else {
428 self.state.load(SeqCst) < 2
429 }
430 }
431
432 /// Calls `wake` on the last `Waker` passed to `register`.
433 ///
434 /// If `register` has not been called yet, then this does nothing.
435 #[inline]
436 pub fn wake(&self) {
437 self.check_before_wake(false, Self::wake_waker);
438 }
439
440 /// Same as [`wake`](Self::wake), but with the waking path marked `#[cold]`.
441 ///
442 /// This allows the method to inline more effectively. Prefer this over
443 /// `wake` when waking is the uncommon case.
444 #[inline]
445 pub fn wake_cold(&self) {
446 self.check_before_wake(true, Self::wake_waker);
447 }
448
449 fn wake_waker(waker: Option<ManuallyDrop<Waker>>) {
450 match waker {
451 Some(w) if CACHED => w.wake_by_ref(),
452 Some(w) if !CACHED => ManuallyDrop::into_inner(w).wake(),
453 _ => {}
454 }
455 }
456
457 #[inline(always)]
458 fn check_before_wake<R>(
459 &self,
460 cold: bool,
461 wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
462 ) -> R {
463 if SYNC {
464 if cold {
465 // SYNC=true requires a Release write on the state, but we don't want to set
466 // the WAKING bit if there is no waker, as it would require unsetting it.
467 // So we attempt a `fetch_add(0)` and hope for no concurrent `register`.
468 if self.state.load(Relaxed) >= 2 && self.state.fetch_add(0, SeqCst) >= 2 {
469 return wake(None);
470 }
471 self.wake_sync_cold(wake)
472 } else {
473 self.wake_sync(wake)
474 }
475 } else {
476 // Load the state to check if there is a registered waker.
477 let state = self.state.load(SeqCst);
478 if state >= 2 {
479 wake(None)
480 } else if cold {
481 self.wake_unsync_cold(state, wake)
482 } else {
483 self.wake_unsync(state, wake)
484 }
485 }
486 }
487
488 fn wake_sync<R>(&self, wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R) -> R {
489 // There might be a waker registered, set the WAKING bit.
490 let state = self.state.fetch_or(WAKING, SeqCst);
491 // A concurrent `wake` has won the race, just return.
492 if state & WAKING != 0 {
493 return wake(None);
494 }
495 if let Some(waker_cell) = self.wakers.get(state) {
496 // SAFETY: the state is locked on WAKING, the cell can be concurrently
497 // accessed with `will_wake`, but it can still be accessed immutably.
498 // The waker is taken before resetting the state.
499 let waker = unsafe { waker_cell.get() };
500 // At this point the only concurrent operation will be:
501 // - fetch_add(0), no issue
502 // - fetch_or(WAKING), another `wake` is losing the race
503 // - CAS(new_idx, cur_idx), will fail because of WAKING flag
504 // The state can thus be swapped to EMPTY without issue.
505 // It could be tempting to use a store instead, but it would not
506 // work as it might overwrite a potential fetch_or and prevent
507 // the synchronization of a racing wake with the next register.
508 let empty = if CACHED { state | EMPTY } else { EMPTY };
509 self.state.swap(empty, SeqCst);
510 wake(Some(waker))
511 } else {
512 // Too bad, no waker was registered. It means that a concurrent `register`
513 // might be concurrently storing a waker in cell 0 and swap the state with
514 // EMPTY. We still need to unset the WAKING flag, but we don't care if it
515 // fails, as it would mean the flag has been unset anyway.
516 // It is theoretically possible that WAKING flag has been already unset and
517 // that another thread has already set it back. In this case, either the
518 // state was not EMPTY and this CAS will fail, or the state was EMPTY and
519 // the other thread doesn't care as much as us about its CAS succeeding.
520 debug_assert!((CACHED && state & EMPTY != 0) || (!CACHED && state == EMPTY));
521 let _ = (self.state).compare_exchange(state | WAKING, state, SeqCst, Relaxed);
522 wake(None)
523 }
524 }
525
526 #[cold]
527 #[inline(never)]
528 fn wake_sync_cold<R>(&self, wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R) -> R {
529 self.wake_sync(wake)
530 }
531
532 fn wake_unsync<R>(
533 &self,
534 state: usize,
535 wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
536 ) -> R {
537 unsafe { assert_unchecked(state < 2) };
538 // Try swapping the state with WAKING. If it fails, it means either:
539 // - a concurrent `wake` has won the race, so we can return
540 // - the waker was overwritten, so the registering thread is supposed
541 // to check again its wakeup condition, so we can just return
542 if (self.state.compare_exchange(state, WAKING, SeqCst, Relaxed)).is_err() {
543 return wake(None);
544 };
545 // SAFETY: the state has been swapped, so a concurrent `overwrite` CAS
546 // will fail, and it is safe to access the cell to take its waker
547 let waker = unsafe { self.wakers[state].get() };
548 // The state can be reset to EMPTY with a simple store.
549 // (loom doesn't support SeqCst and uses RMW operation instead)
550 let empty = if CACHED { state | EMPTY } else { EMPTY };
551 self.state.store(empty, SeqCst);
552 wake(Some(waker))
553 }
554
555 #[cold]
556 #[inline(never)]
557 fn wake_unsync_cold<R>(
558 &self,
559 state: usize,
560 wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
561 ) -> R {
562 self.wake_unsync(state, wake)
563 }
564}
565
566impl<const SYNC: bool> SpmcWaker<SYNC, false> {
567 /// Returns the last `Waker` passed to `register`, so that the caller can wake it.
568 ///
569 /// Sometimes, just waking the `SpmcWaker` is not fine-grained enough. This allows the caller
570 /// to take the waker and then wake it separately, rather than performing both steps in one
571 /// atomic action.
572 ///
573 /// If a waker has not been registered, this returns `None`.
574 pub fn take(&self) -> Option<Waker> {
575 self.check_before_wake(false, Self::take_waker)
576 }
577
578 /// Same as [`take`](Self::take), but with the taking path marked `#[cold]`.
579 ///
580 /// This allows the method to inline more effectively. Prefer this over
581 /// `take` when taking is the uncommon case.
582 #[inline]
583 pub fn take_cold(&self) -> Option<Waker> {
584 self.check_before_wake(true, Self::take_waker)
585 }
586
587 fn take_waker(waker: Option<ManuallyDrop<Waker>>) -> Option<Waker> {
588 waker.map(ManuallyDrop::into_inner)
589 }
590}
591
592impl<const SYNC: bool, const CACHED: bool> Default for SpmcWaker<SYNC, CACHED> {
593 fn default() -> Self {
594 Self::new()
595 }
596}