const_currying/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use std::collections::HashSet;
4
5use auto_enums::auto_enum;
6use darling::{FromAttributes, FromMeta};
7use itertools::Itertools;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11    parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, Block, ConstParam, Expr,
12    FnArg, GenericParam, Generics, Ident, ItemFn, Pat, PatIdent, PatType, Result, Signature, Token,
13    Type,
14};
15
16#[proc_macro_attribute]
17pub fn const_currying(
18    attr: proc_macro::TokenStream,
19    item: proc_macro::TokenStream,
20) -> proc_macro::TokenStream {
21    let input = parse_macro_input!(item as ItemFn);
22    match inner(attr.into(), input) {
23        Ok(output) => output.into(),
24        Err(err) => err.to_compile_error().into(),
25    }
26}
27
28#[derive(Debug, Clone, darling::FromAttributes)]
29#[darling(attributes(maybe_const))]
30struct FieldAttr {
31    #[darling(default)]
32    dispatch: Option<Ident>,
33    #[darling(default)]
34    consts: ConstsArray,
35}
36
37#[derive(Debug, Clone, Default)]
38struct ConstsArray {
39    inner: Punctuated<Expr, Token![,]>,
40}
41
42impl FromMeta for ConstsArray {
43    fn from_expr(expr: &Expr) -> darling::Result<Self> {
44        if let Expr::Array(array) = expr {
45            Ok(Self {
46                inner: array.elems.clone(),
47            })
48        } else {
49            Err(darling::Error::unexpected_expr_type(expr))
50        }
51    }
52}
53
54#[derive(Clone, Debug)]
55struct GenTarget {
56    attr: FieldAttr,
57    idx: usize,
58    arg_name: Ident,
59    input: PatType,
60    ty: Type,
61}
62
63fn remove_attr(arg: FnArg) -> FnArg {
64    match arg {
65        FnArg::Typed(mut typed) => {
66            typed.attrs.clear();
67            FnArg::Typed(typed)
68        }
69        FnArg::Receiver(receiver) => FnArg::Receiver(receiver),
70    }
71}
72
73fn contains_attr(attrs: &[Attribute]) -> bool {
74    attrs.iter().any(|attr| attr.path().is_ident("maybe_const"))
75}
76
77#[auto_enum]
78fn inner(_attr: TokenStream, item: ItemFn) -> Result<TokenStream> {
79    let item2 = item.clone();
80    let ItemFn { sig, .. } = item;
81
82    let Signature {
83        ident,
84        inputs,
85        generics,
86        ..
87    } = &sig;
88
89    let targets = inputs
90        .iter()
91        .enumerate()
92        .filter_map(|(idx, input)| match input {
93            FnArg::Receiver(..) => None,
94            FnArg::Typed(typed) => {
95                let PatType { attrs, ty, pat, .. } = typed;
96                let Pat::Ident(PatIdent {
97                    ident: arg_name, ..
98                }) = &**pat
99                else {
100                    return None;
101                };
102                if !contains_attr(attrs) {
103                    return None;
104                }
105                let attr = FieldAttr::from_attributes(attrs).ok()?;
106                Some(GenTarget {
107                    attr,
108                    idx,
109                    arg_name: arg_name.clone(),
110                    input: typed.clone(),
111                    ty: *ty.clone(),
112                })
113            }
114        })
115        .collect::<Vec<_>>();
116
117    let old_fn_name = format_ident!("{ident}_orig");
118
119    let orig_const_args: Vec<_> = generics
120        .const_params()
121        .map(|param| param.ident.clone())
122        .collect();
123
124    let fns = targets
125        .iter()
126        .cloned()
127        .powerset()
128        .zip(std::iter::from_fn(|| {
129            let item = item2.clone();
130            Some(item)
131        }))
132        .map(|(set, item)| {
133            let ItemFn { sig, .. } = item.clone();
134            let Signature {
135                ident,
136                inputs,
137                generics,
138                ..
139            } = &sig;
140            let new_fn_name = [ident.to_string()]
141                .into_iter()
142                .chain(set.iter().map(|t| {
143                    t.attr
144                        .dispatch
145                        .as_ref()
146                        .map(ToString::to_string)
147                        .unwrap_or(t.arg_name.to_string())
148                }))
149                .join("_");
150            let new_fn_ident = if set.is_empty() {
151                old_fn_name.clone()
152            } else {
153                Ident::new(&new_fn_name, ident.span())
154            };
155
156            let added_generic_params = set
157                .iter()
158                .map(|t: &GenTarget| {
159                    let GenTarget {
160                        attr: _,
161                        idx: _,
162                        arg_name,
163                        input,
164                        ty,
165                    } = t;
166                    ConstParam {
167                        attrs: vec![],
168                        const_token: Token![const](arg_name.span()),
169                        ident: arg_name.clone(),
170                        colon_token: input.colon_token,
171                        ty: ty.clone(),
172                        default: None,
173                        eq_token: None,
174                    }
175                })
176                .map(GenericParam::Const);
177
178            let mut old_generics_pararms = generics.params.clone();
179            for new_param in added_generic_params {
180                old_generics_pararms.push(new_param);
181            }
182            let new_generics = Generics {
183                params: old_generics_pararms,
184                ..generics.clone()
185            };
186            let new_inputs = {
187                let args_to_remove: HashSet<_> = set.iter().map(|t| t.idx).collect();
188                inputs
189                    .iter()
190                    .cloned()
191                    .enumerate()
192                    .filter(|(idx, _)| !args_to_remove.contains(idx))
193                    .map(|(_idx, input)| input)
194                    .map(remove_attr)
195                    .collect::<Punctuated<_, Token![,]>>()
196            };
197            let sig = sig.clone();
198            let new_sig = Signature {
199                ident: new_fn_ident,
200                inputs: new_inputs,
201                generics: new_generics,
202                ..sig
203            };
204            let item = item.clone();
205            let mut new_attrs = item.attrs.clone();
206            let new_attr: Attribute = parse_quote!(#[allow(warnings)]);
207            new_attrs.push(new_attr);
208            ItemFn {
209                sig: new_sig,
210                attrs: new_attrs,
211                ..item
212            }
213        })
214        .collect::<Vec<_>>();
215
216    // Generate the dispatch function
217    let all_target_names = targets
218        .iter()
219        .map(|target| target.arg_name.clone())
220        .collect::<Vec<_>>();
221
222    let mut branches = targets
223        .iter()
224        .cloned()
225        .enumerate()
226        .powerset()
227        .flat_map(|set| {
228            let new_fn_name = [ident.to_string()]
229                .into_iter()
230                .chain(set.iter().map(|(_, t)| {
231                    t.attr
232                        .dispatch
233                        .as_ref()
234                        .map(ToString::to_string)
235                        .unwrap_or(t.arg_name.to_string())
236                }))
237                .join("_");
238            let new_fn_ident = if set.is_empty() {
239                old_fn_name.clone()
240            } else {
241                Ident::new(&new_fn_name, ident.span())
242            };
243
244            let remain_args = {
245                let args_to_remove: HashSet<_> = set.iter().map(|(_, t)| t.idx).collect();
246                inputs
247                    .iter()
248                    .cloned()
249                    .enumerate()
250                    .filter(|(idx, _)| !args_to_remove.contains(idx))
251                    .map(|(_idx, input)| input)
252                    .map(|input| match input {
253                        FnArg::Receiver(_reciver) => quote! { self },
254                        FnArg::Typed(typed) => match *typed.pat {
255                            Pat::Ident(pat_ident) => {
256                                let name = pat_ident.ident;
257                                quote! { #name }
258                            }
259                            _ => panic!("Only support simple pattern"),
260                        },
261                    })
262                    .collect::<Vec<_>>()
263            };
264
265            #[auto_enum(Iterator)]
266            let const_sets = if set.is_empty() {
267                std::iter::once(vec![])
268            } else {
269                Itertools::multi_cartesian_product(set.iter().map(|(idx, target)| {
270                    itertools::izip!(std::iter::repeat(idx), target.attr.consts.inner.iter(),)
271                }))
272            };
273
274            const_sets
275                .map(|const_set| {
276                    let mut match_args = all_target_names
277                        .iter()
278                        .map(|target_name| quote! { #target_name })
279                        .collect::<Vec<_>>();
280                    let mut added_const_args = Vec::with_capacity(const_set.len());
281                    for (idx_in_target, r#const) in const_set {
282                        match_args[*idx_in_target] = quote! { #r#const };
283                        added_const_args.push(quote! { #r#const });
284                    }
285                    let const_args = orig_const_args
286                        .iter()
287                        .map(|ident| quote! { #ident })
288                        .chain(added_const_args.into_iter());
289                    if remain_args.is_empty() {
290                        quote! {
291                            (#(#match_args),*) => {
292                                #new_fn_ident::<#(#const_args),*>()
293                            }
294                        }
295                    } else {
296                        quote! {
297                            (#(#match_args),*) => {
298                                #new_fn_ident::<#(#const_args),*>(#(#remain_args),*,)
299                            }
300                        }
301                    }
302                })
303                .collect::<Vec<_>>()
304        })
305        .collect::<Vec<_>>();
306    branches.reverse();
307
308    let dispatch_fn = {
309        let body: Block = parse_quote! {
310            {
311                match (#(#all_target_names),*) {
312                    #(#branches),*
313                }
314            }
315        };
316        let new_inputs = sig
317            .inputs
318            .iter()
319            .cloned()
320            .map(remove_attr)
321            .collect::<Punctuated<_, Token![,]>>();
322        let new_sig = Signature {
323            inputs: new_inputs,
324            ..sig
325        };
326        let mut new_attrs = item2.attrs.clone();
327        new_attrs.push(parse_quote! { #[inline(always)] });
328        ItemFn {
329            sig: new_sig,
330            block: Box::new(body),
331            attrs: new_attrs,
332            ..item2
333        }
334    };
335
336    Ok(quote! {
337        #dispatch_fn
338        #(#fns)*
339    })
340}