closure_ffi_proc_macros/
lib.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::quote;
6use syn::{parse_macro_input, spanned::Spanned as _, visit_mut::VisitMut};
7
8// the Parse impl for syn::Generics ignores the where clause. This expects
9// it right after the generic parameters.
10struct GenericsWithWhere(syn::Generics);
11impl syn::parse::Parse for GenericsWithWhere {
12    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13        Ok(GenericsWithWhere({
14            let mut generics: syn::Generics = input.parse()?;
15            generics.where_clause = input.parse()?;
16            generics
17        }))
18    }
19}
20
21struct MacroInput {
22    attrs: Vec<syn::Attribute>,
23    generics: syn::Generics,
24    bare_fn: syn::TypeBareFn,
25}
26
27impl syn::parse::Parse for MacroInput {
28    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
29        let all_attrs = input.call(syn::Attribute::parse_outer)?;
30        let mut attrs = Vec::new();
31        let mut generics = None;
32
33        for attr in all_attrs {
34            if !attr.path().is_ident("with") {
35                attrs.push(attr);
36            }
37            else if generics.is_some() {
38                return Err(syn::Error::new_spanned(
39                    attr.path().get_ident(),
40                    "with attribute is already present",
41                ));
42            }
43            else {
44                let meta_list = attr.meta.require_list()?;
45                generics = Some(meta_list.parse_args::<GenericsWithWhere>()?.0);
46            }
47        }
48
49        Ok(Self {
50            attrs,
51            generics: generics.unwrap_or_default(),
52            bare_fn: input.parse()?,
53        })
54    }
55}
56
57fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
58    let fn_part = path.segments.last_mut().unwrap();
59    fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
60        paren_token: Default::default(),
61        inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
62        output: fun.output.clone(),
63    });
64
65    syn::TraitBound {
66        paren_token: None,
67        modifier: syn::TraitBoundModifier::None,
68        lifetimes: fun.lifetimes.clone(),
69        path,
70    }
71}
72
73fn bare_fn_to_sig(
74    fun: &syn::TypeBareFn,
75    ident: syn::Ident,
76    arg_idents: &[syn::Ident],
77) -> syn::Signature {
78    syn::Signature {
79        constness: None,
80        asyncness: None,
81        unsafety: fun.unsafety,
82        abi: fun.abi.clone(),
83        fn_token: syn::Token![fn](pm2::Span::call_site()),
84        ident,
85        generics: syn::Generics {
86            lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
87            params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
88            gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
89            where_clause: None,
90        },
91        paren_token: syn::token::Paren::default(),
92        inputs: fun
93            .inputs
94            .iter()
95            .enumerate()
96            .map(|(i, input)| {
97                syn::FnArg::Typed(syn::PatType {
98                    attrs: Default::default(),
99                    pat: Box::new(syn::Pat::Ident(syn::PatIdent {
100                        attrs: Default::default(),
101                        by_ref: None,
102                        mutability: None,
103                        ident: arg_idents[i].clone(),
104                        subpat: None,
105                    })),
106                    colon_token: syn::Token![:](pm2::Span::call_site()),
107                    ty: Box::new(input.ty.clone()),
108                })
109            })
110            .collect(),
111        variadic: None,
112        output: fun.output.clone(),
113    }
114}
115
116fn path_from_str(str: &str) -> syn::Path {
117    syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
118}
119
120struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
121
122impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
123    fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
124        self.0(i)
125    }
126}
127
128/// Creates an instance of an anonymous type which can be used as a calling convention
129/// for higher-kinded bare functions when instantiating bare closure wrappers.
130///
131/// For example, the following evaluates to an expression which can be passed to `BareFn*::new`
132/// to create an adapter for the closure of type *exactly* `unsafe extern "C" for<'a> fn(&'a str) ->
133/// &'a u32`:
134///
135/// ```ignore
136/// hrtb_cc!(extern "C" for<'a> fn(&'a str) -> &'a u32)
137/// ```
138///
139/// Note that the `unsafe` keyword is automatically added if not present.
140///
141/// The bare function signature can additionally contain generic arguments using the `#[with]`
142/// attribute:
143///
144/// ```ignore
145/// hrtb_cc!(#[with(<T>)] extern "C" for<'a> fn(&'a str) -> &'a T)
146/// ```
147///
148/// This hack is necessary as there is no way to blanket implement the `FnThunk` traits for all
149/// lifetime associations. For this reason, the following won't compile:
150///
151/// ```ignore
152/// use closure_ffi::BareFn;
153///
154/// fn take_higher_rank_fn(bare_fn: unsafe extern "C" fn(&Option<u32>) -> Option<&u32>) {}
155///
156/// let bare_closure = BareFn::new_c(|opt: &Option<u32>| opt.as_ref());
157/// take_higher_rank_fn(bare_closure.bare());
158/// ```
159///
160/// However, using the output of this macro as the calling convention, we can get it to work:
161///
162/// ```ignore
163/// use closure_ffi::BareFn;
164///
165/// fn take_higher_rank_fn(bare_fn: unsafe extern "C" fn(&Option<u32>) -> Option<&u32>) {}
166///
167/// let bare_closure = BareFn::new(
168///     hrtb_cc!(extern "C" fn(&Option<u32>) -> Option<&u32>),
169///     |opt| opt.as_ref()
170/// );
171/// take_higher_rank_fn(bare_closure.bare());
172/// ```
173#[proc_macro]
174pub fn hrtb_cc(tokens: TokenStream) -> TokenStream {
175    let mut input = parse_macro_input!(tokens as MacroInput);
176    input
177        .bare_fn
178        .unsafety
179        .get_or_insert(syn::Token![unsafe](pm2::Span::call_site()));
180
181    let attrs = &input.attrs;
182    let bare_fn = &input.bare_fn;
183
184    let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
185    let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
186        .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
187        .collect();
188
189    let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
190
191    let bare_fn_lt_idents = bare_fn
192        .lifetimes
193        .as_ref()
194        .map(|lt| {
195            lt.lifetimes
196                .iter()
197                .map(|p| match p {
198                    syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
199                    _ => unreachable!(),
200                })
201                .collect::<Vec<_>>()
202        })
203        .unwrap_or_default();
204
205    ReplaceLt(|lt| {
206        if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == &lt.ident.to_string()) {
207            lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
208        }
209    })
210    .visit_signature_mut(&mut thunk_sig);
211
212    let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
213
214    struct ImplDetails {
215        thunk_trait_path: &'static str,
216        fn_trait_path: &'static str,
217        const_name: &'static str,
218        body: pm2::TokenStream,
219    }
220
221    let impl_blocks: [ImplDetails; 3] = [
222        ImplDetails {
223            thunk_trait_path: "::closure_ffi::thunk::FnOnceThunk",
224            fn_trait_path: "::core::ops::FnOnce",
225            const_name: "THUNK_TEMPLATE_ONCE",
226            body: quote! {
227                let closure_ptr: *mut #f_ident;
228                ::closure_ffi::arch::_thunk_asm!(closure_ptr);
229                ::closure_ffi::thunk::_never_inline(|| closure_ptr.read()(#(#arg_idents),*))
230            },
231        },
232        ImplDetails {
233            thunk_trait_path: "::closure_ffi::thunk::FnMutThunk",
234            fn_trait_path: "::core::ops::FnMut",
235            const_name: "THUNK_TEMPLATE_MUT",
236            body: quote! {
237                let closure_ptr: *mut #f_ident;
238                ::closure_ffi::arch::_thunk_asm!(closure_ptr);
239                ::closure_ffi::thunk::_never_inline(|| (&mut *closure_ptr)(#(#arg_idents),*))
240            },
241        },
242        ImplDetails {
243            thunk_trait_path: "::closure_ffi::thunk::FnThunk",
244            fn_trait_path: "::core::ops::Fn",
245            const_name: "THUNK_TEMPLATE",
246            body: quote! {
247                let closure_ptr: *const #f_ident;
248                ::closure_ffi::arch::_thunk_asm!(closure_ptr);
249                ::closure_ffi::thunk::_never_inline(|| (&*closure_ptr)(#(#arg_idents),*))
250            },
251        },
252    ];
253
254    let impls = impl_blocks.iter().map(|impl_block| {
255        let fn_bound =
256            bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
257        let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
258        let body = &impl_block.body;
259        let thunk_trait = path_from_str(impl_block.thunk_trait_path);
260
261        let mut generics = input.generics.clone();
262        generics.params.push(syn::GenericParam::Type(syn::TypeParam {
263            attrs: Default::default(),
264            ident: f_ident.clone(),
265            colon_token: Some(syn::Token![:](pm2::Span::call_site())),
266            bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
267            eq_token: None,
268            default: None,
269        }));
270
271        let mut thunk_sig = thunk_sig.clone();
272        thunk_sig.generics.params.extend(generics.params.clone());
273
274        let (impl_generics, _, where_clause) = generics.split_for_impl();
275        let sig_tys = generics.type_params().map(|t| &t.ident);
276
277        quote! {
278            unsafe impl #impl_generics #thunk_trait<_CustomThunk, #bare_fn>
279            for (_CustomThunk, #f_ident) #where_clause
280            {
281                const #const_ident: *const ::core::primitive::u8 = {
282                    #thunk_sig {
283                        #body
284                    }
285                    #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
286                };
287            }
288        }
289    });
290
291    quote! {{
292        #(#attrs)*
293        #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy)]
294        struct _CustomThunk;
295
296        #(#impls)*
297
298        _CustomThunk
299    }}
300    .into()
301}
302
303struct BareDynInput {
304    dyn_trait: syn::TypeTraitObject,
305    bare_fn: pm2::TokenStream,
306    allocator: Option<syn::Type>,
307    type_path: pm2::TokenStream,
308}
309
310impl syn::parse::Parse for BareDynInput {
311    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
312        let abi: syn::LitStr = input.parse()?;
313        let _ = input.parse::<syn::Token![,]>()?;
314        let dyn_bounds =
315            syn::punctuated::Punctuated::<syn::TypeParamBound, syn::Token![+]>
316            ::parse_separated_nonempty(input)?;
317
318        let (bare_fn_tokens, type_path) = dyn_bounds
319            .iter()
320            .find_map(|bound| match bound {
321                syn::TypeParamBound::Trait(tb) => {
322                    tb.path.segments.last().and_then(|seg| match &seg.arguments {
323                        syn::PathArguments::Parenthesized(args) => {
324                            let bound_lt = &tb.lifetimes;
325                            let params = &args.inputs;
326                            let ret = &args.output;
327                            let bare_fn_tokens = quote! {
328                                #bound_lt unsafe extern #abi fn(#params) #ret
329                            };
330                            Some((
331                                bare_fn_tokens,
332                                match seg.ident.to_string().as_str() {
333                                    "FnOnce" => quote! { ::closure_ffi::BareFnOnce },
334                                    "FnMut" => quote! { ::closure_ffi::BareFnMut },
335                                    "Fn" => quote! { ::closure_ffi::BareFn },
336                                    _ => return None,
337                                },
338                            ))
339                        }
340                        _ => None,
341                    })
342                }
343                _ => None,
344            })
345            .ok_or_else(|| syn::Error::new(dyn_bounds.span(), "Expected a function trait"))?;
346
347        let allocator = input
348            .parse::<Option<syn::Token![,]>>()
349            .and_then(|comma| comma.map(|_| input.parse().map(Some)).unwrap_or(Ok(None)))?;
350
351        Ok(Self {
352            dyn_trait: syn::TypeTraitObject {
353                dyn_token: Some(syn::Token![dyn](pm2::Span::call_site())),
354                bounds: dyn_bounds,
355            },
356            bare_fn: bare_fn_tokens,
357            allocator,
358            type_path,
359        })
360    }
361}
362
363/// Shorthand for a `BareFn*` type taking a boxed closure.
364///
365/// Essentially,
366/// ```ignore
367/// type MyBareFnMut = bare_dyn!("C", FnMut(&u32) -> u32 + Send);
368/// ```
369/// becomes
370/// ```ignore
371/// type MyBareFnMut = BareFnMut<
372///     unsafe extern "C" fn(&u32) -> u32,
373///     Box<dyn FnMut(&u32) -> u32 + Send>
374/// >;
375/// ```
376///
377/// If desired, the JIT allocator used by the `BareFn*` closure wrapper can also be specified
378/// by passing it as a third parameter.
379#[proc_macro]
380pub fn bare_dyn(tokens: TokenStream) -> TokenStream {
381    let input = syn::parse_macro_input!(tokens as BareDynInput);
382    let type_path = &input.type_path;
383    let bare_fn = &input.bare_fn;
384    let dyn_trait = &input.dyn_trait;
385    let allocator = &input.allocator;
386
387    quote! {
388        #type_path::<#bare_fn, ::closure_ffi::bare_closure::Box<#dyn_trait>, #allocator>
389    }
390    .into()
391}