1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, visit::Visit, visit_mut::VisitMut};
7
8struct MacroInput {
9 thunk_attrs: Vec<syn::Attribute>,
10 crate_path: syn::Path,
11 alias: syn::ItemType,
12 bare_fn: syn::TypeBareFn,
13}
14
15impl syn::parse::Parse for MacroInput {
16 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17 let thunk_attrs = input.call(syn::Attribute::parse_outer)?;
18 let crate_path = input.parse()?;
19 let _: syn::Token![,] = input.parse()?;
20 let alias: syn::ItemType = input.parse()?;
21
22 match &*alias.ty {
23 syn::Type::BareFn(bare_fn) => {
24 if let Some(lt) = &bare_fn.lifetimes {
25 if lt.lifetimes.len() > 3 {
26 return Err(syn::Error::new_spanned(
27 <.lifetimes,
28 "At most 3 higher-ranked lifetimes are supported",
29 ));
30 }
31 }
32
33 struct HasImplicitBoundLt(Vec<pm2::Span>);
35 impl<'a> Visit<'a> for HasImplicitBoundLt {
36 fn visit_lifetime(&mut self, i: &'a syn::Lifetime) {
37 if i.ident == "_" {
38 self.0.push(i.span());
39 }
40 }
41
42 fn visit_type_reference(&mut self, i: &'a syn::TypeReference) {
43 match i.lifetime {
44 Some(_) => self.visit_type(&i.elem),
45 None => self.0.push(i.and_token.span),
46 }
47 }
48 }
49 let mut implicit_lt_check = HasImplicitBoundLt(Vec::default());
50 implicit_lt_check.visit_type_bare_fn(bare_fn);
51
52 let mut implicit_lt_err = None;
53 for err_span in implicit_lt_check.0 {
54 let err = syn::Error::new(
55 err_span,
56 "Implicit lifetimes are not permitted; you must name this lifetime",
57 );
58 match implicit_lt_err.as_mut() {
59 None => implicit_lt_err = Some(err),
60 Some(e) => e.combine(err),
61 }
62 }
63 match implicit_lt_err {
64 Some(err) => Err(err),
65 None => Ok(Self {
66 thunk_attrs,
67 crate_path,
68 bare_fn: bare_fn.clone(),
69 alias,
70 }),
71 }
72 }
73 other => Err(syn::Error::new_spanned(
74 other,
75 format!(
76 "Expected bare function type, got {}",
77 other.to_token_stream()
78 ),
79 )),
80 }
81 }
82}
83
84fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
85 let fn_part = path.segments.last_mut().unwrap();
86 fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
87 paren_token: Default::default(),
88 inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
89 output: fun.output.clone(),
90 });
91
92 syn::TraitBound {
93 paren_token: None,
94 modifier: syn::TraitBoundModifier::None,
95 lifetimes: fun.lifetimes.clone(),
96 path,
97 }
98}
99
100fn bare_fn_to_sig(
101 fun: &syn::TypeBareFn,
102 ident: syn::Ident,
103 arg_idents: &[syn::Ident],
104) -> syn::Signature {
105 syn::Signature {
106 constness: None,
107 asyncness: None,
108 unsafety: fun.unsafety,
109 abi: fun.abi.clone(),
110 fn_token: syn::Token),
111 ident,
112 generics: syn::Generics {
113 lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
114 params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
115 gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
116 where_clause: None,
117 },
118 paren_token: syn::token::Paren::default(),
119 inputs: fun
120 .inputs
121 .iter()
122 .enumerate()
123 .map(|(i, input)| {
124 syn::FnArg::Typed(syn::PatType {
125 attrs: Default::default(),
126 pat: Box::new(syn::Pat::Ident(syn::PatIdent {
127 attrs: Default::default(),
128 by_ref: None,
129 mutability: None,
130 ident: arg_idents[i].clone(),
131 subpat: None,
132 })),
133 colon_token: syn::Token),
134 ty: Box::new(input.ty.clone()),
135 })
136 })
137 .collect(),
138 variadic: None,
139 output: fun.output.clone(),
140 }
141}
142
143fn path_from_str(str: &str) -> syn::Path {
144 syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
145}
146
147struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
148
149impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
150 fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
151 self.0(i)
152 }
153}
154
155#[proc_macro]
156pub fn bare_hrtb(tokens: TokenStream) -> TokenStream {
157 let mut input = parse_macro_input!(tokens as MacroInput);
158
159 input
160 .bare_fn
161 .unsafety
162 .get_or_insert(syn::Token));
163
164 let bare_fn = &input.bare_fn;
165
166 let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
167 let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
168 .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
169 .collect();
170
171 let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
172
173 let bare_fn_lt_idents = bare_fn
174 .lifetimes
175 .as_ref()
176 .map(|lt| {
177 lt.lifetimes
178 .iter()
179 .map(|p| match p {
180 syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
181 _ => unreachable!(),
182 })
183 .collect::<Vec<_>>()
184 })
185 .unwrap_or_default();
186
187 ReplaceLt(|lt| {
188 if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == <.ident.to_string()) {
189 lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
190 }
191 })
192 .visit_signature_mut(&mut thunk_sig);
193
194 let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
195 let cc_marker_ident = syn::Ident::new(
196 &format!("{}_CC", &input.alias.ident),
197 pm2::Span::call_site(),
198 );
199 let crate_path = &input.crate_path;
200 let thunk_attrs = &input.thunk_attrs;
201 let alias_ident = &input.alias.ident;
202 let alias_attrs = &input.alias.attrs;
203 let alias_vis = &input.alias.vis;
204 let alias_gen = &input.alias.generics;
205 let (alias_impl_gen, alias_ty_params, alias_where) = &input.alias.generics.split_for_impl();
206 let mut impl_lifetimes =
207 bare_fn.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default();
208 impl_lifetimes.extend((impl_lifetimes.len()..3).map(|i| {
209 syn::GenericParam::Lifetime(syn::LifetimeParam::new(syn::Lifetime::new(
210 &format!("'_extra_{i}"),
211 pm2::Span::call_site(),
212 )))
213 }));
214 let arg_indices: Vec<_> = (0..bare_fn.inputs.len() as u32)
215 .map(|index| {
216 syn::Member::Unnamed(syn::Index {
217 index,
218 span: pm2::Span::call_site(),
219 })
220 })
221 .collect();
222
223 struct ImplDetails {
224 thunk_trait_path: &'static str,
225 fn_trait_path: &'static str,
226 const_name: &'static str,
227 call_ident: &'static str,
228 call_receiver: pm2::TokenStream,
229 thunk_body: pm2::TokenStream,
230 factory_ident: &'static str,
231 factory_bound: &'static str,
232 }
233
234 let impl_blocks: [ImplDetails; 3] = [
235 ImplDetails {
236 thunk_trait_path: "traits::FnOnceThunk",
237 fn_trait_path: "::core::ops::FnOnce",
238 const_name: "THUNK_TEMPLATE_ONCE",
239 call_ident: "call_once",
240 factory_ident: "make_once_thunk",
241 factory_bound: "traits::PackedFnOnce",
242 call_receiver: syn::Token).into_token_stream(),
243 thunk_body: quote! {
244 if const { ::core::mem::size_of::<#f_ident>() == 0 } {
245 let closure: #f_ident = unsafe { ::core::mem::zeroed() };
246 closure(#(#arg_idents),*)
247 }
248 else {
249 let closure_ptr: *mut #f_ident;
250 #crate_path::arch::_thunk_asm!(closure_ptr);
251 #crate_path::arch::_invoke(|| closure_ptr.read()(#(#arg_idents),*))
252 }
253 },
254 },
255 ImplDetails {
256 thunk_trait_path: "traits::FnMutThunk",
257 fn_trait_path: "::core::ops::FnMut",
258 const_name: "THUNK_TEMPLATE_MUT",
259 call_ident: "call_mut",
260 factory_ident: "make_mut_thunk",
261 factory_bound: "traits::PackedFnMut",
262 call_receiver: quote! { &mut self },
263 thunk_body: quote! {
264 if const { ::core::mem::size_of::<#f_ident>() == 0 } {
265 let closure: &mut #f_ident = unsafe { &mut *::core::ptr::dangling_mut() };
266 closure(#(#arg_idents),*)
267 }
268 else {
269 let closure_ptr: *mut #f_ident;
270 #crate_path::arch::_thunk_asm!(closure_ptr);
271 #crate_path::arch::_invoke(|| (&mut *closure_ptr)(#(#arg_idents),*))
272 }
273 },
274 },
275 ImplDetails {
276 thunk_trait_path: "traits::FnThunk",
277 fn_trait_path: "::core::ops::Fn",
278 const_name: "THUNK_TEMPLATE",
279 call_ident: "call",
280 factory_ident: "make_thunk",
281 factory_bound: "traits::PackedFn",
282 call_receiver: quote! { &self },
283 thunk_body: quote! {
284 if const { ::core::mem::size_of::<#f_ident>() == 0 } {
285 let closure: &#f_ident = unsafe { &*::core::ptr::dangling() };
286 closure(#(#arg_idents),*)
287 }
288 else {
289 let closure_ptr: *const #f_ident;
290 #crate_path::arch::_thunk_asm!(closure_ptr);
291 #crate_path::arch::_invoke(|| (&*closure_ptr)(#(#arg_idents),*))
292 }
293 },
294 },
295 ];
296
297 struct ImplTokens {
298 thunk_impl: pm2::TokenStream,
299 thunk_factory_impl: pm2::TokenStream,
300 }
301
302 let impl_tokens: Vec<_> = impl_blocks.iter().map(|impl_block| {
303 let fn_bound =
304 bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
305
306 let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
307 let call_ident = syn::Ident::new(impl_block.call_ident, pm2::Span::call_site());
308 let call_receiver = &impl_block.call_receiver;
309 let body = &impl_block.thunk_body;
310 let mut thunk_trait = input.crate_path.clone();
311 thunk_trait.segments.extend(path_from_str(impl_block.thunk_trait_path).segments);
312
313 let mut generics = input.alias.generics.clone();
314 generics.params.push(syn::GenericParam::Type(syn::TypeParam {
315 attrs: Default::default(),
316 ident: f_ident.clone(),
317 colon_token: Some(syn::Token)),
318 bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
319 eq_token: None,
320 default: None,
321 }));
322
323 let mut thunk_sig = thunk_sig.clone();
324 thunk_sig.generics.params.extend(generics.params.clone());
325
326 let (impl_generics, _, where_clause) = generics.split_for_impl();
327 let sig_tys = generics.type_params().map(|t| &t.ident);
328
329 let thunk_impl = quote! {
330 unsafe impl #impl_generics #thunk_trait<#alias_ident #alias_ty_params>
331 for (#cc_marker_ident, #f_ident) #where_clause
332 {
333 const #const_ident: *const ::core::primitive::u8 = {
334 #(#thunk_attrs)*
335 #[allow(clippy::too_many_arguments)]
336 #thunk_sig {
337 #body
338 }
339 #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
340 };
341
342 #[allow(unused_variables)]
343 #[inline(always)]
344 unsafe fn #call_ident<'_x, '_y, '_z>(#call_receiver, args: <#alias_ident #alias_ty_params as #crate_path::traits::FnPtr>::Args<'_x, '_y, '_z>
345 ) ->
346 <#alias_ident #alias_ty_params as #crate_path::traits::FnPtr>::Ret<'_x, '_y, '_z>
347 {
348 (self.1)(#(args.#arg_indices,)*)
349 }
350 }
351 };
352
353 let factory_bound_path = path_from_str(impl_block.factory_bound);
354 let factory_ident = syn::Ident::new(impl_block.factory_ident, pm2::Span::call_site());
355 let thunk_factory_impl = quote! {
356 #[inline(always)]
357 #[allow(unused_mut)]
358 fn #factory_ident<F>(mut fun: F) -> impl #thunk_trait<Self>
359 where
360 F: for<'_x, '_y, '_z> #crate_path::#factory_bound_path<'_x, '_y, '_z, Self>
361 {
362 #[inline(always)]
363 fn coerce #impl_generics (fun: #f_ident) -> #f_ident #where_clause {
364 fun
365 }
366
367 let coerced = coerce(move |#(#arg_idents,)*| fun((#(#arg_idents,)*)));
368 (#cc_marker_ident, coerced)
369 }
370 };
371
372 ImplTokens {
373 thunk_impl,
374 thunk_factory_impl
375 }
376 }).collect();
377
378 let alias_ident_lit = syn::LitStr::new(&alias_ident.to_string(), pm2::Span::call_site());
379 let alias_ident_doc_lit =
380 syn::LitStr::new(&format!("[`{alias_ident}`]."), pm2::Span::call_site());
381
382 let tuple_args = bare_fn.inputs.iter().map(|i| &i.ty);
383 let bare_fn_output = match &bare_fn.output {
384 syn::ReturnType::Default => &syn::Type::Tuple(syn::TypeTuple {
385 paren_token: syn::token::Paren(pm2::Span::call_site()),
386 elems: syn::punctuated::Punctuated::new(),
387 }),
388 syn::ReturnType::Type(_, ty) => ty,
389 };
390
391 let trait_impls = impl_tokens.iter().map(|t| &t.thunk_impl);
392 let thunk_factory_impls = impl_tokens.iter().map(|t| &t.thunk_factory_impl);
393
394 quote! {
395 #[doc = #alias_ident_doc_lit]
397 #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy, ::core::default::Default)]
398 #alias_vis struct #cc_marker_ident;
399
400 #(#alias_attrs)*
401 #[repr(transparent)]
402 #alias_vis struct #alias_ident #alias_gen (pub #bare_fn) #alias_where;
403
404 impl #alias_impl_gen #alias_ident #alias_ty_params #alias_where {
405 pub fn cc() -> #cc_marker_ident {
407 #cc_marker_ident::default()
408 }
409 }
410
411 impl #alias_impl_gen ::core::clone::Clone for #alias_ident #alias_ty_params #alias_where {
412 fn clone(&self) -> Self {
413 Self(self.0)
414 }
415 }
416
417 impl #alias_impl_gen ::core::marker::Copy for #alias_ident #alias_ty_params #alias_where {}
418
419 impl #alias_impl_gen ::core::fmt::Debug for #alias_ident #alias_ty_params #alias_where {
420 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
421 f.debug_tuple(#alias_ident_lit)
422 .field(&self.0)
423 .finish()
424 }
425 }
426
427 impl #alias_impl_gen ::core::convert::From<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
428 fn from(value: #bare_fn) -> Self {
429 Self(value)
430 }
431 }
432
433 impl #alias_impl_gen ::core::convert::From<#alias_ident #alias_ty_params> for #bare_fn #alias_where {
434 fn from(value: #alias_ident #alias_ty_params) -> Self {
435 value.0
436 }
437 }
438
439 impl #alias_impl_gen ::core::ops::Deref for #alias_ident #alias_ty_params #alias_where {
440 type Target = #bare_fn;
441
442 fn deref(&self) -> &Self::Target {
443 &self.0
444 }
445 }
446
447 unsafe impl #alias_impl_gen #crate_path::traits::FnPtr for #alias_ident #alias_ty_params #alias_where {
448 type CC = #cc_marker_ident;
449 type Args<#impl_lifetimes> = (#(#tuple_args,)*);
450 type Ret<#impl_lifetimes> = #bare_fn_output;
451
452 #[inline(always)]
453 unsafe fn call<#impl_lifetimes>(
454 self,
455 args: Self::Args<#impl_lifetimes>
456 ) -> Self::Ret<#impl_lifetimes>
457 {
458 (self.0)(#(args.#arg_indices,)*)
459 }
460
461 #[inline(always)]
462 unsafe fn from_ptr(ptr: *const ()) -> Self {
463 unsafe { core::mem::transmute_copy(&ptr) }
464 }
465
466 #[inline(always)]
467 fn to_ptr(self) -> *const () {
468 self.0 as *const _
469 }
470
471 #(#thunk_factory_impls)*
472 }
473
474 #(#trait_impls)*
475 }
476 .into()
477}