Skip to main content

cgp_extra_macro_lib/entrypoints/
cgp_auto_dispatch.rs

1use std::collections::BTreeSet;
2
3use cgp_macro_lib::utils::to_camel_case_str;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::token::Comma;
9use syn::{
10    FnArg, GenericParam, Ident, ImplItem, ImplItemFn, ItemTrait, Lifetime, Pat, PatIdent,
11    ReturnType, TraitItemFn, Type, Visibility, parse2,
12};
13
14pub fn cgp_auto_dispatch(_attr: TokenStream, mut out: TokenStream) -> syn::Result<TokenStream> {
15    let item_trait: ItemTrait = parse2(out.clone())?;
16
17    let blanket_impl = derive_blanket_impl(&item_trait)?;
18    out.extend(blanket_impl);
19
20    for item in item_trait.items.iter() {
21        match item {
22            syn::TraitItem::Fn(fn_item) => {
23                let method_computer = derive_method_computer(&item_trait, fn_item)?;
24                out.extend(method_computer);
25            }
26            _ => {
27                return Err(syn::Error::new(
28                    item.span(),
29                    "Only function items are allowed in a dispatch trait",
30                ));
31            }
32        }
33    }
34
35    Ok(out)
36}
37
38fn derive_blanket_impl(item_trait: &ItemTrait) -> syn::Result<TokenStream> {
39    let trait_ident = &item_trait.ident;
40    let context_ident = quote! { __Variants__ };
41
42    let mut generics = item_trait.generics.clone();
43    generics
44        .params
45        .insert(0, parse2(quote! { #context_ident })?);
46
47    let where_clause = generics.make_where_clause();
48
49    let extra_life: Lifetime = parse2(quote! { '__a__ })?;
50
51    let mut impl_items: Vec<ImplItem> = Vec::new();
52
53    for trait_item in item_trait.items.iter() {
54        let method = if let syn::TraitItem::Fn(method) = trait_item {
55            method
56        } else {
57            return Err(syn::Error::new(
58                trait_item.span(),
59                "Only function items are allowed in a dispatch trait",
60            ));
61        };
62
63        let mut signature = method.sig.clone();
64        let method_ident = &signature.ident;
65        let mut hrtbs: BTreeSet<Ident> = BTreeSet::new();
66
67        let computer_ident = Ident::new(
68            &format!("Compute{}", to_camel_case_str(&method_ident.to_string())),
69            method_ident.span(),
70        );
71
72        for generic_param in signature.generics.params.iter() {
73            match generic_param {
74                GenericParam::Lifetime(_) => {}
75                _ => {
76                    return Err(syn::Error::new(
77                        generic_param.span(),
78                        "Dispatch trait methods cannot contain non-lifetime generic parameters due to the lack of quantified constraints in Rust",
79                    ));
80                }
81            }
82        }
83
84        let mut args = signature.inputs.iter_mut();
85
86        let receiver = if let Some(FnArg::Receiver(receiver)) = args.next() {
87            receiver
88        } else {
89            return Err(syn::Error::new(
90                signature.span(),
91                "Dispatcher method must have a self argument",
92            ));
93        };
94
95        let mut arg_idents = Punctuated::<_, Comma>::new();
96        let mut arg_types = Punctuated::<_, Comma>::new();
97
98        for (i, arg) in args.enumerate() {
99            if let FnArg::Typed(pat_type) = arg {
100                let arg_ident = Ident::new(&format!("arg_{}", i), pat_type.span());
101                arg_idents.push(arg_ident.clone());
102                *pat_type.pat = Pat::Ident(PatIdent {
103                    ident: arg_ident,
104                    attrs: Default::default(),
105                    by_ref: Default::default(),
106                    mutability: Default::default(),
107                    subpat: Default::default(),
108                });
109
110                let mut arg_type = pat_type.ty.as_ref().clone();
111                if let Type::Reference(arg_type) = &mut arg_type {
112                    match &arg_type.lifetime {
113                        Some(lifetime) => {
114                            hrtbs.insert(lifetime.ident.clone());
115                        }
116                        None => {
117                            hrtbs.insert(extra_life.ident.clone());
118                            arg_type.lifetime = Some(extra_life.clone());
119                        }
120                    }
121                }
122
123                arg_types.push(arg_type);
124            } else {
125                return Err(syn::Error::new(
126                    arg.span(),
127                    "Dispatcher method arguments must be typed",
128                ));
129            }
130        }
131
132        let output_type = match &signature.output {
133            ReturnType::Default => {
134                quote! { () }
135            }
136            ReturnType::Type(_, output) => {
137                let mut output = output.as_ref().clone();
138                if let Type::Reference(output_type) = &mut output {
139                    match &output_type.lifetime {
140                        Some(lifetime) => {
141                            hrtbs.insert(lifetime.ident.clone());
142                        }
143                        None => {
144                            hrtbs.insert(extra_life.ident.clone());
145                            output_type.lifetime = Some(extra_life.clone());
146                        }
147                    }
148                }
149                quote! { #output }
150            }
151        };
152
153        let (context_type, matcher) = if let Some((_, life)) = &receiver.reference {
154            let life = life.as_ref().unwrap_or_else(|| {
155                hrtbs.insert(extra_life.ident.clone());
156                &extra_life
157            });
158
159            let mutability = &receiver.mutability;
160            let context_type = quote! { & #life #mutability #context_ident };
161            let matcher = if mutability.is_some() {
162                if arg_types.is_empty() {
163                    quote! { MatchWithValueHandlersMut }
164                } else {
165                    quote! { MatchFirstWithValueHandlersMut }
166                }
167            } else if arg_types.is_empty() {
168                quote! { MatchWithValueHandlersRef }
169            } else {
170                quote! { MatchFirstWithValueHandlersRef }
171            };
172
173            (context_type, matcher)
174        } else {
175            let context_type = quote! { #context_ident  };
176            let matcher = if arg_types.is_empty() {
177                quote! { MatchWithValueHandlers }
178            } else {
179                quote! { MatchFirstWithValueHandlers }
180            };
181
182            (context_type, matcher)
183        };
184
185        let mut hrtb = TokenStream::new();
186
187        for ident in hrtbs {
188            if ident != "static" {
189                let lifetime = Lifetime {
190                    apostrophe: Span::call_site(),
191                    ident,
192                };
193                hrtb = quote! { for<#lifetime> }
194            }
195        }
196
197        let input_type = if arg_types.is_empty() {
198            quote! { #context_type }
199        } else {
200            quote! { (#context_type, (#arg_types)) }
201        };
202
203        if signature.asyncness.is_some() {
204            where_clause.predicates.push(parse2(quote! {
205                #matcher<#computer_ident>: #hrtb
206                    AsyncComputer<(), (), #input_type, Output = #output_type>
207            })?);
208        } else {
209            where_clause.predicates.push(parse2(quote! {
210                #matcher<#computer_ident>: #hrtb
211                    Computer<(), (), #input_type, Output = #output_type>
212            })?);
213        }
214
215        let args = if arg_idents.is_empty() {
216            quote! { self }
217        } else {
218            quote! { (self, (#arg_idents)) }
219        };
220
221        let method_body = if signature.asyncness.is_some() {
222            quote! {
223                #matcher::<#computer_ident>::compute_async(
224                    &(),
225                    ::core::marker::PhantomData::<()>,
226                    #args,
227                ).await
228            }
229        } else {
230            quote! {
231                #matcher::<#computer_ident>::compute(
232                    &(),
233                    ::core::marker::PhantomData::<()>,
234                    #args,
235                )
236            }
237        };
238
239        let impl_item = ImplItem::Fn(ImplItemFn {
240            attrs: Default::default(),
241            vis: Visibility::Inherited,
242            defaultness: None,
243            sig: signature,
244            block: parse2(quote! {
245                { #method_body }
246            })?,
247        });
248
249        impl_items.push(impl_item);
250    }
251
252    where_clause.predicates.push(parse2(quote! {
253        #context_ident: HasExtractor
254    })?);
255
256    let ty_generics = item_trait.generics.split_for_impl().1;
257    let (impl_generics, _, where_clause) = generics.split_for_impl();
258
259    let item_impl = quote! {
260        impl #impl_generics #trait_ident #ty_generics for #context_ident
261            #where_clause
262        {
263            #(#impl_items)*
264        }
265    };
266
267    Ok(item_impl)
268}
269
270fn derive_method_computer(
271    item_trait: &ItemTrait,
272    method: &TraitItemFn,
273) -> syn::Result<TokenStream> {
274    let mut signature = method.sig.clone();
275    let method_ident = &signature.ident;
276    let async_token = signature.asyncness;
277
278    let context_ident = quote! { __Variants__ };
279
280    let mut generics = {
281        let mut generics = item_trait.generics.clone();
282
283        generics
284            .params
285            .extend(signature.generics.params.iter().cloned());
286
287        if let Some(method_where_clause) = &signature.generics.where_clause {
288            generics
289                .make_where_clause()
290                .predicates
291                .extend(method_where_clause.predicates.iter().cloned());
292        }
293
294        let trait_ident = &item_trait.ident;
295
296        let type_generics = item_trait.generics.split_for_impl().1;
297
298        generics.params.insert(
299            0,
300            parse2(quote! {
301                #context_ident: #trait_ident #type_generics
302            })?,
303        );
304
305        generics
306    };
307
308    let mut args = signature.inputs.iter_mut();
309
310    let receiver = if let Some(FnArg::Receiver(receiver)) = args.next() {
311        receiver
312    } else {
313        return Err(syn::Error::new(
314            signature.span(),
315            "Dispatcher method must have a self argument",
316        ));
317    };
318
319    let extra_life: Lifetime = parse2(quote! { '__a__ })?;
320    let mut use_extra_life = false;
321
322    let context_type = match (&receiver.reference, &receiver.mutability) {
323        (Some((_, life)), Some(_)) => {
324            let life = life.as_ref().unwrap_or_else(|| {
325                use_extra_life = true;
326                &extra_life
327            });
328
329            quote! { &#life mut #context_ident }
330        }
331        (Some((_, life)), None) => {
332            let life = life.as_ref().unwrap_or_else(|| {
333                use_extra_life = true;
334                &extra_life
335            });
336
337            quote! { & #life #context_ident }
338        }
339        _ => quote! { #context_ident },
340    };
341
342    let mut arg_idents = Punctuated::<_, Comma>::new();
343    let mut arg_types = Punctuated::<_, Comma>::new();
344
345    for (i, arg) in args.enumerate() {
346        if let FnArg::Typed(pat_type) = arg {
347            arg_idents.push(Ident::new(&format!("arg_{}", i), pat_type.span()));
348
349            let arg_type = pat_type.ty.as_mut();
350            if let Type::Reference(arg_type) = arg_type
351                && arg_type.lifetime.is_none()
352            {
353                use_extra_life = true;
354                arg_type.lifetime = Some(extra_life.clone());
355            }
356
357            arg_types.push(arg_type);
358        } else {
359            return Err(syn::Error::new(
360                arg.span(),
361                "Dispatcher method arguments must be typed",
362            ));
363        }
364    }
365
366    let return_type = &mut signature.output;
367
368    if let ReturnType::Type(_, return_type) = return_type
369        && let Type::Reference(return_type) = return_type.as_mut()
370        && return_type.lifetime.is_none()
371    {
372        use_extra_life = true;
373        return_type.lifetime = Some(extra_life.clone());
374    }
375
376    if use_extra_life {
377        generics.params.insert(0, parse2(quote! { #extra_life })?);
378    }
379
380    let arg_params = if arg_idents.is_empty() {
381        TokenStream::new()
382    } else {
383        quote! {
384            (#arg_idents): (#arg_types)
385        }
386    };
387
388    let dot_await = if async_token.is_some() {
389        quote! { .await }
390    } else {
391        TokenStream::new()
392    };
393
394    let computer_ident = Ident::new(
395        &format!("Compute{}", to_camel_case_str(&method_ident.to_string())),
396        method_ident.span(),
397    );
398
399    let method_generics = {
400        let method_generics = method
401            .sig
402            .generics
403            .params
404            .iter()
405            .filter(|param| !matches!(param, syn::GenericParam::Lifetime(_)))
406            .collect::<Punctuated<_, Comma>>();
407
408        if method_generics.is_empty() {
409            TokenStream::new()
410        } else {
411            quote! { ::< #method_generics > }
412        }
413    };
414
415    let (impl_generics, _, where_clause) = generics.split_for_impl();
416
417    Ok(quote! {
418        #[cgp_computer( #computer_ident )]
419        #async_token fn #method_ident #impl_generics (
420            #context_ident: #context_type,
421            #arg_params
422        ) #return_type
423        #where_clause
424        {
425            #context_ident. #method_ident #method_generics ( #arg_idents ) #dot_await
426        }
427    })
428}