precompile_macro/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{punctuated::Punctuated, token::Comma, *};
4
5#[proc_macro_attribute]
6pub fn precompile(
7    attr: proc_macro::TokenStream,
8    item: proc_macro::TokenStream,
9) -> proc_macro::TokenStream {
10    let item: TokenStream = item.into();
11    let attr: TokenStream = attr.into();
12
13    let Ok(item) = parse2::<ItemFn>(item) else {
14        return quote! {
15            ::core::compile_error!("expected function");
16        }
17        .into();
18    };
19
20    if !attr.is_empty() {
21        return quote! { ::core::compile_error!("#[precompile::precompile] does not take any arguments
22        use #[precompile_with(A, B, ...)] to add types."); }
23        .into();
24    };
25
26    let ItemFn {
27        mut attrs,
28        vis,
29        mut sig,
30        block: _,
31    } = item.clone();
32    let original_name = sig.ident.clone();
33
34    let mut new_attrs = Vec::new();
35    let mut types: Vec<Vec<_>> = Vec::new();
36
37    for attr in attrs {
38        match attr.meta.clone() {
39            Meta::List(list) => {
40                if list
41                    .path
42                    .get_ident()
43                    .map(|ident| &*ident.to_string() == "precompile_with")
44                    == Some(true)
45                {
46                    let tokens = list.tokens.clone();
47                    if let Ok(ty) = parse2::<Type>(tokens) {
48                        types.push(vec![ty]);
49                    } else {
50                        let tokens = list.tokens;
51                        let Ok(ty_tuple) = parse2::<TypeTuple>(quote! { (#tokens) }) else {
52                            return quote! { ::core::compile_error!("expected comma-separated list of types"); }.into();
53                        };
54                        types.push(ty_tuple.elems.into_iter().collect());
55                    };
56                } else {
57                    new_attrs.push(attr);
58                }
59            }
60            _ => new_attrs.push(attr),
61        }
62    }
63    attrs = new_attrs;
64
65    let mut inner = item;
66    inner.attrs = attrs.clone();
67    inner.sig.ident = Ident::new("__precompile_inner_impl", Span::call_site());
68
69    let Generics {
70        lt_token: _,
71        params,
72        gt_token: _,
73        where_clause,
74    } = inner.sig.generics.clone();
75
76    let mut ty_param = Vec::new();
77    let mut ty_param_names = Vec::new();
78
79    for param in params.clone().into_iter() {
80        match param {
81            GenericParam::Type(ty) => {
82                ty_param_names.push(ty.ident.clone());
83                ty_param.push(ty);
84            }
85            GenericParam::Lifetime(_) => {}
86            GenericParam::Const(_) => {
87                return quote! { ::core::compile_error!("precompiling const generics is currently unsupported"); }.into();
88            }
89        }
90    }
91
92    let mut inputs = Punctuated::new();
93    let mut input_name = Vec::new();
94    let mut input_types = Vec::new();
95    let output_type = sig.output.clone();
96
97    for (idx, input) in sig.inputs.clone().into_iter().enumerate() {
98        match input {
99            FnArg::Receiver(_) => {
100                return quote! { ::core::compile_error!("only free functions can be precompiled"); }
101                    .into()
102            }
103            FnArg::Typed(mut ty) => {
104                let ident = Ident::new(&format!("__{idx}"), Span::call_site());
105                ty.pat = Box::new(Pat::Ident(PatIdent {
106                    attrs: Vec::new(),
107                    by_ref: None,
108                    mutability: None,
109                    ident: ident.clone(),
110                    subpat: None,
111                }));
112                let tyty = (*ty.ty).clone();
113                input_types.push(quote! { #tyty });
114                inputs.push_value(FnArg::Typed(ty));
115                inputs.push_punct(Comma {
116                    spans: [Span::call_site()],
117                });
118                input_name.push(ident);
119            }
120        }
121    }
122    sig.inputs = inputs.clone();
123
124    let spec_code = types
125        .into_iter()
126        .map(|types| {
127            let inner_no_generics = inner.clone();
128            let ItemFn {
129                attrs,
130                vis: _,
131                mut sig,
132                block: _,
133            } = inner_no_generics;
134
135            let Generics {
136                lt_token, gt_token, ..
137            } = sig.generics;
138
139            sig.inputs = inputs.clone();
140
141            sig.ident = Ident::new("__precompile_inner_impl_spec", Span::call_site());
142            sig.generics = Generics {
143                lt_token,
144                params: Punctuated::new(),
145                gt_token,
146                where_clause: None,
147            };
148
149            quote! {
150                #(type #ty_param_names = #types;)*
151                impl ::precompile::Impl for __PrecompileImplSpec<(#(#types,)*)> {
152                    const FN_PTR: *const () = {
153                        #(#attrs)*
154                        pub #sig {
155                            #[allow(unused_unsafe)]
156                            unsafe { __precompile_inner_impl::<#(#types,)*>(#(#input_name,)*) }
157                        }
158                        __precompile_inner_impl_spec as *const ()
159                    };
160                }
161            }
162        })
163        .collect::<Vec<_>>();
164
165    let abi = sig.abi.clone();
166
167    let code = quote! {
168        #(#attrs)*
169        #vis #sig {
170            #inner
171
172            struct __PrecompileImplGeneric<Inner>(::core::marker::PhantomData<Inner>);
173            struct __PrecompileImplSpec<Inner>(::core::marker::PhantomData<Inner>);
174
175            impl<#(#ty_param,)*> ::precompile::Impl for __PrecompileImplGeneric<(#(#ty_param_names,)*)> #where_clause {
176                const FN_PTR: *const () = __precompile_inner_impl::<#(#ty_param_names,)*> as *const ();
177            }
178
179            #({ #spec_code })*
180
181            unsafe {
182                use ::precompile::Impl;
183                let mut __fn_ptr: unsafe #abi fn(#(#input_types,)*) #output_type = ::core::mem::transmute(::precompile::pick(
184                    __PrecompileImplGeneric(::core::marker::PhantomData::<(#(#ty_param_names,)*)>),
185                    __PrecompileImplSpec   (::core::marker::PhantomData::<(#(#ty_param_names,)*)>),
186                ));
187
188                __fn_ptr(#(#input_name,)*)
189            }
190        }
191    };
192    code.into()
193}