Skip to main content

const_reify_derive/
lib.rs

1//! `#[reifiable]`: turn a trait with const-generic methods into a runtime
2//! dispatch table.
3//!
4//! [`const-reify`](https://docs.rs/const-reify) gives you a primitive way
5//! to lift a runtime `u64` into a `const N: u64`: implement
6//! [`NatCallback`](https://docs.rs/const-reify/latest/const_reify/trait.NatCallback.html)
7//! and call
8//! [`reify_nat`](https://docs.rs/const-reify/latest/const_reify/fn.reify_nat.html).
9//! That works, but it gets verbose if a single trait has several
10//! const-generic methods: you end up writing one `NatCallback` impl per
11//! method.
12//!
13//! `#[reifiable(range = 0..=255)]` on the trait declaration eliminates
14//! that boilerplate. The macro:
15//!
16//! - Generates a `reify_<method_name>` dispatch function for each
17//!   const-generic method. The function takes a runtime `u64`, picks
18//!   the matching monomorphization, and forwards.
19//! - Generates the `NatCallback` wrapper structs that
20//!   [`reify_nat`](https://docs.rs/const-reify/latest/const_reify/fn.reify_nat.html)
21//!   needs internally.
22//! - Leaves non-const-generic methods alone.
23//!
24//! You then implement the trait normally, and call the generated
25//! `reify_*` dispatch functions from runtime code.
26//!
27//! See [Guide 4: the `#[reifiable]` macro][guide4] for a full worked
28//! example, and [`docs/rfcs/0003-reifiable-proc-macro.md`][rfc] for the
29//! design.
30//!
31//! [guide4]: https://github.com/joshburgess/reify-reflect/blob/main/docs/guides/04-reifiable-macro.md
32//! [rfc]: https://github.com/joshburgess/reify-reflect/blob/main/docs/rfcs/0003-reifiable-proc-macro.md
33//!
34//! # Example
35//!
36//! ```ignore
37//! use const_reify_derive::reifiable;
38//!
39//! #[reifiable(range = 0..=255)]
40//! trait ModArith {
41//!     fn pow_mod<const N: u64>(&self, base: u64, exp: u64) -> u64;
42//!     fn mul_mod<const N: u64>(&self, a: u64, b: u64) -> u64;
43//!     fn name(&self) -> &str;  // not const-generic, left alone
44//! }
45//!
46//! struct FastMod;
47//! impl ModArith for FastMod {
48//!     fn pow_mod<const N: u64>(&self, base: u64, exp: u64) -> u64 { /* ... */ 0 }
49//!     fn mul_mod<const N: u64>(&self, a: u64, b: u64) -> u64 { (a * b) % N }
50//!     fn name(&self) -> &str { "fast" }
51//! }
52//!
53//! // Now `reify_pow_mod` and `reify_mul_mod` are generated dispatchers:
54//! let modulus: u64 = 13;
55//! let result = reify_pow_mod(modulus, &FastMod, 2, 12);
56//! ```
57
58#![deny(unsafe_code)]
59
60extern crate proc_macro;
61
62use proc_macro::TokenStream;
63use proc_macro2::{Span, TokenStream as TokenStream2};
64use quote::{format_ident, quote};
65use syn::{
66    parse::{Parse, ParseStream},
67    parse_macro_input,
68    punctuated::Punctuated,
69    token::Comma,
70    ConstParam, FnArg, GenericParam, Ident, ItemTrait, LitInt, Pat, ReturnType, Token, TraitItem,
71    TraitItemFn, Type, Visibility,
72};
73
74// ---------------------------------------------------------------------------
75// Attribute argument parsing
76// ---------------------------------------------------------------------------
77
78/// Parsed `#[reifiable(range = START..=END)]` arguments.
79struct ReifiableArgs {
80    range_start: u64,
81    range_end: u64,
82}
83
84impl Parse for ReifiableArgs {
85    fn parse(input: ParseStream) -> syn::Result<Self> {
86        // Parse: range = START..=END
87        let ident: Ident = input.parse()?;
88        if ident != "range" {
89            return Err(syn::Error::new(ident.span(), "expected `range`"));
90        }
91        let _eq: Token![=] = input.parse()?;
92        let start: LitInt = input.parse()?;
93        let _dots: Token![..] = input.parse()?;
94        let _eq2: Token![=] = input.parse()?;
95        let end: LitInt = input.parse()?;
96
97        Ok(ReifiableArgs {
98            range_start: start.base10_parse()?,
99            range_end: end.base10_parse()?,
100        })
101    }
102}
103
104// ---------------------------------------------------------------------------
105// Method analysis
106// ---------------------------------------------------------------------------
107
108/// A const-generic method extracted from the trait.
109struct ConstMethod {
110    /// Method name.
111    name: Ident,
112    /// The const generic parameter (name and type).
113    _const_param_name: Ident,
114    const_param_ty: Type,
115    /// Whether the method takes &self or &mut self.
116    is_mut: bool,
117    /// Non-self, non-const-generic parameters: (name, type) pairs.
118    params: Vec<(Ident, Type)>,
119    /// Lifetime parameters on the method.
120    lifetime_params: Vec<syn::LifetimeParam>,
121    /// Type parameters on the method (non-const generics).
122    type_params: Vec<syn::TypeParam>,
123    /// Return type (None = ()).
124    return_type: ReturnType,
125}
126
127/// Check if a return type mentions a given identifier (the const param).
128fn type_mentions_ident(ty: &Type, ident: &Ident) -> bool {
129    let ty_str = quote!(#ty).to_string();
130    let ident_str = ident.to_string();
131    // Simple heuristic: check if the ident appears as a token in the type.
132    // A proper implementation would walk the type AST, but this catches
133    // the common cases like [u8; N] and Foo<N>.
134    ty_str
135        .split(|c: char| !c.is_alphanumeric() && c != '_')
136        .any(|word| word == ident_str)
137}
138
139fn analyze_method(method: &TraitItemFn) -> Option<Result<ConstMethod, syn::Error>> {
140    // Find const generic parameters
141    let const_params: Vec<&ConstParam> = method
142        .sig
143        .generics
144        .params
145        .iter()
146        .filter_map(|p| match p {
147            GenericParam::Const(cp) => Some(cp),
148            _ => None,
149        })
150        .collect();
151
152    if const_params.is_empty() {
153        return None; // Not a const-generic method, skip
154    }
155
156    if const_params.len() > 1 {
157        return Some(Err(syn::Error::new_spanned(
158            &method.sig,
159            "#[reifiable] V1 only supports a single const generic parameter per method",
160        )));
161    }
162
163    let cp = const_params[0];
164
165    // Check receiver
166    let receiver = method.sig.receiver();
167    let is_mut = match receiver {
168        Some(r) => r.mutability.is_some(),
169        None => {
170            return Some(Err(syn::Error::new_spanned(
171                &method.sig,
172                "#[reifiable] requires methods with &self or &mut self receiver",
173            )));
174        }
175    };
176
177    // Check return type doesn't depend on N
178    if let ReturnType::Type(_, ref ty) = method.sig.output {
179        if type_mentions_ident(ty, &cp.ident) {
180            return Some(Err(syn::Error::new_spanned(
181                ty,
182                format!(
183                    "#[reifiable] V1 does not support return types that depend on \
184                     the const parameter `{}`. Use NatCallback manually for this case.",
185                    cp.ident
186                ),
187            )));
188        }
189    }
190
191    // Extract non-self parameters
192    let params: Vec<(Ident, Type)> = method
193        .sig
194        .inputs
195        .iter()
196        .filter_map(|arg| match arg {
197            FnArg::Typed(pat_type) => {
198                let name = match pat_type.pat.as_ref() {
199                    Pat::Ident(pi) => pi.ident.clone(),
200                    _ => Ident::new("_arg", Span::call_site()),
201                };
202                Some((name, (*pat_type.ty).clone()))
203            }
204            FnArg::Receiver(_) => None,
205        })
206        .collect();
207
208    // Extract lifetime and type params (non-const)
209    let lifetime_params: Vec<syn::LifetimeParam> = method
210        .sig
211        .generics
212        .params
213        .iter()
214        .filter_map(|p| match p {
215            GenericParam::Lifetime(lp) => Some(lp.clone()),
216            _ => None,
217        })
218        .collect();
219
220    let type_params: Vec<syn::TypeParam> = method
221        .sig
222        .generics
223        .params
224        .iter()
225        .filter_map(|p| match p {
226            GenericParam::Type(tp) => Some(tp.clone()),
227            _ => None,
228        })
229        .collect();
230
231    Some(Ok(ConstMethod {
232        name: method.sig.ident.clone(),
233        _const_param_name: cp.ident.clone(),
234        const_param_ty: cp.ty.clone(),
235        is_mut,
236        params,
237        lifetime_params,
238        type_params,
239        return_type: method.sig.output.clone(),
240    }))
241}
242
243// ---------------------------------------------------------------------------
244// Code generation
245// ---------------------------------------------------------------------------
246
247fn generate_dispatch_fn(
248    trait_name: &Ident,
249    trait_generics: &syn::Generics,
250    trait_vis: &Visibility,
251    method: &ConstMethod,
252    range_start: u64,
253    range_end: u64,
254) -> TokenStream2 {
255    let fn_name = format_ident!("reify_{}", method.name);
256    let method_name = &method.name;
257    let const_ty = &method.const_param_ty;
258    let return_type = &method.return_type;
259
260    // Build range literals
261    let range_lits: Vec<LitInt> = (range_start..=range_end)
262        .map(|n| LitInt::new(&n.to_string(), Span::call_site()))
263        .collect();
264
265    // Parameter names and types for the dispatch function signature
266    let param_names: Vec<&Ident> = method.params.iter().map(|(n, _)| n).collect();
267    let _param_types: Vec<&Type> = method.params.iter().map(|(_, t)| t).collect();
268    let param_decls: Vec<TokenStream2> =
269        method.params.iter().map(|(n, t)| quote!(#n: #t)).collect();
270
271    // Trait generic params and where clause
272    let _trait_generic_params = &trait_generics.params;
273    let _trait_where_clause = &trait_generics.where_clause;
274
275    // Build the trait bound: T: TraitName<GenericArgs>
276    let trait_generic_args: Punctuated<TokenStream2, Comma> = trait_generics
277        .params
278        .iter()
279        .map(|p| match p {
280            GenericParam::Type(tp) => {
281                let ident = &tp.ident;
282                quote!(#ident)
283            }
284            GenericParam::Lifetime(lp) => {
285                let lt = &lp.lifetime;
286                quote!(#lt)
287            }
288            GenericParam::Const(cp) => {
289                let ident = &cp.ident;
290                quote!(#ident)
291            }
292        })
293        .collect();
294
295    let trait_bound = if trait_generic_args.is_empty() {
296        quote!(#trait_name)
297    } else {
298        quote!(#trait_name<#trait_generic_args>)
299    };
300
301    // Method lifetime and type params
302    let method_lifetime_params: Vec<TokenStream2> = method
303        .lifetime_params
304        .iter()
305        .map(|lp| quote!(#lp))
306        .collect();
307    let method_type_params: Vec<TokenStream2> =
308        method.type_params.iter().map(|tp| quote!(#tp)).collect();
309    let method_type_args: Vec<TokenStream2> = method
310        .type_params
311        .iter()
312        .map(|tp| {
313            let ident = &tp.ident;
314            quote!(#ident)
315        })
316        .collect();
317
318    // All generic params for the dispatch function
319    let mut all_fn_generics: Vec<TokenStream2> = Vec::new();
320    for lp in &method_lifetime_params {
321        all_fn_generics.push(lp.clone());
322    }
323    // Trait's own generics
324    for p in trait_generics.params.iter() {
325        all_fn_generics.push(quote!(#p));
326    }
327    for tp in &method_type_params {
328        all_fn_generics.push(tp.clone());
329    }
330    all_fn_generics.push(quote!(__ReifyT: #trait_bound));
331
332    let fn_generics = if all_fn_generics.is_empty() {
333        quote!()
334    } else {
335        quote!(<#(#all_fn_generics),*>)
336    };
337
338    // Self receiver
339    let obj_param = if method.is_mut {
340        quote!(obj: &mut __ReifyT)
341    } else {
342        quote!(obj: &__ReifyT)
343    };
344
345    // Match arms — each calls obj.method::<N>(args...) with optional type args
346    let match_arms: Vec<TokenStream2> = range_lits
347        .iter()
348        .map(|n| {
349            if method_type_args.is_empty() {
350                quote!(#n => obj.#method_name::<#n>(#(#param_names),*))
351            } else {
352                quote!(#n => obj.#method_name::<#n, #(#method_type_args),*>(#(#param_names),*))
353            }
354        })
355        .collect();
356
357    let range_end_display = range_end;
358
359    quote! {
360        /// Auto-generated dispatch function for [`#trait_name::#method_name`].
361        ///
362        /// Dispatches a runtime `val` to the corresponding const-generic
363        /// instantiation of the method.
364        #trait_vis fn #fn_name #fn_generics(
365            val: #const_ty,
366            #obj_param,
367            #(#param_decls),*
368        ) #return_type {
369            match val {
370                #(#match_arms,)*
371                other => panic!(
372                    concat!(
373                        "#[reifiable] dispatch for ",
374                        stringify!(#trait_name),
375                        "::",
376                        stringify!(#method_name),
377                        ": value {} out of range 0..={}",
378                    ),
379                    other,
380                    #range_end_display,
381                ),
382            }
383        }
384    }
385}
386
387fn generate_callback_wrapper(
388    trait_name: &Ident,
389    trait_generics: &syn::Generics,
390    trait_vis: &Visibility,
391    method: &ConstMethod,
392) -> TokenStream2 {
393    let wrapper_name = format_ident!(
394        "{}{}Callback",
395        trait_name,
396        pascal_case(&method.name.to_string())
397    );
398    let method_name = &method.name;
399    let return_type_inner = match &method.return_type {
400        ReturnType::Default => quote!(()),
401        ReturnType::Type(_, ty) => quote!(#ty),
402    };
403
404    // Fields: obj reference + each parameter
405    let param_names: Vec<&Ident> = method.params.iter().map(|(n, _)| n).collect();
406    let _param_types: Vec<&Type> = method.params.iter().map(|(_, t)| t).collect();
407
408    // Trait generic params
409    let trait_generic_params = &trait_generics.params;
410    let trait_generic_args: Punctuated<TokenStream2, Comma> = trait_generics
411        .params
412        .iter()
413        .map(|p| match p {
414            GenericParam::Type(tp) => {
415                let ident = &tp.ident;
416                quote!(#ident)
417            }
418            GenericParam::Lifetime(lp) => {
419                let lt = &lp.lifetime;
420                quote!(#lt)
421            }
422            GenericParam::Const(cp) => {
423                let ident = &cp.ident;
424                quote!(#ident)
425            }
426        })
427        .collect();
428
429    let trait_bound = if trait_generic_args.is_empty() {
430        quote!(#trait_name)
431    } else {
432        quote!(#trait_name<#trait_generic_args>)
433    };
434
435    // Struct generics include a lifetime, the trait's generics, and T
436    let has_trait_generics = !trait_generics.params.is_empty();
437
438    let obj_ref = if method.is_mut {
439        // Can't have &mut in a NatCallback (call takes &self), so skip wrapper for mut methods
440        return quote!();
441    } else {
442        quote!(&'__reify_a __ReifyT)
443    };
444
445    let struct_fields: Vec<TokenStream2> = std::iter::once(quote! {
446        /// The trait implementor.
447        pub obj: #obj_ref
448    })
449    .chain(method.params.iter().map(|(n, t)| quote!(pub #n: #t)))
450    .collect();
451
452    let struct_generics = if has_trait_generics {
453        quote!(<'__reify_a, #trait_generic_params, __ReifyT: #trait_bound>)
454    } else {
455        quote!(<'__reify_a, __ReifyT: #trait_bound>)
456    };
457
458    let impl_generics = if has_trait_generics {
459        quote!(<#trait_generic_params, __ReifyT: #trait_bound>)
460    } else {
461        quote!(<__ReifyT: #trait_bound>)
462    };
463
464    // Method type params for the call
465    let method_type_args: Vec<TokenStream2> = method
466        .type_params
467        .iter()
468        .map(|tp| {
469            let ident = &tp.ident;
470            quote!(#ident)
471        })
472        .collect();
473
474    let call_expr = if method_type_args.is_empty() {
475        quote!(self.obj.#method_name::<N>(#(self.#param_names),*))
476    } else {
477        quote!(self.obj.#method_name::<N, #(#method_type_args),*>(#(self.#param_names),*))
478    };
479
480    quote! {
481        /// Auto-generated [`const_reify::NatCallback`] wrapper for
482        /// [`#trait_name::#method_name`].
483        #trait_vis struct #wrapper_name #struct_generics {
484            #(#struct_fields,)*
485        }
486
487        impl #impl_generics const_reify::NatCallback<#return_type_inner>
488            for #wrapper_name<'_, #trait_generic_args __ReifyT>
489        {
490            fn call<const N: u64>(&self) -> #return_type_inner {
491                #call_expr
492            }
493        }
494    }
495}
496
497fn pascal_case(s: &str) -> String {
498    let mut result = String::new();
499    let mut capitalize_next = true;
500    for c in s.chars() {
501        if c == '_' {
502            capitalize_next = true;
503        } else if capitalize_next {
504            result.push(c.to_ascii_uppercase());
505            capitalize_next = false;
506        } else {
507            result.push(c);
508        }
509    }
510    result
511}
512
513// ---------------------------------------------------------------------------
514// Entry point
515// ---------------------------------------------------------------------------
516
517/// Generates const-generic dispatch functions for a trait's methods.
518///
519/// Annotate a trait with `#[reifiable(range = 0..=255)]` and the macro
520/// generates a `reify_<method>` dispatch function for each method that
521/// has a const generic parameter.
522///
523/// # Examples
524///
525/// ```ignore
526/// #[reifiable(range = 0..=255)]
527/// trait ModArith {
528///     fn mul_mod<const N: u64>(&self, a: u64, b: u64) -> u64;
529/// }
530///
531/// // Generated:
532/// // fn reify_mul_mod<T: ModArith>(val: u64, obj: &T, a: u64, b: u64) -> u64
533/// ```
534///
535/// # Limitations (V1)
536///
537/// - Only supports a single const generic parameter per method
538/// - Return types must not depend on the const parameter
539/// - `&mut self` methods get dispatch functions but not NatCallback wrappers
540#[proc_macro_attribute]
541pub fn reifiable(attr: TokenStream, item: TokenStream) -> TokenStream {
542    let args = parse_macro_input!(attr as ReifiableArgs);
543    let trait_def = parse_macro_input!(item as ItemTrait);
544
545    match reifiable_impl(args, &trait_def) {
546        Ok(tokens) => tokens.into(),
547        Err(e) => {
548            let trait_tokens = quote!(#trait_def);
549            let err = e.to_compile_error();
550            // Emit the original trait so downstream code doesn't break,
551            // plus the error.
552            TokenStream::from(quote! {
553                #trait_tokens
554                #err
555            })
556        }
557    }
558}
559
560fn reifiable_impl(args: ReifiableArgs, trait_def: &ItemTrait) -> syn::Result<TokenStream2> {
561    let trait_name = &trait_def.ident;
562    let trait_vis = &trait_def.vis;
563    let trait_generics = &trait_def.generics;
564
565    // Validate range
566    if args.range_end > 1023 {
567        return Err(syn::Error::new(
568            Span::call_site(),
569            format!(
570                "#[reifiable] range 0..={} would generate {} monomorphizations per method. \
571                 Maximum is 1024. Use a smaller range.",
572                args.range_end,
573                args.range_end + 1,
574            ),
575        ));
576    }
577
578    let mut dispatch_fns = Vec::new();
579    let mut callback_wrappers = Vec::new();
580
581    for item in &trait_def.items {
582        if let TraitItem::Fn(method) = item {
583            if let Some(result) = analyze_method(method) {
584                let cm = result?;
585
586                dispatch_fns.push(generate_dispatch_fn(
587                    trait_name,
588                    trait_generics,
589                    trait_vis,
590                    &cm,
591                    args.range_start,
592                    args.range_end,
593                ));
594
595                let wrapper = generate_callback_wrapper(trait_name, trait_generics, trait_vis, &cm);
596                if !wrapper.is_empty() {
597                    callback_wrappers.push(wrapper);
598                }
599            }
600        }
601    }
602
603    // Emit the original trait unchanged, plus generated code
604    Ok(quote! {
605        #trait_def
606
607        #(#dispatch_fns)*
608
609        #(#callback_wrappers)*
610    })
611}