1#![deny(unsafe_code)]
59
60extern crate proc_macro;
61
62use proc_macro::TokenStream;
63use proc_macro2::{Span, TokenStream as TokenStream2};
64use quote::{format_ident, quote};
65use syn::{
66 parse::{Parse, ParseStream},
67 parse_macro_input,
68 punctuated::Punctuated,
69 token::Comma,
70 ConstParam, FnArg, GenericParam, Ident, ItemTrait, LitInt, Pat, ReturnType, Token, TraitItem,
71 TraitItemFn, Type, Visibility,
72};
73
74struct ReifiableArgs {
80 range_start: u64,
81 range_end: u64,
82}
83
84impl Parse for ReifiableArgs {
85 fn parse(input: ParseStream) -> syn::Result<Self> {
86 let ident: Ident = input.parse()?;
88 if ident != "range" {
89 return Err(syn::Error::new(ident.span(), "expected `range`"));
90 }
91 let _eq: Token![=] = input.parse()?;
92 let start: LitInt = input.parse()?;
93 let _dots: Token![..] = input.parse()?;
94 let _eq2: Token![=] = input.parse()?;
95 let end: LitInt = input.parse()?;
96
97 Ok(ReifiableArgs {
98 range_start: start.base10_parse()?,
99 range_end: end.base10_parse()?,
100 })
101 }
102}
103
104struct ConstMethod {
110 name: Ident,
112 _const_param_name: Ident,
114 const_param_ty: Type,
115 is_mut: bool,
117 params: Vec<(Ident, Type)>,
119 lifetime_params: Vec<syn::LifetimeParam>,
121 type_params: Vec<syn::TypeParam>,
123 return_type: ReturnType,
125}
126
127fn type_mentions_ident(ty: &Type, ident: &Ident) -> bool {
129 let ty_str = quote!(#ty).to_string();
130 let ident_str = ident.to_string();
131 ty_str
135 .split(|c: char| !c.is_alphanumeric() && c != '_')
136 .any(|word| word == ident_str)
137}
138
139fn analyze_method(method: &TraitItemFn) -> Option<Result<ConstMethod, syn::Error>> {
140 let const_params: Vec<&ConstParam> = method
142 .sig
143 .generics
144 .params
145 .iter()
146 .filter_map(|p| match p {
147 GenericParam::Const(cp) => Some(cp),
148 _ => None,
149 })
150 .collect();
151
152 if const_params.is_empty() {
153 return None; }
155
156 if const_params.len() > 1 {
157 return Some(Err(syn::Error::new_spanned(
158 &method.sig,
159 "#[reifiable] V1 only supports a single const generic parameter per method",
160 )));
161 }
162
163 let cp = const_params[0];
164
165 let receiver = method.sig.receiver();
167 let is_mut = match receiver {
168 Some(r) => r.mutability.is_some(),
169 None => {
170 return Some(Err(syn::Error::new_spanned(
171 &method.sig,
172 "#[reifiable] requires methods with &self or &mut self receiver",
173 )));
174 }
175 };
176
177 if let ReturnType::Type(_, ref ty) = method.sig.output {
179 if type_mentions_ident(ty, &cp.ident) {
180 return Some(Err(syn::Error::new_spanned(
181 ty,
182 format!(
183 "#[reifiable] V1 does not support return types that depend on \
184 the const parameter `{}`. Use NatCallback manually for this case.",
185 cp.ident
186 ),
187 )));
188 }
189 }
190
191 let params: Vec<(Ident, Type)> = method
193 .sig
194 .inputs
195 .iter()
196 .filter_map(|arg| match arg {
197 FnArg::Typed(pat_type) => {
198 let name = match pat_type.pat.as_ref() {
199 Pat::Ident(pi) => pi.ident.clone(),
200 _ => Ident::new("_arg", Span::call_site()),
201 };
202 Some((name, (*pat_type.ty).clone()))
203 }
204 FnArg::Receiver(_) => None,
205 })
206 .collect();
207
208 let lifetime_params: Vec<syn::LifetimeParam> = method
210 .sig
211 .generics
212 .params
213 .iter()
214 .filter_map(|p| match p {
215 GenericParam::Lifetime(lp) => Some(lp.clone()),
216 _ => None,
217 })
218 .collect();
219
220 let type_params: Vec<syn::TypeParam> = method
221 .sig
222 .generics
223 .params
224 .iter()
225 .filter_map(|p| match p {
226 GenericParam::Type(tp) => Some(tp.clone()),
227 _ => None,
228 })
229 .collect();
230
231 Some(Ok(ConstMethod {
232 name: method.sig.ident.clone(),
233 _const_param_name: cp.ident.clone(),
234 const_param_ty: cp.ty.clone(),
235 is_mut,
236 params,
237 lifetime_params,
238 type_params,
239 return_type: method.sig.output.clone(),
240 }))
241}
242
243fn generate_dispatch_fn(
248 trait_name: &Ident,
249 trait_generics: &syn::Generics,
250 trait_vis: &Visibility,
251 method: &ConstMethod,
252 range_start: u64,
253 range_end: u64,
254) -> TokenStream2 {
255 let fn_name = format_ident!("reify_{}", method.name);
256 let method_name = &method.name;
257 let const_ty = &method.const_param_ty;
258 let return_type = &method.return_type;
259
260 let range_lits: Vec<LitInt> = (range_start..=range_end)
262 .map(|n| LitInt::new(&n.to_string(), Span::call_site()))
263 .collect();
264
265 let param_names: Vec<&Ident> = method.params.iter().map(|(n, _)| n).collect();
267 let _param_types: Vec<&Type> = method.params.iter().map(|(_, t)| t).collect();
268 let param_decls: Vec<TokenStream2> =
269 method.params.iter().map(|(n, t)| quote!(#n: #t)).collect();
270
271 let _trait_generic_params = &trait_generics.params;
273 let _trait_where_clause = &trait_generics.where_clause;
274
275 let trait_generic_args: Punctuated<TokenStream2, Comma> = trait_generics
277 .params
278 .iter()
279 .map(|p| match p {
280 GenericParam::Type(tp) => {
281 let ident = &tp.ident;
282 quote!(#ident)
283 }
284 GenericParam::Lifetime(lp) => {
285 let lt = &lp.lifetime;
286 quote!(#lt)
287 }
288 GenericParam::Const(cp) => {
289 let ident = &cp.ident;
290 quote!(#ident)
291 }
292 })
293 .collect();
294
295 let trait_bound = if trait_generic_args.is_empty() {
296 quote!(#trait_name)
297 } else {
298 quote!(#trait_name<#trait_generic_args>)
299 };
300
301 let method_lifetime_params: Vec<TokenStream2> = method
303 .lifetime_params
304 .iter()
305 .map(|lp| quote!(#lp))
306 .collect();
307 let method_type_params: Vec<TokenStream2> =
308 method.type_params.iter().map(|tp| quote!(#tp)).collect();
309 let method_type_args: Vec<TokenStream2> = method
310 .type_params
311 .iter()
312 .map(|tp| {
313 let ident = &tp.ident;
314 quote!(#ident)
315 })
316 .collect();
317
318 let mut all_fn_generics: Vec<TokenStream2> = Vec::new();
320 for lp in &method_lifetime_params {
321 all_fn_generics.push(lp.clone());
322 }
323 for p in trait_generics.params.iter() {
325 all_fn_generics.push(quote!(#p));
326 }
327 for tp in &method_type_params {
328 all_fn_generics.push(tp.clone());
329 }
330 all_fn_generics.push(quote!(__ReifyT: #trait_bound));
331
332 let fn_generics = if all_fn_generics.is_empty() {
333 quote!()
334 } else {
335 quote!(<#(#all_fn_generics),*>)
336 };
337
338 let obj_param = if method.is_mut {
340 quote!(obj: &mut __ReifyT)
341 } else {
342 quote!(obj: &__ReifyT)
343 };
344
345 let match_arms: Vec<TokenStream2> = range_lits
347 .iter()
348 .map(|n| {
349 if method_type_args.is_empty() {
350 quote!(#n => obj.#method_name::<#n>(#(#param_names),*))
351 } else {
352 quote!(#n => obj.#method_name::<#n, #(#method_type_args),*>(#(#param_names),*))
353 }
354 })
355 .collect();
356
357 let range_end_display = range_end;
358
359 quote! {
360 #trait_vis fn #fn_name #fn_generics(
365 val: #const_ty,
366 #obj_param,
367 #(#param_decls),*
368 ) #return_type {
369 match val {
370 #(#match_arms,)*
371 other => panic!(
372 concat!(
373 "#[reifiable] dispatch for ",
374 stringify!(#trait_name),
375 "::",
376 stringify!(#method_name),
377 ": value {} out of range 0..={}",
378 ),
379 other,
380 #range_end_display,
381 ),
382 }
383 }
384 }
385}
386
387fn generate_callback_wrapper(
388 trait_name: &Ident,
389 trait_generics: &syn::Generics,
390 trait_vis: &Visibility,
391 method: &ConstMethod,
392) -> TokenStream2 {
393 let wrapper_name = format_ident!(
394 "{}{}Callback",
395 trait_name,
396 pascal_case(&method.name.to_string())
397 );
398 let method_name = &method.name;
399 let return_type_inner = match &method.return_type {
400 ReturnType::Default => quote!(()),
401 ReturnType::Type(_, ty) => quote!(#ty),
402 };
403
404 let param_names: Vec<&Ident> = method.params.iter().map(|(n, _)| n).collect();
406 let _param_types: Vec<&Type> = method.params.iter().map(|(_, t)| t).collect();
407
408 let trait_generic_params = &trait_generics.params;
410 let trait_generic_args: Punctuated<TokenStream2, Comma> = trait_generics
411 .params
412 .iter()
413 .map(|p| match p {
414 GenericParam::Type(tp) => {
415 let ident = &tp.ident;
416 quote!(#ident)
417 }
418 GenericParam::Lifetime(lp) => {
419 let lt = &lp.lifetime;
420 quote!(#lt)
421 }
422 GenericParam::Const(cp) => {
423 let ident = &cp.ident;
424 quote!(#ident)
425 }
426 })
427 .collect();
428
429 let trait_bound = if trait_generic_args.is_empty() {
430 quote!(#trait_name)
431 } else {
432 quote!(#trait_name<#trait_generic_args>)
433 };
434
435 let has_trait_generics = !trait_generics.params.is_empty();
437
438 let obj_ref = if method.is_mut {
439 return quote!();
441 } else {
442 quote!(&'__reify_a __ReifyT)
443 };
444
445 let struct_fields: Vec<TokenStream2> = std::iter::once(quote! {
446 pub obj: #obj_ref
448 })
449 .chain(method.params.iter().map(|(n, t)| quote!(pub #n: #t)))
450 .collect();
451
452 let struct_generics = if has_trait_generics {
453 quote!(<'__reify_a, #trait_generic_params, __ReifyT: #trait_bound>)
454 } else {
455 quote!(<'__reify_a, __ReifyT: #trait_bound>)
456 };
457
458 let impl_generics = if has_trait_generics {
459 quote!(<#trait_generic_params, __ReifyT: #trait_bound>)
460 } else {
461 quote!(<__ReifyT: #trait_bound>)
462 };
463
464 let method_type_args: Vec<TokenStream2> = method
466 .type_params
467 .iter()
468 .map(|tp| {
469 let ident = &tp.ident;
470 quote!(#ident)
471 })
472 .collect();
473
474 let call_expr = if method_type_args.is_empty() {
475 quote!(self.obj.#method_name::<N>(#(self.#param_names),*))
476 } else {
477 quote!(self.obj.#method_name::<N, #(#method_type_args),*>(#(self.#param_names),*))
478 };
479
480 quote! {
481 #trait_vis struct #wrapper_name #struct_generics {
484 #(#struct_fields,)*
485 }
486
487 impl #impl_generics const_reify::NatCallback<#return_type_inner>
488 for #wrapper_name<'_, #trait_generic_args __ReifyT>
489 {
490 fn call<const N: u64>(&self) -> #return_type_inner {
491 #call_expr
492 }
493 }
494 }
495}
496
497fn pascal_case(s: &str) -> String {
498 let mut result = String::new();
499 let mut capitalize_next = true;
500 for c in s.chars() {
501 if c == '_' {
502 capitalize_next = true;
503 } else if capitalize_next {
504 result.push(c.to_ascii_uppercase());
505 capitalize_next = false;
506 } else {
507 result.push(c);
508 }
509 }
510 result
511}
512
513#[proc_macro_attribute]
541pub fn reifiable(attr: TokenStream, item: TokenStream) -> TokenStream {
542 let args = parse_macro_input!(attr as ReifiableArgs);
543 let trait_def = parse_macro_input!(item as ItemTrait);
544
545 match reifiable_impl(args, &trait_def) {
546 Ok(tokens) => tokens.into(),
547 Err(e) => {
548 let trait_tokens = quote!(#trait_def);
549 let err = e.to_compile_error();
550 TokenStream::from(quote! {
553 #trait_tokens
554 #err
555 })
556 }
557 }
558}
559
560fn reifiable_impl(args: ReifiableArgs, trait_def: &ItemTrait) -> syn::Result<TokenStream2> {
561 let trait_name = &trait_def.ident;
562 let trait_vis = &trait_def.vis;
563 let trait_generics = &trait_def.generics;
564
565 if args.range_end > 1023 {
567 return Err(syn::Error::new(
568 Span::call_site(),
569 format!(
570 "#[reifiable] range 0..={} would generate {} monomorphizations per method. \
571 Maximum is 1024. Use a smaller range.",
572 args.range_end,
573 args.range_end + 1,
574 ),
575 ));
576 }
577
578 let mut dispatch_fns = Vec::new();
579 let mut callback_wrappers = Vec::new();
580
581 for item in &trait_def.items {
582 if let TraitItem::Fn(method) = item {
583 if let Some(result) = analyze_method(method) {
584 let cm = result?;
585
586 dispatch_fns.push(generate_dispatch_fn(
587 trait_name,
588 trait_generics,
589 trait_vis,
590 &cm,
591 args.range_start,
592 args.range_end,
593 ));
594
595 let wrapper = generate_callback_wrapper(trait_name, trait_generics, trait_vis, &cm);
596 if !wrapper.is_empty() {
597 callback_wrappers.push(wrapper);
598 }
599 }
600 }
601 }
602
603 Ok(quote! {
605 #trait_def
606
607 #(#dispatch_fns)*
608
609 #(#callback_wrappers)*
610 })
611}