closure_ffi_proc_macros/
lib.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, visit::Visit, visit_mut::VisitMut};
7
8struct MacroInput {
9    thunk_attrs: Vec<syn::Attribute>,
10    crate_path: syn::Path,
11    alias: syn::ItemType,
12    bare_fn: syn::TypeBareFn,
13}
14
15impl syn::parse::Parse for MacroInput {
16    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17        let thunk_attrs = input.call(syn::Attribute::parse_outer)?;
18        let crate_path = input.parse()?;
19        let _: syn::Token![,] = input.parse()?;
20        let alias: syn::ItemType = input.parse()?;
21
22        match &*alias.ty {
23            syn::Type::BareFn(bare_fn) => {
24                if let Some(lt) = &bare_fn.lifetimes {
25                    if lt.lifetimes.len() > 3 {
26                        return Err(syn::Error::new_spanned(
27                            &lt.lifetimes,
28                            "At most 3 higher-ranked lifetimes are supported",
29                        ));
30                    }
31                }
32
33                // Check that bare_fn has no implicit lifetimes
34                struct HasImplicitBoundLt(Vec<pm2::Span>);
35                impl<'a> Visit<'a> for HasImplicitBoundLt {
36                    fn visit_lifetime(&mut self, i: &'a syn::Lifetime) {
37                        if i.ident == "_" {
38                            self.0.push(i.span());
39                        }
40                    }
41
42                    fn visit_type_reference(&mut self, i: &'a syn::TypeReference) {
43                        match i.lifetime {
44                            Some(_) => self.visit_type(&i.elem),
45                            None => self.0.push(i.and_token.span),
46                        }
47                    }
48                }
49                let mut implicit_lt_check = HasImplicitBoundLt(Vec::default());
50                implicit_lt_check.visit_type_bare_fn(bare_fn);
51
52                let mut implicit_lt_err = None;
53                for err_span in implicit_lt_check.0 {
54                    let err = syn::Error::new(
55                        err_span,
56                        "Implicit lifetimes are not permitted; you must name this lifetime",
57                    );
58                    match implicit_lt_err.as_mut() {
59                        None => implicit_lt_err = Some(err),
60                        Some(e) => e.combine(err),
61                    }
62                }
63                match implicit_lt_err {
64                    Some(err) => Err(err),
65                    None => Ok(Self {
66                        thunk_attrs,
67                        crate_path,
68                        bare_fn: bare_fn.clone(),
69                        alias,
70                    }),
71                }
72            }
73            other => Err(syn::Error::new_spanned(
74                other,
75                format!(
76                    "Expected bare function type, got {}",
77                    other.to_token_stream()
78                ),
79            )),
80        }
81    }
82}
83
84fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
85    let fn_part = path.segments.last_mut().unwrap();
86    fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
87        paren_token: Default::default(),
88        inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
89        output: fun.output.clone(),
90    });
91
92    syn::TraitBound {
93        paren_token: None,
94        modifier: syn::TraitBoundModifier::None,
95        lifetimes: fun.lifetimes.clone(),
96        path,
97    }
98}
99
100fn bare_fn_to_sig(
101    fun: &syn::TypeBareFn,
102    ident: syn::Ident,
103    arg_idents: &[syn::Ident],
104) -> syn::Signature {
105    syn::Signature {
106        constness: None,
107        asyncness: None,
108        unsafety: fun.unsafety,
109        abi: fun.abi.clone(),
110        fn_token: syn::Token![fn](pm2::Span::call_site()),
111        ident,
112        generics: syn::Generics {
113            lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
114            params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
115            gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
116            where_clause: None,
117        },
118        paren_token: syn::token::Paren::default(),
119        inputs: fun
120            .inputs
121            .iter()
122            .enumerate()
123            .map(|(i, input)| {
124                syn::FnArg::Typed(syn::PatType {
125                    attrs: Default::default(),
126                    pat: Box::new(syn::Pat::Ident(syn::PatIdent {
127                        attrs: Default::default(),
128                        by_ref: None,
129                        mutability: None,
130                        ident: arg_idents[i].clone(),
131                        subpat: None,
132                    })),
133                    colon_token: syn::Token![:](pm2::Span::call_site()),
134                    ty: Box::new(input.ty.clone()),
135                })
136            })
137            .collect(),
138        variadic: None,
139        output: fun.output.clone(),
140    }
141}
142
143fn path_from_str(str: &str) -> syn::Path {
144    syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
145}
146
147struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
148
149impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
150    fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
151        self.0(i)
152    }
153}
154
155#[proc_macro]
156pub fn bare_hrtb(tokens: TokenStream) -> TokenStream {
157    let mut input = parse_macro_input!(tokens as MacroInput);
158
159    input
160        .bare_fn
161        .unsafety
162        .get_or_insert(syn::Token![unsafe](pm2::Span::call_site()));
163
164    let bare_fn = &input.bare_fn;
165
166    let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
167    let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
168        .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
169        .collect();
170
171    let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
172
173    let bare_fn_lt_idents = bare_fn
174        .lifetimes
175        .as_ref()
176        .map(|lt| {
177            lt.lifetimes
178                .iter()
179                .map(|p| match p {
180                    syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
181                    _ => unreachable!(),
182                })
183                .collect::<Vec<_>>()
184        })
185        .unwrap_or_default();
186
187    ReplaceLt(|lt| {
188        if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == &lt.ident.to_string()) {
189            lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
190        }
191    })
192    .visit_signature_mut(&mut thunk_sig);
193
194    let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
195    let cc_marker_ident = syn::Ident::new(
196        &format!("{}_CC", &input.alias.ident),
197        pm2::Span::call_site(),
198    );
199    let crate_path = &input.crate_path;
200    let thunk_attrs = &input.thunk_attrs;
201    let alias_ident = &input.alias.ident;
202    let alias_attrs = &input.alias.attrs;
203    let alias_vis = &input.alias.vis;
204    let alias_gen = &input.alias.generics;
205    let (alias_impl_gen, alias_ty_params, alias_where) = &input.alias.generics.split_for_impl();
206    let mut impl_lifetimes =
207        bare_fn.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default();
208    impl_lifetimes.extend((impl_lifetimes.len()..3).map(|i| {
209        syn::GenericParam::Lifetime(syn::LifetimeParam::new(syn::Lifetime::new(
210            &format!("'_extra_{i}"),
211            pm2::Span::call_site(),
212        )))
213    }));
214    let arg_indices: Vec<_> = (0..bare_fn.inputs.len() as u32)
215        .map(|index| {
216            syn::Member::Unnamed(syn::Index {
217                index,
218                span: pm2::Span::call_site(),
219            })
220        })
221        .collect();
222
223    struct ImplDetails {
224        thunk_trait_path: &'static str,
225        fn_trait_path: &'static str,
226        const_name: &'static str,
227        call_ident: &'static str,
228        call_receiver: pm2::TokenStream,
229        thunk_body: pm2::TokenStream,
230        factory_ident: &'static str,
231        factory_bound: &'static str,
232    }
233
234    let impl_blocks: [ImplDetails; 3] = [
235        ImplDetails {
236            thunk_trait_path: "traits::FnOnceThunk",
237            fn_trait_path: "::core::ops::FnOnce",
238            const_name: "THUNK_TEMPLATE_ONCE",
239            call_ident: "call_once",
240            factory_ident: "make_once_thunk",
241            factory_bound: "traits::PackedFnOnce",
242            call_receiver: syn::Token![self](pm2::Span::call_site()).into_token_stream(),
243            thunk_body: quote! {
244                if const { ::core::mem::size_of::<#f_ident>() == 0 } {
245                    let closure: #f_ident = unsafe { ::core::mem::zeroed() };
246                    closure(#(#arg_idents),*)
247                }
248                else {
249                    let closure_ptr: *mut #f_ident;
250                    #crate_path::arch::_thunk_asm!(closure_ptr);
251                    #crate_path::arch::_invoke(|| closure_ptr.read()(#(#arg_idents),*))
252                }
253            },
254        },
255        ImplDetails {
256            thunk_trait_path: "traits::FnMutThunk",
257            fn_trait_path: "::core::ops::FnMut",
258            const_name: "THUNK_TEMPLATE_MUT",
259            call_ident: "call_mut",
260            factory_ident: "make_mut_thunk",
261            factory_bound: "traits::PackedFnMut",
262            call_receiver: quote! { &mut self },
263            thunk_body: quote! {
264                if const { ::core::mem::size_of::<#f_ident>() == 0 } {
265                    let closure: &mut #f_ident = unsafe { &mut *::core::ptr::dangling_mut() };
266                    closure(#(#arg_idents),*)
267                }
268                else {
269                    let closure_ptr: *mut #f_ident;
270                    #crate_path::arch::_thunk_asm!(closure_ptr);
271                    #crate_path::arch::_invoke(|| (&mut *closure_ptr)(#(#arg_idents),*))
272                }
273            },
274        },
275        ImplDetails {
276            thunk_trait_path: "traits::FnThunk",
277            fn_trait_path: "::core::ops::Fn",
278            const_name: "THUNK_TEMPLATE",
279            call_ident: "call",
280            factory_ident: "make_thunk",
281            factory_bound: "traits::PackedFn",
282            call_receiver: quote! { &self },
283            thunk_body: quote! {
284                if const { ::core::mem::size_of::<#f_ident>() == 0 } {
285                    let closure: &#f_ident = unsafe { &*::core::ptr::dangling() };
286                    closure(#(#arg_idents),*)
287                }
288                else {
289                    let closure_ptr: *const #f_ident;
290                    #crate_path::arch::_thunk_asm!(closure_ptr);
291                    #crate_path::arch::_invoke(|| (&*closure_ptr)(#(#arg_idents),*))
292                }
293            },
294        },
295    ];
296
297    struct ImplTokens {
298        thunk_impl: pm2::TokenStream,
299        thunk_factory_impl: pm2::TokenStream,
300    }
301
302    let impl_tokens: Vec<_> = impl_blocks.iter().map(|impl_block| {
303        let fn_bound =
304            bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
305
306        let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
307        let call_ident = syn::Ident::new(impl_block.call_ident, pm2::Span::call_site());
308        let call_receiver = &impl_block.call_receiver;
309        let body = &impl_block.thunk_body;
310        let mut thunk_trait = input.crate_path.clone();
311        thunk_trait.segments.extend(path_from_str(impl_block.thunk_trait_path).segments);
312
313        let mut generics = input.alias.generics.clone();
314        generics.params.push(syn::GenericParam::Type(syn::TypeParam {
315            attrs: Default::default(),
316            ident: f_ident.clone(),
317            colon_token: Some(syn::Token![:](pm2::Span::call_site())),
318            bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
319            eq_token: None,
320            default: None,
321        }));
322
323        let mut thunk_sig = thunk_sig.clone();
324        thunk_sig.generics.params.extend(generics.params.clone());
325
326        let (impl_generics, _, where_clause) = generics.split_for_impl();
327        let sig_tys = generics.type_params().map(|t| &t.ident);
328
329        let thunk_impl = quote! {
330            unsafe impl #impl_generics #thunk_trait<#alias_ident #alias_ty_params>
331            for (#cc_marker_ident, #f_ident) #where_clause
332            {
333                const #const_ident: *const ::core::primitive::u8 = {
334                    #(#thunk_attrs)*
335                    #[allow(clippy::too_many_arguments)]
336                    #thunk_sig {
337                        #body
338                    }
339                    #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
340                };
341
342                #[allow(unused_variables)]
343                #[inline(always)]
344                unsafe fn #call_ident<'_x, '_y, '_z>(#call_receiver, args: <#alias_ident #alias_ty_params as #crate_path::traits::FnPtr>::Args<'_x, '_y, '_z>
345                ) ->
346                    <#alias_ident #alias_ty_params as #crate_path::traits::FnPtr>::Ret<'_x, '_y, '_z>
347                {
348                    (self.1)(#(args.#arg_indices,)*)
349                }
350            }
351        };
352
353        let factory_bound_path = path_from_str(impl_block.factory_bound);
354        let factory_ident = syn::Ident::new(impl_block.factory_ident, pm2::Span::call_site());
355        let thunk_factory_impl = quote! {
356            #[inline(always)]
357            #[allow(unused_mut)]
358            fn #factory_ident<F>(mut fun: F) -> impl #thunk_trait<Self>
359            where
360                F: for<'_x, '_y, '_z> #crate_path::#factory_bound_path<'_x, '_y, '_z, Self>
361            {
362                #[inline(always)]
363                fn coerce #impl_generics (fun: #f_ident) -> #f_ident #where_clause {
364                    fun
365                }
366
367                let coerced = coerce(move |#(#arg_idents,)*| fun((#(#arg_idents,)*)));
368                (#cc_marker_ident, coerced)
369            }
370        };
371
372        ImplTokens {
373            thunk_impl,
374            thunk_factory_impl
375        }
376    }).collect();
377
378    let alias_ident_lit = syn::LitStr::new(&alias_ident.to_string(), pm2::Span::call_site());
379    let alias_ident_doc_lit =
380        syn::LitStr::new(&format!("[`{alias_ident}`]."), pm2::Span::call_site());
381
382    let tuple_args = bare_fn.inputs.iter().map(|i| &i.ty);
383    let bare_fn_output = match &bare_fn.output {
384        syn::ReturnType::Default => &syn::Type::Tuple(syn::TypeTuple {
385            paren_token: syn::token::Paren(pm2::Span::call_site()),
386            elems: syn::punctuated::Punctuated::new(),
387        }),
388        syn::ReturnType::Type(_, ty) => ty,
389    };
390
391    let trait_impls = impl_tokens.iter().map(|t| &t.thunk_impl);
392    let thunk_factory_impls = impl_tokens.iter().map(|t| &t.thunk_factory_impl);
393
394    quote! {
395        /// Calling convention marker type for higher-ranked bare function wrapper type
396        #[doc = #alias_ident_doc_lit]
397        #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy, ::core::default::Default)]
398        #alias_vis struct #cc_marker_ident;
399
400        #(#alias_attrs)*
401        #[repr(transparent)]
402        #alias_vis struct #alias_ident #alias_gen (pub #bare_fn) #alias_where;
403
404        impl #alias_impl_gen #alias_ident #alias_ty_params #alias_where {
405            /// Returns an instance of the calling convention marker type for this bare function.
406            pub fn cc() -> #cc_marker_ident {
407                #cc_marker_ident::default()
408            }
409        }
410
411        impl #alias_impl_gen ::core::clone::Clone for #alias_ident #alias_ty_params #alias_where {
412            fn clone(&self) -> Self {
413                Self(self.0)
414            }
415        }
416
417        impl #alias_impl_gen ::core::marker::Copy for #alias_ident #alias_ty_params #alias_where {}
418
419        impl #alias_impl_gen ::core::fmt::Debug for #alias_ident #alias_ty_params #alias_where {
420            fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
421                f.debug_tuple(#alias_ident_lit)
422                    .field(&self.0)
423                    .finish()
424            }
425        }
426
427        impl #alias_impl_gen ::core::convert::From<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
428            fn from(value: #bare_fn) -> Self {
429                Self(value)
430            }
431        }
432
433        impl #alias_impl_gen ::core::convert::From<#alias_ident #alias_ty_params> for #bare_fn #alias_where {
434            fn from(value: #alias_ident #alias_ty_params) -> Self {
435                value.0
436            }
437        }
438
439        impl #alias_impl_gen ::core::ops::Deref for #alias_ident #alias_ty_params #alias_where {
440            type Target = #bare_fn;
441
442            fn deref(&self) -> &Self::Target {
443                &self.0
444            }
445        }
446
447        unsafe impl #alias_impl_gen #crate_path::traits::FnPtr for #alias_ident #alias_ty_params #alias_where {
448            type CC = #cc_marker_ident;
449            type Args<#impl_lifetimes> = (#(#tuple_args,)*);
450            type Ret<#impl_lifetimes> = #bare_fn_output;
451
452            #[inline(always)]
453            unsafe fn call<#impl_lifetimes>(
454                self,
455                args: Self::Args<#impl_lifetimes>
456            ) -> Self::Ret<#impl_lifetimes>
457            {
458                (self.0)(#(args.#arg_indices,)*)
459            }
460
461            #[inline(always)]
462            unsafe fn from_ptr(ptr: *const ()) -> Self {
463                unsafe { core::mem::transmute_copy(&ptr) }
464            }
465
466            #[inline(always)]
467            fn to_ptr(self) -> *const () {
468                self.0 as *const _
469            }
470
471            #(#thunk_factory_impls)*
472        }
473
474        #(#trait_impls)*
475    }
476    .into()
477}