cooked_waker_derive/
lib.rs

1//! `cooked-waker-derive` is a series of derive implementations for the traits
2//! in [`cooked-waker`](../cooked_waker). See that crate's
3//! documentation for more info.
4
5extern crate proc_macro;
6
7use proc_macro as pm;
8use proc_macro2::TokenStream;
9use quote::quote;
10use syn::{self, parse_macro_input, parse_quote, Data, DeriveInput, Fields};
11
12/// `IntoWaker` derive implementation.
13///
14/// This derive creates an `IntoWaker` implementation for any concrete type.
15/// It does this be creating a static `RawWakerVTable` associated with that
16/// type, with methods that forward to the relevant trait methods, using
17/// stowaway to pack the waker object into a pointer, then wrapping it all in
18/// a Waker.
19///
20/// Note that `IntoWaker` requires `Wake + Clone + Send + Sync + 'static`.
21///
22/// # Interior mutability note
23///
24/// Note that, in some rare circumstances, interior mutability may not be
25/// respected in the by-reference methods (clone and wake_by_ref). This will
26/// happen if your struct is <= the size of a pointer, and if you attempt to
27/// mutate the bytes of the struct directly through shared mutability. In this
28/// case, the struct bytes are stored directly in the pointer. These bytes are
29/// copied into the relevant functions, but it is assumed they never chage, so
30/// if the `wake_by_ref` or `clone` methods change these bytes (for instance,
31/// if your waker is a Cell<Option<Box<T>>>), these changes will NOT be
32/// reflected in subsequent calls. If you *need* interior mutability on a small
33/// struct, you can manually Box it in a wrapper struct and derive `IntoWaker`
34/// on the wrapper.
35#[proc_macro_derive(IntoWaker)]
36pub fn into_waker_derive(stream: pm::TokenStream) -> pm::TokenStream {
37    let input = parse_macro_input!(stream as DeriveInput);
38
39    if !input.generics.params.is_empty() {
40        panic!("IntoWaker can only be derived for concrete types");
41    }
42
43    #[allow(non_snake_case)]
44    let WakerStruct = input.ident;
45
46    let implementation = quote! {
47        impl cooked_waker::IntoWaker for #WakerStruct {
48            #[must_use]
49            fn into_waker(self) -> core::task::Waker {
50                use core::task::{Waker, RawWaker, RawWakerVTable};
51                use core::clone::Clone;
52                use cooked_waker::{Wake, WakeRef};
53                use cooked_waker::stowaway::{self, Stowaway};
54
55                #[inline]
56                fn make_raw_waker(waker: #WakerStruct) -> RawWaker {
57                    let stowed = Stowaway::new(waker);
58                    RawWaker::new(Stowaway::into_raw(stowed), &VTABLE)
59                }
60
61                static VTABLE: RawWakerVTable = RawWakerVTable::new(
62                    // clone
63                    |raw| {
64                        let raw = raw as *mut ();
65                        let waker: & #WakerStruct = unsafe { stowaway::ref_from_stowed(&raw) };
66                        make_raw_waker(Clone::clone(waker))
67                    },
68                    // wake by value
69                    |raw| {
70                        let waker: #WakerStruct = unsafe { stowaway::unstow(raw as *mut ()) };
71                        Wake::wake(waker);
72                    },
73                    // wake by ref
74                    |raw| {
75                        let raw = raw as *mut ();
76                        let waker: & #WakerStruct = unsafe { stowaway::ref_from_stowed(&raw) };
77                        WakeRef::wake_by_ref(waker)
78                    },
79                    // Drop
80                    |raw| {
81                        let _waker: Stowaway<#WakerStruct> = unsafe {
82                            Stowaway::from_raw(raw as *mut ())
83                        };
84                    },
85                );
86
87                let raw_waker = make_raw_waker(self);
88                unsafe { Waker::from_raw(raw_waker) }
89            }
90        }
91    };
92
93    implementation.into()
94}
95
96#[derive(Debug, Copy, Clone, PartialEq, Eq)]
97enum WakeTrait {
98    Wake,
99    WakeRef,
100}
101
102impl WakeTrait {
103    #[inline]
104    fn trait_path(self) -> syn::Path {
105        match self {
106            WakeTrait::Wake => parse_quote! {::cooked_waker::Wake},
107            WakeTrait::WakeRef => parse_quote! {::cooked_waker::WakeRef},
108        }
109    }
110
111    #[inline]
112    fn method(self) -> syn::Ident {
113        match self {
114            WakeTrait::Wake => parse_quote! {wake},
115            WakeTrait::WakeRef => parse_quote! {wake_by_ref},
116        }
117    }
118
119    #[inline]
120    fn name(self) -> &'static str {
121        match self {
122            WakeTrait::Wake => "Wake",
123            WakeTrait::WakeRef => "WakeRef",
124        }
125    }
126
127    /// Change a token stream like `self` or `self.value` to `&self` or
128    /// `&self.value` if this is WakeRef
129    #[inline]
130    fn apply_reference(self, input: TokenStream) -> TokenStream {
131        match self {
132            WakeTrait::Wake => input,
133            WakeTrait::WakeRef => quote! {& #input},
134        }
135    }
136}
137
138fn derive_wake_like(spec: WakeTrait, stream: pm::TokenStream) -> pm::TokenStream {
139    let input = parse_macro_input!(stream as DeriveInput);
140
141    let trait_path = spec.trait_path();
142    let method = spec.method();
143
144    let type_name = input.ident;
145    let mut generics = input.generics;
146    let where_clause = generics.make_where_clause();
147
148    match input.data {
149        Data::Struct(s) => {
150            // Normalize named and unnamed struct fields.
151            let fields = match s.fields {
152                Fields::Named(fields) => fields.named,
153                Fields::Unnamed(fields) => fields.unnamed,
154                Fields::Unit => panic!(
155                    "`{name}` can only be derived on structs with a single `{name}` field",
156                    name = spec.name()
157                ),
158            };
159
160            if fields.len() != 1 {
161                panic!(
162                    "Can only derive `{name}` on structs with exactly 1 field",
163                    name = spec.name()
164                );
165            }
166
167            let field = fields.first().unwrap();
168            let field_type = &field.ty;
169
170            // field_name is either `name` or `0`; it allows for `self.name`
171            // or `self.0`.
172            let field_name: syn::Member = field
173                .ident
174                .clone()
175                .map(syn::Member::Named)
176                .unwrap_or_else(|| parse_quote!(0));
177
178            // Add "where FieldType: Wake"
179            where_clause
180                .predicates
181                .push(parse_quote! {#field_type: #trait_path});
182
183            let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
184
185            // The self parameter for the function signature: self or &self
186            let self_param = spec.apply_reference(quote! {self});
187
188            // The getter for the self field: self.0 or self.field_name
189            let field_invocation = spec.apply_reference(quote! {self.#field_name});
190
191            let implementation = quote! {
192                impl #impl_generics #trait_path for #type_name #ty_generics #where_clause {
193                    #[inline]
194                    fn #method(#self_param) {
195                        #trait_path::#method(#field_invocation)
196                    }
197                }
198            };
199
200            implementation.into()
201        }
202        Data::Enum(..) => unimplemented!("derive(Wake) for enums is still WIP"),
203        Data::Union(..) => panic!("`Wake` can only be derived for struct or enum types"),
204    }
205}
206
207/// Create a `Wake` implementation for a `struct` that forwards to the
208/// `struct`'s field. The `struct` must have exactly one field, and that
209/// field must implement `Wake`.
210///
211/// In the future this derive will also support `enum`.
212#[proc_macro_derive(Wake)]
213pub fn wake_derive(stream: pm::TokenStream) -> pm::TokenStream {
214    derive_wake_like(WakeTrait::Wake, stream)
215}
216
217/// Create a `WakeRef` implementation for a `struct` that forwards to the
218/// `struct`'s field. The `struct` must have exactly one field, and that
219/// field must implement `WakeRef`.
220///
221/// In the future this derive will also support `enum`.
222#[proc_macro_derive(WakeRef)]
223pub fn wake_ref_derive(stream: pm::TokenStream) -> pm::TokenStream {
224    derive_wake_like(WakeTrait::WakeRef, stream)
225}