dynamic_dispatch_proc_macro/
lib.rs1extern crate proc_macro;
2extern crate quote;
3
4use proc_macro2::{Ident, Span, TokenStream};
5use proc_macro_error::{abort, proc_macro_error};
6use quote::{quote, ToTokens};
7use std::collections::HashMap;
8use syn::parse::{Parse, ParseStream};
9use syn::spanned::Spanned;
10use syn::{
11 parse_macro_input, parse_quote, Expr, ExprArray, ExprPath, GenericParam, ItemFn, ItemImpl,
12 ItemTrait, Token, Type, TypeParamBound,
13};
14use syn::{FnArg, Item};
15
16struct FunctionSpecializations {
17 specs: Vec<(String, Vec<ExprPath>)>,
18}
19
20impl Parse for FunctionSpecializations {
21 fn parse(input: ParseStream) -> syn::Result<Self> {
22 let mut specs = Vec::new();
23
24 while !input.is_empty() {
25 let name: Ident = input.parse()?;
26 input.parse::<Token![=]>()?;
27 let array: ExprArray = input.parse()?;
28
29 let elems: Vec<_> = array
30 .elems
31 .iter()
32 .map(|x| {
33 if let Expr::Path(path) = x {
34 path.clone()
35 } else {
36 abort!(x.span(), "Expected path.");
37 }
38 })
39 .collect();
40
41 specs.push((name.to_string(), elems));
42 if input.is_empty() {
43 break;
44 }
45 input.parse::<Token![,]>()?;
46 }
47 Ok(FunctionSpecializations { specs })
48 }
49}
50
51fn static_dispatch_fn(args: FunctionSpecializations, function: ItemFn) -> TokenStream {
52 let mut generics_list = Vec::new();
53
54 let mut attr_params = HashMap::new();
55 for (name, arg) in args.specs {
56 attr_params.insert(name, arg);
57 }
58
59 for param in function.sig.generics.params.clone() {
60 let (name, const_type, first_bound) = match param.clone() {
61 GenericParam::Type(ty) => {
62 let first_bound = ty
63 .bounds
64 .iter()
65 .filter(|x| {
66 if let TypeParamBound::Trait(_) = x {
67 true
68 } else {
69 false
70 }
71 })
72 .next()
73 .expect("At least one bound for each generic parametere must be specified.")
74 .to_token_stream();
75
76 (ty.ident.to_string(), None, Some(first_bound))
77 }
78 GenericParam::Const(cs) => (cs.ident.to_string(), Some(cs.ty), None),
79 GenericParam::Lifetime(_) => continue, };
81
82 let names: Vec<_> = match attr_params.get(&name) {
83 None => {
84 abort!(
85 param.span(),
86 "Static dispatch not specified for generic attribute '{}'",
87 name
88 );
89 }
90 Some(names) => names.clone().into_iter().collect(),
91 };
92
93 generics_list.push((name, names, const_type, param.clone(), first_bound));
94 }
95
96 let fn_name = function.sig.ident.clone();
97 let static_fn_name = Ident::new(
98 &format!("{}_static", function.sig.ident),
99 function.sig.ident.span(),
100 );
101
102 let dynamic_dispatch_fn_name = Ident::new(
103 &format!("__{}_static", function.sig.ident),
104 function.sig.ident.span(),
105 );
106
107 let fn_args = function.sig.inputs.clone();
108 let fn_args_pass: Vec<_> = function
109 .sig
110 .inputs
111 .iter()
112 .map(|x| match x {
113 FnArg::Receiver(x) => x.self_token.to_token_stream(),
114 FnArg::Typed(x) => x.pat.to_token_stream(),
115 })
116 .collect();
117 let fn_rettype = function.sig.output.clone();
118
119 let make_function_name = |name| {
120 Ident::new(
121 &format!("dispatch_fn_{}_{}", function.sig.ident.to_string(), name),
122 function.sig.span(),
123 )
124 };
125
126 let mut dispatch_traits = TokenStream::new();
127 for (name, list, const_type, _, _) in &generics_list {
128 if let Some(const_type) = const_type {
129 let dispatch_function_name = make_function_name(name.clone());
130
131 let mut match_branches = TokenStream::new();
132 for (idx, value) in list.iter().enumerate() {
133 (quote! {
134 #value => #idx,
135 })
136 .to_tokens(&mut match_branches);
137 }
138
139 (quote! {
140 #[allow(non_snake_case)]
141 #[doc(hidden)]
142 fn #dispatch_function_name(x: #const_type) -> usize {
143 match x {
144 #match_branches
145 _ => panic!(concat!("Const range for variable ", concat!(#name, " not supported!")))
146 }
147 }
148 })
149 .to_tokens(&mut dispatch_traits);
150 }
151 }
152
153 let mut dispatch_generic_args = TokenStream::new();
154 let mut dispatch_generic_args_pass = TokenStream::new();
155 let mut dispatch_tuple_members = TokenStream::new();
156 let mut dispatch_tuple_builders = TokenStream::new();
157
158 for (name, _list, const_type, generic, first_bound) in &generics_list {
159 let ident_name = Ident::new(&name, Span::call_site());
160
161 if let Some(const_type) = const_type {
162 let dispatch_function = make_function_name(name.clone());
163
164 (quote! {
165 const #ident_name: #const_type,
166 })
167 .to_tokens(&mut dispatch_generic_args);
168
169 (quote! {
170 #dispatch_function(#ident_name),
171 })
172 .to_tokens(&mut dispatch_tuple_builders);
173
174 (quote! {
175 usize,
176 })
177 .to_tokens(&mut dispatch_tuple_members);
178 } else {
179 (quote! {
180 #generic,
181 })
182 .to_tokens(&mut dispatch_generic_args);
183
184 (quote! {
185 <#ident_name as #first_bound>::dynamic_dispatch_id(),
186 })
187 .to_tokens(&mut dispatch_tuple_builders);
188
189 (quote! {
190 ::dynamic_dispatch::DynamicDispatch<()>,
191 })
192 .to_tokens(&mut dispatch_tuple_members);
193 }
194
195 (quote! {
196 #ident_name,
197 })
198 .to_tokens(&mut dispatch_generic_args_pass);
199 }
200
201 fn recursive_dispatch_builder(
202 index: usize,
203 gen_args: TokenStream,
204 generics_list: &Vec<(
205 String,
206 Vec<ExprPath>,
207 Option<Type>,
208 GenericParam,
209 Option<TokenStream>,
210 )>,
211 fn_name: &Ident,
212 fn_args: &TokenStream,
213 ) -> TokenStream {
214 if index == generics_list.len() {
215 quote! { return #fn_name::<#gen_args>(#fn_args); }
216 } else {
217 let mut output_dispatcher = TokenStream::new();
218
219 let is_const = generics_list[index].2.is_some();
220 let tuple_index = syn::Index::from(index);
221
222 for (idx, ty) in generics_list[index].1.iter().enumerate() {
223 let attrs = &ty.attrs;
224 let path = &ty.path;
225
226 let gen_args = if index == 0 {
227 quote! { #path }
228 } else {
229 quote! { #gen_args, #path }
230 };
231
232 let nested = recursive_dispatch_builder(
233 index + 1,
234 gen_args,
235 generics_list,
236 fn_name,
237 fn_args,
238 );
239
240 if is_const {
241 quote! {
242 #(#attrs)*
243 if #idx == dispatch_tuple.#tuple_index {
244 #nested
245 }
246 }
247 } else {
248 let first_bound = generics_list[index].4.as_ref().unwrap();
249
250 quote! {
251 #(#attrs)*
252 if <#path as #first_bound>::dynamic_dispatch_id() == dispatch_tuple.#tuple_index {
253 #nested
254 }
255 }
256 }
257 .to_tokens(&mut output_dispatcher);
258 }
259
260 quote! {
261 #output_dispatcher
262 panic!("Static dispatch bug, arg {:?}!", dispatch_tuple.#tuple_index);
263 }
264 }
265 }
266
267 let final_dispatcher = recursive_dispatch_builder(
268 0,
269 TokenStream::new(),
270 &generics_list,
271 &fn_name.clone(),
272 "e! { #(#fn_args_pass),* },
273 );
274
275 quote! {
276
277 #dispatch_traits
278
279 #[doc(hidden)]
280 #[inline(always)]
281 fn __dispatch<#dispatch_generic_args>() -> (#dispatch_tuple_members) {
282 (#dispatch_tuple_builders)
283 }
284
285 #[doc(hidden)]
286 #[inline(never)]
287 pub fn #dynamic_dispatch_fn_name(dispatch_tuple: (#dispatch_tuple_members), #fn_args) #fn_rettype {
288 #final_dispatcher
289 }
290
291 #[doc(hidden)]
292 #[inline(always)]
293 pub fn #static_fn_name<#dispatch_generic_args>(#fn_args) #fn_rettype {
294 let dispatch_tuple = __dispatch::<#dispatch_generic_args_pass>();
295 #dynamic_dispatch_fn_name(dispatch_tuple, #(#fn_args_pass),*)
296 }
297
298 pub mod static_dispatch {
299 pub use super::#static_fn_name as #fn_name;
300 }
301
302 pub mod dynamic_dispatch {
303 pub use super::#dynamic_dispatch_fn_name as #fn_name;
304 }
305
306 }
307}
308
309fn static_dispatch_trait(mut trait_: ItemTrait) -> TokenStream {
310 trait_.items.push(
311 parse_quote! { fn dynamic_dispatch_id() -> ::dynamic_dispatch::DynamicDispatch<()>; },
312 );
313
314 trait_.to_token_stream()
315}
316
317fn static_dispatch_impl(mut impl_: ItemImpl) -> TokenStream {
318 impl_.impl_token;
319
320 impl_.items.push(parse_quote! {
321 fn dynamic_dispatch_id() -> ::dynamic_dispatch::DynamicDispatch::<()> {
322 ::dynamic_dispatch::DynamicDispatch::<()> { value: std::any::TypeId::of::<Self>(), _phantom: std::marker::PhantomData }
323 }
324 });
325
326 impl_.to_token_stream()
327}
328
329#[proc_macro_error]
330#[proc_macro_attribute]
331pub fn dynamic_dispatch(
332 args: proc_macro::TokenStream,
333 input: proc_macro::TokenStream,
334) -> proc_macro::TokenStream {
335 let input_ = input.clone();
336 let function = parse_macro_input!(input_ as Item);
337 let input = proc_macro2::TokenStream::from(input);
338
339 let (input, static_dispatch_module) = match function {
340 Item::Fn(function) => {
341 let args = parse_macro_input!(args as FunctionSpecializations);
342 (input, static_dispatch_fn(args, function))
343 }
344 Item::Trait(trait_) => (TokenStream::new(), static_dispatch_trait(trait_)),
345 Item::Impl(impl_) => (TokenStream::new(), static_dispatch_impl(impl_)),
346 _ => {
347 panic!(
348 "dynamic_dispatch attribute is applicable only to functions, traits or trait impls."
349 );
350 }
351 };
352
353 quote!(#input #static_dispatch_module).into()
356}