Skip to main content

cgp_extra_macro_lib/entrypoints/
cgp_computer.rs

1use cgp_macro_lib::utils::to_camel_case_str;
2use proc_macro2::TokenStream;
3use quote::{ToTokens, quote};
4use syn::punctuated::Punctuated;
5use syn::spanned::Spanned;
6use syn::token::Comma;
7use syn::{FnArg, Ident, ItemFn, ItemImpl, ReturnType, Type, parse2};
8
9use crate::parse::MaybeResultType;
10
11pub fn cgp_computer(attr: TokenStream, body: TokenStream) -> syn::Result<TokenStream> {
12    let item_fn: ItemFn = parse2(body)?;
13
14    let fn_sig = &item_fn.sig;
15    let fn_ident = &fn_sig.ident;
16    let fn_inputs = &fn_sig.inputs;
17
18    let computer_ident = if attr.is_empty() {
19        Ident::new(&to_camel_case_str(&fn_ident.to_string()), fn_ident.span())
20    } else {
21        parse2(attr)?
22    };
23
24    let mut input_types = Punctuated::<Type, Comma>::new();
25    let mut input_idents = Punctuated::<Ident, Comma>::new();
26
27    for (i, input) in fn_inputs.iter().enumerate() {
28        match input {
29            FnArg::Receiver(_) => {
30                return Err(syn::Error::new(
31                    input.span(),
32                    "Computer functions cannot have a receiver",
33                ));
34            }
35            FnArg::Typed(pat) => {
36                input_types.push(pat.ty.as_ref().clone());
37                input_idents.push(Ident::new(&format!("arg_{i}"), pat.span()));
38            }
39        }
40    }
41
42    let fn_output = match &fn_sig.output {
43        ReturnType::Type(_, ty) => ty.as_ref().clone(),
44        ReturnType::Default => syn::parse_quote!(()),
45    };
46
47    let maybe_result_type = parse2::<MaybeResultType>(fn_output.to_token_stream())?;
48
49    if fn_sig.asyncness.is_none() {
50        let mut generics = fn_sig.generics.clone();
51        generics.params.push(parse2(quote! { __Context__ })?);
52        generics.params.push(parse2(quote! { __Code__ })?);
53
54        let (impl_generics, _, where_clause) = generics.split_for_impl();
55
56        let computer: ItemImpl = parse2(quote! {
57            #[cgp_new_provider]
58            impl #impl_generics
59                Computer<__Context__, __Code__, ( #input_types )>
60                for #computer_ident
61            #where_clause
62            {
63                type Output = #fn_output;
64
65                fn compute(_context: &__Context__, _code: PhantomData<__Code__>, ( #input_idents ): ( #input_types )) -> Self::Output {
66                    #fn_ident( #input_idents )
67                }
68            }
69        })?;
70
71        let delegate = if maybe_result_type.error_type.is_some() {
72            quote! {
73                delegate_components! {
74                    #computer_ident {
75                        [
76                            ComputerRefComponent,
77                            TryComputerComponent,
78                            TryComputerRefComponent,
79                            AsyncComputerComponent,
80                            AsyncComputerRefComponent,
81                            HandlerComponent,
82                            HandlerRefComponent,
83                        ] ->
84                            PromoteTryComputer<Self>,
85                    }
86                }
87            }
88        } else {
89            quote! {
90                delegate_components! {
91                    #computer_ident {
92                        [
93                            ComputerRefComponent,
94                            TryComputerComponent,
95                            TryComputerRefComponent,
96                            AsyncComputerComponent,
97                            AsyncComputerRefComponent,
98                            HandlerComponent,
99                            HandlerRefComponent,
100                        ] ->
101                            PromoteComputer<Self>,
102                    }
103                }
104            }
105        };
106
107        Ok(quote! {
108            #item_fn
109
110            #computer
111
112            #delegate
113        })
114    } else {
115        let mut generics = fn_sig.generics.clone();
116        generics.params.push(parse2(quote! { __Context__ })?);
117        generics.params.push(parse2(quote! { __Code__ })?);
118
119        let (impl_generics, _, where_clause) = generics.split_for_impl();
120
121        let computer: ItemImpl = parse2(quote! {
122            #[cgp_new_provider]
123            impl #impl_generics
124                AsyncComputer<__Context__, __Code__, ( #input_types )>
125                for #computer_ident
126            #where_clause
127            {
128                type Output = #fn_output;
129
130                async fn compute_async(
131                    _context: &__Context__,
132                    _code: PhantomData<__Code__>,
133                    ( #input_idents ): ( #input_types )
134                ) -> Self::Output {
135                    #fn_ident( #input_idents ).await
136                }
137            }
138        })?;
139
140        let delegate_ref = if maybe_result_type.error_type.is_some() {
141            quote! {
142                delegate_components! {
143                    #computer_ident {
144                        [
145                            AsyncComputerRefComponent,
146                            HandlerComponent,
147                            HandlerRefComponent,
148                        ] ->
149                            PromoteHandler<Self>,
150                    }
151                }
152            }
153        } else {
154            quote! {
155                delegate_components! {
156                    #computer_ident {
157                        [
158                            AsyncComputerRefComponent,
159                            HandlerComponent,
160                            HandlerRefComponent,
161                        ] ->
162                            PromoteAsyncComputer<Self>,
163                    }
164                }
165            }
166        };
167
168        Ok(quote! {
169            #item_fn
170
171            #computer
172
173            #delegate_ref
174        })
175    }
176}