defunctionalize_proc_macro/
lib.rs1use heck::CamelCase;
2use proc_macro::TokenStream;
3use proc_macro_error::{diagnostic, Diagnostic, Level::Error};
4use quote::{format_ident, quote};
5use syn::{spanned::Spanned, FnArg, Ident, Item, ItemMod, Pat, ReturnType, Visibility};
6
7mod signature;
8mod simple_arg;
9
10use signature::Signature;
11use simple_arg::SimpleArg;
12
13#[proc_macro_attribute]
14#[proc_macro_error::proc_macro_error]
15pub fn defunctionalize(attr: TokenStream, item: TokenStream) -> TokenStream {
16 let mut mod_item = syn::parse_macro_input!(item as ItemMod);
17 let signature = syn::parse_macro_input!(attr as Signature);
18
19 let items = match &mod_item.content {
20 Some((.., items)) => items,
21 None => panic!(),
22 };
23
24 let derive_position = mod_item
25 .attrs
26 .iter()
27 .position(|attr| attr.path.segments[0].ident == "derive");
28 let derives = match derive_position {
29 Some(position) => vec![mod_item.attrs.remove(position)],
30 None => vec![],
31 };
32
33 let mod_name = &mod_item.ident;
34 let enum_name = signature
35 .ident
36 .clone()
37 .unwrap_or_else(|| format_ident!("{}", mod_name.to_string().to_camel_case()));
38
39 let functions = items
40 .iter()
41 .filter_map(|item| match item {
42 Item::Fn(item) => Some(item),
43 _ => None,
44 })
45 .filter(|item| matches!(item.vis, Visibility::Public(..)))
46 .collect::<Vec<_>>();
47
48 let case_names = functions
49 .iter()
50 .map(|item| item.sig.ident.to_string().to_camel_case())
51 .map(|name| format_ident!("{}", name))
52 .collect::<Vec<_>>();
53
54 let function_names = functions
55 .iter()
56 .map(|item| &item.sig.ident)
57 .collect::<Vec<_>>();
58
59 let case_arg_names = functions
60 .iter()
61 .map(|item| {
62 item.sig
63 .inputs
64 .iter()
65 .map(|arg| match arg {
66 FnArg::Receiver(..) => Err(diagnostic!(
67 arg.span(),
68 Error,
69 "defunctionalized functions cannot have receivers"
70 )),
71 FnArg::Typed(pat) => Ok(pat.pat.as_ref()),
72 })
73 .map(|pat| match pat? {
74 Pat::Ident(ident) => Ok(&ident.ident),
75 pat => Err(diagnostic!(
76 pat.span(),
77 Error,
78 "arguments to defunctionalized functions must be named"
79 )),
80 })
81 .collect::<Result<Vec<_>, _>>()
82 })
83 .map(|mut args| {
84 match &mut args {
85 Ok(args) => args.truncate(args.len() - signature.inputs.len()),
86 Err(..) => {}
87 }
88 args
89 })
90 .map(|args| {
91 let args = args?;
92 Ok(if args.is_empty() { vec![] } else { vec![args] })
93 })
94 .collect::<Result<Vec<_>, Diagnostic>>();
95 let case_arg_names = match case_arg_names {
96 Ok(case_arg_names) => case_arg_names,
97 Err(diagnostic) => diagnostic.abort(),
98 };
99
100 let case_arg_types = functions
101 .iter()
102 .map(|item| {
103 item.sig
104 .inputs
105 .iter()
106 .map(|arg| match arg {
107 FnArg::Receiver(..) => unreachable!(),
108 FnArg::Typed(pat) => pat.ty.as_ref(),
109 })
110 .collect::<Vec<_>>()
111 })
112 .map(|mut args| {
113 args.truncate(args.len() - signature.inputs.len());
114 args
115 })
116 .map(|args| if args.is_empty() { vec![] } else { vec![args] })
117 .collect::<Vec<_>>();
118
119 let visibility = &mod_item.vis;
120 let generics = &signature.generics;
121 let where_clause = &signature.generics.where_clause;
122 let inputs = &signature.inputs;
123 let input_types = inputs.iter().map(|arg| &arg.ty).collect::<Vec<_>>();
124 let input_names = &signature
125 .inputs
126 .iter()
127 .map(|arg| &arg.ident)
128 .collect::<Vec<&Ident>>();
129 let arg_idents = std::iter::repeat(&input_names);
130 let output = &signature.output;
131 let output_type = match output {
132 ReturnType::Default => quote!(()),
133 ReturnType::Type(.., ty) => quote!(#ty),
134 };
135
136 let output = quote! {
137 #mod_item
138
139 #(#derives)*
140 #visibility enum #enum_name {
141 #(#case_names#((#(#case_arg_types),*))*),*
142 }
143
144 impl #generics defunctionalize::DeFn<(#(#input_types),*)> for #enum_name #where_clause {
145 type Output = #output_type;
146
147 fn call (self, (#(#input_names),*): (#(#input_types),*)) #output {
148 self.call(#(#input_names),*)
149 }
150 }
151
152 impl #enum_name {
153 #visibility fn call #generics (self, #inputs) #output #where_clause {
154 match self {
155 #(Self::#case_names#((#(#case_arg_names),*))* => {
156 #mod_name::#function_names(
157 #(#(#case_arg_names,)*)*
158 #(#arg_idents),*
159 )
160 })*
161 }
162 }
163 }
164 };
165
166 output.into()
167}