1use proc_macro::TokenStream;
2
3use heck::{ToPascalCase, ToSnakeCase};
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use quote::{ToTokens, quote};
6use syn::{
7    AngleBracketedGenericArguments, AssocType, ExprAssign, FnArg, GenericArgument, ImplItem,
8    ItemImpl, Path, PathArguments, PathSegment, ReturnType, TraitBound, Type, TypeImplTrait,
9    TypeParamBound, TypePath,
10    parse::{Parse, ParseStream},
11    parse_macro_input, parse2,
12    punctuated::Punctuated,
13    token::Comma,
14};
15
16struct Meta {
17    server: bool,
18    client: bool,
19    public: TokenStream2,
20    services: Vec<(TokenStream2, TokenStream2)>,
21}
22
23impl Parse for Meta {
24    fn parse(input: ParseStream) -> syn::Result<Self> {
25        let items = Punctuated::<ExprAssign, Comma>::parse_terminated(input).unwrap();
26        let mut server = false;
27        let mut client = false;
28        let mut public = quote!();
29        let services = items
30            .iter()
31            .filter_map(|i| {
32                let j = i.left.to_token_stream();
33                let k = i.right.to_token_stream();
34                if j.to_string() == "server" {
35                    server = k.to_string() == "true";
36                    None
37                } else if j.to_string() == "client" {
38                    client = k.to_string() == "true";
39                    None
40                } else if j.to_string() == "public" {
41                    public = if k.to_string() == "true" {
42                        quote! {pub}
43                    } else {
44                        quote! {pub(#k)}
45                    };
46                    None
47                } else {
48                    Some((j, k))
49                }
50            })
51            .collect();
52        Ok(Meta {
53            server,
54            client,
55            public,
56            services,
57        })
58    }
59}
60
61fn unwrap_stream_item_type(ty: &Type) -> Option<Type> {
62    match ty {
63        Type::ImplTrait(TypeImplTrait { bounds, .. }) => match bounds.first() {
64            Some(TypeParamBound::Trait(TraitBound { path, .. })) => match path.segments.last() {
65                Some(PathSegment {
66                    arguments: PathArguments::AngleBracketed(path),
67                    ..
68                }) => match path.args.first() {
69                    Some(GenericArgument::AssocType(AssocType { ty, .. })) => Some(ty.clone()),
70                    _ => None,
71                },
72                _ => None,
73            },
74            _ => panic!("Only support impl Stream."),
75        },
76        Type::Path(TypePath {
77            path: Path { segments, .. },
78            ..
79        }) => match segments.last() {
80            Some(PathSegment {
81                ident,
82                arguments:
83                    PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }),
84                ..
85            }) if ident == "Result" => match args.first() {
86                Some(GenericArgument::Type(ty)) => Some(ty.clone()),
87                _ => None,
88            },
89            _ => None,
90        },
91        _ => None,
92    }
93}
94
95#[proc_macro_attribute]
123pub fn service(attrs: TokenStream, input: TokenStream) -> TokenStream {
124    let meta: Meta = parse2(Into::<TokenStream2>::into(attrs)).unwrap();
125    let item = parse_macro_input!(input as ItemImpl);
126    let service_name = item.self_ty.as_ref().clone();
127    let service_name_str = service_name.to_token_stream().to_string();
128    let public = meta.public;
129    let (request_name, response_name) = {
130        let name = service_name.to_token_stream().to_string();
131        (
132            Ident::new(&(name.clone() + "Request"), Span::call_site()),
133            Ident::new(&(name + "Response"), Span::call_site()),
134        )
135    };
136    let items = item
137        .items
138        .iter()
139        .filter_map(|i| match i {
140            ImplItem::Fn(f) => Some((f.sig.clone(), f.attrs.clone())),
141            _ => None,
142        })
143        .collect::<Vec<_>>();
144    let func_items = items
145        .iter()
146        .map(|(func, attrs)| {
147            if func.asyncness.is_none() {
148                panic!("Function `{}` must be asyncable.", func.ident);
149            }
150            let self_ = func.inputs.iter().find(|i| match i {
151                FnArg::Receiver(_) => true,
152                FnArg::Typed(_) => false,
153            });
154            if self_.is_none() {
155                panic!("Function `{}` must contain `self` argument.", func.ident);
156            }
157
158            let mut client_stream_item: Option<Type> = None;
159            let args = func
160                .inputs
161                .iter()
162                .filter_map(|i| match i {
163                    FnArg::Receiver(..) => None,
164                    FnArg::Typed(ty) => match unwrap_stream_item_type(ty.ty.as_ref()) {
165                        None => Some((ty.pat.as_ref().clone(), ty.ty.as_ref().clone())),
166                        Some(ty) => {
167                            client_stream_item.replace(ty);
168                            None
169                        }
170                    },
171                })
172                .collect::<Vec<_>>();
173            let arg_names = args.iter().map(|i| i.0.clone()).collect::<Vec<_>>();
174            let arg_types = args.iter().map(|i| i.1.clone()).collect::<Vec<_>>();
175
176            let (server_stream_item, ret) = match func.output {
177                ReturnType::Default => (None, None),
178                ReturnType::Type(_, ref ty) => unwrap_stream_item_type(ty.as_ref())
179                    .map_or((None, Some(ty.as_ref().clone())), |t| {
180                        (Some(t), Some(ty.as_ref().clone()))
181                    }),
182            };
183
184            (
185                attrs,
186                func.ident.clone(),
187                Ident::new(&func.ident.to_string().to_pascal_case(), Span::call_site()),
188                arg_names,
189                arg_types,
190                ret,
191                client_stream_item,
192                server_stream_item,
193            )
194        })
195        .collect::<Vec<_>>();
196
197    let mut request_enum_variants = func_items
198        .iter()
199        .map(|(_, _, name, _, _, _, _, _)| {
200            let name2 = Ident::new(&(name.to_string() + "Request"), Span::call_site());
201            quote! {#name(#name2)}
202        })
203        .collect::<Vec<_>>();
204    request_enum_variants.extend(
205        func_items
206            .iter()
207            .filter_map(|(_, _, name, _, _, _, client_stream_item, _)| {
208                if client_stream_item.is_some() {
209                    let name2 = Ident::new(&(name.to_string() + "Put"), Span::call_site());
210                    return Some(quote! {#name2(#client_stream_item)});
211                }
212                None
213            })
214            .collect::<Vec<_>>(),
215    );
216    request_enum_variants.extend(
217        meta.services
218            .iter()
219            .map(|(subname, _)| {
220                let name = Ident::new(&(subname.to_string() + "Request"), Span::call_site());
221                quote! {#subname(#name)}
222            })
223            .collect::<Vec<_>>(),
224    );
225
226    let mut response_enum_variants = func_items
227        .iter()
228        .map(|(_, _, name, _, _, _, _, _)| {
229            let name2 = Ident::new(&(name.to_string() + "Response"), Span::call_site());
230            quote! {#name(#name2)}
231        })
232        .collect::<Vec<_>>();
233    response_enum_variants.extend(
234        meta.services
235            .iter()
236            .map(|(subname, _)| {
237                let name = Ident::new(&(subname.to_string() + "Response"), Span::call_site());
238                quote! {#name(#name)}
239            })
240            .collect::<Vec<_>>(),
241    );
242
243    let server = if meta.server {
244        let child_request_patterns = meta
245            .services
246            .iter()
247            .map(|(subname, field)| {
248                let handler = if field.to_string() == "None" {
249                    quote!{quic_rpc_utils::GetServiceHandler::<#subname>::get_handler(self)}
250                } else {
251                    quote!{self.#field.clone()}
252                };
253
254                quote! {
255                    #request_name::#subname(req) => #handler.handle_rpc_request(req, chan.map().boxed(), rt).await?
256                }
257            })
258            .collect::<Vec<_>>();
259
260        let request_match_patterns = func_items
261            .iter()
262            .map(|(_, origin_name, name, arg_names, _, ret, client_stream_item, server_stream_item)| {
263                let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
264                let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
265
266                let args = if arg_names.is_empty() {
267                    quote!()
268                } else {
269                    quote!{#(#arg_names),*}
270                };
271                let parse_args = if arg_names.is_empty() {
272                    quote!{
273                        let #req_name = req;
274                    }
275                } else {
276                    quote!{
277                        let #req_name (#(ref #arg_names),*) = req;
278                        let (#args) = (#(#arg_names.to_owned()),*);
279                    }
280                };
281
282                if client_stream_item.is_some() && server_stream_item.is_some() {
283                    let call_stream = quote!{
284                        let stream = self_.#origin_name(#args, rx2.into_stream()).await;
285                        quic_rpc_utils::pin!(stream);
286                        while let Some(i) = stream.next().await {
287                            let _ = tx.send_async(#res_name(i)).await;
288                        }
289                    };
290
291                    quote! {
292                        #request_name::#name(req) => {
293                            #parse_args
294
295                            let (tx, rx) = quic_rpc_utils::flume_bounded(2);
296                            let (tx2, rx2) = quic_rpc_utils::flume_bounded(2);
297                            let self_ = self.clone();
298                            let task = rt.spawn(async move {
299                                #call_stream
300                            });
301                            let (tx3, rx3) = quic_rpc_utils::oneshot_channel();
302                            match chan.bidi_streaming(req, self, |self_, req, updates| {
303                                let _ = tx3.send(rt.spawn(async move {
304                                    quic_rpc_utils::pin!(updates);
305                                    while let Some(item) = updates.next().await {
306                                        let _ = tx2.send_async(item).await;
307                                    }
308                                })).map_err(|e| e.abort());
309                                rx.into_stream()
310                            }).await {
311                                Err(e) => {
312                                    rx3.await.map_err(|e2| quic_rpc_utils::QuicRpcWrapError::Recv(format!("{}: {}", e2, e)))?.abort();
313                                    Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
314                                }
315                                Ok(()) => Ok(()),
316                            }?
317                        }
318                    }
319                } else if client_stream_item.is_some() {
320                    let call_stream = if ret.is_some() {
321                        quote!{
322                            #res_name(self_.#origin_name(#args, updates).await)
323                        }
324                    } else {
325                        quote!{
326                            self_.#origin_name(#args, updates).await;
327                            #res_name
328                        }
329                    };
330
331                    quote! {
332                        #request_name::#name(req) => chan.client_streaming(req, self, |self_, req, updates| async move {
333                            #parse_args
334
335                            #call_stream
336                        }).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
337                    }
338                } else if server_stream_item.is_some() {
339                    let call_stream = quote!{
340                        let stream = self_.#origin_name(#args).await;
341                        quic_rpc_utils::pin!(stream);
342                        while let Some(i) = stream.next().await {
343                            let _ = tx.send_async(#res_name(i)).await;
344                        }
345                    };
346
347                    quote! {
348                        #request_name::#name(req) => {
349                            #parse_args
350
351                            let (tx, rx) = quic_rpc_utils::flume_bounded(2);
352                            let self_ = self.clone();
353                            rt.spawn(async move {
354                                #call_stream
355                            });
356                            chan.server_streaming(req, self, move |_, _| rx.into_stream()).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
357                        }
358                    }
359                } else {
360                    let call = if ret.is_some() {
361                        quote! {
362                            #res_name(self_.#origin_name(#args).await)
363                        }
364                    } else {
365                        quote! {
366                            self_.#origin_name(#args).await;
367                            #res_name
368                        }
369                    };
370
371                    quote! {
372                        #request_name::#name(req) => chan.rpc(req, self, |self_, req| async move {
373                            #parse_args
374
375                            #call
376                        }).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
377                    }
378                }
379            })
380            .collect::<Vec<_>>();
381
382        let handler_match =
383            if child_request_patterns.is_empty() && request_match_patterns.is_empty() {
384                quote!()
385            } else {
386                quote! {
387                    match req {
388                        #(#child_request_patterns,)*
389                        #(#request_match_patterns,)*
390                        _ => return Err(quic_rpc_utils::QuicRpcWrapError::Request)
391                    }
392                }
393            };
394
395        quote! {
396            #item
397
398            impl<C: quic_rpc_utils::ChannelTypes<#service_name>> quic_rpc_utils::ServiceHandler<#service_name, C> for #service_name {
399                #[track_caller]
400                async fn handle_rpc_request(
401                    self: std::sync::Arc<Self>,
402                    req: #request_name,
403                    chan: quic_rpc_utils::RpcChannel<#service_name, C>,
404                    rt: &'static quic_rpc_utils::Runtime
405                ) -> quic_rpc_utils::Result<()> {
406                    #handler_match
407                    Ok(())
408                }
409            }
410        }
411    } else {
412        quote!()
413    };
414
415    let client = if meta.client {
416        let client_name = Ident::new(
417            &(service_name.to_token_stream().to_string() + "Client"),
418            Span::call_site(),
419        );
420        let client_methods = func_items
421            .iter()
422            .map(|(attrs, origin_name, name, arg_names, arg_types, ret, client_stream_item, server_stream_item)| {
423                let args2 = arg_names
424                    .iter()
425                    .enumerate()
426                    .map(|(i, j)| {
427                        let ty = arg_types[i].clone();
428                        quote! {#j: #ty}
429                    })
430                    .collect::<Vec<_>>();
431
432                let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
433                let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
434                let request = if arg_types.is_empty() {
435                    quote! {#req_name}
436                } else {
437                    quote! {#req_name(#(#arg_names),*)}
438                };
439
440                if client_stream_item.is_some() && server_stream_item.is_some() {
441                    quote! {
442                        #(#attrs)*
443                        #[track_caller]
444                        pub async fn #origin_name(
445                            &self,
446                            #(#args2),*
447                        ) ->quic_rpc_utils:: Result<(
448                            quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, ()>,
449                            quic_rpc_utils::ServerStreamingResponse<#server_stream_item>
450                        )> {
451                            let (sink, res) = self.client.bidi(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
452                            let res = quic_rpc_utils::ServerStreamingResponse::new(res.map(|i| match i {
453                                Ok(#res_name(i)) => Ok(i),
454                                Ok(_) => Err(quic_rpc_utils::QuicRpcWrapError::Response("Invalid response.".to_string())),
455                                Err(e) => Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
456                            }));
457
458                            Ok((quic_rpc_utils::ClientStreamingResponse::new(sink, async {
459                                Ok(())
460                            }), res))
461                        }
462                    }
463                } else if client_stream_item.is_some() {
464                    quote! {
465                        #(#attrs)*
466                        #[track_caller]
467                        pub async fn #origin_name(
468                            &self,
469                            #(#args2),*
470                        ) -> quic_rpc_utils::Result<
471                            quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, #ret>,
472                        > {
473                            let (sink, res) = self.client.client_streaming(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
474                            Ok(quic_rpc_utils::ClientStreamingResponse::new(sink, async move {
475                                Ok(res.await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?.0)
476                            }))
477                        }
478                    }
479                } else if server_stream_item.is_some() {
480                    quote! {
481                        #(#attrs)*
482                        #[track_caller]
483                        pub async fn #origin_name(
484                            &self,
485                            #(#args2),*
486                        ) -> quic_rpc_utils::Result<
487                            quic_rpc_utils::ServerStreamingResponse<#server_stream_item>
488                        > {
489                            let stream = self.client
490                                .server_streaming(#request).await
491                                .map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
492                                .map(|i| match i {
493                                    Ok(#res_name(i)) => Ok(i),
494                                    Ok(_) => Err(quic_rpc_utils::QuicRpcWrapError::Response("Invalid response.".to_string())),
495                                    Err(e) => Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
496                                });
497                            Ok(quic_rpc_utils::ServerStreamingResponse::new(stream))
498                        }
499                    }
500                } else {
501                    let (ret, response) = if ret.is_some() {
502                        (quote!{#ret}, quote!{
503                            Ok(self.client.rpc(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?.0)
504                        })
505                    } else {
506                        (quote!{()}, quote!{
507                            self.client.rpc(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
508
509                            Ok(())
510                        })
511                    };
512
513                    quote! {
514                        #(#attrs)*
515                        #[track_caller]
516                        pub async fn #origin_name(&self, #(#args2),*) -> quic_rpc_utils::Result<#ret> {
517                            #response
518                        }
519                    }
520                }
521            })
522            .collect::<Vec<_>>();
523
524        let client_fields = meta
525            .services
526            .iter()
527            .map(|(subname, _)| {
528                let name = Ident::new(
529                    subname
530                        .to_string()
531                        .trim_end_matches("Service")
532                        .to_snake_case()
533                        .as_str(),
534                    Span::call_site(),
535                );
536                let name2 = Ident::new(&(subname.to_string() + "Client"), Span::call_site());
537                let field = quote! {pub #name: #name2};
538                (
539                    field,
540                    quote! {#name: #name2::new(&client.clone().map().boxed())},
541                )
542            })
543            .collect::<Vec<_>>();
544        let client_children = client_fields
545            .iter()
546            .map(|(_, ch)| ch.clone())
547            .collect::<Vec<_>>();
548        let client_fields = client_fields
549            .iter()
550            .map(|(f, _)| f.clone())
551            .collect::<Vec<_>>();
552
553        quote! {
554            #public struct #client_name<C: quic_rpc_utils::Connector<#service_name> = quic_rpc_utils::BoxedConnector<#service_name>> {
555                client: quic_rpc_utils::RpcClient<#service_name, C>,
556                #(#client_fields),*
557            }
558
559            impl<C: quic_rpc_utils::Connector<#service_name>> #client_name<C> {
560                pub fn new(client: &quic_rpc_utils::RpcClient<#service_name, C>) -> Self {
561                    Self {
562                        client: client.clone(),
563                        #(#client_children),*
564                    }
565                }
566
567                #(#client_methods)*
568            }
569        }
570    } else {
571        quote!()
572    };
573
574    let declared_types = func_items
575        .iter()
576        .map(
577            |(_, _, name, _, arg_types, ret, client_stream_item, server_stream_item)| {
578                let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
579                let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
580
581                let args = if arg_types.is_empty() {
582                    quote!()
583                } else {
584                    quote! {(#(#arg_types),*)}
585                };
586
587                let req_impls = if client_stream_item.is_some() && server_stream_item.is_some() {
588                    quote! {
589                        impl quic_rpc_utils::Msg<#service_name> for #req_name {
590                            type Pattern = quic_rpc_utils::BidiStreaming;
591                        }
592
593                        impl quic_rpc_utils::BidiStreamingMsg<#service_name> for #req_name {
594                            type Update = #client_stream_item;
595                            type Response = #res_name;
596                        }
597                    }
598                } else if client_stream_item.is_some() {
599                    quote! {
600                        impl quic_rpc_utils::Msg<#service_name> for #req_name {
601                            type Pattern = quic_rpc_utils::ClientStreaming;
602                        }
603
604                        impl quic_rpc_utils::ClientStreamingMsg<#service_name> for #req_name {
605                            type Update = #client_stream_item;
606                            type Response = #res_name;
607                        }
608                    }
609                } else if server_stream_item.is_some() {
610                    quote! {
611                        impl quic_rpc_utils::Msg<#service_name> for #req_name {
612                            type Pattern = quic_rpc_utils::ServerStreaming;
613                        }
614
615                        impl quic_rpc_utils::ServerStreamingMsg<#service_name> for #req_name {
616                            type Response = #res_name;
617                        }
618                    }
619                } else {
620                    quote! {
621                        impl quic_rpc_utils::RpcMsg<#service_name> for #req_name {
622                            type Response = #res_name;
623                        }
624                    }
625                };
626
627                let res_type = if ret.is_none() {
628                    quote! {struct #res_name;}
629                } else if server_stream_item.is_some() {
630                    quote! {struct #res_name (#server_stream_item);}
631                } else {
632                    quote! {struct #res_name (#ret);}
633                };
634
635                quote! {
636                    #[derive(Debug, serde::Serialize, serde::Deserialize)]
637                    struct #req_name #args;
638
639                    #req_impls
640
641                    #[derive(Debug, serde::Serialize, serde::Deserialize)]
642                    #res_type
643                }
644            },
645        )
646        .collect::<Vec<_>>();
647
648    let children_debug = meta
649        .services
650        .iter()
651        .map(|(_, field)| quote!(let res = write!(f, "{:?}", self.#field)))
652        .collect::<Vec<_>>();
653
654    let output = quote! {
655        #server
656
657        #client
658
659        #(#declared_types)*
660
661        #[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
662        #public enum #request_name {
663            #(#request_enum_variants),*
664        }
665
666        #[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
667        #public enum #response_name {
668            #(#response_enum_variants),*
669        }
670
671        impl quic_rpc_utils::RpcMsg<#service_name> for #request_name {
672            type Response = #response_name;
673        }
674
675        impl quic_rpc_utils::Service for #service_name {
676            type Req = #request_name;
677            type Res = #response_name;
678        }
679
680        impl std::fmt::Debug for #service_name {
681            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
682                let res = write!(f, "{}(Request:{}, Response:{})\n", #service_name_str, std::mem::size_of::<#request_name>(), std::mem::size_of::<#response_name>());
683                #(#children_debug;)*
684                res
685            }
686        }
687    };
688    output.into()
690}