1use proc_macro::TokenStream;
2use quote::{format_ident, quote, ToTokens};
3use syn::{
4 fold::{self, Fold},
5 parse_macro_input, parse_quote,
6 punctuated::Punctuated,
7 token::Comma,
8 Expr, FnArg, GenericParam, ItemFn, Macro, ReturnType, Stmt, Type,
9};
10
11#[proc_macro_attribute]
12pub fn composable(_attr: TokenStream, item: TokenStream) -> TokenStream {
13 let item = parse_macro_input!(item as ItemFn);
14
15 let ident = item.sig.ident;
16 let vis = item.vis;
17 let mut generics = Vec::new();
18 for param in &item.sig.generics.params {
19 match param {
20 GenericParam::Type(type_param) => {
21 generics.push(type_param.ident.clone());
22 }
23 _ => todo!(),
24 }
25 }
26
27 let generics_clause = item.sig.generics.params;
28 let where_clause = item.sig.generics.where_clause;
29
30 let output = match item.sig.output {
31 ReturnType::Type(_, ty) => Some(*ty),
32 ReturnType::Default => None,
33 };
34 let output_ty = output.clone().unwrap_or(parse_quote!(()));
35
36 let block = Folder {
37 is_nested: false,
38 is_replaceable: false,
39 pos: 0,
40 }
41 .fold_block(*item.block);
42
43 let mut input_pats = Vec::new();
44 let mut input_types = Vec::new();
45 for input in item.sig.inputs {
46 match input {
47 FnArg::Typed(typed) => {
48 input_pats.push(typed.pat);
49 input_types.push(typed.ty);
50 }
51 _ => todo!(),
52 }
53 }
54
55 let struct_ident = format_ident!("{}_composable", ident);
56 let inputs: Vec<_> = input_pats
57 .iter()
58 .zip(&input_types)
59 .map(|(pat, ty)| quote!(#pat: #ty))
60 .collect();
61
62 let mut struct_fields = inputs.clone();
63
64 let input_generics: Vec<_> = input_types
65 .iter()
66 .filter_map(|ty| {
67 match &**ty {
68 Type::Path(type_path) => {
69 if let Some(ident) = type_path.path.get_ident() {
70 return Some(ident);
71 }
72 }
73 _ => {}
74 }
75
76 None
77 })
78 .collect();
79
80 let mut struct_markers = Vec::new();
81 for (idx, generic) in generics.iter().enumerate() {
82 if !input_generics.contains(&generic) {
83 let ident = format_ident!("_marker{}", idx);
84 struct_fields.push(parse_quote!(#ident: std::marker::PhantomData<#generic>));
85 struct_markers.push(quote!(#ident: std::marker::PhantomData));
86 }
87 }
88
89 let group_id = quote!(std::any::TypeId::of::<#struct_ident::<#(#generics,)*>>());
90 let group = if output.is_some() {
91 quote! {
92 composer.replaceable_group(#group_id, move |composer| #block)
93 }
94 } else {
95 if inputs.is_empty() {
96 quote! {
97 composer.restart_group(#group_id, move |composer| {
98 #block
107 });
108 }
109 } else {
110 let checks = input_pats.iter().enumerate().map(|(idx, input)| {
111 let i: u32 = 0b111 << (idx * 3 + 1);
112 quote! {
113 if changed & #i == 0 {
114 dirty = changed | if composer.changed(&x) { 4 } else { 2 };
115 }
116 }
117 });
118
119 let mut mask = 1u32;
120 let mut value = 0u32;
121 for idx in 0..input_pats.len() {
122 mask |= 0b101 << (idx * 3 + 1);
123 value |= 0b10 << (idx * 3);
124 }
125
126 quote! {
127 composer.restart_group(#group_id, move |composer| {
128 #block
141 });
142 }
143 }
144 };
145
146 let mut constructor_fields = Punctuated::<_, Comma>::new();
147 constructor_fields.extend(input_pats.iter().map(|pat| pat.to_token_stream()));
148 constructor_fields.extend(struct_markers.clone());
149
150 let mut struct_pattern = Punctuated::<_, Comma>::new();
151 struct_pattern.extend(input_pats.iter().map(|pat| pat.to_token_stream()));
152 struct_pattern.push(quote!(..));
153
154 let expanded = quote! {
155 #[must_use]
156 #vis fn #ident <#generics_clause> (#(#inputs),*) -> impl concoct::Composable<Output = #output_ty> #where_clause {
157 #[allow(non_camel_case_types)]
158 struct #struct_ident <#(#generics),*> {
159 #(#struct_fields),*
160 }
161
162 impl<#generics_clause> concoct::Composable<> for #struct_ident <#(#generics),*> #where_clause {
163 type Output = #output_ty;
164
165 fn compose(self, composer: &mut concoct::Composer, changed: u32) -> Self::Output {
166 compose!(());
167
168 let Self { #struct_pattern } = self;
169
170 #group
171 }
172 }
173
174 #struct_ident {
175 #constructor_fields
176 }
177 }
178 };
179
180 TokenStream::from(expanded)
181}
182
183struct Folder {
184 is_nested: bool,
185 is_replaceable: bool,
186 pos: usize,
187}
188
189impl Fold for Folder {
190 fn fold_stmt(&mut self, mut i: syn::Stmt) -> syn::Stmt {
191 if let Stmt::Macro(stmt_macro) = &i {
192 if let Some(expr) = get_compose_macro(&stmt_macro.mac) {
193 self.is_replaceable = true;
194 i = parse_quote! {
195 (#expr).compose(composer, 0);
196 };
197 }
198 }
199
200 fold::fold_stmt(self, i)
201 }
202
203 fn fold_expr(&mut self, mut i: Expr) -> Expr {
204 match &mut i {
205 Expr::Macro(expr_macro) => {
206 self.is_replaceable = true;
207 if let Some(expr) = get_compose_macro(&expr_macro.mac) {
208 i = parse_quote! {
209 (#expr).compose(composer, 0)
210 };
211 }
212 }
213 Expr::If(expr_if) => {
214 let old = self.is_nested;
215 self.is_nested = true;
216
217 *expr_if = fold::fold_expr_if(self, expr_if.clone());
218 self.is_nested = old;
219 }
220 _ => {}
221 }
222
223 fold::fold_expr(self, i)
224 }
225
226 fn fold_block(&mut self, i: syn::Block) -> syn::Block {
227 if self.is_nested {
228 let old = self.is_replaceable;
229 self.is_replaceable = false;
230
231 let mut block = fold::fold_block(self, i);
232 if self.is_replaceable {
233 let ident = format_ident!("Group{}", self.pos);
234 self.pos += 1;
235
236 block = parse_quote!({
237 struct #ident;
238 composer.replaceable_group(std::any::TypeId::of::<#ident>(), |composer| #block)
239 });
240 }
241
242 self.is_replaceable = old;
243
244 block
245 } else {
246 fold::fold_block(self, i)
247 }
248 }
249}
250
251fn get_compose_macro(mac: &Macro) -> Option<Expr> {
252 if mac.path.get_ident().map(ToString::to_string).as_deref() == Some("compose") {
253 let body = mac.parse_body().unwrap();
254 Some(body)
255 } else {
256 None
257 }
258}