1#![doc = include_str!("../README.md")]
2
3use std::collections::HashSet;
4
5use auto_enums::auto_enum;
6use darling::{FromAttributes, FromMeta};
7use itertools::Itertools;
8use proc_macro2::TokenStream;
9use quote::{format_ident, quote};
10use syn::{
11 parse_macro_input, parse_quote, punctuated::Punctuated, Attribute, Block, ConstParam, Expr,
12 FnArg, GenericParam, Generics, Ident, ItemFn, Pat, PatIdent, PatType, Result, Signature, Token,
13 Type,
14};
15
16#[proc_macro_attribute]
17pub fn const_currying(
18 attr: proc_macro::TokenStream,
19 item: proc_macro::TokenStream,
20) -> proc_macro::TokenStream {
21 let input = parse_macro_input!(item as ItemFn);
22 match inner(attr.into(), input) {
23 Ok(output) => output.into(),
24 Err(err) => err.to_compile_error().into(),
25 }
26}
27
28#[derive(Debug, Clone, darling::FromAttributes)]
29#[darling(attributes(maybe_const))]
30struct FieldAttr {
31 #[darling(default)]
32 dispatch: Option<Ident>,
33 #[darling(default)]
34 consts: ConstsArray,
35}
36
37#[derive(Debug, Clone, Default)]
38struct ConstsArray {
39 inner: Punctuated<Expr, Token![,]>,
40}
41
42impl FromMeta for ConstsArray {
43 fn from_expr(expr: &Expr) -> darling::Result<Self> {
44 if let Expr::Array(array) = expr {
45 Ok(Self {
46 inner: array.elems.clone(),
47 })
48 } else {
49 Err(darling::Error::unexpected_expr_type(expr))
50 }
51 }
52}
53
54#[derive(Clone, Debug)]
55struct GenTarget {
56 attr: FieldAttr,
57 idx: usize,
58 arg_name: Ident,
59 input: PatType,
60 ty: Type,
61}
62
63fn remove_attr(arg: FnArg) -> FnArg {
64 match arg {
65 FnArg::Typed(mut typed) => {
66 typed.attrs.clear();
67 FnArg::Typed(typed)
68 }
69 FnArg::Receiver(receiver) => FnArg::Receiver(receiver),
70 }
71}
72
73fn contains_attr(attrs: &[Attribute]) -> bool {
74 attrs.iter().any(|attr| attr.path().is_ident("maybe_const"))
75}
76
77#[auto_enum]
78fn inner(_attr: TokenStream, item: ItemFn) -> Result<TokenStream> {
79 let item2 = item.clone();
80 let ItemFn { sig, .. } = item;
81
82 let Signature {
83 ident,
84 inputs,
85 generics,
86 ..
87 } = &sig;
88
89 let targets = inputs
90 .iter()
91 .enumerate()
92 .filter_map(|(idx, input)| match input {
93 FnArg::Receiver(..) => None,
94 FnArg::Typed(typed) => {
95 let PatType { attrs, ty, pat, .. } = typed;
96 let Pat::Ident(PatIdent {
97 ident: arg_name, ..
98 }) = &**pat
99 else {
100 return None;
101 };
102 if !contains_attr(attrs) {
103 return None;
104 }
105 let attr = FieldAttr::from_attributes(attrs).ok()?;
106 Some(GenTarget {
107 attr,
108 idx,
109 arg_name: arg_name.clone(),
110 input: typed.clone(),
111 ty: *ty.clone(),
112 })
113 }
114 })
115 .collect::<Vec<_>>();
116
117 let old_fn_name = format_ident!("{ident}_orig");
118
119 let orig_const_args: Vec<_> = generics
120 .const_params()
121 .map(|param| param.ident.clone())
122 .collect();
123
124 let fns = targets
125 .iter()
126 .cloned()
127 .powerset()
128 .zip(std::iter::from_fn(|| {
129 let item = item2.clone();
130 Some(item)
131 }))
132 .map(|(set, item)| {
133 let ItemFn { sig, .. } = item.clone();
134 let Signature {
135 ident,
136 inputs,
137 generics,
138 ..
139 } = &sig;
140 let new_fn_name = [ident.to_string()]
141 .into_iter()
142 .chain(set.iter().map(|t| {
143 t.attr
144 .dispatch
145 .as_ref()
146 .map(ToString::to_string)
147 .unwrap_or(t.arg_name.to_string())
148 }))
149 .join("_");
150 let new_fn_ident = if set.is_empty() {
151 old_fn_name.clone()
152 } else {
153 Ident::new(&new_fn_name, ident.span())
154 };
155
156 let added_generic_params = set
157 .iter()
158 .map(|t: &GenTarget| {
159 let GenTarget {
160 attr: _,
161 idx: _,
162 arg_name,
163 input,
164 ty,
165 } = t;
166 ConstParam {
167 attrs: vec![],
168 const_token: Token),
169 ident: arg_name.clone(),
170 colon_token: input.colon_token,
171 ty: ty.clone(),
172 default: None,
173 eq_token: None,
174 }
175 })
176 .map(GenericParam::Const);
177
178 let mut old_generics_pararms = generics.params.clone();
179 for new_param in added_generic_params {
180 old_generics_pararms.push(new_param);
181 }
182 let new_generics = Generics {
183 params: old_generics_pararms,
184 ..generics.clone()
185 };
186 let new_inputs = {
187 let args_to_remove: HashSet<_> = set.iter().map(|t| t.idx).collect();
188 inputs
189 .iter()
190 .cloned()
191 .enumerate()
192 .filter(|(idx, _)| !args_to_remove.contains(idx))
193 .map(|(_idx, input)| input)
194 .map(remove_attr)
195 .collect::<Punctuated<_, Token![,]>>()
196 };
197 let sig = sig.clone();
198 let new_sig = Signature {
199 ident: new_fn_ident,
200 inputs: new_inputs,
201 generics: new_generics,
202 ..sig
203 };
204 let item = item.clone();
205 let mut new_attrs = item.attrs.clone();
206 let new_attr: Attribute = parse_quote!(#[allow(warnings)]);
207 new_attrs.push(new_attr);
208 ItemFn {
209 sig: new_sig,
210 attrs: new_attrs,
211 ..item
212 }
213 })
214 .collect::<Vec<_>>();
215
216 let all_target_names = targets
218 .iter()
219 .map(|target| target.arg_name.clone())
220 .collect::<Vec<_>>();
221
222 let mut branches = targets
223 .iter()
224 .cloned()
225 .enumerate()
226 .powerset()
227 .flat_map(|set| {
228 let new_fn_name = [ident.to_string()]
229 .into_iter()
230 .chain(set.iter().map(|(_, t)| {
231 t.attr
232 .dispatch
233 .as_ref()
234 .map(ToString::to_string)
235 .unwrap_or(t.arg_name.to_string())
236 }))
237 .join("_");
238 let new_fn_ident = if set.is_empty() {
239 old_fn_name.clone()
240 } else {
241 Ident::new(&new_fn_name, ident.span())
242 };
243
244 let remain_args = {
245 let args_to_remove: HashSet<_> = set.iter().map(|(_, t)| t.idx).collect();
246 inputs
247 .iter()
248 .cloned()
249 .enumerate()
250 .filter(|(idx, _)| !args_to_remove.contains(idx))
251 .map(|(_idx, input)| input)
252 .map(|input| match input {
253 FnArg::Receiver(_reciver) => quote! { self },
254 FnArg::Typed(typed) => match *typed.pat {
255 Pat::Ident(pat_ident) => {
256 let name = pat_ident.ident;
257 quote! { #name }
258 }
259 _ => panic!("Only support simple pattern"),
260 },
261 })
262 .collect::<Vec<_>>()
263 };
264
265 #[auto_enum(Iterator)]
266 let const_sets = if set.is_empty() {
267 std::iter::once(vec![])
268 } else {
269 Itertools::multi_cartesian_product(set.iter().map(|(idx, target)| {
270 itertools::izip!(std::iter::repeat(idx), target.attr.consts.inner.iter(),)
271 }))
272 };
273
274 const_sets
275 .map(|const_set| {
276 let mut match_args = all_target_names
277 .iter()
278 .map(|target_name| quote! { #target_name })
279 .collect::<Vec<_>>();
280 let mut added_const_args = Vec::with_capacity(const_set.len());
281 for (idx_in_target, r#const) in const_set {
282 match_args[*idx_in_target] = quote! { #r#const };
283 added_const_args.push(quote! { #r#const });
284 }
285 let const_args = orig_const_args
286 .iter()
287 .map(|ident| quote! { #ident })
288 .chain(added_const_args.into_iter());
289 if remain_args.is_empty() {
290 quote! {
291 (#(#match_args),*) => {
292 #new_fn_ident::<#(#const_args),*>()
293 }
294 }
295 } else {
296 quote! {
297 (#(#match_args),*) => {
298 #new_fn_ident::<#(#const_args),*>(#(#remain_args),*,)
299 }
300 }
301 }
302 })
303 .collect::<Vec<_>>()
304 })
305 .collect::<Vec<_>>();
306 branches.reverse();
307
308 let dispatch_fn = {
309 let body: Block = parse_quote! {
310 {
311 match (#(#all_target_names),*) {
312 #(#branches),*
313 }
314 }
315 };
316 let new_inputs = sig
317 .inputs
318 .iter()
319 .cloned()
320 .map(remove_attr)
321 .collect::<Punctuated<_, Token![,]>>();
322 let new_sig = Signature {
323 inputs: new_inputs,
324 ..sig
325 };
326 let mut new_attrs = item2.attrs.clone();
327 new_attrs.push(parse_quote! { #[inline(always)] });
328 ItemFn {
329 sig: new_sig,
330 block: Box::new(body),
331 attrs: new_attrs,
332 ..item2
333 }
334 };
335
336 Ok(quote! {
337 #dispatch_fn
338 #(#fns)*
339 })
340}