1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::quote;
6use syn::{parse_macro_input, spanned::Spanned as _, visit_mut::VisitMut};
7
8struct GenericsWithWhere(syn::Generics);
11impl syn::parse::Parse for GenericsWithWhere {
12 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
13 Ok(GenericsWithWhere({
14 let mut generics: syn::Generics = input.parse()?;
15 generics.where_clause = input.parse()?;
16 generics
17 }))
18 }
19}
20
21struct MacroInput {
22 attrs: Vec<syn::Attribute>,
23 generics: syn::Generics,
24 bare_fn: syn::TypeBareFn,
25}
26
27impl syn::parse::Parse for MacroInput {
28 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
29 let all_attrs = input.call(syn::Attribute::parse_outer)?;
30 let mut attrs = Vec::new();
31 let mut generics = None;
32
33 for attr in all_attrs {
34 if !attr.path().is_ident("with") {
35 attrs.push(attr);
36 }
37 else if generics.is_some() {
38 return Err(syn::Error::new_spanned(
39 attr.path().get_ident(),
40 "with attribute is already present",
41 ));
42 }
43 else {
44 let meta_list = attr.meta.require_list()?;
45 generics = Some(meta_list.parse_args::<GenericsWithWhere>()?.0);
46 }
47 }
48
49 Ok(Self {
50 attrs,
51 generics: generics.unwrap_or_default(),
52 bare_fn: input.parse()?,
53 })
54 }
55}
56
57fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
58 let fn_part = path.segments.last_mut().unwrap();
59 fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
60 paren_token: Default::default(),
61 inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
62 output: fun.output.clone(),
63 });
64
65 syn::TraitBound {
66 paren_token: None,
67 modifier: syn::TraitBoundModifier::None,
68 lifetimes: fun.lifetimes.clone(),
69 path,
70 }
71}
72
73fn bare_fn_to_sig(
74 fun: &syn::TypeBareFn,
75 ident: syn::Ident,
76 arg_idents: &[syn::Ident],
77) -> syn::Signature {
78 syn::Signature {
79 constness: None,
80 asyncness: None,
81 unsafety: fun.unsafety,
82 abi: fun.abi.clone(),
83 fn_token: syn::Token),
84 ident,
85 generics: syn::Generics {
86 lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
87 params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
88 gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
89 where_clause: None,
90 },
91 paren_token: syn::token::Paren::default(),
92 inputs: fun
93 .inputs
94 .iter()
95 .enumerate()
96 .map(|(i, input)| {
97 syn::FnArg::Typed(syn::PatType {
98 attrs: Default::default(),
99 pat: Box::new(syn::Pat::Ident(syn::PatIdent {
100 attrs: Default::default(),
101 by_ref: None,
102 mutability: None,
103 ident: arg_idents[i].clone(),
104 subpat: None,
105 })),
106 colon_token: syn::Token),
107 ty: Box::new(input.ty.clone()),
108 })
109 })
110 .collect(),
111 variadic: None,
112 output: fun.output.clone(),
113 }
114}
115
116fn path_from_str(str: &str) -> syn::Path {
117 syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
118}
119
120struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
121
122impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
123 fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
124 self.0(i)
125 }
126}
127
128#[proc_macro]
174pub fn hrtb_cc(tokens: TokenStream) -> TokenStream {
175 let mut input = parse_macro_input!(tokens as MacroInput);
176 input
177 .bare_fn
178 .unsafety
179 .get_or_insert(syn::Token));
180
181 let attrs = &input.attrs;
182 let bare_fn = &input.bare_fn;
183
184 let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
185 let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
186 .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
187 .collect();
188
189 let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
190
191 let bare_fn_lt_idents = bare_fn
192 .lifetimes
193 .as_ref()
194 .map(|lt| {
195 lt.lifetimes
196 .iter()
197 .map(|p| match p {
198 syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
199 _ => unreachable!(),
200 })
201 .collect::<Vec<_>>()
202 })
203 .unwrap_or_default();
204
205 ReplaceLt(|lt| {
206 if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == <.ident.to_string()) {
207 lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
208 }
209 })
210 .visit_signature_mut(&mut thunk_sig);
211
212 let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
213
214 struct ImplDetails {
215 thunk_trait_path: &'static str,
216 fn_trait_path: &'static str,
217 const_name: &'static str,
218 body: pm2::TokenStream,
219 }
220
221 let impl_blocks: [ImplDetails; 3] = [
222 ImplDetails {
223 thunk_trait_path: "::closure_ffi::thunk::FnOnceThunk",
224 fn_trait_path: "::core::ops::FnOnce",
225 const_name: "THUNK_TEMPLATE_ONCE",
226 body: quote! {
227 let closure_ptr: *mut #f_ident;
228 ::closure_ffi::arch::_thunk_asm!(closure_ptr);
229 closure_ptr.read()(#(#arg_idents),*)
230 },
231 },
232 ImplDetails {
233 thunk_trait_path: "::closure_ffi::thunk::FnMutThunk",
234 fn_trait_path: "::core::ops::FnMut",
235 const_name: "THUNK_TEMPLATE_MUT",
236 body: quote! {
237 let closure_ptr: *mut #f_ident;
238 ::closure_ffi::arch::_thunk_asm!(closure_ptr);
239 (&mut *closure_ptr)(#(#arg_idents),*)
240 },
241 },
242 ImplDetails {
243 thunk_trait_path: "::closure_ffi::thunk::FnThunk",
244 fn_trait_path: "::core::ops::Fn",
245 const_name: "THUNK_TEMPLATE",
246 body: quote! {
247 let closure_ptr: *const #f_ident;
248 ::closure_ffi::arch::_thunk_asm!(closure_ptr);
249 (&*closure_ptr)(#(#arg_idents),*)
250 },
251 },
252 ];
253
254 let impls = impl_blocks.iter().map(|impl_block| {
255 let fn_bound =
256 bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
257 let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
258 let body = &impl_block.body;
259 let thunk_trait = path_from_str(impl_block.thunk_trait_path);
260
261 let mut generics = input.generics.clone();
262 generics.params.push(syn::GenericParam::Type(syn::TypeParam {
263 attrs: Default::default(),
264 ident: f_ident.clone(),
265 colon_token: Some(syn::Token)),
266 bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
267 eq_token: None,
268 default: None,
269 }));
270
271 let mut thunk_sig = thunk_sig.clone();
272 thunk_sig.generics.params.extend(generics.params.clone());
273
274 let (impl_generics, _, where_clause) = generics.split_for_impl();
275 let sig_tys = generics.type_params().map(|t| &t.ident);
276
277 quote! {
278 unsafe impl #impl_generics #thunk_trait<_CustomThunk, #bare_fn>
279 for (_CustomThunk, #f_ident) #where_clause
280 {
281 const #const_ident: *const ::core::primitive::u8 = {
282 #thunk_sig {
283 #body
284 }
285 #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
286 };
287 }
288 }
289 });
290
291 quote! {{
292 #(#attrs)*
293 #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy)]
294 struct _CustomThunk;
295
296 #(#impls)*
297
298 _CustomThunk
299 }}
300 .into()
301}
302
303struct BareDynInput {
304 dyn_trait: syn::TypeTraitObject,
305 bare_fn: pm2::TokenStream,
306 allocator: Option<syn::Type>,
307 type_path: pm2::TokenStream,
308}
309
310impl syn::parse::Parse for BareDynInput {
311 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
312 let abi: syn::LitStr = input.parse()?;
313 let _ = input.parse::<syn::Token![,]>()?;
314 let dyn_bounds =
315 syn::punctuated::Punctuated::<syn::TypeParamBound, syn::Token![+]>
316 ::parse_separated_nonempty(input)?;
317
318 let (bare_fn_tokens, type_path) = dyn_bounds
319 .iter()
320 .find_map(|bound| match bound {
321 syn::TypeParamBound::Trait(tb) => {
322 tb.path.segments.last().and_then(|seg| match &seg.arguments {
323 syn::PathArguments::Parenthesized(args) => {
324 let bound_lt = &tb.lifetimes;
325 let params = &args.inputs;
326 let ret = &args.output;
327 let bare_fn_tokens = quote! {
328 #bound_lt unsafe extern #abi fn(#params) #ret
329 };
330 Some((
331 bare_fn_tokens,
332 match seg.ident.to_string().as_str() {
333 "FnOnce" => quote! { ::closure_ffi::BareFnOnce },
334 "FnMut" => quote! { ::closure_ffi::BareFnMut },
335 "Fn" => quote! { ::closure_ffi::BareFn },
336 _ => return None,
337 },
338 ))
339 }
340 _ => None,
341 })
342 }
343 _ => None,
344 })
345 .ok_or_else(|| syn::Error::new(dyn_bounds.span(), "Expected a function trait"))?;
346
347 let allocator = input
348 .parse::<Option<syn::Token![,]>>()
349 .and_then(|comma| comma.map(|_| input.parse().map(Some)).unwrap_or(Ok(None)))?;
350
351 Ok(Self {
352 dyn_trait: syn::TypeTraitObject {
353 dyn_token: Some(syn::Token)),
354 bounds: dyn_bounds,
355 },
356 bare_fn: bare_fn_tokens,
357 allocator,
358 type_path,
359 })
360 }
361}
362
363#[proc_macro]
380pub fn bare_dyn(tokens: TokenStream) -> TokenStream {
381 let input = syn::parse_macro_input!(tokens as BareDynInput);
382 let type_path = &input.type_path;
383 let bare_fn = &input.bare_fn;
384 let dyn_trait = &input.dyn_trait;
385 let allocator = &input.allocator;
386
387 quote! {
388 #type_path::<#bare_fn, ::closure_ffi::bare_closure::Box<#dyn_trait>, #allocator>
389 }
390 .into()
391}