closure_ffi_proc_macros/
lib.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, visit::Visit, visit_mut::VisitMut};
7
8struct MacroInput {
9    crate_path: syn::Path,
10    alias: syn::ItemType,
11    bare_fn: syn::TypeBareFn,
12}
13
14impl syn::parse::Parse for MacroInput {
15    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16        let crate_path = input.parse()?;
17        let _: syn::Token![,] = input.parse()?;
18        let alias: syn::ItemType = input.parse()?;
19
20        match &*alias.ty {
21            syn::Type::BareFn(bare_fn) => {
22                if let Some(lt) = &bare_fn.lifetimes {
23                    if lt.lifetimes.len() > 3 {
24                        return Err(syn::Error::new_spanned(
25                            &lt.lifetimes,
26                            "At most 3 higher-ranked lifetimes are supported",
27                        ));
28                    }
29                }
30
31                // Check that bare_fn has no implicit lifetimes
32                struct HasImplicitBoundLt(Vec<pm2::Span>);
33                impl<'a> Visit<'a> for HasImplicitBoundLt {
34                    fn visit_lifetime(&mut self, i: &'a syn::Lifetime) {
35                        if i.ident.to_string() == "_" {
36                            self.0.push(i.span());
37                        }
38                    }
39
40                    fn visit_type_reference(&mut self, i: &'a syn::TypeReference) {
41                        match i.lifetime {
42                            Some(_) => self.visit_type(&i.elem),
43                            None => self.0.push(i.and_token.span),
44                        }
45                    }
46                }
47                let mut implicit_lt_check = HasImplicitBoundLt(Vec::default());
48                implicit_lt_check.visit_type_bare_fn(bare_fn);
49
50                let mut implicit_lt_err = None;
51                for err_span in implicit_lt_check.0 {
52                    let err = syn::Error::new(
53                        err_span,
54                        "Implicit lifetimes are not permitted; you must name this lifetime",
55                    );
56                    match implicit_lt_err.as_mut() {
57                        None => implicit_lt_err = Some(err),
58                        Some(e) => e.combine(err),
59                    }
60                }
61                match implicit_lt_err {
62                    Some(err) => Err(err),
63                    None => Ok(Self {
64                        crate_path,
65                        bare_fn: bare_fn.clone(),
66                        alias,
67                    }),
68                }
69            }
70            other => Err(syn::Error::new_spanned(
71                other,
72                &format!(
73                    "Expected bare function type, got {}",
74                    other.to_token_stream().to_string()
75                ),
76            )),
77        }
78    }
79}
80
81fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
82    let fn_part = path.segments.last_mut().unwrap();
83    fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
84        paren_token: Default::default(),
85        inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
86        output: fun.output.clone(),
87    });
88
89    syn::TraitBound {
90        paren_token: None,
91        modifier: syn::TraitBoundModifier::None,
92        lifetimes: fun.lifetimes.clone(),
93        path,
94    }
95}
96
97fn bare_fn_to_sig(
98    fun: &syn::TypeBareFn,
99    ident: syn::Ident,
100    arg_idents: &[syn::Ident],
101) -> syn::Signature {
102    syn::Signature {
103        constness: None,
104        asyncness: None,
105        unsafety: fun.unsafety,
106        abi: fun.abi.clone(),
107        fn_token: syn::Token![fn](pm2::Span::call_site()),
108        ident,
109        generics: syn::Generics {
110            lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
111            params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
112            gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
113            where_clause: None,
114        },
115        paren_token: syn::token::Paren::default(),
116        inputs: fun
117            .inputs
118            .iter()
119            .enumerate()
120            .map(|(i, input)| {
121                syn::FnArg::Typed(syn::PatType {
122                    attrs: Default::default(),
123                    pat: Box::new(syn::Pat::Ident(syn::PatIdent {
124                        attrs: Default::default(),
125                        by_ref: None,
126                        mutability: None,
127                        ident: arg_idents[i].clone(),
128                        subpat: None,
129                    })),
130                    colon_token: syn::Token![:](pm2::Span::call_site()),
131                    ty: Box::new(input.ty.clone()),
132                })
133            })
134            .collect(),
135        variadic: None,
136        output: fun.output.clone(),
137    }
138}
139
140fn path_from_str(str: &str) -> syn::Path {
141    syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
142}
143
144struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
145
146impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
147    fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
148        self.0(i)
149    }
150}
151
152#[proc_macro]
153pub fn bare_hrtb(tokens: TokenStream) -> TokenStream {
154    let mut input = parse_macro_input!(tokens as MacroInput);
155
156    input
157        .bare_fn
158        .unsafety
159        .get_or_insert(syn::Token![unsafe](pm2::Span::call_site()));
160
161    let bare_fn = &input.bare_fn;
162
163    let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
164    let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
165        .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
166        .collect();
167
168    let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
169
170    let bare_fn_lt_idents = bare_fn
171        .lifetimes
172        .as_ref()
173        .map(|lt| {
174            lt.lifetimes
175                .iter()
176                .map(|p| match p {
177                    syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
178                    _ => unreachable!(),
179                })
180                .collect::<Vec<_>>()
181        })
182        .unwrap_or_default();
183
184    ReplaceLt(|lt| {
185        if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == &lt.ident.to_string()) {
186            lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
187        }
188    })
189    .visit_signature_mut(&mut thunk_sig);
190
191    let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
192    let cc_marker_ident = syn::Ident::new(
193        &format!("{}_CC", &input.alias.ident),
194        pm2::Span::call_site(),
195    );
196    let crate_path = &input.crate_path;
197
198    struct ImplDetails {
199        thunk_trait_path: &'static str,
200        fn_trait_path: &'static str,
201        const_name: &'static str,
202        body: pm2::TokenStream,
203    }
204
205    let impl_blocks: [ImplDetails; 3] = [
206        ImplDetails {
207            thunk_trait_path: "traits::FnOnceThunk",
208            fn_trait_path: "::core::ops::FnOnce",
209            const_name: "THUNK_TEMPLATE_ONCE",
210            body: quote! {
211                let closure_ptr: *mut #f_ident;
212                #crate_path::arch::_thunk_asm!(closure_ptr);
213                #crate_path::arch::_never_inline(|| closure_ptr.read()(#(#arg_idents),*))
214            },
215        },
216        ImplDetails {
217            thunk_trait_path: "traits::FnMutThunk",
218            fn_trait_path: "::core::ops::FnMut",
219            const_name: "THUNK_TEMPLATE_MUT",
220            body: quote! {
221                let closure_ptr: *mut #f_ident;
222                #crate_path::arch::_thunk_asm!(closure_ptr);
223                #crate_path::arch::_never_inline(|| (&mut *closure_ptr)(#(#arg_idents),*))
224            },
225        },
226        ImplDetails {
227            thunk_trait_path: "traits::FnThunk",
228            fn_trait_path: "::core::ops::Fn",
229            const_name: "THUNK_TEMPLATE",
230            body: quote! {
231                let closure_ptr: *const #f_ident;
232                #crate_path::arch::_thunk_asm!(closure_ptr);
233                #crate_path::arch::_never_inline(|| (&*closure_ptr)(#(#arg_idents),*))
234            },
235        },
236    ];
237
238    let alias_ident = &input.alias.ident;
239    let alias_attrs = &input.alias.attrs;
240    let alias_vis = &input.alias.vis;
241    let alias_gen = &input.alias.generics;
242    let (alias_impl_gen, alias_ty_params, alias_where) = &input.alias.generics.split_for_impl();
243
244    let impls = impl_blocks.iter().map(|impl_block| {
245        let fn_bound =
246            bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
247        let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
248        let body = &impl_block.body;
249        let mut thunk_trait = input.crate_path.clone();
250        thunk_trait.segments.extend(path_from_str(impl_block.thunk_trait_path).segments);
251
252        let mut generics = input.alias.generics.clone();
253        generics.params.push(syn::GenericParam::Type(syn::TypeParam {
254            attrs: Default::default(),
255            ident: f_ident.clone(),
256            colon_token: Some(syn::Token![:](pm2::Span::call_site())),
257            bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
258            eq_token: None,
259            default: None,
260        }));
261
262        let mut thunk_sig = thunk_sig.clone();
263        thunk_sig.generics.params.extend(generics.params.clone());
264
265        let (impl_generics, _, where_clause) = generics.split_for_impl();
266        let sig_tys = generics.type_params().map(|t| &t.ident);
267
268        quote! {
269            unsafe impl #impl_generics #thunk_trait<#alias_ident #alias_ty_params>
270            for (#cc_marker_ident, #f_ident) #where_clause
271            {
272                const #const_ident: *const ::core::primitive::u8 = {
273                    #thunk_sig {
274                        #body
275                    }
276                    #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
277                };
278            }
279        }
280    });
281
282    let alias_ident_lit = syn::LitStr::new(&alias_ident.to_string(), pm2::Span::call_site());
283    let alias_ident_doc_lit = syn::LitStr::new(
284        &format!("[`{}`].", alias_ident.to_string()),
285        pm2::Span::call_site(),
286    );
287
288    let mut punc_impl_lifetimes =
289        bare_fn.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default();
290    punc_impl_lifetimes.extend((punc_impl_lifetimes.len()..3).map(|i| {
291        syn::GenericParam::Lifetime(syn::LifetimeParam::new(syn::Lifetime::new(
292            &format!("'_extra_{i}"),
293            pm2::Span::call_site(),
294        )))
295    }));
296    let impl_lifetimes: Vec<_> = punc_impl_lifetimes.iter().collect();
297
298    let tuple_args = bare_fn.inputs.iter().map(|i| &i.ty);
299    let bare_fn_output = match &bare_fn.output {
300        syn::ReturnType::Default => &syn::Type::Tuple(syn::TypeTuple {
301            paren_token: syn::token::Paren(pm2::Span::call_site()),
302            elems: syn::punctuated::Punctuated::new(),
303        }),
304        syn::ReturnType::Type(_, ty) => &*ty,
305    };
306    let arg_indices = (0..bare_fn.inputs.len() as u32).map(|index| {
307        syn::Member::Unnamed(syn::Index {
308            index,
309            span: pm2::Span::call_site(),
310        })
311    });
312
313    quote! {
314        /// Calling convention marker type for higher-ranked bare function wrapper type
315        #[doc = #alias_ident_doc_lit]
316        #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy, ::core::default::Default)]
317        #alias_vis struct #cc_marker_ident;
318
319        #(#alias_attrs)*
320        #[repr(transparent)]
321        #alias_vis struct #alias_ident #alias_gen (pub #bare_fn) #alias_where;
322
323        impl #alias_impl_gen #alias_ident #alias_ty_params #alias_where {
324            /// Returns an instance of the calling convention marker type for this bare function.
325            pub fn cc() -> #cc_marker_ident {
326                #cc_marker_ident::default()
327            }
328        }
329
330        impl #alias_impl_gen ::core::clone::Clone for #alias_ident #alias_ty_params #alias_where {
331            fn clone(&self) -> Self {
332                Self(self.0)
333            }
334        }
335
336        impl #alias_impl_gen ::core::marker::Copy for #alias_ident #alias_ty_params #alias_where {}
337
338        impl #alias_impl_gen ::core::fmt::Debug for #alias_ident #alias_ty_params #alias_where {
339            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
340                f.debug_tuple(#alias_ident_lit)
341                    .field(&self.0)
342                    .finish()
343            }
344        }
345
346        impl #alias_impl_gen ::core::convert::From<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
347            fn from(value: #bare_fn) -> Self {
348                Self(value)
349            }
350        }
351
352        impl #alias_impl_gen ::core::convert::Into<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
353            fn into(self) -> #bare_fn {
354                self.0
355            }
356        }
357
358        impl #alias_impl_gen ::core::ops::Deref for #alias_ident #alias_ty_params #alias_where {
359            type Target = #bare_fn;
360
361            fn deref(&self) -> &Self::Target {
362                &self.0
363            }
364        }
365
366        unsafe impl #alias_impl_gen #crate_path::traits::FnPtr for #alias_ident #alias_ty_params #alias_where {
367            type CC = #cc_marker_ident;
368            type Args<#punc_impl_lifetimes> = (#(#tuple_args,)*) where Self: #(#impl_lifetimes)+*;
369            type Ret<#punc_impl_lifetimes> = #bare_fn_output where Self: #(#impl_lifetimes)+*;
370
371            #[inline(always)]
372            unsafe fn call<#punc_impl_lifetimes>(
373                self,
374                args: Self::Args<#punc_impl_lifetimes>
375            ) -> Self::Ret<#punc_impl_lifetimes>
376                where Self: #(#impl_lifetimes)+*
377            {
378                (self.0)(#(args.#arg_indices,)*)
379            }
380
381            #[inline(always)]
382            unsafe fn from_ptr(ptr: *const ()) -> Self {
383                unsafe { core::mem::transmute_copy(&ptr) }
384            }
385
386            #[inline(always)]
387            fn to_ptr(self) -> *const () {
388                self.0 as *const _
389            }
390        }
391
392        #(#impls)*
393    }
394    .into()
395}