1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2 as pm2;
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, visit::Visit, visit_mut::VisitMut};
7
8struct MacroInput {
9 crate_path: syn::Path,
10 alias: syn::ItemType,
11 bare_fn: syn::TypeBareFn,
12}
13
14impl syn::parse::Parse for MacroInput {
15 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16 let crate_path = input.parse()?;
17 let _: syn::Token![,] = input.parse()?;
18 let alias: syn::ItemType = input.parse()?;
19
20 match &*alias.ty {
21 syn::Type::BareFn(bare_fn) => {
22 if let Some(lt) = &bare_fn.lifetimes {
23 if lt.lifetimes.len() > 3 {
24 return Err(syn::Error::new_spanned(
25 <.lifetimes,
26 "At most 3 higher-ranked lifetimes are supported",
27 ));
28 }
29 }
30
31 struct HasImplicitBoundLt(Vec<pm2::Span>);
33 impl<'a> Visit<'a> for HasImplicitBoundLt {
34 fn visit_lifetime(&mut self, i: &'a syn::Lifetime) {
35 if i.ident == "_" {
36 self.0.push(i.span());
37 }
38 }
39
40 fn visit_type_reference(&mut self, i: &'a syn::TypeReference) {
41 match i.lifetime {
42 Some(_) => self.visit_type(&i.elem),
43 None => self.0.push(i.and_token.span),
44 }
45 }
46 }
47 let mut implicit_lt_check = HasImplicitBoundLt(Vec::default());
48 implicit_lt_check.visit_type_bare_fn(bare_fn);
49
50 let mut implicit_lt_err = None;
51 for err_span in implicit_lt_check.0 {
52 let err = syn::Error::new(
53 err_span,
54 "Implicit lifetimes are not permitted; you must name this lifetime",
55 );
56 match implicit_lt_err.as_mut() {
57 None => implicit_lt_err = Some(err),
58 Some(e) => e.combine(err),
59 }
60 }
61 match implicit_lt_err {
62 Some(err) => Err(err),
63 None => Ok(Self {
64 crate_path,
65 bare_fn: bare_fn.clone(),
66 alias,
67 }),
68 }
69 }
70 other => Err(syn::Error::new_spanned(
71 other,
72 format!(
73 "Expected bare function type, got {}",
74 other.to_token_stream()
75 ),
76 )),
77 }
78 }
79}
80
81fn bare_fn_to_trait_bound(fun: &syn::TypeBareFn, mut path: syn::Path) -> syn::TraitBound {
82 let fn_part = path.segments.last_mut().unwrap();
83 fn_part.arguments = syn::PathArguments::Parenthesized(syn::ParenthesizedGenericArguments {
84 paren_token: Default::default(),
85 inputs: fun.inputs.iter().map(|arg| arg.ty.clone()).collect(),
86 output: fun.output.clone(),
87 });
88
89 syn::TraitBound {
90 paren_token: None,
91 modifier: syn::TraitBoundModifier::None,
92 lifetimes: fun.lifetimes.clone(),
93 path,
94 }
95}
96
97fn bare_fn_to_sig(
98 fun: &syn::TypeBareFn,
99 ident: syn::Ident,
100 arg_idents: &[syn::Ident],
101) -> syn::Signature {
102 syn::Signature {
103 constness: None,
104 asyncness: None,
105 unsafety: fun.unsafety,
106 abi: fun.abi.clone(),
107 fn_token: syn::Token),
108 ident,
109 generics: syn::Generics {
110 lt_token: fun.lifetimes.as_ref().map(|lt| lt.lt_token),
111 params: fun.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default(),
112 gt_token: fun.lifetimes.as_ref().map(|lt| lt.gt_token),
113 where_clause: None,
114 },
115 paren_token: syn::token::Paren::default(),
116 inputs: fun
117 .inputs
118 .iter()
119 .enumerate()
120 .map(|(i, input)| {
121 syn::FnArg::Typed(syn::PatType {
122 attrs: Default::default(),
123 pat: Box::new(syn::Pat::Ident(syn::PatIdent {
124 attrs: Default::default(),
125 by_ref: None,
126 mutability: None,
127 ident: arg_idents[i].clone(),
128 subpat: None,
129 })),
130 colon_token: syn::Token),
131 ty: Box::new(input.ty.clone()),
132 })
133 })
134 .collect(),
135 variadic: None,
136 output: fun.output.clone(),
137 }
138}
139
140fn path_from_str(str: &str) -> syn::Path {
141 syn::parse(TokenStream::from_str(str).unwrap()).unwrap()
142}
143
144struct ReplaceLt<F: FnMut(&mut syn::Lifetime)>(F);
145
146impl<F: FnMut(&mut syn::Lifetime)> syn::visit_mut::VisitMut for ReplaceLt<F> {
147 fn visit_lifetime_mut(&mut self, i: &mut syn::Lifetime) {
148 self.0(i)
149 }
150}
151
152#[proc_macro]
153pub fn bare_hrtb(tokens: TokenStream) -> TokenStream {
154 let mut input = parse_macro_input!(tokens as MacroInput);
155
156 input
157 .bare_fn
158 .unsafety
159 .get_or_insert(syn::Token));
160
161 let bare_fn = &input.bare_fn;
162
163 let thunk_ident = syn::Ident::new("thunk", pm2::Span::call_site());
164 let arg_idents: Vec<_> = (0..input.bare_fn.inputs.len())
165 .map(|i| syn::Ident::new(&format!("a{i}"), pm2::Span::call_site()))
166 .collect();
167
168 let mut thunk_sig = bare_fn_to_sig(bare_fn, thunk_ident.clone(), &arg_idents);
169
170 let bare_fn_lt_idents = bare_fn
171 .lifetimes
172 .as_ref()
173 .map(|lt| {
174 lt.lifetimes
175 .iter()
176 .map(|p| match p {
177 syn::GenericParam::Lifetime(lt) => lt.lifetime.ident.to_string(),
178 _ => unreachable!(),
179 })
180 .collect::<Vec<_>>()
181 })
182 .unwrap_or_default();
183
184 ReplaceLt(|lt| {
185 if let Some(for_ident) = bare_fn_lt_idents.iter().find(|&l| l == <.ident.to_string()) {
186 lt.ident = syn::Ident::new(&format!("for_{for_ident}"), pm2::Span::call_site())
187 }
188 })
189 .visit_signature_mut(&mut thunk_sig);
190
191 let f_ident = syn::Ident::new("_F", pm2::Span::call_site());
192 let cc_marker_ident = syn::Ident::new(
193 &format!("{}_CC", &input.alias.ident),
194 pm2::Span::call_site(),
195 );
196 let crate_path = &input.crate_path;
197
198 struct ImplDetails {
199 thunk_trait_path: &'static str,
200 fn_trait_path: &'static str,
201 const_name: &'static str,
202 body: pm2::TokenStream,
203 }
204
205 let impl_blocks: [ImplDetails; 3] = [
206 ImplDetails {
207 thunk_trait_path: "traits::FnOnceThunk",
208 fn_trait_path: "::core::ops::FnOnce",
209 const_name: "THUNK_TEMPLATE_ONCE",
210 body: quote! {
211 let closure_ptr: *mut #f_ident;
212 #crate_path::arch::_thunk_asm!(closure_ptr);
213 #crate_path::arch::_never_inline(|| closure_ptr.read()(#(#arg_idents),*))
214 },
215 },
216 ImplDetails {
217 thunk_trait_path: "traits::FnMutThunk",
218 fn_trait_path: "::core::ops::FnMut",
219 const_name: "THUNK_TEMPLATE_MUT",
220 body: quote! {
221 let closure_ptr: *mut #f_ident;
222 #crate_path::arch::_thunk_asm!(closure_ptr);
223 #crate_path::arch::_never_inline(|| (&mut *closure_ptr)(#(#arg_idents),*))
224 },
225 },
226 ImplDetails {
227 thunk_trait_path: "traits::FnThunk",
228 fn_trait_path: "::core::ops::Fn",
229 const_name: "THUNK_TEMPLATE",
230 body: quote! {
231 let closure_ptr: *const #f_ident;
232 #crate_path::arch::_thunk_asm!(closure_ptr);
233 #crate_path::arch::_never_inline(|| (&*closure_ptr)(#(#arg_idents),*))
234 },
235 },
236 ];
237
238 let alias_ident = &input.alias.ident;
239 let alias_attrs = &input.alias.attrs;
240 let alias_vis = &input.alias.vis;
241 let alias_gen = &input.alias.generics;
242 let (alias_impl_gen, alias_ty_params, alias_where) = &input.alias.generics.split_for_impl();
243
244 let impls = impl_blocks.iter().map(|impl_block| {
245 let fn_bound =
246 bare_fn_to_trait_bound(&input.bare_fn, path_from_str(impl_block.fn_trait_path));
247 let const_ident = syn::Ident::new(impl_block.const_name, pm2::Span::call_site());
248 let body = &impl_block.body;
249 let mut thunk_trait = input.crate_path.clone();
250 thunk_trait.segments.extend(path_from_str(impl_block.thunk_trait_path).segments);
251
252 let mut generics = input.alias.generics.clone();
253 generics.params.push(syn::GenericParam::Type(syn::TypeParam {
254 attrs: Default::default(),
255 ident: f_ident.clone(),
256 colon_token: Some(syn::Token)),
257 bounds: [syn::TypeParamBound::Trait(fn_bound)].into_iter().collect(),
258 eq_token: None,
259 default: None,
260 }));
261
262 let mut thunk_sig = thunk_sig.clone();
263 thunk_sig.generics.params.extend(generics.params.clone());
264
265 let (impl_generics, _, where_clause) = generics.split_for_impl();
266 let sig_tys = generics.type_params().map(|t| &t.ident);
267
268 quote! {
269 unsafe impl #impl_generics #thunk_trait<#alias_ident #alias_ty_params>
270 for (#cc_marker_ident, #f_ident) #where_clause
271 {
272 const #const_ident: *const ::core::primitive::u8 = {
273 #thunk_sig {
274 #body
275 }
276 #thunk_ident::<#(#sig_tys),*> as *const ::core::primitive::u8
277 };
278 }
279 }
280 });
281
282 let alias_ident_lit = syn::LitStr::new(&alias_ident.to_string(), pm2::Span::call_site());
283 let alias_ident_doc_lit =
284 syn::LitStr::new(&format!("[`{alias_ident}`]."), pm2::Span::call_site());
285
286 let mut punc_impl_lifetimes =
287 bare_fn.lifetimes.as_ref().map(|lt| lt.lifetimes.clone()).unwrap_or_default();
288 punc_impl_lifetimes.extend((punc_impl_lifetimes.len()..3).map(|i| {
289 syn::GenericParam::Lifetime(syn::LifetimeParam::new(syn::Lifetime::new(
290 &format!("'_extra_{i}"),
291 pm2::Span::call_site(),
292 )))
293 }));
294 let impl_lifetimes: Vec<_> = punc_impl_lifetimes.iter().collect();
295
296 let tuple_args = bare_fn.inputs.iter().map(|i| &i.ty);
297 let bare_fn_output = match &bare_fn.output {
298 syn::ReturnType::Default => &syn::Type::Tuple(syn::TypeTuple {
299 paren_token: syn::token::Paren(pm2::Span::call_site()),
300 elems: syn::punctuated::Punctuated::new(),
301 }),
302 syn::ReturnType::Type(_, ty) => ty,
303 };
304 let arg_indices = (0..bare_fn.inputs.len() as u32).map(|index| {
305 syn::Member::Unnamed(syn::Index {
306 index,
307 span: pm2::Span::call_site(),
308 })
309 });
310
311 quote! {
312 #[doc = #alias_ident_doc_lit]
314 #[derive(::core::fmt::Debug, ::core::clone::Clone, ::core::marker::Copy, ::core::default::Default)]
315 #alias_vis struct #cc_marker_ident;
316
317 #(#alias_attrs)*
318 #[repr(transparent)]
319 #alias_vis struct #alias_ident #alias_gen (pub #bare_fn) #alias_where;
320
321 impl #alias_impl_gen #alias_ident #alias_ty_params #alias_where {
322 pub fn cc() -> #cc_marker_ident {
324 #cc_marker_ident::default()
325 }
326 }
327
328 impl #alias_impl_gen ::core::clone::Clone for #alias_ident #alias_ty_params #alias_where {
329 fn clone(&self) -> Self {
330 Self(self.0)
331 }
332 }
333
334 impl #alias_impl_gen ::core::marker::Copy for #alias_ident #alias_ty_params #alias_where {}
335
336 impl #alias_impl_gen ::core::fmt::Debug for #alias_ident #alias_ty_params #alias_where {
337 fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
338 f.debug_tuple(#alias_ident_lit)
339 .field(&self.0)
340 .finish()
341 }
342 }
343
344 impl #alias_impl_gen ::core::convert::From<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
345 fn from(value: #bare_fn) -> Self {
346 Self(value)
347 }
348 }
349
350 impl #alias_impl_gen ::core::convert::Into<#bare_fn> for #alias_ident #alias_ty_params #alias_where {
351 fn into(self) -> #bare_fn {
352 self.0
353 }
354 }
355
356 impl #alias_impl_gen ::core::ops::Deref for #alias_ident #alias_ty_params #alias_where {
357 type Target = #bare_fn;
358
359 fn deref(&self) -> &Self::Target {
360 &self.0
361 }
362 }
363
364 unsafe impl #alias_impl_gen #crate_path::traits::FnPtr for #alias_ident #alias_ty_params #alias_where {
365 type CC = #cc_marker_ident;
366 type Args<#punc_impl_lifetimes> = (#(#tuple_args,)*) where Self: #(#impl_lifetimes)+*;
367 type Ret<#punc_impl_lifetimes> = #bare_fn_output where Self: #(#impl_lifetimes)+*;
368
369 #[inline(always)]
370 unsafe fn call<#punc_impl_lifetimes>(
371 self,
372 args: Self::Args<#punc_impl_lifetimes>
373 ) -> Self::Ret<#punc_impl_lifetimes>
374 where Self: #(#impl_lifetimes)+*
375 {
376 (self.0)(#(args.#arg_indices,)*)
377 }
378
379 #[inline(always)]
380 unsafe fn from_ptr(ptr: *const ()) -> Self {
381 unsafe { core::mem::transmute_copy(&ptr) }
382 }
383
384 #[inline(always)]
385 fn to_ptr(self) -> *const () {
386 self.0 as *const _
387 }
388 }
389
390 #(#impls)*
391 }
392 .into()
393}