cooked_waker/lib.rs
1#![no_std]
2
3//! cooked_waker provides safe traits for working with
4//! [`std::task::Waker`][Waker] and creating those wakers out of regular, safe
5//! Rust structs. It cooks `RawWaker` and `RawWakerVTable`, making them safe
6//! for consumption.
7//!
8//! It provides the [`Wake`] and [`WakeRef`] traits, which correspond to the
9//! [`wake`][Waker::wake] and [`wake_by_ref`][Waker::wake_by_ref] methods
10//! on [`std::task::Waker`][Waker], and it provides implenetations of these
11//! types for the common reference & pointer types (`Arc`, `Rc`, `&'static`,
12//! etc).
13//!
14//! Additionally, it provides [`IntoWaker`], which allows converting any
15//! `Wake + Clone` type into a [`Waker`]. This trait is automatically derived
16//! for any `Wake + Clone + Send + Sync + 'static` type.
17//!
18//! # Basic example
19//!
20//! ```
21//! use cooked_waker::{Wake, WakeRef, IntoWaker, ViaRawPointer};
22//! use std::sync::atomic::{AtomicUsize, Ordering};
23//! use std::task::Waker;
24//!
25//! static wake_ref_count: AtomicUsize = AtomicUsize::new(0);
26//! static wake_value_count: AtomicUsize = AtomicUsize::new(0);
27//! static drop_count: AtomicUsize = AtomicUsize::new(0);
28//!
29//! // A simple Waker struct that atomically increments the relevant static
30//! // counters.
31//! #[derive(Debug, Clone)]
32//! struct StaticWaker;
33//!
34//! impl WakeRef for StaticWaker {
35//! fn wake_by_ref(&self) {
36//! wake_ref_count.fetch_add(1, Ordering::SeqCst);
37//! }
38//! }
39//!
40//! impl Wake for StaticWaker {
41//! fn wake(self) {
42//! wake_value_count.fetch_add(1, Ordering::SeqCst);
43//! }
44//! }
45//!
46//! impl Drop for StaticWaker {
47//! fn drop(&mut self) {
48//! drop_count.fetch_add(1, Ordering::SeqCst);
49//! }
50//! }
51//!
52//! // Usually in practice you'll be using an Arc or Box, which already
53//! // implement this, so there will be no need to implement it yourself.
54//! unsafe impl ViaRawPointer for StaticWaker {
55//! type Target = ();
56//!
57//! fn into_raw(self) -> *mut () {
58//! // Need to forget self because we're being converted into a pointer,
59//! // so destructors should not run.
60//! std::mem::forget(self);
61//! std::ptr::null_mut()
62//! }
63//!
64//! unsafe fn from_raw(ptr: *mut ()) -> Self {
65//! StaticWaker
66//! }
67//! }
68//!
69//! assert_eq!(drop_count.load(Ordering::SeqCst), 0);
70//!
71//! let waker = StaticWaker;
72//! {
73//! let waker1: Waker = waker.into_waker();
74//!
75//! waker1.wake_by_ref();
76//! assert_eq!(wake_ref_count.load(Ordering::SeqCst), 1);
77//!
78//! let waker2: Waker = waker1.clone();
79//! waker2.wake_by_ref();
80//! assert_eq!(wake_ref_count.load(Ordering::SeqCst), 2);
81//!
82//! waker1.wake();
83//! assert_eq!(wake_value_count.load(Ordering::SeqCst), 1);
84//! assert_eq!(drop_count.load(Ordering::SeqCst), 1);
85//! }
86//! assert_eq!(drop_count.load(Ordering::SeqCst), 2);
87//! ```
88//!
89//! # Arc example
90//!
91//! ```
92//! use cooked_waker::{Wake, WakeRef, IntoWaker};
93//! use std::sync::atomic::{AtomicUsize, Ordering};
94//! use std::sync::Arc;
95//! use std::task::Waker;
96//!
97//! // A simple struct that counts the number of times it is awoken. Can't
98//! // be awoken by value (because that would discard the counter), so we
99//! // must instead wrap it in an Arc.
100//! #[derive(Debug, Default)]
101//! struct Counter {
102//! // We use atomic usize because we need Send + Sync and also interior
103//! // mutability
104//! count: AtomicUsize,
105//! }
106//!
107//! impl Counter {
108//! fn get(&self) -> usize {
109//! self.count.load(Ordering::SeqCst)
110//! }
111//! }
112//!
113//! impl WakeRef for Counter {
114//! fn wake_by_ref(&self) {
115//! let _prev = self.count.fetch_add(1, Ordering::SeqCst);
116//! }
117//! }
118//!
119//! let counter_handle = Arc::new(Counter::default());
120//!
121//! // Create an std::task::Waker
122//! let waker: Waker = counter_handle.clone().into_waker();
123//!
124//! waker.wake_by_ref();
125//! waker.wake_by_ref();
126//!
127//! let waker2 = waker.clone();
128//! waker2.wake_by_ref();
129//!
130//! // Because IntoWaker wrap the pointer directly, without additional
131//! // boxing, we can use will_wake
132//! assert!(waker.will_wake(&waker2));
133//!
134//! // This calls Counter::wake_by_ref because the Arc doesn't have exclusive
135//! // ownership of the underlying Counter
136//! waker2.wake();
137//!
138//! assert_eq!(counter_handle.get(), 4);
139//! ```
140
141extern crate alloc;
142
143use alloc::boxed::Box;
144use alloc::rc;
145use alloc::sync as arc;
146use core::{
147 mem::ManuallyDrop,
148 ptr,
149 task::{RawWaker, RawWakerVTable, Waker},
150};
151
152/// Trait for types that can be converted into raw pointers and back again.
153///
154/// # Safety
155///
156/// - Implementors must ensure that, for a given object, the pointer remains
157/// fixed as long as no mutable operations are performed (that is, calling
158/// from_ptr() followed by into_ptr(), with no mutable operations in between,
159/// the returned pointer has the same value.)
160/// - Implementors also must not panic when the interface is used correctly.
161/// The Waker constructed by IntoWaker can cause a double drop if either of
162/// these functions panic.
163///
164/// In the future, we hope to have a similar trait added to the standard
165/// library; see https://github.com/rust-lang/rust/issues/75846 for details.
166pub unsafe trait ViaRawPointer {
167 type Target: ?Sized;
168
169 /// Convert this object into a raw pointer.
170 fn into_raw(self) -> *mut Self::Target;
171
172 /// Convert a raw pointer back into this object.
173 ///
174 /// # Safety
175 ///
176 /// This method must ONLY be called on a pointer that was received via
177 /// `Self::into_raw`, and that pointer must not be used afterwards.
178 unsafe fn from_raw(ptr: *mut Self::Target) -> Self;
179}
180
181/// Wakers that can wake by reference. This trait is used to enable a [`Wake`]
182/// implementation for types that don't own an underlying handle, like `Arc<T>`
183/// and `&T`.
184///
185/// This trait is implemented for most container and reference types, like
186/// `&T where T: WakeRef`, `Box<T: WakeRef>`, and `Arc<T: WakeRef>`.
187pub trait WakeRef {
188 /// Wake up the task by reference. In general [`Wake::wake`] should be
189 /// preferred, if available, as it's probably more efficient.
190 ///
191 /// A [`Waker`] created by [`IntoWaker`] will call this method through
192 /// [`Waker::wake_by_ref`].
193 fn wake_by_ref(&self);
194}
195
196/// Wakers that can wake by value. This is the primary means of waking a task.
197///
198/// This trait is implemented for most container types, like `Box<T: Wake>`
199/// and `Option<T: Wake>`. It is also implemented for shared pointer types like
200/// `Arc<T>` and `&T`, but those implementations call `T::wake_by_ref`, because
201/// they don't have ownership of the underlying `T`.
202pub trait Wake: WakeRef + Sized {
203 /// Wake up the task by value. By default, this simply calls
204 /// [`WakeRef::wake_by_ref`].
205 ///
206 /// A [`Waker`] created by [`IntoWaker`] will call this method through
207 /// [`Waker::wake`].
208 #[inline]
209 fn wake(self) {
210 self.wake_by_ref()
211 }
212}
213
214/// Objects that can be converted into an [`Waker`]. This trait is
215/// automatically implemented for types that fulfill the waker interface.
216/// Such types must be:
217/// - [`Clone`]
218/// - `Send + Sync`
219/// - `'static`
220/// - [`Wake`]
221/// - [`ViaRawPointer`]
222///
223/// The implementation of this trait sets up a [`RawWakerVTable`] for the type,
224/// and arranges a conversion into a [`Waker`] through the [`ViaRawPointer`]
225/// trait, which should be implemented for types that be converted to and from
226/// pointers. This trait is implemented for all the standard library pointer
227/// types (such as `Arc` and `Box`), and you can implement it on your own types
228/// if you want to use them for wakers.
229///
230/// It should never be necessary to implement this trait manually.
231///
232/// [`RawWakerVTable`]: core::task::RawWakerVTable
233/// [`Waker`]: core::task::Waker
234/// [`Clone`]: core::clone::Clone
235pub trait IntoWaker {
236 /// The RawWakerVTable for this type. This should never be used directly;
237 /// it is entirely handled by `into_waker`. It is present as an associated
238 /// const because that's the only way for it to work in generic contexts.
239 #[doc(hidden)]
240 const VTABLE: &'static RawWakerVTable;
241
242 /// Convert this object into a `Waker`.
243 #[must_use]
244 fn into_waker(self) -> Waker;
245}
246
247impl<T> IntoWaker for T
248where
249 T: Wake + Clone + Send + Sync + 'static + ViaRawPointer,
250 T::Target: Sized,
251{
252 const VTABLE: &'static RawWakerVTable = &RawWakerVTable::new(
253 // clone
254 |raw| {
255 let raw = raw as *mut T::Target;
256
257 let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
258 let cloned: T = (*waker).clone();
259
260 // We can't save the `into_raw` back into the raw waker, so we must
261 // simply assert that the pointer has remained the same. This is
262 // part of the ViaRawPointer safety contract, so we only check it
263 // in debug builds.
264 debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
265
266 let cloned_raw = cloned.into_raw();
267 let cloned_raw = cloned_raw as *const ();
268 RawWaker::new(cloned_raw, T::VTABLE)
269 },
270 // wake by value
271 |raw| {
272 let raw = raw as *mut T::Target;
273 let waker: T = unsafe { ViaRawPointer::from_raw(raw) };
274 waker.wake();
275 },
276 // wake by ref
277 |raw| {
278 let raw = raw as *mut T::Target;
279 let waker = ManuallyDrop::<T>::new(unsafe { ViaRawPointer::from_raw(raw) });
280 waker.wake_by_ref();
281
282 debug_assert_eq!(ManuallyDrop::into_inner(waker).into_raw(), raw);
283 },
284 // Drop
285 |raw| {
286 let raw = raw as *mut T::Target;
287 let _waker: T = unsafe { ViaRawPointer::from_raw(raw) };
288 },
289 );
290
291 fn into_waker(self) -> Waker {
292 let raw = self.into_raw();
293 let raw = raw as *const ();
294 let raw_waker = RawWaker::new(raw, T::VTABLE);
295 unsafe { Waker::from_raw(raw_waker) }
296 }
297}
298
299// Waker implementations for std types. Feel free to open PRs for additional
300// stdlib types here.
301
302// We'd prefer to implement WakeRef for T: Deref<Target=WakeRef>, but that
303// results in type coherence issues with non-deref stdlib types.
304
305impl<T: WakeRef + ?Sized> WakeRef for &T {
306 #[inline]
307 fn wake_by_ref(&self) {
308 T::wake_by_ref(*self)
309 }
310}
311
312impl<T: WakeRef + ?Sized> Wake for &T {}
313
314unsafe impl<T: ?Sized> ViaRawPointer for Box<T> {
315 type Target = T;
316
317 fn into_raw(self) -> *mut T {
318 Box::into_raw(self)
319 }
320
321 unsafe fn from_raw(ptr: *mut T) -> Self {
322 Box::from_raw(ptr)
323 }
324}
325
326impl<T: WakeRef + ?Sized> WakeRef for Box<T> {
327 #[inline]
328 fn wake_by_ref(&self) {
329 T::wake_by_ref(self.as_ref())
330 }
331}
332
333impl<T: Wake> Wake for Box<T> {
334 #[inline]
335 fn wake(self) {
336 T::wake(*self)
337 }
338}
339
340unsafe impl<T: ?Sized> ViaRawPointer for arc::Arc<T> {
341 type Target = T;
342
343 fn into_raw(self) -> *mut T {
344 arc::Arc::into_raw(self) as *mut T
345 }
346
347 unsafe fn from_raw(ptr: *mut T) -> Self {
348 arc::Arc::from_raw(ptr as *const T)
349 }
350}
351
352impl<T: WakeRef + ?Sized> WakeRef for arc::Arc<T> {
353 #[inline]
354 fn wake_by_ref(&self) {
355 T::wake_by_ref(self.as_ref())
356 }
357}
358
359impl<T: WakeRef + ?Sized> Wake for arc::Arc<T> {}
360
361unsafe impl<T> ViaRawPointer for arc::Weak<T> {
362 type Target = T;
363
364 fn into_raw(self) -> *mut T {
365 arc::Weak::into_raw(self) as *mut T
366 }
367
368 unsafe fn from_raw(ptr: *mut T) -> Self {
369 arc::Weak::from_raw(ptr as *const T)
370 }
371}
372
373impl<T: WakeRef + ?Sized> WakeRef for arc::Weak<T> {
374 #[inline]
375 fn wake_by_ref(&self) {
376 self.upgrade().wake()
377 }
378}
379
380impl<T: WakeRef + ?Sized> Wake for arc::Weak<T> {}
381
382impl<T: WakeRef + ?Sized> WakeRef for rc::Rc<T> {
383 #[inline]
384 fn wake_by_ref(&self) {
385 T::wake_by_ref(self.as_ref())
386 }
387}
388
389unsafe impl<T: ?Sized> ViaRawPointer for rc::Rc<T> {
390 type Target = T;
391
392 fn into_raw(self) -> *mut T {
393 rc::Rc::into_raw(self) as *mut T
394 }
395
396 unsafe fn from_raw(ptr: *mut T) -> Self {
397 rc::Rc::from_raw(ptr as *const T)
398 }
399}
400
401impl<T: WakeRef + ?Sized> Wake for rc::Rc<T> {
402 #[inline]
403 fn wake(self) {
404 T::wake_by_ref(self.as_ref())
405 }
406}
407
408unsafe impl<T> ViaRawPointer for rc::Weak<T> {
409 type Target = T;
410
411 fn into_raw(self) -> *mut T {
412 rc::Weak::into_raw(self) as *mut T
413 }
414
415 unsafe fn from_raw(ptr: *mut T) -> Self {
416 rc::Weak::from_raw(ptr as *const T)
417 }
418}
419
420impl<T: WakeRef + ?Sized> WakeRef for rc::Weak<T> {
421 #[inline]
422 fn wake_by_ref(&self) {
423 self.upgrade().wake()
424 }
425}
426
427impl<T: WakeRef + ?Sized> Wake for rc::Weak<T> {}
428
429unsafe impl<T: ViaRawPointer> ViaRawPointer for Option<T>
430where
431 T::Target: Sized,
432{
433 type Target = T::Target;
434
435 fn into_raw(self) -> *mut Self::Target {
436 match self {
437 Some(value) => match value.into_raw() {
438 ptr if ptr.is_null() => {
439 let _ = unsafe { T::from_raw(ptr) };
440 ptr::null_mut()
441 }
442 ptr => ptr,
443 },
444 None => ptr::null_mut(),
445 }
446 }
447
448 unsafe fn from_raw(ptr: *mut Self::Target) -> Self {
449 match ptr.is_null() {
450 false => Some(T::from_raw(ptr)),
451 true => None,
452 }
453 }
454}
455
456impl<T: WakeRef> WakeRef for Option<T> {
457 #[inline]
458 fn wake_by_ref(&self) {
459 if let Some(waker) = self {
460 waker.wake_by_ref()
461 }
462 }
463}
464
465impl<T: Wake> Wake for Option<T> {
466 #[inline]
467 fn wake(self) {
468 if let Some(waker) = self {
469 waker.wake()
470 }
471 }
472}
473
474impl WakeRef for Waker {
475 #[inline]
476 fn wake_by_ref(&self) {
477 Waker::wake_by_ref(self)
478 }
479}
480
481impl Wake for Waker {
482 #[inline]
483 fn wake(self) {
484 Waker::wake(self)
485 }
486}
487
488#[cfg(test)]
489mod test {
490 extern crate std;
491
492 use super::*;
493 use std::panic;
494 use std::sync::atomic::{AtomicUsize, Ordering};
495 use std::task::Waker;
496
497 static PANIC_WAKE_REF_COUNT: AtomicUsize = AtomicUsize::new(0);
498 static PANIC_WAKE_VALUE_COUNT: AtomicUsize = AtomicUsize::new(0);
499 static PANIC_DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
500
501 #[derive(Debug, Clone)]
502 struct PanicWaker;
503
504 impl WakeRef for PanicWaker {
505 fn wake_by_ref(&self) {
506 PANIC_WAKE_REF_COUNT.fetch_add(1, Ordering::SeqCst);
507 panic!();
508 }
509 }
510
511 impl Wake for PanicWaker {
512 fn wake(self) {
513 PANIC_WAKE_VALUE_COUNT.fetch_add(1, Ordering::SeqCst);
514 }
515 }
516
517 impl Drop for PanicWaker {
518 fn drop(&mut self) {
519 PANIC_DROP_COUNT.fetch_add(1, Ordering::SeqCst);
520 }
521 }
522
523 unsafe impl ViaRawPointer for PanicWaker {
524 type Target = ();
525
526 fn into_raw(self) -> *mut () {
527 std::mem::forget(self);
528 std::ptr::null_mut()
529 }
530
531 unsafe fn from_raw(_ptr: *mut ()) -> Self {
532 PanicWaker
533 }
534 }
535
536 // Test that the wake_by_ref() behaves correctly even if it panics.
537 #[test]
538 fn panic_wake() {
539 assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
540
541 let waker = PanicWaker;
542 {
543 let waker1: Waker = waker.into_waker();
544
545 let waker2: Waker = waker1.clone();
546
547 let result = panic::catch_unwind(|| {
548 waker2.wake_by_ref();
549 });
550 assert!(result.is_err());
551 assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 1);
552 assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
553
554 let result = panic::catch_unwind(|| {
555 waker1.wake_by_ref();
556 });
557 assert!(result.is_err());
558 assert_eq!(PANIC_WAKE_REF_COUNT.load(Ordering::SeqCst), 2);
559 assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 0);
560
561 let result = panic::catch_unwind(|| {
562 waker1.wake();
563 });
564 assert!(result.is_ok());
565 assert_eq!(PANIC_WAKE_VALUE_COUNT.load(Ordering::SeqCst), 1);
566 assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 1);
567 }
568 assert_eq!(PANIC_DROP_COUNT.load(Ordering::SeqCst), 2);
569 }
570}