cgp_extra_macro_lib/entrypoints/
cgp_auto_dispatch.rs1use std::collections::BTreeSet;
2
3use cgp_macro_lib::utils::to_camel_case_str;
4use proc_macro2::{Span, TokenStream};
5use quote::quote;
6use syn::punctuated::Punctuated;
7use syn::spanned::Spanned;
8use syn::token::Comma;
9use syn::{
10 FnArg, GenericParam, Ident, ImplItem, ImplItemFn, ItemTrait, Lifetime, Pat, PatIdent,
11 ReturnType, TraitItemFn, Type, Visibility, parse2,
12};
13
14pub fn cgp_auto_dispatch(_attr: TokenStream, mut out: TokenStream) -> syn::Result<TokenStream> {
15 let item_trait: ItemTrait = parse2(out.clone())?;
16
17 let blanket_impl = derive_blanket_impl(&item_trait)?;
18 out.extend(blanket_impl);
19
20 for item in item_trait.items.iter() {
21 match item {
22 syn::TraitItem::Fn(fn_item) => {
23 let method_computer = derive_method_computer(&item_trait, fn_item)?;
24 out.extend(method_computer);
25 }
26 _ => {
27 return Err(syn::Error::new(
28 item.span(),
29 "Only function items are allowed in a dispatch trait",
30 ));
31 }
32 }
33 }
34
35 Ok(out)
36}
37
38fn derive_blanket_impl(item_trait: &ItemTrait) -> syn::Result<TokenStream> {
39 let trait_ident = &item_trait.ident;
40 let context_ident = quote! { __Variants__ };
41
42 let mut generics = item_trait.generics.clone();
43 generics
44 .params
45 .insert(0, parse2(quote! { #context_ident })?);
46
47 let where_clause = generics.make_where_clause();
48
49 let extra_life: Lifetime = parse2(quote! { '__a__ })?;
50
51 let mut impl_items: Vec<ImplItem> = Vec::new();
52
53 for trait_item in item_trait.items.iter() {
54 let method = if let syn::TraitItem::Fn(method) = trait_item {
55 method
56 } else {
57 return Err(syn::Error::new(
58 trait_item.span(),
59 "Only function items are allowed in a dispatch trait",
60 ));
61 };
62
63 let mut signature = method.sig.clone();
64 let method_ident = &signature.ident;
65 let mut hrtbs: BTreeSet<Ident> = BTreeSet::new();
66
67 let computer_ident = Ident::new(
68 &format!("Compute{}", to_camel_case_str(&method_ident.to_string())),
69 method_ident.span(),
70 );
71
72 for generic_param in signature.generics.params.iter() {
73 match generic_param {
74 GenericParam::Lifetime(_) => {}
75 _ => {
76 return Err(syn::Error::new(
77 generic_param.span(),
78 "Dispatch trait methods cannot contain non-lifetime generic parameters due to the lack of quantified constraints in Rust",
79 ));
80 }
81 }
82 }
83
84 let mut args = signature.inputs.iter_mut();
85
86 let receiver = if let Some(FnArg::Receiver(receiver)) = args.next() {
87 receiver
88 } else {
89 return Err(syn::Error::new(
90 signature.span(),
91 "Dispatcher method must have a self argument",
92 ));
93 };
94
95 let mut arg_idents = Punctuated::<_, Comma>::new();
96 let mut arg_types = Punctuated::<_, Comma>::new();
97
98 for (i, arg) in args.enumerate() {
99 if let FnArg::Typed(pat_type) = arg {
100 let arg_ident = Ident::new(&format!("arg_{}", i), pat_type.span());
101 arg_idents.push(arg_ident.clone());
102 *pat_type.pat = Pat::Ident(PatIdent {
103 ident: arg_ident,
104 attrs: Default::default(),
105 by_ref: Default::default(),
106 mutability: Default::default(),
107 subpat: Default::default(),
108 });
109
110 let mut arg_type = pat_type.ty.as_ref().clone();
111 if let Type::Reference(arg_type) = &mut arg_type {
112 match &arg_type.lifetime {
113 Some(lifetime) => {
114 hrtbs.insert(lifetime.ident.clone());
115 }
116 None => {
117 hrtbs.insert(extra_life.ident.clone());
118 arg_type.lifetime = Some(extra_life.clone());
119 }
120 }
121 }
122
123 arg_types.push(arg_type);
124 } else {
125 return Err(syn::Error::new(
126 arg.span(),
127 "Dispatcher method arguments must be typed",
128 ));
129 }
130 }
131
132 let output_type = match &signature.output {
133 ReturnType::Default => {
134 quote! { () }
135 }
136 ReturnType::Type(_, output) => {
137 let mut output = output.as_ref().clone();
138 if let Type::Reference(output_type) = &mut output {
139 match &output_type.lifetime {
140 Some(lifetime) => {
141 hrtbs.insert(lifetime.ident.clone());
142 }
143 None => {
144 hrtbs.insert(extra_life.ident.clone());
145 output_type.lifetime = Some(extra_life.clone());
146 }
147 }
148 }
149 quote! { #output }
150 }
151 };
152
153 let (context_type, matcher) = if let Some((_, life)) = &receiver.reference {
154 let life = life.as_ref().unwrap_or_else(|| {
155 hrtbs.insert(extra_life.ident.clone());
156 &extra_life
157 });
158
159 let mutability = &receiver.mutability;
160 let context_type = quote! { & #life #mutability #context_ident };
161 let matcher = if mutability.is_some() {
162 if arg_types.is_empty() {
163 quote! { MatchWithValueHandlersMut }
164 } else {
165 quote! { MatchFirstWithValueHandlersMut }
166 }
167 } else if arg_types.is_empty() {
168 quote! { MatchWithValueHandlersRef }
169 } else {
170 quote! { MatchFirstWithValueHandlersRef }
171 };
172
173 (context_type, matcher)
174 } else {
175 let context_type = quote! { #context_ident };
176 let matcher = if arg_types.is_empty() {
177 quote! { MatchWithValueHandlers }
178 } else {
179 quote! { MatchFirstWithValueHandlers }
180 };
181
182 (context_type, matcher)
183 };
184
185 let mut hrtb = TokenStream::new();
186
187 for ident in hrtbs {
188 if ident != "static" {
189 let lifetime = Lifetime {
190 apostrophe: Span::call_site(),
191 ident,
192 };
193 hrtb = quote! { for<#lifetime> }
194 }
195 }
196
197 let input_type = if arg_types.is_empty() {
198 quote! { #context_type }
199 } else {
200 quote! { (#context_type, (#arg_types)) }
201 };
202
203 if signature.asyncness.is_some() {
204 where_clause.predicates.push(parse2(quote! {
205 #matcher<#computer_ident>: #hrtb
206 AsyncComputer<(), (), #input_type, Output = #output_type>
207 })?);
208 } else {
209 where_clause.predicates.push(parse2(quote! {
210 #matcher<#computer_ident>: #hrtb
211 Computer<(), (), #input_type, Output = #output_type>
212 })?);
213 }
214
215 let args = if arg_idents.is_empty() {
216 quote! { self }
217 } else {
218 quote! { (self, (#arg_idents)) }
219 };
220
221 let method_body = if signature.asyncness.is_some() {
222 quote! {
223 #matcher::<#computer_ident>::compute_async(
224 &(),
225 ::core::marker::PhantomData::<()>,
226 #args,
227 ).await
228 }
229 } else {
230 quote! {
231 #matcher::<#computer_ident>::compute(
232 &(),
233 ::core::marker::PhantomData::<()>,
234 #args,
235 )
236 }
237 };
238
239 let impl_item = ImplItem::Fn(ImplItemFn {
240 attrs: Default::default(),
241 vis: Visibility::Inherited,
242 defaultness: None,
243 sig: signature,
244 block: parse2(quote! {
245 { #method_body }
246 })?,
247 });
248
249 impl_items.push(impl_item);
250 }
251
252 where_clause.predicates.push(parse2(quote! {
253 #context_ident: HasExtractor
254 })?);
255
256 let ty_generics = item_trait.generics.split_for_impl().1;
257 let (impl_generics, _, where_clause) = generics.split_for_impl();
258
259 let item_impl = quote! {
260 impl #impl_generics #trait_ident #ty_generics for #context_ident
261 #where_clause
262 {
263 #(#impl_items)*
264 }
265 };
266
267 Ok(item_impl)
268}
269
270fn derive_method_computer(
271 item_trait: &ItemTrait,
272 method: &TraitItemFn,
273) -> syn::Result<TokenStream> {
274 let mut signature = method.sig.clone();
275 let method_ident = &signature.ident;
276 let async_token = signature.asyncness;
277
278 let context_ident = quote! { __Variants__ };
279
280 let mut generics = {
281 let mut generics = item_trait.generics.clone();
282
283 generics
284 .params
285 .extend(signature.generics.params.iter().cloned());
286
287 if let Some(method_where_clause) = &signature.generics.where_clause {
288 generics
289 .make_where_clause()
290 .predicates
291 .extend(method_where_clause.predicates.iter().cloned());
292 }
293
294 let trait_ident = &item_trait.ident;
295
296 let type_generics = item_trait.generics.split_for_impl().1;
297
298 generics.params.insert(
299 0,
300 parse2(quote! {
301 #context_ident: #trait_ident #type_generics
302 })?,
303 );
304
305 generics
306 };
307
308 let mut args = signature.inputs.iter_mut();
309
310 let receiver = if let Some(FnArg::Receiver(receiver)) = args.next() {
311 receiver
312 } else {
313 return Err(syn::Error::new(
314 signature.span(),
315 "Dispatcher method must have a self argument",
316 ));
317 };
318
319 let extra_life: Lifetime = parse2(quote! { '__a__ })?;
320 let mut use_extra_life = false;
321
322 let context_type = match (&receiver.reference, &receiver.mutability) {
323 (Some((_, life)), Some(_)) => {
324 let life = life.as_ref().unwrap_or_else(|| {
325 use_extra_life = true;
326 &extra_life
327 });
328
329 quote! { &#life mut #context_ident }
330 }
331 (Some((_, life)), None) => {
332 let life = life.as_ref().unwrap_or_else(|| {
333 use_extra_life = true;
334 &extra_life
335 });
336
337 quote! { & #life #context_ident }
338 }
339 _ => quote! { #context_ident },
340 };
341
342 let mut arg_idents = Punctuated::<_, Comma>::new();
343 let mut arg_types = Punctuated::<_, Comma>::new();
344
345 for (i, arg) in args.enumerate() {
346 if let FnArg::Typed(pat_type) = arg {
347 arg_idents.push(Ident::new(&format!("arg_{}", i), pat_type.span()));
348
349 let arg_type = pat_type.ty.as_mut();
350 if let Type::Reference(arg_type) = arg_type
351 && arg_type.lifetime.is_none()
352 {
353 use_extra_life = true;
354 arg_type.lifetime = Some(extra_life.clone());
355 }
356
357 arg_types.push(arg_type);
358 } else {
359 return Err(syn::Error::new(
360 arg.span(),
361 "Dispatcher method arguments must be typed",
362 ));
363 }
364 }
365
366 let return_type = &mut signature.output;
367
368 if let ReturnType::Type(_, return_type) = return_type
369 && let Type::Reference(return_type) = return_type.as_mut()
370 && return_type.lifetime.is_none()
371 {
372 use_extra_life = true;
373 return_type.lifetime = Some(extra_life.clone());
374 }
375
376 if use_extra_life {
377 generics.params.insert(0, parse2(quote! { #extra_life })?);
378 }
379
380 let arg_params = if arg_idents.is_empty() {
381 TokenStream::new()
382 } else {
383 quote! {
384 (#arg_idents): (#arg_types)
385 }
386 };
387
388 let dot_await = if async_token.is_some() {
389 quote! { .await }
390 } else {
391 TokenStream::new()
392 };
393
394 let computer_ident = Ident::new(
395 &format!("Compute{}", to_camel_case_str(&method_ident.to_string())),
396 method_ident.span(),
397 );
398
399 let method_generics = {
400 let method_generics = method
401 .sig
402 .generics
403 .params
404 .iter()
405 .filter(|param| !matches!(param, syn::GenericParam::Lifetime(_)))
406 .collect::<Punctuated<_, Comma>>();
407
408 if method_generics.is_empty() {
409 TokenStream::new()
410 } else {
411 quote! { ::< #method_generics > }
412 }
413 };
414
415 let (impl_generics, _, where_clause) = generics.split_for_impl();
416
417 Ok(quote! {
418 #[cgp_computer( #computer_ident )]
419 #async_token fn #method_ident #impl_generics (
420 #context_ident: #context_type,
421 #arg_params
422 ) #return_type
423 #where_clause
424 {
425 #context_ident. #method_ident #method_generics ( #arg_idents ) #dot_await
426 }
427 })
428}