dynamic_dispatch_proc_macro/
lib.rs

1extern crate proc_macro;
2extern crate quote;
3
4use proc_macro2::{Ident, Span, TokenStream};
5use proc_macro_error::{abort, proc_macro_error};
6use quote::{quote, ToTokens};
7use std::collections::HashMap;
8use syn::parse::{Parse, ParseStream};
9use syn::spanned::Spanned;
10use syn::{
11    parse_macro_input, parse_quote, Expr, ExprArray, ExprPath, GenericParam, ItemFn, ItemImpl,
12    ItemTrait, Token, Type, TypeParamBound,
13};
14use syn::{FnArg, Item};
15
16struct FunctionSpecializations {
17    specs: Vec<(String, Vec<ExprPath>)>,
18}
19
20impl Parse for FunctionSpecializations {
21    fn parse(input: ParseStream) -> syn::Result<Self> {
22        let mut specs = Vec::new();
23
24        while !input.is_empty() {
25            let name: Ident = input.parse()?;
26            input.parse::<Token![=]>()?;
27            let array: ExprArray = input.parse()?;
28
29            let elems: Vec<_> = array
30                .elems
31                .iter()
32                .map(|x| {
33                    if let Expr::Path(path) = x {
34                        path.clone()
35                    } else {
36                        abort!(x.span(), "Expected path.");
37                    }
38                })
39                .collect();
40
41            specs.push((name.to_string(), elems));
42            if input.is_empty() {
43                break;
44            }
45            input.parse::<Token![,]>()?;
46        }
47        Ok(FunctionSpecializations { specs })
48    }
49}
50
51fn static_dispatch_fn(args: FunctionSpecializations, function: ItemFn) -> TokenStream {
52    let mut generics_list = Vec::new();
53
54    let mut attr_params = HashMap::new();
55    for (name, arg) in args.specs {
56        attr_params.insert(name, arg);
57    }
58
59    for param in function.sig.generics.params.clone() {
60        let (name, const_type, first_bound) = match param.clone() {
61            GenericParam::Type(ty) => {
62                let first_bound = ty
63                    .bounds
64                    .iter()
65                    .filter(|x| {
66                        if let TypeParamBound::Trait(_) = x {
67                            true
68                        } else {
69                            false
70                        }
71                    })
72                    .next()
73                    .expect("At least one bound for each generic parametere must be specified.")
74                    .to_token_stream();
75
76                (ty.ident.to_string(), None, Some(first_bound))
77            }
78            GenericParam::Const(cs) => (cs.ident.to_string(), Some(cs.ty), None),
79            GenericParam::Lifetime(_) => continue, // Ignored
80        };
81
82        let names: Vec<_> = match attr_params.get(&name) {
83            None => {
84                abort!(
85                    param.span(),
86                    "Static dispatch not specified for generic attribute '{}'",
87                    name
88                );
89            }
90            Some(names) => names.clone().into_iter().collect(),
91        };
92
93        generics_list.push((name, names, const_type, param.clone(), first_bound));
94    }
95
96    let fn_name = function.sig.ident.clone();
97    let static_fn_name = Ident::new(
98        &format!("{}_static", function.sig.ident),
99        function.sig.ident.span(),
100    );
101
102    let dynamic_dispatch_fn_name = Ident::new(
103        &format!("__{}_static", function.sig.ident),
104        function.sig.ident.span(),
105    );
106
107    let fn_args = function.sig.inputs.clone();
108    let fn_args_pass: Vec<_> = function
109        .sig
110        .inputs
111        .iter()
112        .map(|x| match x {
113            FnArg::Receiver(x) => x.self_token.to_token_stream(),
114            FnArg::Typed(x) => x.pat.to_token_stream(),
115        })
116        .collect();
117    let fn_rettype = function.sig.output.clone();
118
119    let make_function_name = |name| {
120        Ident::new(
121            &format!("dispatch_fn_{}_{}", function.sig.ident.to_string(), name),
122            function.sig.span(),
123        )
124    };
125
126    let mut dispatch_traits = TokenStream::new();
127    for (name, list, const_type, _, _) in &generics_list {
128        if let Some(const_type) = const_type {
129            let dispatch_function_name = make_function_name(name.clone());
130
131            let mut match_branches = TokenStream::new();
132            for (idx, value) in list.iter().enumerate() {
133                (quote! {
134                    #value => #idx,
135                })
136                .to_tokens(&mut match_branches);
137            }
138
139            (quote! {
140                    #[allow(non_snake_case)]
141                    #[doc(hidden)]
142                    fn #dispatch_function_name(x: #const_type) -> usize {
143                        match x {
144                            #match_branches
145                            _ => panic!(concat!("Const range for variable ", concat!(#name, " not supported!")))
146                        }
147                    }
148                })
149                .to_tokens(&mut dispatch_traits);
150        }
151    }
152
153    let mut dispatch_generic_args = TokenStream::new();
154    let mut dispatch_generic_args_pass = TokenStream::new();
155    let mut dispatch_tuple_members = TokenStream::new();
156    let mut dispatch_tuple_builders = TokenStream::new();
157
158    for (name, _list, const_type, generic, first_bound) in &generics_list {
159        let ident_name = Ident::new(&name, Span::call_site());
160
161        if let Some(const_type) = const_type {
162            let dispatch_function = make_function_name(name.clone());
163
164            (quote! {
165                const #ident_name: #const_type,
166            })
167            .to_tokens(&mut dispatch_generic_args);
168
169            (quote! {
170                #dispatch_function(#ident_name),
171            })
172            .to_tokens(&mut dispatch_tuple_builders);
173
174            (quote! {
175                usize,
176            })
177            .to_tokens(&mut dispatch_tuple_members);
178        } else {
179            (quote! {
180                #generic,
181            })
182            .to_tokens(&mut dispatch_generic_args);
183
184            (quote! {
185                <#ident_name as #first_bound>::dynamic_dispatch_id(),
186            })
187            .to_tokens(&mut dispatch_tuple_builders);
188
189            (quote! {
190                ::dynamic_dispatch::DynamicDispatch<()>,
191            })
192            .to_tokens(&mut dispatch_tuple_members);
193        }
194
195        (quote! {
196            #ident_name,
197        })
198        .to_tokens(&mut dispatch_generic_args_pass);
199    }
200
201    fn recursive_dispatch_builder(
202        index: usize,
203        gen_args: TokenStream,
204        generics_list: &Vec<(
205            String,
206            Vec<ExprPath>,
207            Option<Type>,
208            GenericParam,
209            Option<TokenStream>,
210        )>,
211        fn_name: &Ident,
212        fn_args: &TokenStream,
213    ) -> TokenStream {
214        if index == generics_list.len() {
215            quote! { return #fn_name::<#gen_args>(#fn_args); }
216        } else {
217            let mut output_dispatcher = TokenStream::new();
218
219            let is_const = generics_list[index].2.is_some();
220            let tuple_index = syn::Index::from(index);
221
222            for (idx, ty) in generics_list[index].1.iter().enumerate() {
223                let attrs = &ty.attrs;
224                let path = &ty.path;
225
226                let gen_args = if index == 0 {
227                    quote! { #path }
228                } else {
229                    quote! { #gen_args, #path }
230                };
231
232                let nested = recursive_dispatch_builder(
233                    index + 1,
234                    gen_args,
235                    generics_list,
236                    fn_name,
237                    fn_args,
238                );
239
240                if is_const {
241                    quote! {
242                        #(#attrs)*
243                        if #idx == dispatch_tuple.#tuple_index {
244                            #nested
245                        }
246                    }
247                } else {
248                    let first_bound = generics_list[index].4.as_ref().unwrap();
249
250                    quote! {
251                        #(#attrs)*
252                        if <#path as #first_bound>::dynamic_dispatch_id() == dispatch_tuple.#tuple_index {
253                            #nested
254                        }
255                    }
256                }
257                .to_tokens(&mut output_dispatcher);
258            }
259
260            quote! {
261                #output_dispatcher
262                panic!("Static dispatch bug, arg {:?}!", dispatch_tuple.#tuple_index);
263            }
264        }
265    }
266
267    let final_dispatcher = recursive_dispatch_builder(
268        0,
269        TokenStream::new(),
270        &generics_list,
271        &fn_name.clone(),
272        &quote! { #(#fn_args_pass),* },
273    );
274
275    quote! {
276
277        #dispatch_traits
278
279        #[doc(hidden)]
280        #[inline(always)]
281        fn __dispatch<#dispatch_generic_args>() -> (#dispatch_tuple_members) {
282            (#dispatch_tuple_builders)
283        }
284
285        #[doc(hidden)]
286        #[inline(never)]
287        pub fn #dynamic_dispatch_fn_name(dispatch_tuple: (#dispatch_tuple_members), #fn_args) #fn_rettype {
288             #final_dispatcher
289        }
290
291        #[doc(hidden)]
292        #[inline(always)]
293        pub fn #static_fn_name<#dispatch_generic_args>(#fn_args) #fn_rettype {
294            let dispatch_tuple = __dispatch::<#dispatch_generic_args_pass>();
295            #dynamic_dispatch_fn_name(dispatch_tuple, #(#fn_args_pass),*)
296        }
297
298        pub mod static_dispatch {
299            pub use super::#static_fn_name as #fn_name;
300        }
301
302        pub mod dynamic_dispatch {
303            pub use super::#dynamic_dispatch_fn_name as #fn_name;
304        }
305
306    }
307}
308
309fn static_dispatch_trait(mut trait_: ItemTrait) -> TokenStream {
310    trait_.items.push(
311        parse_quote! { fn dynamic_dispatch_id() -> ::dynamic_dispatch::DynamicDispatch<()>; },
312    );
313
314    trait_.to_token_stream()
315}
316
317fn static_dispatch_impl(mut impl_: ItemImpl) -> TokenStream {
318    impl_.impl_token;
319
320    impl_.items.push(parse_quote! {
321        fn dynamic_dispatch_id() -> ::dynamic_dispatch::DynamicDispatch::<()> {
322            ::dynamic_dispatch::DynamicDispatch::<()> { value: std::any::TypeId::of::<Self>(), _phantom: std::marker::PhantomData }
323        }
324    });
325
326    impl_.to_token_stream()
327}
328
329#[proc_macro_error]
330#[proc_macro_attribute]
331pub fn dynamic_dispatch(
332    args: proc_macro::TokenStream,
333    input: proc_macro::TokenStream,
334) -> proc_macro::TokenStream {
335    let input_ = input.clone();
336    let function = parse_macro_input!(input_ as Item);
337    let input = proc_macro2::TokenStream::from(input);
338
339    let (input, static_dispatch_module) = match function {
340        Item::Fn(function) => {
341            let args = parse_macro_input!(args as FunctionSpecializations);
342            (input, static_dispatch_fn(args, function))
343        }
344        Item::Trait(trait_) => (TokenStream::new(), static_dispatch_trait(trait_)),
345        Item::Impl(impl_) => (TokenStream::new(), static_dispatch_impl(impl_)),
346        _ => {
347            panic!(
348                "dynamic_dispatch attribute is applicable only to functions, traits or trait impls."
349            );
350        }
351    };
352
353    // panic!("{}", static_dispatch_module.to_string());
354
355    quote!(#input #static_dispatch_module).into()
356}