Skip to main content

spmc_waker/
lib.rs

1//! A synchronization primitive for task wakeup.
2//!
3//! This crate provides [`SpmcWaker`], a single-producer, multiple-consumer (SPMC)
4//! atomic waker.
5//!
6//! # Features
7//!
8//! - `portable-atomic`: use `portable-atomic` crate to provide functionality to
9//!   targets without atomics.
10#![no_std]
11#[cfg(doc)]
12extern crate std;
13use core::{hint::assert_unchecked, mem::ManuallyDrop, task::Waker};
14
15use crate::{
16    loom::{
17        AtomicUsizeExt,
18        sync::atomic::{
19            AtomicUsize,
20            Ordering::{Relaxed, SeqCst},
21        },
22    },
23    waker_cell::WakerCell,
24};
25
26#[cfg(all(debug_assertions, not(loom)))]
27mod exclusive;
28mod loom;
29mod waker_cell;
30
31const EMPTY: usize = 2;
32const WAKING: usize = 4;
33
34/// A synchronization primitive for task wakeup.
35///
36/// Sometimes the task interested in a given event will change over time.
37/// A `SpmcWaker` can coordinate concurrent notifications with the consumer
38/// potentially "updating" the underlying task to wake up. This is useful in
39/// scenarios where a computation completes in another thread and wants to
40/// notify the consumer, but the consumer is in the process of being migrated to
41/// a new logical task.
42///
43/// Consumers should call `register` before checking the result of a computation
44/// and producers should call `wake` after producing the computation (this
45/// differs from the usual `thread::park` pattern). It is also permitted for
46/// `wake` to be called **before** `register`. This results in a no-op.
47///
48/// A single `SpmcWaker` may be reused for any number of calls to `register` or
49/// `wake`.
50///
51/// # Single-producer, multiple-consumer (SPMC)
52///
53/// `SpmcWaker` algorithm assumes a single thread calling `register`/`unregister`
54/// at a time. It is enforced by the methods' safety condition.
55///
56/// This assumption allows significant optimizations compared to an MPMC algorithm
57/// like [`AtomicWaker`].
58///
59/// # Memory ordering
60///
61/// `SpmcWaker` atomic operations use `SeqCst` ordering, and it has a generic
62/// `SYNC` parameter which determines the synchronization guarantees.
63///
64/// ### `SYNC=false` (the default)
65///
66/// There is no acquire-release synchronization between `register` and `wake`.
67///
68/// Because a `wake` call may not see the waker registered by a concurrent
69/// `register`, the waking condition should use a total order, i.e. `SeqCst`
70/// or RMW operations. It ensures that checking the waking condition after
71/// `register` succeeds even when a concurrent `wake` misses the registered
72/// waker.
73///
74/// When no waker is registered, `wake` is reduced to a single atomic load.
75///
76/// ### `SYNC=true`
77///
78/// Calling `register` "acquires" all memory "released" by calls to `wake`
79/// before the call to `register`.
80///
81/// It allows setting the waking condition and checking it with a relaxed
82/// ordering after the registration, at the cost of having a mandatory
83/// atomic RMW operation in `wake`.
84///
85/// If the waking condition is already set through an atomic RMW operation,
86/// adding `SeqCst` ordering to it and to the waking condition check
87/// comes at a minimal cost, and allows to save an atomic RMW operation
88/// in `wake` by switching to `SYNC=false`. As a matter of fact `SYNC=true`
89/// should only be considered when the waking condition has no RMW involved.
90///
91/// # Waker caching
92///
93/// Most of the time, `SpmcWaker` is used in a single task, so the waker
94/// registered is always the same. That's why it provides a second generic
95/// parameter `CACHED`.
96///
97/// ### `CACHED=true` (the default)
98///
99/// The last waker registered is kept cached to avoid cloning it at the next
100/// registration. As a consequence, waking is done with [`Waker::wake_by_ref`].
101/// As wakers are often `Arc`s, caching avoids atomic RMW operations updating
102/// the reference counter.
103///
104/// ### `CACHED=false`
105///
106/// Waker is cloned when registered by reference, and the tasks are woken with
107/// [`Waker::wake`].
108///
109/// # Examples
110///
111/// Here is a simple example providing a `Flag` that can be signaled manually
112/// when it is ready.
113///
114/// ```rust
115/// use std::{
116///     pin::Pin,
117///     sync::{
118///         Arc,
119///         atomic::{
120///             AtomicBool,
121///             Ordering::{Relaxed, SeqCst},
122///         },
123///     },
124///     task::{Context, Poll},
125/// };
126///
127/// use spmc_waker::SpmcWaker;
128///
129/// #[derive(Default)]
130/// struct Inner {
131///     notified: AtomicBool,
132///     waker: SpmcWaker,
133/// }
134///
135/// #[derive(Clone)]
136/// struct Notifier(Arc<Inner>);
137///
138/// impl Notifier {
139///     pub fn new() -> Self {
140///         Self(Arc::new(Inner {
141///             waker: SpmcWaker::new(),
142///             notified: AtomicBool::new(false),
143///         }))
144///     }
145///
146///     pub fn signal(&self) {
147///         // Use seqcst ordering to synchronize with the load after `register`
148///         self.0.notified.store(true, SeqCst);
149///         self.0.waker.wake();
150///     }
151/// }
152///
153/// #[derive(Default)]
154/// struct Waiter(Arc<Inner>);
155///
156/// impl Waiter {
157///     fn notifier(&self) -> Notifier {
158///         Notifier(self.0.clone())
159///     }
160/// }
161///
162/// impl Future for Waiter {
163///     type Output = ();
164///
165///     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
166///         // quick check to avoid registration if already done.
167///         if self.0.notified.load(Relaxed) {
168///             return Poll::Ready(());
169///         }
170///
171///         // SAFETY: mutable reference on non-cloneable `Waiter` ensures no concurrent call
172///         unsafe { self.0.waker.register(cx.waker()) };
173///
174///         // Need to check condition **after** `register` to avoid a race
175///         // condition that would result in lost notifications.
176///         // Use seqcst ordering so it synchronizes with the store before wake.
177///         if self.0.notified.load(SeqCst) {
178///             // Unregister the waker to avoid spurious wakeups.
179///             // SAFETY: mutable reference on non-cloneable `Waiter` ensures no concurrent call
180///             unsafe { self.0.waker.unregister() };
181///             Poll::Ready(())
182///         } else {
183///             Poll::Pending
184///         }
185///     }
186/// }
187///
188/// fn event() -> (Notifier, Waiter) {
189///     let waiter = Waiter::default();
190///     (waiter.notifier(), waiter)
191/// }
192/// ```
193///
194/// [`AtomicWaker`]: https://docs.rs/futures/latest/futures/task/struct.AtomicWaker.html
195#[derive(Debug)]
196pub struct SpmcWaker<const SYNC: bool = false, const CACHED: bool = true> {
197    wakers: [WakerCell; 2],
198    /// State possible values are:
199    /// - 0 or 1: A waker is registered in `wakers[state]`
200    /// - EMPTY: there is no waker registered
201    ///   with CACHED=true, it becomes a bit-flag and the state's LSB gives
202    ///   the cached waker index (cells are initialized with dummy wakers)
203    /// - WAKING: a `wake` operation is ongoing;
204    ///   with SYNC=true, it becomes a bit-flag
205    state: AtomicUsize,
206    #[cfg(all(debug_assertions, not(loom)))]
207    exclusive: exclusive::Exclusive,
208}
209
210unsafe impl<const SYNC: bool, const CACHED: bool> Send for SpmcWaker<SYNC, CACHED> {}
211unsafe impl<const SYNC: bool, const CACHED: bool> Sync for SpmcWaker<SYNC, CACHED> {}
212
213impl<const SYNC: bool, const CACHED: bool> Drop for SpmcWaker<SYNC, CACHED> {
214    #[inline]
215    fn drop(&mut self) {
216        let state = self.state.load_mut();
217        if CACHED || state < 2 {
218            // SAFETY: state is the index of a waker currently registered
219            // that must be taken back, and access is safe in destructor
220            unsafe { self.wakers[state % 2].drop() };
221        }
222    }
223}
224
225impl<const SYNC: bool, const CACHED: bool> SpmcWaker<SYNC, CACHED> {
226    /// Creates a new `SpmcWaker`.
227    #[cfg_attr(loom, const_fn::const_fn(cfg(false)))]
228    #[inline]
229    pub const fn new() -> Self {
230        Self {
231            wakers: [WakerCell::new(), WakerCell::new()],
232            state: AtomicUsize::new(EMPTY),
233            #[cfg(all(debug_assertions, not(loom)))]
234            exclusive: exclusive::Exclusive::new(),
235        }
236    }
237
238    /// Registers the waker to be notified on calls to `wake`.
239    ///
240    /// The new task will take place of any previous tasks that were registered
241    /// by previous calls to `register`. Any calls to `wake` that happen after
242    /// a call to `register` (as defined by the memory ordering rules), will
243    /// notify the `register` caller's task and deregister the waker from future
244    /// notifications. Because of this, callers should ensure `register` gets
245    /// invoked with a new `Waker` **each** time they require a wakeup.
246    ///
247    /// It is safe to call `register` with multiple other threads concurrently
248    /// calling `wake`. This will result in the `register` caller's current
249    /// task being notified once. A concurrent `wake` may prevent `register`
250    /// to succeed, in which case it will return `false`. If despite the
251    /// concurrent `wake`, the wakeup condition is still not fulfilled, then
252    /// `Waker::wake` might be called to reschedule the task and give it
253    /// another opportunity to register is waker — this would be equivalent
254    /// to [`std::thread::yield_now`]. It is also possible to call `register`
255    /// in small [spin-loop](std::hint::spin_loop), before falling back to
256    /// calling `Waker::wake`.
257    ///
258    /// # Safety
259    ///
260    /// `register` and `unregister` methods must not be called concurrently
261    /// from multiple threads.
262    #[inline]
263    pub unsafe fn register(&self, waker: &Waker) -> bool {
264        #[cfg(all(debug_assertions, not(loom)))]
265        let _guard = self.exclusive.check();
266        // State is loaded and expected to be EMPTY. Otherwise, it means
267        // there already is a registered waker that needs to be overwritten.
268        let state = self.state.load(SeqCst);
269        // The case `CACHED && state == EMPTY | 1` is handled in `overwrite`.
270        if state == EMPTY {
271            // SAFETY: SeqCst protect against outdated read, and `register`
272            // cannot be called concurrently. It means that reading EMPTY
273            // ensures there cannot be any registered waker at this point.
274            // A concurrent `wake` will thus not attempt any read, so it's
275            // safe to access both cells mutably.
276            unsafe {
277                if !CACHED {
278                    self.wakers[0].set(waker.clone());
279                } else if !self.wakers[0].will_wake(waker) {
280                    return self.overwrite(waker, state);
281                }
282            }
283            // SYNC=true uses swap, as `wake` must synchronize with `register`
284            if SYNC {
285                self.state.swap(0, SeqCst);
286            } else {
287                self.state.store(0, SeqCst);
288            }
289            true
290        } else {
291            self.overwrite(waker, state)
292        }
293    }
294
295    // Overwriting a registered waker is expected to be rare, hence the `#[cold]` attribute.
296    #[cold]
297    fn overwrite(&self, waker: &Waker, state: usize) -> bool {
298        // A concurrent `wake` may be happening.
299        if (SYNC && state & WAKING != 0) || (!SYNC && state == WAKING) {
300            // A thread is currently waking the registered waker, so we can
301            // assume we should not wait and return immediately.
302            // If a waking thread is preempted before resetting the state,
303            // the task could loop infinitely on this state. This
304            // is caught by loom and requires `spin_loop` to escape the
305            // infinite loop. In practice, `spin_loop` or `Waker::wake`
306            // are already expected to be called in between.
307            #[cfg(loom)]
308            ::loom::hint::spin_loop();
309            return false;
310        }
311        // We voluntarily don't handle `state & EMPTY != 0` in `register` and
312        // only handle index 0 instead to avoid dependency on the state when
313        // computing `self.wakers[0].will_wake(&waker)`, allowing speculative
314        // execution.
315        if CACHED && state & EMPTY != 0 {
316            // SAFETY: same as in `register`
317            unsafe {
318                if state == EMPTY {
319                    // State is `EMPTY | 0`, but the cached waker needs to be overwritten.
320                    self.wakers[0].drop();
321                    self.wakers[0].set(waker.clone());
322                } else if self.wakers[1].will_wake(waker) {
323                    // If the cached waker at index 1 matches, it is moved to
324                    // index 0 to optimize future `register`.
325                    self.wakers[0].set(ManuallyDrop::into_inner(self.wakers[1].get()));
326                } else {
327                    // Otherwise, overwrite the cached waker, writing the new
328                    // one at index 0 to optimize future `register`.
329                    self.wakers[1].drop();
330                    self.wakers[0].set(waker.clone());
331                }
332            }
333            // same as in `register`
334            if SYNC {
335                self.state.swap(0, SeqCst);
336            } else {
337                self.state.store(0, SeqCst);
338            }
339            return true;
340        }
341        let cur_idx = state;
342        // SAFETY: state is not EMPTY nor WAKING, so it must be the cell index
343        // of a registered waker.
344        unsafe { assert_unchecked(cur_idx < 2) };
345        // If the new waker wakes the same task, there is no need to replace it.
346        // Crucially, no state update is needed even for `SYNC=true`: the `SeqCst`
347        // load at the top of `register` already participates in the total SeqCst
348        // order, so any release from a preceding `wake` is already visible to
349        // the caller — the synchronization guarantee is satisfied regardless.
350        // SAFETY: `overwrite` cannot be called concurrently, but `wake` could. However,
351        // both access the cell immutably, so it is safe.
352        if unsafe { self.wakers[cur_idx].will_wake(waker) } {
353            return true;
354        }
355        let new_idx = (cur_idx + 1) % 2;
356        // SAFETY: SeqCst protect against outdated read, and `overwrite` cannot be called
357        // concurrently. It means that `wake` can only access the cell at `cur_idx`, so
358        // the cell at `new_idx` is safe to access mutably.
359        unsafe { self.wakers[new_idx].set(waker.clone()) };
360        // The cell index is attempted to be swapped with the new one just initialized.
361        if let Err(state) = (self.state).compare_exchange(cur_idx, new_idx, SeqCst, SeqCst) {
362            // State update failed, which means a concurrent `wake` was happening.
363            // The registered waker should be dropped.
364            debug_assert!(state >= 2);
365            // SAFETY: state has not been updated, so `new_idx` cell is still safe
366            // to access, and the waker previously set can be taken back.
367            unsafe { ManuallyDrop::drop(&mut self.wakers[new_idx].get()) }
368            false
369        } else {
370            // SAFETY: cell index has been successfully swapped, so the cell
371            // at `cur_idx` is now safe to access to drop its waker.
372            unsafe { self.wakers[cur_idx].drop() };
373            true
374        }
375    }
376
377    /// Removes the registered waker if there is one, returning `true` in this case.
378    ///
379    /// It allows avoiding spurious wakeups when a waker has been registered,
380    /// but the wake condition is already met.
381    ///
382    /// # Safety
383    ///
384    /// `register` and `unregister` methods must not be called concurrently
385    /// from multiple threads.
386    #[inline]
387    pub unsafe fn unregister(&self) -> bool {
388        #[cfg(all(debug_assertions, not(loom)))]
389        let _guard = self.exclusive.check();
390        let state = self.state.load(Relaxed);
391        let Some(waker_cell) = self.wakers.get(state) else {
392            return false;
393        };
394        let empty = if CACHED { state | EMPTY } else { EMPTY };
395        // Relaxed order is ok here, as `unregister` and `register` are called in the same
396        // thread, i.e. sequenced-before, so there is no risk that this CAS make possible a
397        // stale load of an empty state instead of inhabited state. It may provoke a stale load
398        // of inhabited state while empty, but wake deals with it.
399        let res = self.state.compare_exchange(state, empty, Relaxed, Relaxed);
400        match res {
401            // SAFETY: state has been swapped to EMPTY, so the cell can
402            // no longer be accessed by `wake`, and its waker can be taken
403            Ok(_) if !CACHED => unsafe { waker_cell.drop() },
404            Ok(_) => {}
405            Err(s) => debug_assert!(s >= 2),
406        }
407        res.is_ok()
408    }
409
410    /// Returns `true` if a waker is currently registered.
411    ///
412    /// This provides a best-effort snapshot: a concurrent [`wake`] call may
413    /// consume the waker right after this returns `true`, and a concurrent
414    /// [`register`] call may store one right after this returns `false`.
415    ///
416    /// Calling `has_waker_registered` then `wake` if it is returned `true`
417    /// is guaranteed to provide the same synchronization as calling `wake`
418    /// alone.
419    ///
420    /// [`register`]: Self::register
421    /// [`wake`]: Self::wake
422    #[inline]
423    pub fn has_waker_registered(&self) -> bool {
424        if SYNC {
425            // See `check_before_wake` about `fetch_add(0)`
426            self.state.load(Relaxed) < 2 || self.state.fetch_add(0, SeqCst) < 2
427        } else {
428            self.state.load(SeqCst) < 2
429        }
430    }
431
432    /// Calls `wake` on the last `Waker` passed to `register`.
433    ///
434    /// If `register` has not been called yet, then this does nothing.
435    #[inline]
436    pub fn wake(&self) {
437        self.check_before_wake(false, Self::wake_waker);
438    }
439
440    /// Same as [`wake`](Self::wake), but with the waking path marked `#[cold]`.
441    ///
442    /// This allows the method to inline more effectively. Prefer this over
443    /// `wake` when waking is the uncommon case.
444    #[inline]
445    pub fn wake_cold(&self) {
446        self.check_before_wake(true, Self::wake_waker);
447    }
448
449    fn wake_waker(waker: Option<ManuallyDrop<Waker>>) {
450        match waker {
451            Some(w) if CACHED => w.wake_by_ref(),
452            Some(w) if !CACHED => ManuallyDrop::into_inner(w).wake(),
453            _ => {}
454        }
455    }
456
457    #[inline(always)]
458    fn check_before_wake<R>(
459        &self,
460        cold: bool,
461        wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
462    ) -> R {
463        if SYNC {
464            if cold {
465                // SYNC=true requires a Release write on the state, but we don't want to set
466                // the WAKING bit if there is no waker, as it would require unsetting it.
467                // So we attempt a `fetch_add(0)` and hope for no concurrent `register`.
468                if self.state.load(Relaxed) >= 2 && self.state.fetch_add(0, SeqCst) >= 2 {
469                    return wake(None);
470                }
471                self.wake_sync_cold(wake)
472            } else {
473                self.wake_sync(wake)
474            }
475        } else {
476            // Load the state to check if there is a registered waker.
477            let state = self.state.load(SeqCst);
478            if state >= 2 {
479                wake(None)
480            } else if cold {
481                self.wake_unsync_cold(state, wake)
482            } else {
483                self.wake_unsync(state, wake)
484            }
485        }
486    }
487
488    fn wake_sync<R>(&self, wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R) -> R {
489        // There might be a waker registered, set the WAKING bit.
490        let state = self.state.fetch_or(WAKING, SeqCst);
491        // A concurrent `wake` has won the race, just return.
492        if state & WAKING != 0 {
493            return wake(None);
494        }
495        if let Some(waker_cell) = self.wakers.get(state) {
496            // SAFETY: the state is locked on WAKING, the cell can be concurrently
497            // accessed with `will_wake`, but it can still be accessed immutably.
498            // The waker is taken before resetting the state.
499            let waker = unsafe { waker_cell.get() };
500            // At this point the only concurrent operation will be:
501            // - fetch_add(0), no issue
502            // - fetch_or(WAKING), another `wake` is losing the race
503            // - CAS(new_idx, cur_idx), will fail because of WAKING flag
504            // The state can thus be swapped to EMPTY without issue.
505            // It could be tempting to use a store instead, but it would not
506            // work as it might overwrite a potential fetch_or and prevent
507            // the synchronization of a racing wake with the next register.
508            let empty = if CACHED { state | EMPTY } else { EMPTY };
509            self.state.swap(empty, SeqCst);
510            wake(Some(waker))
511        } else {
512            // Too bad, no waker was registered. It means that a concurrent `register`
513            // might be concurrently storing a waker in cell 0 and swap the state with
514            // EMPTY. We still need to unset the WAKING flag, but we don't care if it
515            // fails, as it would mean the flag has been unset anyway.
516            // It is theoretically possible that WAKING flag has been already unset and
517            // that another thread has already set it back. In this case, either the
518            // state was not EMPTY and this CAS will fail, or the state was EMPTY and
519            // the other thread doesn't care as much as us about its CAS succeeding.
520            debug_assert!((CACHED && state & EMPTY != 0) || (!CACHED && state == EMPTY));
521            let _ = (self.state).compare_exchange(state | WAKING, state, SeqCst, Relaxed);
522            wake(None)
523        }
524    }
525
526    #[cold]
527    #[inline(never)]
528    fn wake_sync_cold<R>(&self, wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R) -> R {
529        self.wake_sync(wake)
530    }
531
532    fn wake_unsync<R>(
533        &self,
534        state: usize,
535        wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
536    ) -> R {
537        unsafe { assert_unchecked(state < 2) };
538        // Try swapping the state with WAKING. If it fails, it means either:
539        // - a concurrent `wake` has won the race, so we can return
540        // - the waker was overwritten, so the registering thread is supposed
541        //   to check again its wakeup condition, so we can just return
542        if (self.state.compare_exchange(state, WAKING, SeqCst, Relaxed)).is_err() {
543            return wake(None);
544        };
545        // SAFETY: the state has been swapped, so a concurrent `overwrite` CAS
546        // will fail, and it is safe to access the cell to take its waker
547        let waker = unsafe { self.wakers[state].get() };
548        // The state can be reset to EMPTY with a simple store.
549        // (loom doesn't support SeqCst and uses RMW operation instead)
550        let empty = if CACHED { state | EMPTY } else { EMPTY };
551        self.state.store(empty, SeqCst);
552        wake(Some(waker))
553    }
554
555    #[cold]
556    #[inline(never)]
557    fn wake_unsync_cold<R>(
558        &self,
559        state: usize,
560        wake: impl FnOnce(Option<ManuallyDrop<Waker>>) -> R,
561    ) -> R {
562        self.wake_unsync(state, wake)
563    }
564}
565
566impl<const SYNC: bool> SpmcWaker<SYNC, false> {
567    /// Returns the last `Waker` passed to `register`, so that the caller can wake it.
568    ///
569    /// Sometimes, just waking the `SpmcWaker` is not fine-grained enough. This allows the caller
570    /// to take the waker and then wake it separately, rather than performing both steps in one
571    /// atomic action.
572    ///
573    /// If a waker has not been registered, this returns `None`.
574    pub fn take(&self) -> Option<Waker> {
575        self.check_before_wake(false, Self::take_waker)
576    }
577
578    /// Same as [`take`](Self::take), but with the taking path marked `#[cold]`.
579    ///
580    /// This allows the method to inline more effectively. Prefer this over
581    /// `take` when taking is the uncommon case.
582    #[inline]
583    pub fn take_cold(&self) -> Option<Waker> {
584        self.check_before_wake(true, Self::take_waker)
585    }
586
587    fn take_waker(waker: Option<ManuallyDrop<Waker>>) -> Option<Waker> {
588        waker.map(ManuallyDrop::into_inner)
589    }
590}
591
592impl<const SYNC: bool, const CACHED: bool> Default for SpmcWaker<SYNC, CACHED> {
593    fn default() -> Self {
594        Self::new()
595    }
596}