Skip to main content

static_dispatch_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{ToTokens, quote};
4use syn::{
5    Error, Fields, FnArg, GenericParam, Item, ItemEnum, ItemTrait, Path, Token, TraitItem,
6    WhereClause, parse::Parse, parse_macro_input,
7};
8
9/// Setup this type for static dispatch with [`implementation`].
10///
11/// See the module for documentation.
12#[proc_macro_attribute]
13pub fn setup(_attr: TokenStream, item: TokenStream) -> TokenStream {
14    // todo: avoid double parse, we just need the name
15    let input = parse_macro_input!(item as Item);
16    let name = match &input {
17        Item::Trait(value) => &value.ident,
18        Item::Enum(value) => &value.ident,
19        _ => {
20            return Error::new_spanned(&input, "dispatch is only valid on traits or enums")
21                .to_compile_error()
22                .into();
23        }
24    };
25
26    let save = macro_data::save(name, &input);
27
28    quote! {
29        #input
30        #save
31    }
32    .into()
33}
34
35struct GenerateInput {
36    trait_name: Path,
37    _for: Token![for],
38    enum_name: Path,
39}
40
41impl Parse for GenerateInput {
42    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
43        Ok(Self {
44            trait_name: input.parse()?,
45            _for: input.parse()?,
46            enum_name: input.parse()?,
47        })
48    }
49}
50
51/// Syntax: `implementation!(<trait> for <enum>)`
52///
53/// Generate the trait implementation for the enum.
54/// Both require a `![setup]` annotation.
55///
56/// Because this uses macros to read the data, use `<crate>::<name>`
57/// for external types.
58#[proc_macro]
59pub fn implementation(input: TokenStream) -> TokenStream {
60    let input = parse_macro_input!(input as GenerateInput);
61
62    let data = FinalTransfer {
63        trait_item: macro_data::request(&input.trait_name),
64        comma: syn::token::Comma(Span::mixed_site()),
65        enum_item: macro_data::request(&input.enum_name),
66    };
67
68    macro_data::transfer("static_dispatch", "generate_final", &data).into()
69}
70
71struct FinalTransfer<S: macro_data::Storage> {
72    trait_item: macro_data::Transfer<ItemTrait, S>,
73    comma: Token![,],
74    enum_item: macro_data::Transfer<ItemEnum, S>,
75}
76
77impl ToTokens for FinalTransfer<macro_data::Request> {
78    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
79        self.trait_item.to_tokens(tokens);
80        self.comma.to_tokens(tokens);
81        self.enum_item.to_tokens(tokens);
82    }
83}
84
85impl Parse for FinalTransfer<macro_data::Load> {
86    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
87        Ok(Self {
88            trait_item: input.parse()?,
89            comma: input.parse()?,
90            enum_item: input.parse()?,
91        })
92    }
93}
94
95/// This macro is designed to be called by other macros, not in normal code.
96///
97/// See the module for documentation.
98#[doc(hidden)]
99#[proc_macro]
100pub fn generate_final(input: TokenStream) -> TokenStream {
101    let input = parse_macro_input!(input as FinalTransfer<macro_data::Load>);
102    let trait_item = input.trait_item.0;
103    let enum_item = input.enum_item.0;
104
105    let trait_ident = &trait_item.ident;
106    let enum_ident = &enum_item.ident;
107
108    // Combine generic parameters from trait and enum.
109    let mut all_params = Vec::new();
110    for param in &trait_item.generics.params {
111        all_params.push(param.clone());
112    }
113    for param in &enum_item.generics.params {
114        all_params.push(param.clone());
115    }
116    all_params.sort_by_key(|param| match param {
117        GenericParam::Lifetime(_) => 0,
118        GenericParam::Const(_) => 1,
119        GenericParam::Type(_) => 2,
120    });
121
122    let impl_generics = if all_params.is_empty() {
123        quote! {}
124    } else {
125        quote! { < #(#all_params),* > }
126    };
127
128    // Combine where clauses from trait and enum.
129    let mut where_predicates = Vec::new();
130    if let Some(wc) = &trait_item.generics.where_clause {
131        where_predicates.extend(wc.predicates.iter().cloned());
132    }
133    if let Some(wc) = &enum_item.generics.where_clause {
134        where_predicates.extend(wc.predicates.iter().cloned());
135    }
136    all_params.sort_by_key(|param| match param {
137        GenericParam::Lifetime(_) => 0,
138        GenericParam::Const(_) => 1,
139        GenericParam::Type(_) => 2,
140    });
141
142    let where_clause = if where_predicates.is_empty() {
143        None
144    } else {
145        Some(WhereClause {
146            where_token: syn::token::Where::default(),
147            predicates: syn::punctuated::Punctuated::from_iter(where_predicates),
148        })
149    };
150    let trait_args = generic_args(&trait_item.generics);
151    let enum_args = generic_args(&enum_item.generics);
152
153    // Generate methods
154    let impl_methods = trait_item
155        .items
156        .iter()
157        .map(|item| {
158            let TraitItem::Fn(method) = item else {
159                return Error::new_spanned(item, "Only methods are supported").to_compile_error();
160            };
161            let sig = &method.sig;
162            let method_name = &sig.ident;
163            let method_gen = sig
164                .generics
165                .params
166                .iter()
167                .filter_map(|param| match param {
168                    GenericParam::Lifetime(_) => None,
169                    GenericParam::Const(param) => Some(&param.ident),
170                    GenericParam::Type(param) => Some(&param.ident),
171                })
172                .collect::<Vec<_>>();
173
174            let mut args = sig.inputs.iter();
175            let self_arg = match args.next() {
176                Some(FnArg::Receiver(rec)) => &rec.self_token,
177                _ => {
178                    return Error::new_spanned(sig, "Function requires self argument")
179                        .to_compile_error();
180                }
181            };
182
183            let args = sig
184                .inputs
185                .iter()
186                .skip(1)
187                .map(|arg| {
188                    if let syn::FnArg::Typed(pat_type) = arg {
189                        if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
190                            pat_ident.ident.clone()
191                        } else {
192                            panic!("Unsupported argument pattern");
193                        }
194                    } else {
195                        panic!("Expected typed argument");
196                    }
197                })
198                .collect::<Vec<_>>();
199
200            let async_suffix = match sig.asyncness {
201                None => quote! {},
202                Some(_) => quote! {.await},
203            };
204
205            // Build match arms
206            let arms = enum_item
207                .variants
208                .iter()
209                .map(|variant| {
210                    let variant_ident = &variant.ident;
211                    let Fields::Unnamed(fields) = &variant.fields else {
212                        panic!("Only enum tuples supported");
213                    };
214                    let field = fields.unnamed.iter().next().expect("expected a field");
215                    let field_type = &field.ty;
216                    let method_gen = quote! { ::<#(#method_gen,)*> };
217                    quote! {
218                        #enum_ident::#variant_ident(__static_dispatch_value) =>
219                            <#field_type as #trait_ident #trait_args>::#method_name #method_gen(
220                            __static_dispatch_value,
221                            #(#args),*
222                        ) #async_suffix
223                    }
224                })
225                .collect::<Vec<_>>();
226
227            quote! {
228                #sig {
229                    match #self_arg {
230                        #(#arms,)*
231                    }
232                }
233            }
234        })
235        .collect::<Vec<_>>();
236
237    let expanded = quote! {
238        impl #impl_generics #trait_ident #trait_args for #enum_ident #enum_args #where_clause {
239            #(#impl_methods)*
240        }
241    };
242
243    expanded.into()
244}
245
246fn generic_args(generics: &syn::Generics) -> proc_macro2::TokenStream {
247    let args: Vec<_> = generics
248        .params
249        .iter()
250        .map(|param| match param {
251            GenericParam::Type(ty) => {
252                let ident = &ty.ident;
253                quote! { #ident }
254            }
255            GenericParam::Lifetime(lifetime) => {
256                let lt = &lifetime.lifetime;
257                quote! { #lt }
258            }
259            GenericParam::Const(c) => {
260                let ident = &c.ident;
261                quote! { #ident }
262            }
263        })
264        .collect();
265    if args.is_empty() {
266        quote! {}
267    } else {
268        quote! { < #(#args),* > }
269    }
270}