constify/
lib.rs

1//! # constify
2//!
3//! Example usage:
4//! ```rust
5//! use constify::constify;
6//! #[constify]
7//! fn foo(
8//!     #[constify] a: bool,
9//!     #[constify] b: bool,
10//!     c: bool
11//! ) -> u32 {
12//!     let mut sum = 0;
13//!     if a {
14//!         sum += 1;
15//!     }
16//!     if b {
17//!         sum += 10;
18//!     }
19//!     if c {
20//!         sum += 100;
21//!     }
22//!     sum
23//! }
24//! ```
25//!
26//! Expansion:
27//! ```rust
28//! #[inline(always)]
29//! fn foo(a: bool, b: bool, c: bool) -> u32 {
30//!     fn foo<const a: bool, const b: bool>(c: bool) -> u32 {
31//!         let mut sum = 0;
32//!         if a {
33//!             sum += 1;
34//!         }
35//!         if b {
36//!             sum += 10;
37//!         }
38//!         if c {
39//!             sum += 100;
40//!         }
41//!         sum
42//!     }
43//!     match (a, b) {
44//!         (false, false) => foo::<false, false>(c),
45//!         (true, false) => foo::<true, false>(c),
46//!         (false, true) => foo::<false, true>(c),
47//!         (true, true) => foo::<true, true>(c),
48//!     }
49//! }
50//! ```
51//!
52//! Inspired by <https://github.com/TennyZhuang/const-currying-rs>
53
54use proc_macro::TokenStream as TokenStream1;
55use proc_macro2::TokenStream;
56use quote::{quote, quote_spanned};
57use syn::{
58    Block, ConstParam, FnArg, GenericParam, ItemFn, Pat, PatIdent, Result, Token, Type,
59    parse_macro_input, parse_quote, spanned::Spanned,
60};
61
62#[proc_macro_attribute]
63pub fn constify(_attr: TokenStream1, item: TokenStream1) -> TokenStream1 {
64    let input = parse_macro_input!(item as ItemFn);
65    match inner(input) {
66        Ok(output) => output.into(),
67        Err(err) => err.to_compile_error().into(),
68    }
69}
70
71fn remove_attr(arg: &FnArg) -> FnArg {
72    match arg.clone() {
73        FnArg::Typed(mut typed) => {
74            typed.attrs.clear();
75            FnArg::Typed(typed)
76        }
77        r @ FnArg::Receiver(_) => r,
78    }
79}
80
81const CONSTIFY: &str = "constify";
82fn inner(item: ItemFn) -> Result<TokenStream> {
83    let ItemFn {
84        sig, attrs, block, ..
85    } = &item;
86
87    // collect constifyed bool targets
88    let mut constify_args = Vec::new();
89    let mut non_constify_args = Vec::new();
90    for input in sig.inputs.iter() {
91        if let FnArg::Typed(typed) = input {
92            if typed.attrs.iter().any(|a| a.path().is_ident(CONSTIFY)) {
93                let Pat::Ident(PatIdent {
94                    ident: arg_name, ..
95                }) = &*typed.pat
96                else {
97                    return Err(syn::Error::new(
98                        typed.pat.span(),
99                        "Only simple identifiers are supported for constifyed args",
100                    ));
101                };
102                if !matches!(&*typed.ty, Type::Path(tp) if tp.path.is_ident("bool")) {
103                    return Err(syn::Error::new(
104                        typed.ty.span(),
105                        "#[constify] only supports `bool` parameters",
106                    ));
107                }
108                constify_args.push((arg_name.clone(), typed.clone()));
109                continue;
110            }
111        }
112        non_constify_args.push(remove_attr(input));
113    }
114
115    // build specialized fn (add const<bool> generics and strip constifyed args)
116    if constify_args.is_empty() {
117        return Ok(quote! { #item });
118    }
119
120    let mut new_sig = sig.clone();
121
122    // convert Vec<FnArg> non_constify_args into Punctuated<FnArg, Comma>
123    new_sig.inputs = non_constify_args.into_iter().collect();
124
125    new_sig
126        .generics
127        .params
128        .extend(constify_args.iter().map(|(arg_name, input)| {
129            GenericParam::Const(ConstParam {
130                attrs: vec![],
131                const_token: Token![const](arg_name.span()),
132                ident: arg_name.clone(),
133                colon_token: input.colon_token,
134                ty: parse_quote!(bool),
135                default: None,
136                eq_token: None,
137            })
138        }));
139
140    let mut new_attrs = attrs.clone();
141    new_attrs.push(parse_quote!(#[allow(warnings)]));
142
143    let specialized_fn = ItemFn {
144        sig: new_sig,
145        attrs: new_attrs,
146        block: block.clone(),
147        ..item.clone()
148    };
149
150    let non_constifyed_arg_ts = specialized_fn.sig.inputs.iter()
151        .map(| input| match input {
152            FnArg::Receiver(_) => quote! { self },
153            FnArg::Typed(typed) => match &*typed.pat {
154                Pat::Ident(p) => {
155                    let name = &p.ident;
156                    quote! { #name }
157                }
158                _ => quote_spanned! { typed.pat.span()=> compile_error!("Only simple identifiers are supported for non-constify args"); },
159            }
160        })
161        .collect::<Vec<_>>();
162
163    // build branches for all 2^N bool combinations
164    let n = constify_args.len();
165    let total = 1usize << n;
166    let mut branches = Vec::with_capacity(total);
167
168    for mask in 0..total {
169        // bool literals in target order
170        let mut match_args = Vec::with_capacity(n);
171        let mut added_const_args = Vec::with_capacity(n);
172
173        for i in 0..n {
174            let b = ((mask >> i) & 1) == 1;
175            let lit = if b { quote!(true) } else { quote!(false) };
176            match_args.push(lit.clone());
177            added_const_args.push(lit);
178        }
179
180        // <type_generics..., const_generics..., added_bool_consts...>
181        let all_generic_args = sig
182            .generics
183            .params
184            .iter()
185            .filter_map(|p| {
186                let id = match p {
187                    GenericParam::Type(t) => &t.ident,
188                    GenericParam::Const(c) => &c.ident,
189                    _ => return None,
190                };
191                Some(quote! { #id })
192            })
193            .chain(added_const_args.into_iter());
194        let generic_appl = quote! { ::<#(#all_generic_args),*> };
195
196        let fn_ident = &sig.ident;
197        let call = if non_constifyed_arg_ts.is_empty() {
198            quote! { #fn_ident #generic_appl () }
199        } else {
200            quote! { #fn_ident #generic_appl (#(#non_constifyed_arg_ts),*,) }
201        };
202
203        branches.push(quote! { (#(#match_args),*) => { #call } });
204    }
205
206    // outer constifyer with original signature (minus arg attrs)
207    let constify_fn: ItemFn = {
208        let mut new_sig = sig.clone();
209        new_sig.inputs = sig.inputs.iter().map(remove_attr).collect();
210
211        let mut new_attrs = attrs.clone();
212        new_attrs.push(parse_quote!(#[inline(always)]));
213
214        let all_target_names: Vec<_> = constify_args.iter().map(|(name, _)| name).collect();
215
216        let body: Block = parse_quote! {{
217            #specialized_fn
218            match (#(#all_target_names),*) {
219                #(#branches),*
220            }
221        }};
222
223        ItemFn {
224            sig: new_sig,
225            attrs: new_attrs,
226            block: Box::new(body),
227            vis: item.vis.clone(),
228        }
229    };
230
231    Ok(quote! { #constify_fn })
232}