cooked_waker/
lib.rs

1#![no_std]
2
3//! cooked_waker provides safe traits for working with
4//! [`std::task::Waker`][Waker] and creating those wakers out of regular, safe
5//! Rust structs. It cooks `RawWaker` and `RawWakerVTable`, making them safe
6//! for consumption.
7//!
8//! It provides the [`Wake`] and [`WakeRef`] traits, which correspond to the
9//! [`wake`][Waker::wake] and [`wake_by_ref`][Waker::wake_by_ref] methods
10//! on [`std::task::Waker`][Waker], and it provides implenetations of these
11//! types for the common reference & pointer types (`Arc`, `Rc`, `&'static`,
12//! etc).
13//!
14//! Additionally, it provides [`IntoWaker`], which allows converting any
15//! `Wake + Clone` type into a [`Waker`]. This trait is automatically derived
16//! for any `Wake + Clone + Send + Sync + 'static` type.
17//!
18//! # Basic example
19//!
20//! ```
21//! use cooked_waker::{Wake, WakeRef, IntoWaker, ViaRawPointer};
22//! use std::sync::atomic::{AtomicUsize, Ordering};
23//! use std::task::Waker;
24//!
25//! static wake_ref_count: AtomicUsize = AtomicUsize::new(0);
26//! static wake_value_count: AtomicUsize = AtomicUsize::new(0);
27//! static drop_count: AtomicUsize = AtomicUsize::new(0);
28//!
29//! // A simple Waker struct that atomically increments the relevant static
30//! // counters.
31//! #[derive(Debug, Clone)]
32//! struct StaticWaker;
33//!
34//! impl WakeRef for StaticWaker {
35//!     fn wake_by_ref(&self) {
36//!         wake_ref_count.fetch_add(1, Ordering::SeqCst);
37//!     }
38//! }
39//!
40//! impl Wake for StaticWaker {
41//!     fn wake(self) {
42//!         wake_value_count.fetch_add(1, Ordering::SeqCst);
43//!     }
44//! }
45//!
46//! impl Drop for StaticWaker {
47//!     fn drop(&mut self) {
48//!         drop_count.fetch_add(1, Ordering::SeqCst);
49//!     }
50//! }
51//!
52//! // Usually in practice you'll be using an Arc or Box, which already
53//! // implement this, so there will be no need to implement it yourself.
54//! unsafe impl ViaRawPointer for StaticWaker {
55//!     type Target = ();
56//!
57//!     fn into_raw(self) -> *mut () {
58//!         // Need to forget self because we're being converted into a pointer,
59//!         // so destructors should not run.
60//!         std::mem::forget(self);
61//!         std::ptr::null_mut()
62//!     }
63//!
64//!     unsafe fn from_raw(ptr: *mut ()) -> Self {
65//!         StaticWaker
66//!     }
67//! }
68//!
69//! assert_eq!(drop_count.load(Ordering::SeqCst), 0);
70//!
71//! let waker = StaticWaker;
72//! {
73//!     let waker1: Waker = waker.into_waker();
74//!
75//!     waker1.wake_by_ref();
76//!     assert_eq!(wake_ref_count.load(Ordering::SeqCst), 1);
77//!
78//!     let waker2: Waker = waker1.clone();
79//!     waker2.wake_by_ref();
80//!     assert_eq!(wake_ref_count.load(Ordering::SeqCst), 2);
81//!
82//!     waker1.wake();
83//!     assert_eq!(wake_value_count.load(Ordering::SeqCst), 1);
84//!     assert_eq!(drop_count.load(Ordering::SeqCst), 1);
85//! }
86//! assert_eq!(drop_count.load(Ordering::SeqCst), 2);
87//! ```
88//!
89//! # Arc example
90//!
91//! ```
92//! use cooked_waker::{Wake, WakeRef, IntoWaker};
93//! use std::sync::atomic::{AtomicUsize, Ordering};
94//! use std::sync::Arc;
95//! use std::task::Waker;
96//!
97//! // A simple struct that counts the number of times it is awoken. Can't
98//! // be awoken by value (because that would discard the counter), so we
99//! // must instead wrap it in an Arc.
100//! #[derive(Debug, Default)]
101//! struct Counter {
102//!     // We use atomic usize because we need Send + Sync and also interior
103//!     // mutability
104//!     count: AtomicUsize,
105//! }
106//!
107//! impl Counter {
108//!     fn get(&self) -> usize {
109//!         self.count.load(Ordering::SeqCst)
110//!     }
111//! }
112//!
113//! impl WakeRef for Counter {
114//!     fn wake_by_ref(&self) {
115//!         let _prev = self.count.fetch_add(1, Ordering::SeqCst);
116//!     }
117//! }
118//!
119//! let counter_handle = Arc::new(Counter::default());
120//!
121//! // Create an std::task::Waker
122//! let waker: Waker = counter_handle.clone().into_waker();
123//!
124//! waker.wake_by_ref();
125//! waker.wake_by_ref();
126//!
127//! let waker2 = waker.clone();
128//! waker2.wake_by_ref();
129//!
130//! // Because IntoWaker wrap the pointer directly, without additional
131//! // boxing, we can use will_wake
132//! assert!(waker.will_wake(&waker2));
133//!
134//! // This calls Counter::wake_by_ref because the Arc doesn't have exclusive
135//! // ownership of the underlying Counter
136//! waker2.wake();
137//!
138//! assert_eq!(counter_handle.get(), 4);
139//! ```
140
141extern crate alloc;
142
143use alloc::boxed::Box;
144use alloc::rc;
145use alloc::sync as arc;
146use core::{
147    mem::ManuallyDrop,
148    ptr,
149    task::{RawWaker, RawWakerVTable, Waker},
150};
151
152/// Trait for types that can be converted into raw pointers and back again.
153///
154/// # Safety
155///
156/// - Implementors must ensure that, for a given object, the pointer remains
157///   fixed as long as no mutable operations are performed (that is, calling
158///   from_ptr() followed by into_ptr(), with no mutable operations in between,
159///   the returned pointer has the same value.)
160/// - Implementors also must not panic when the interface is used correctly.
161///   The Waker constructed by IntoWaker can cause a double drop if either of
162///   these functions panic.
163///
164/// In the future, we hope to have a similar trait added to the standard
165/// library; see https://github.com/rust-lang/rust/issues/75846 for details.
166pub unsafe trait ViaRawPointer {
167    type Target: ?Sized;
168
169    /// Convert this object into a raw pointer.
170    fn into_raw(self) -> *mut Self::Target;
171
172    /// Convert a raw pointer back into this object.
173    ///
174    /// # Safety
175    ///
176    /// This method must ONLY be called on a pointer that was received via
177    /// `Self::into_raw`, and that pointer must not be used afterwards.
178    unsafe fn from_raw(ptr: *mut Self::Target) -> Self;
179}
180
181/// Wakers that can wake by reference. This trait is used to enable a [`Wake`]
182/// implementation for types that don't own an underlying handle, like `Arc<T>`
183/// and `&T`.
184///
185/// This trait is implemented for most container and reference types, like
186/// `&T where T: WakeRef`, `Box<T: WakeRef>`, and `Arc<T: WakeRef>`.
187pub trait WakeRef {
188    /// Wake up the task by reference. In general [`Wake::wake`] should be
189    /// preferred, if available, as it's probably more efficient.
190    ///
191    /// A [`Waker`] created by [`IntoWaker`] will call this method through
192    /// [`Waker::wake_by_ref`].
193    fn wake_by_ref(&self);
194}
195
196/// Wakers that can wake by value. This is the primary means of waking a task.
197///
198/// This trait is implemented for most container types, like `Box<T: Wake>`
199/// and `Option<T: Wake>`. It is also implemented for shared pointer types like
200/// `Arc<T>` and `&T`, but those implementations call `T::wake_by_ref`, because
201/// they don't have ownership of the underlying `T`.
202pub trait Wake: WakeRef + Sized {
203    /// Wake up the task by value. By default, this simply calls
204    /// [`WakeRef::wake_by_ref`].
205    ///
206    /// A [`Waker`] created by [`IntoWaker`] will call this method through
207    /// [`Waker::wake`].
208    #[inline]
209    fn wake(self) {
210        self.wake_by_ref()
211    }
212}
213
214/// Objects that can be converted into an [`Waker`]. This trait is
215/// automatically implemented for types that fulfill the waker interface.
216/// Such types must be:
217/// - [`Clone`]
218/// - `Send + Sync`
219/// - `'static`
220/// - [`Wake`]
221/// - [`ViaRawPointer`]
222///
223/// The implementation of this trait sets up a [`RawWakerVTable`] for the type,
224/// and arranges a conversion into a [`Waker`] through the [`ViaRawPointer`]
225/// trait, which should be implemented for types that be converted to and from
226/// pointers. This trait is implemented for all the standard library pointer
227/// types (such as `Arc` and `Box`), and you can implement it on your own types
228/// if you want to use them for wakers.
229///
230/// It should never be necessary to implement this trait manually.
231///
232/// [`RawWakerVTable`]: core::task::RawWakerVTable
233/// [`Waker`]: core::task::Waker
234/// [`Clone`]: core::clone::Clone
235pub trait IntoWaker {
236    /// The RawWakerVTable for this type. This should never be used directly;
237    /// it is entirely handled by `into_waker`. It is present as an associated
238    /// const because that's the only way for it to work in generic contexts.
239    #[doc(hidden)]
240    const VTABLE: &'static RawWakerVTable;
241
242    /// Convert this object into a `Waker`.
243    #[must_use]
244    fn into_waker(self) -> Waker;
245}
246
247impl<T> IntoWaker for T
248where
249    T: Wake + Clone + Send + Sync + 'static + ViaRawPointer,
250    T::Target: Sized,
251{
252    const VTABLE: &'static RawWakerVTable = &RawWakerVTable::new(
253        // clone
254        |raw| {
255            let raw = raw as *mut T::Target;
256
257            let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
258            let cloned: T = (*waker).clone();
259
260            // We can't save the `into_raw` back into the raw waker, so we must
261            // simply assert that the pointer has remained the same. This is
262            // part of the ViaRawPointer safety contract, so we only check it
263            // in debug builds.
264            debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
265
266            let cloned_raw = cloned.into_raw();
267            let cloned_raw = cloned_raw as *const ();
268            RawWaker::new(cloned_raw, T::VTABLE)
269        },
270        // wake by value
271        |raw| {
272            let raw = raw as *mut T::Target;
273            let waker: T = unsafe { ViaRawPointer::from_raw(raw) };
274            waker.wake();
275        },
276        // wake by ref
277        |raw| {
278            let raw = raw as *mut T::Target;
279            let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
280            waker.wake_by_ref();
281
282            debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
283        },
284        // Drop
285        |raw| {
286            let raw = raw as *mut T::Target;
287            let _waker: T = unsafe { ViaRawPointer::from_raw(raw) };
288        },
289    );
290
291    fn into_waker(self) -> Waker {
292        let raw = self.into_raw();
293        let raw = raw as *const ();
294        let raw_waker = RawWaker::new(raw, T::VTABLE);
295        unsafe { Waker::from_raw(raw_waker) }
296    }
297}
298
299// Waker implementations for std types. Feel free to open PRs for additional
300// stdlib types here.
301
302// We'd prefer to implement WakeRef for T: Deref<Target=WakeRef>, but that
303// results in type coherence issues with non-deref stdlib types.
304
305impl<T: WakeRef + ?Sized> WakeRef for &T {
306    #[inline]
307    fn wake_by_ref(&self) {
308        T::wake_by_ref(*self)
309    }
310}
311
312impl<T: WakeRef + ?Sized> Wake for &T {}
313
314unsafe impl<T: ?Sized> ViaRawPointer for Box<T> {
315    type Target = T;
316
317    fn into_raw(self) -> *mut T {
318        Box::into_raw(self)
319    }
320
321    unsafe fn from_raw(ptr: *mut T) -> Self {
322        Box::from_raw(ptr)
323    }
324}
325
326impl<T: WakeRef + ?Sized> WakeRef for Box<T> {
327    #[inline]
328    fn wake_by_ref(&self) {
329        T::wake_by_ref(self.as_ref())
330    }
331}
332
333impl<T: Wake> Wake for Box<T> {
334    #[inline]
335    fn wake(self) {
336        T::wake(*self)
337    }
338}
339
340unsafe impl<T: ?Sized> ViaRawPointer for arc::Arc<T> {
341    type Target = T;
342
343    fn into_raw(self) -> *mut T {
344        arc::Arc::into_raw(self) as *mut T
345    }
346
347    unsafe fn from_raw(ptr: *mut T) -> Self {
348        arc::Arc::from_raw(ptr as *const T)
349    }
350}
351
352impl<T: WakeRef + ?Sized> WakeRef for arc::Arc<T> {
353    #[inline]
354    fn wake_by_ref(&self) {
355        T::wake_by_ref(self.as_ref())
356    }
357}
358
359impl<T: WakeRef + ?Sized> Wake for arc::Arc<T> {}
360
361unsafe impl<T> ViaRawPointer for arc::Weak<T> {
362    type Target = T;
363
364    fn into_raw(self) -> *mut T {
365        arc::Weak::into_raw(self) as *mut T
366    }
367
368    unsafe fn from_raw(ptr: *mut T) -> Self {
369        arc::Weak::from_raw(ptr as *const T)
370    }
371}
372
373impl<T: WakeRef + ?Sized> WakeRef for arc::Weak<T> {
374    #[inline]
375    fn wake_by_ref(&self) {
376        self.upgrade().wake()
377    }
378}
379
380impl<T: WakeRef + ?Sized> Wake for arc::Weak<T> {}
381
382impl<T: WakeRef + ?Sized> WakeRef for rc::Rc<T> {
383    #[inline]
384    fn wake_by_ref(&self) {
385        T::wake_by_ref(self.as_ref())
386    }
387}
388
389unsafe impl<T: ?Sized> ViaRawPointer for rc::Rc<T> {
390    type Target = T;
391
392    fn into_raw(self) -> *mut T {
393        rc::Rc::into_raw(self) as *mut T
394    }
395
396    unsafe fn from_raw(ptr: *mut T) -> Self {
397        rc::Rc::from_raw(ptr as *const T)
398    }
399}
400
401impl<T: WakeRef + ?Sized> Wake for rc::Rc<T> {
402    #[inline]
403    fn wake(self) {
404        T::wake_by_ref(self.as_ref())
405    }
406}
407
408unsafe impl<T> ViaRawPointer for rc::Weak<T> {
409    type Target = T;
410
411    fn into_raw(self) -> *mut T {
412        rc::Weak::into_raw(self) as *mut T
413    }
414
415    unsafe fn from_raw(ptr: *mut T) -> Self {
416        rc::Weak::from_raw(ptr as *const T)
417    }
418}
419
420impl<T: WakeRef + ?Sized> WakeRef for rc::Weak<T> {
421    #[inline]
422    fn wake_by_ref(&self) {
423        self.upgrade().wake()
424    }
425}
426
427impl<T: WakeRef + ?Sized> Wake for rc::Weak<T> {}
428
429unsafe impl<T: ViaRawPointer> ViaRawPointer for Option<T>
430where
431    T::Target: Sized,
432{
433    type Target = T::Target;
434
435    fn into_raw(self) -> *mut Self::Target {
436        match self {
437            Some(value) => match value.into_raw() {
438                ptr if ptr.is_null() => {
439                    let _ = unsafe { T::from_raw(ptr) };
440                    ptr::null_mut()
441                }
442                ptr => ptr,
443            },
444            None => ptr::null_mut(),
445        }
446    }
447
448    unsafe fn from_raw(ptr: *mut Self::Target) -> Self {
449        match ptr.is_null() {
450            false => Some(T::from_raw(ptr)),
451            true => None,
452        }
453    }
454}
455
456impl<T: WakeRef> WakeRef for Option<T> {
457    #[inline]
458    fn wake_by_ref(&self) {
459        if let Some(waker) = self {
460            waker.wake_by_ref()
461        }
462    }
463}
464
465impl<T: Wake> Wake for Option<T> {
466    #[inline]
467    fn wake(self) {
468        if let Some(waker) = self {
469            waker.wake()
470        }
471    }
472}
473
474impl WakeRef for Waker {
475    #[inline]
476    fn wake_by_ref(&self) {
477        Waker::wake_by_ref(self)
478    }
479}
480
481impl Wake for Waker {
482    #[inline]
483    fn wake(self) {
484        Waker::wake(self)
485    }
486}
487
488#[cfg(test)]
489mod test {
490    extern crate std;
491
492    use super::*;
493    use std::panic;
494    use std::sync::atomic::{AtomicUsize, Ordering};
495    use std::task::Waker;
496
497    static PANIC_WAKE_REF_COUNT: AtomicUsize = AtomicUsize::new(0);
498    static PANIC_WAKE_VALUE_COUNT: AtomicUsize = AtomicUsize::new(0);
499    static PANIC_DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
500
501    #[derive(Debug, Clone)]
502    struct PanicWaker;
503
504    impl WakeRef for PanicWaker {
505        fn wake_by_ref(&self) {
506            PANIC_WAKE_REF_COUNT.fetch_add(1, Ordering::SeqCst);
507            panic!();
508        }
509    }
510
511    impl Wake for PanicWaker {
512        fn wake(self) {
513            PANIC_WAKE_VALUE_COUNT.fetch_add(1, Ordering::SeqCst);
514        }
515    }
516
517    impl Drop for PanicWaker {
518        fn drop(&mut self) {
519            PANIC_DROP_COUNT.fetch_add(1, Ordering::SeqCst);
520        }
521    }
522
523    unsafe impl ViaRawPointer for PanicWaker {
524        type Target = ();
525
526        fn into_raw(self) -> *mut () {
527            std::mem::forget(self);
528            std::ptr::null_mut()
529        }
530
531        unsafe fn from_raw(_ptr: *mut ()) -> Self {
532            PanicWaker
533        }
534    }
535
536    // Test that the wake_by_ref() behaves correctly even if it panics.
537    #[test]
538    fn panic_wake() {
539        assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
540
541        let waker = PanicWaker;
542        {
543            let waker1: Waker = waker.into_waker();
544
545            let waker2: Waker = waker1.clone();
546
547            let result = panic::catch_unwind(|| {
548                waker2.wake_by_ref();
549            });
550            assert!(result.is_err());
551            assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 1);
552            assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
553
554            let result = panic::catch_unwind(|| {
555                waker1.wake_by_ref();
556            });
557            assert!(result.is_err());
558            assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 2);
559            assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
560
561            let result = panic::catch_unwind(|| {
562                waker1.wake();
563            });
564            assert!(result.is_ok());
565            assert_eq!(PANIC_WAKE_VALUE_COUNT.load(Ordering::SeqCst), 1);
566            assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 1);
567        }
568        assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 2);
569    }
570}