defunctionalize_proc_macro/
lib.rs

1use heck::CamelCase;
2use proc_macro::TokenStream;
3use proc_macro_error::{diagnostic, Diagnostic, Level::Error};
4use quote::{format_ident, quote};
5use syn::{spanned::Spanned, FnArg, Ident, Item, ItemMod, Pat, ReturnType, Visibility};
6
7mod signature;
8mod simple_arg;
9
10use signature::Signature;
11use simple_arg::SimpleArg;
12
13#[proc_macro_attribute]
14#[proc_macro_error::proc_macro_error]
15pub fn defunctionalize(attr: TokenStream, item: TokenStream) -> TokenStream {
16    let mut mod_item = syn::parse_macro_input!(item as ItemMod);
17    let signature = syn::parse_macro_input!(attr as Signature);
18
19    let items = match &mod_item.content {
20        Some((.., items)) => items,
21        None => panic!(),
22    };
23
24    let derive_position = mod_item
25        .attrs
26        .iter()
27        .position(|attr| attr.path.segments[0].ident == "derive");
28    let derives = match derive_position {
29        Some(position) => vec![mod_item.attrs.remove(position)],
30        None => vec![],
31    };
32
33    let mod_name = &mod_item.ident;
34    let enum_name = signature
35        .ident
36        .clone()
37        .unwrap_or_else(|| format_ident!("{}", mod_name.to_string().to_camel_case()));
38
39    let functions = items
40        .iter()
41        .filter_map(|item| match item {
42            Item::Fn(item) => Some(item),
43            _ => None,
44        })
45        .filter(|item| matches!(item.vis, Visibility::Public(..)))
46        .collect::<Vec<_>>();
47
48    let case_names = functions
49        .iter()
50        .map(|item| item.sig.ident.to_string().to_camel_case())
51        .map(|name| format_ident!("{}", name))
52        .collect::<Vec<_>>();
53
54    let function_names = functions
55        .iter()
56        .map(|item| &item.sig.ident)
57        .collect::<Vec<_>>();
58
59    let case_arg_names = functions
60        .iter()
61        .map(|item| {
62            item.sig
63                .inputs
64                .iter()
65                .map(|arg| match arg {
66                    FnArg::Receiver(..) => Err(diagnostic!(
67                        arg.span(),
68                        Error,
69                        "defunctionalized functions cannot have receivers"
70                    )),
71                    FnArg::Typed(pat) => Ok(pat.pat.as_ref()),
72                })
73                .map(|pat| match pat? {
74                    Pat::Ident(ident) => Ok(&ident.ident),
75                    pat => Err(diagnostic!(
76                        pat.span(),
77                        Error,
78                        "arguments to defunctionalized functions must be named"
79                    )),
80                })
81                .collect::<Result<Vec<_>, _>>()
82        })
83        .map(|mut args| {
84            match &mut args {
85                Ok(args) => args.truncate(args.len() - signature.inputs.len()),
86                Err(..) => {}
87            }
88            args
89        })
90        .map(|args| {
91            let args = args?;
92            Ok(if args.is_empty() { vec![] } else { vec![args] })
93        })
94        .collect::<Result<Vec<_>, Diagnostic>>();
95    let case_arg_names = match case_arg_names {
96        Ok(case_arg_names) => case_arg_names,
97        Err(diagnostic) => diagnostic.abort(),
98    };
99
100    let case_arg_types = functions
101        .iter()
102        .map(|item| {
103            item.sig
104                .inputs
105                .iter()
106                .map(|arg| match arg {
107                    FnArg::Receiver(..) => unreachable!(),
108                    FnArg::Typed(pat) => pat.ty.as_ref(),
109                })
110                .collect::<Vec<_>>()
111        })
112        .map(|mut args| {
113            args.truncate(args.len() - signature.inputs.len());
114            args
115        })
116        .map(|args| if args.is_empty() { vec![] } else { vec![args] })
117        .collect::<Vec<_>>();
118
119    let visibility = &mod_item.vis;
120    let generics = &signature.generics;
121    let where_clause = &signature.generics.where_clause;
122    let inputs = &signature.inputs;
123    let input_types = inputs.iter().map(|arg| &arg.ty).collect::<Vec<_>>();
124    let input_names = &signature
125        .inputs
126        .iter()
127        .map(|arg| &arg.ident)
128        .collect::<Vec<&Ident>>();
129    let arg_idents = std::iter::repeat(&input_names);
130    let output = &signature.output;
131    let output_type = match output {
132        ReturnType::Default => quote!(()),
133        ReturnType::Type(.., ty) => quote!(#ty),
134    };
135
136    let output = quote! {
137        #mod_item
138
139        #(#derives)*
140        #visibility enum #enum_name {
141            #(#case_names#((#(#case_arg_types),*))*),*
142        }
143
144        impl #generics defunctionalize::DeFn<(#(#input_types),*)> for #enum_name #where_clause {
145            type Output = #output_type;
146
147            fn call (self, (#(#input_names),*): (#(#input_types),*)) #output {
148                self.call(#(#input_names),*)
149            }
150        }
151
152        impl #enum_name {
153            #visibility fn call #generics (self, #inputs) #output #where_clause {
154                match self {
155                    #(Self::#case_names#((#(#case_arg_names),*))* => {
156                        #mod_name::#function_names(
157                            #(#(#case_arg_names,)*)*
158                            #(#arg_idents),*
159                        )
160                    })*
161                }
162            }
163        }
164    };
165
166    output.into()
167}