1use proc_macro::TokenStream;
2
3use heck::{ToPascalCase, ToSnakeCase};
4use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
5use quote::{ToTokens, quote};
6use syn::{
7 AngleBracketedGenericArguments, AssocType, ExprAssign, FnArg, GenericArgument, ImplItem,
8 ItemImpl, Path, PathArguments, PathSegment, ReturnType, TraitBound, Type, TypeImplTrait,
9 TypeParamBound, TypePath,
10 parse::{Parse, ParseStream},
11 parse_macro_input, parse2,
12 punctuated::Punctuated,
13 token::Comma,
14};
15
16struct Meta {
17 server: bool,
18 client: bool,
19 public: TokenStream2,
20 services: Vec<(TokenStream2, TokenStream2)>,
21}
22
23impl Parse for Meta {
24 fn parse(input: ParseStream) -> syn::Result<Self> {
25 let items = Punctuated::<ExprAssign, Comma>::parse_terminated(input).unwrap();
26 let mut server = false;
27 let mut client = false;
28 let mut public = quote!();
29 let services = items
30 .iter()
31 .filter_map(|i| {
32 let j = i.left.to_token_stream();
33 let k = i.right.to_token_stream();
34 if j.to_string() == "server" {
35 server = k.to_string() == "true";
36 None
37 } else if j.to_string() == "client" {
38 client = k.to_string() == "true";
39 None
40 } else if j.to_string() == "public" {
41 public = if k.to_string() == "true" {
42 quote! {pub}
43 } else {
44 quote! {pub(#k)}
45 };
46 None
47 } else {
48 Some((j, k))
49 }
50 })
51 .collect();
52 Ok(Meta {
53 server,
54 client,
55 public,
56 services,
57 })
58 }
59}
60
61fn unwrap_stream_item_type(ty: &Type) -> Option<(Type, Option<Type>)> {
62 match ty {
63 Type::ImplTrait(TypeImplTrait { bounds, .. }) => match bounds.first() {
64 Some(TypeParamBound::Trait(TraitBound { path, .. })) => match path.segments.last() {
65 Some(PathSegment {
66 arguments: PathArguments::AngleBracketed(path),
67 ..
68 }) => match path.args.first() {
69 Some(GenericArgument::AssocType(AssocType { ty, .. })) => {
70 Some((ty.clone(), None))
71 }
72 _ => None,
73 },
74 _ => None,
75 },
76 _ => panic!("Only support impl Stream."),
77 },
78 Type::Path(TypePath {
79 path: Path { segments, .. },
80 ..
81 }) => match segments.last() {
82 Some(PathSegment {
83 ident,
84 arguments:
85 PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }),
86 ..
87 }) if ident == "Result" => match args.first() {
88 Some(GenericArgument::Type(ty)) => {
89 unwrap_stream_item_type(ty).map_or(None, |(ty, _)| match args.last() {
90 Some(GenericArgument::Type(err_type)) => Some((ty, Some(err_type.clone()))),
91 _ => Some((ty, None)),
92 })
93 }
94 _ => None,
95 },
96 _ => None,
97 },
98 _ => None,
99 }
100}
101
102#[proc_macro_attribute]
130pub fn service(attrs: TokenStream, input: TokenStream) -> TokenStream {
131 let meta: Meta = parse2(Into::<TokenStream2>::into(attrs)).unwrap();
132 let item = parse_macro_input!(input as ItemImpl);
133 let service_name = item.self_ty.as_ref().clone();
134 let service_name_str = service_name.to_token_stream().to_string();
135 let public = meta.public;
136 let (request_name, response_name) = {
137 let name = service_name.to_token_stream().to_string();
138 (
139 Ident::new(&(name.clone() + "Request"), Span::call_site()),
140 Ident::new(&(name + "Response"), Span::call_site()),
141 )
142 };
143 let items = item
144 .items
145 .iter()
146 .filter_map(|i| match i {
147 ImplItem::Fn(f) => Some((f.sig.clone(), f.attrs.clone())),
148 _ => None,
149 })
150 .collect::<Vec<_>>();
151 let func_items = items
152 .iter()
153 .map(|(func, attrs)| {
154 if func.asyncness.is_none() {
155 panic!("Function `{}` must be asyncable.", func.ident);
156 }
157 let self_ = func.inputs.iter().find(|i| match i {
158 FnArg::Receiver(_) => true,
159 FnArg::Typed(_) => false,
160 });
161 if self_.is_none() {
162 panic!("Function `{}` must contain `self` argument.", func.ident);
163 }
164
165 let mut client_stream_item: Option<Type> = None;
166 let args = func
167 .inputs
168 .iter()
169 .filter_map(|i| match i {
170 FnArg::Receiver(..) => None,
171 FnArg::Typed(ty) => match unwrap_stream_item_type(ty.ty.as_ref()) {
172 None => Some((ty.pat.as_ref().clone(), ty.ty.as_ref().clone())),
173 Some((ty, _)) => {
174 client_stream_item.replace(ty);
175 None
176 }
177 },
178 })
179 .collect::<Vec<_>>();
180 let arg_names = args.iter().map(|i| i.0.clone()).collect::<Vec<_>>();
181 let arg_types = args.iter().map(|i| i.1.clone()).collect::<Vec<_>>();
182
183 let (server_stream_item, server_stream_err_type, ret) = match func.output {
184 ReturnType::Default => (None, None, None),
185 ReturnType::Type(_, ref ty) => unwrap_stream_item_type(ty.as_ref())
186 .map_or((None, None, Some(ty.as_ref().clone())), |(t, e)| {
187 (Some(t), e, Some(ty.as_ref().clone()))
188 }),
189 };
190
191 (
192 attrs,
193 func.ident.clone(),
194 Ident::new(&func.ident.to_string().to_pascal_case(), Span::call_site()),
195 arg_names,
196 arg_types,
197 ret,
198 client_stream_item,
199 server_stream_item,
200 server_stream_err_type,
201 )
202 })
203 .collect::<Vec<_>>();
204
205 let mut request_enum_variants = func_items
206 .iter()
207 .map(|(_, _, name, _, _, _, _, _, _)| {
208 let name2 = Ident::new(&(name.to_string() + "Request"), Span::call_site());
209 quote! {#name(#name2)}
210 })
211 .collect::<Vec<_>>();
212 request_enum_variants.extend(
213 func_items
214 .iter()
215 .filter_map(|(_, _, name, _, _, _, client_stream_item, _, _)| {
216 if client_stream_item.is_some() {
217 let name2 = Ident::new(&(name.to_string() + "Put"), Span::call_site());
218 return Some(quote! {#name2(#client_stream_item)});
219 }
220 None
221 })
222 .collect::<Vec<_>>(),
223 );
224 request_enum_variants.extend(
225 meta.services
226 .iter()
227 .map(|(subname, _)| {
228 let name = Ident::new(&(subname.to_string() + "Request"), Span::call_site());
229 quote! {#subname(#name)}
230 })
231 .collect::<Vec<_>>(),
232 );
233
234 let mut response_enum_variants = func_items
235 .iter()
236 .map(|(_, _, name, _, _, _, _, _, _)| {
237 let name2 = Ident::new(&(name.to_string() + "Response"), Span::call_site());
238 quote! {#name(#name2)}
239 })
240 .collect::<Vec<_>>();
241 response_enum_variants.extend(
242 meta.services
243 .iter()
244 .map(|(subname, _)| {
245 let name = Ident::new(&(subname.to_string() + "Response"), Span::call_site());
246 quote! {#name(#name)}
247 })
248 .collect::<Vec<_>>(),
249 );
250
251 let server = if meta.server {
252 let child_request_patterns = meta
253 .services
254 .iter()
255 .map(|(subname, field)| {
256 let handler = if field.to_string() == "None" {
257 quote!{quic_rpc_utils::GetServiceHandler::<#subname>::get_handler(self)}
258 } else {
259 quote!{self.#field.clone()}
260 };
261
262 quote! {
263 #request_name::#subname(req) => #handler.handle_rpc_request(req, chan.map().boxed(), rt).await?
264 }
265 })
266 .collect::<Vec<_>>();
267
268 let request_match_patterns = func_items
269 .iter()
270 .map(|(_, origin_name, name, arg_names, _, ret, client_stream_item, server_stream_item, server_stream_err_type)| {
271 let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
272 let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
273
274 let args = if arg_names.is_empty() {
275 quote!()
276 } else {
277 quote!{#(#arg_names),*}
278 };
279 let parse_args = if arg_names.is_empty() {
280 quote!{
281 let #req_name = req;
282 }
283 } else {
284 quote!{
285 let #req_name (#(ref #arg_names),*) = req;
286 let (#args) = (#(#arg_names.to_owned()),*);
287 }
288 };
289
290 if client_stream_item.is_some() && server_stream_item.is_some() {
291 let call_stream = if server_stream_err_type.is_some() {
292 quote!{
293 let stream = match self_.#origin_name(#args, rx2.into_stream()).await {
294 Ok(stream) => stream,
295 Err(e) => {
296 let _ = tx.send_async(#res_name(Err(e))).await;
297 return;
298 }
299 };
300 quic_rpc_utils::pin!(stream);
301 while let Some(i) = stream.next().await {
302 let _ = tx.send_async(#res_name(Ok(i))).await;
303 }
304 }
305 } else {
306 quote!{
307 let stream = self_.#origin_name(#args, rx2.into_stream()).await;
308 quic_rpc_utils::pin!(stream);
309 while let Some(i) = stream.next().await {
310 let _ = tx.send_async(#res_name(i)).await;
311 }
312 }
313 };
314
315 quote! {
316 #request_name::#name(req) => {
317 #parse_args
318
319 let (tx, rx) = quic_rpc_utils::flume_bounded(2);
320 let (tx2, rx2) = quic_rpc_utils::flume_bounded(2);
321 let self_ = self.clone();
322 let task = rt.spawn(async move {
323 #call_stream
324 });
325 let (tx3, rx3) = quic_rpc_utils::oneshot_channel();
326 match chan.bidi_streaming(req, self, |self_, req, updates| {
327 let _ = tx3.send(rt.spawn(async move {
328 quic_rpc_utils::pin!(updates);
329 while let Some(item) = updates.next().await {
330 let _ = tx2.send_async(item).await;
331 }
332 })).map_err(|e| e.abort());
333 rx.into_stream()
334 }).await {
335 Err(e) => {
336 rx3.await.map_err(|e2| quic_rpc_utils::QuicRpcWrapError::Recv(format!("{}: {}", e2, e)))?.abort();
337 Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
338 }
339 Ok(()) => Ok(()),
340 }?
341 }
342 }
343 } else if client_stream_item.is_some() {
344 let call_stream = if ret.is_some() {
345 quote!{
346 #res_name(self_.#origin_name(#args, updates).await)
347 }
348 } else {
349 quote!{
350 self_.#origin_name(#args, updates).await;
351 #res_name
352 }
353 };
354
355 quote! {
356 #request_name::#name(req) => chan.client_streaming(req, self, |self_, req, updates| async move {
357 #parse_args
358
359 #call_stream
360 }).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
361 }
362 } else if server_stream_item.is_some() {
363 let call_stream = if server_stream_err_type.is_some() {
364 quote!{
365 let stream = match self_.#origin_name(#args).await {
366 Ok(stream) => stream,
367 Err(e) => {
368 let _ = tx.send_async(#res_name(Err(e))).await;
369 return;
370 }
371 };
372 quic_rpc_utils::pin!(stream);
373 while let Some(i) = stream.next().await {
374 let _ = tx.send_async(#res_name(Ok(i))).await;
375 }
376 }
377 } else {
378 quote!{
379 let stream = self_.#origin_name(#args).await;
380 quic_rpc_utils::pin!(stream);
381 while let Some(i) = stream.next().await {
382 let _ = tx.send_async(#res_name(i)).await;
383 }
384 }
385 };
386
387 quote! {
388 #request_name::#name(req) => {
389 #parse_args
390
391 let (tx, rx) = quic_rpc_utils::flume_bounded(2);
392 let self_ = self.clone();
393 rt.spawn(async move {
394 #call_stream
395 });
396 chan.server_streaming(req, self, move |_, _| rx.into_stream()).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
397 }
398 }
399 } else {
400 let call = if ret.is_some() {
401 quote! {
402 #res_name(self_.#origin_name(#args).await)
403 }
404 } else {
405 quote! {
406 self_.#origin_name(#args).await;
407 #res_name
408 }
409 };
410
411 quote! {
412 #request_name::#name(req) => chan.rpc(req, self, |self_, req| async move {
413 #parse_args
414
415 #call
416 }).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
417 }
418 }
419 })
420 .collect::<Vec<_>>();
421
422 let handler_match =
423 if child_request_patterns.is_empty() && request_match_patterns.is_empty() {
424 quote!()
425 } else {
426 quote! {
427 match req {
428 #(#child_request_patterns,)*
429 #(#request_match_patterns,)*
430 _ => return Err(quic_rpc_utils::QuicRpcWrapError::Request)
431 }
432 }
433 };
434
435 quote! {
436 #item
437
438 impl<C: quic_rpc_utils::ChannelTypes<#service_name>> quic_rpc_utils::ServiceHandler<#service_name, C> for #service_name {
439 #[track_caller]
440 async fn handle_rpc_request(
441 self: std::sync::Arc<Self>,
442 req: #request_name,
443 chan: quic_rpc_utils::RpcChannel<#service_name, C>,
444 rt: &'static quic_rpc_utils::Runtime
445 ) -> quic_rpc_utils::Result<()> {
446 #handler_match
447 Ok(())
448 }
449 }
450 }
451 } else {
452 quote!()
453 };
454
455 let client = if meta.client {
456 let client_name = Ident::new(
457 &(service_name.to_token_stream().to_string() + "Client"),
458 Span::call_site(),
459 );
460 let client_methods = func_items
461 .iter()
462 .map(|(attrs, origin_name, name, arg_names, arg_types, ret, client_stream_item, server_stream_item, server_stream_err_type)| {
463 let args2 = arg_names
464 .iter()
465 .enumerate()
466 .map(|(i, j)| {
467 let ty = arg_types[i].clone();
468 quote! {#j: #ty}
469 })
470 .collect::<Vec<_>>();
471
472 let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
473 let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
474 let request = if arg_types.is_empty() {
475 quote! {#req_name}
476 } else {
477 quote! {#req_name(#(#arg_names),*)}
478 };
479
480 if client_stream_item.is_some() && server_stream_item.is_some() {
481 let server_stream_item = if server_stream_err_type.is_some() {
482 quote!{std::result::Result<#server_stream_item, #server_stream_err_type>}
483 } else {
484 quote!{#server_stream_item}
485 };
486
487 quote! {
488 #(#attrs)*
489 #[track_caller]
490 pub async fn #origin_name(
491 &self,
492 #(#args2),*
493 ) ->quic_rpc_utils:: Result<(
494 quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, ()>,
495 quic_rpc_utils::ServerStreamingResponse<#server_stream_item>
496 )> {
497 let (sink, res) = self.client.bidi(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
498 let res = quic_rpc_utils::ServerStreamingResponse::new(res.map(|i| match i {
499 Ok(#res_name(i)) => Ok(i),
500 Ok(_) => Err(quic_rpc_utils::QuicRpcWrapError::Response("Invalid response.".to_string())),
501 Err(e) => Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
502 }));
503
504 Ok((quic_rpc_utils::ClientStreamingResponse::new(sink, async {
505 Ok(())
506 }), res))
507 }
508 }
509 } else if client_stream_item.is_some() {
510 quote! {
511 #(#attrs)*
512 #[track_caller]
513 pub async fn #origin_name(
514 &self,
515 #(#args2),*
516 ) -> quic_rpc_utils::Result<
517 quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, #ret>,
518 > {
519 let (sink, res) = self.client.client_streaming(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
520 Ok(quic_rpc_utils::ClientStreamingResponse::new(sink, async move {
521 Ok(res.await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?.0)
522 }))
523 }
524 }
525 } else if server_stream_item.is_some() {
526 let server_stream_item = if server_stream_err_type.is_some() {
527 quote!{Result<#server_stream_item, #server_stream_err_type>}
528 } else {
529 quote!{#server_stream_item}
530 };
531
532 quote! {
533 #(#attrs)*
534 #[track_caller]
535 pub async fn #origin_name(
536 &self,
537 #(#args2),*
538 ) -> quic_rpc_utils::Result<
539 quic_rpc_utils::ServerStreamingResponse<#server_stream_item>
540 > {
541 let stream = self.client
542 .server_streaming(#request).await
543 .map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?
544 .map(|i| match i {
545 Ok(#res_name(i)) => Ok(i),
546 Ok(_) => Err(quic_rpc_utils::QuicRpcWrapError::Response("Invalid response.".to_string())),
547 Err(e) => Err(quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))
548 });
549 Ok(quic_rpc_utils::ServerStreamingResponse::new(stream))
550 }
551 }
552 } else {
553 let (ret, response) = if ret.is_some() {
554 (quote!{#ret}, quote!{
555 Ok(self.client.rpc(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?.0)
556 })
557 } else {
558 (quote!{()}, quote!{
559 self.client.rpc(#request).await.map_err(|e| quic_rpc_utils::QuicRpcWrapError::Response(e.to_string()))?;
560
561 Ok(())
562 })
563 };
564
565 quote! {
566 #(#attrs)*
567 #[track_caller]
568 pub async fn #origin_name(&self, #(#args2),*) -> quic_rpc_utils::Result<#ret> {
569 #response
570 }
571 }
572 }
573 })
574 .collect::<Vec<_>>();
575
576 let client_fields = meta
577 .services
578 .iter()
579 .map(|(subname, _)| {
580 let name = Ident::new(
581 subname
582 .to_string()
583 .trim_end_matches("Service")
584 .to_snake_case()
585 .as_str(),
586 Span::call_site(),
587 );
588 let name2 = Ident::new(&(subname.to_string() + "Client"), Span::call_site());
589 let field = quote! {pub #name: #name2};
590 (
591 field,
592 quote! {#name: #name2::new(&client.clone().map().boxed())},
593 )
594 })
595 .collect::<Vec<_>>();
596 let client_children = client_fields
597 .iter()
598 .map(|(_, ch)| ch.clone())
599 .collect::<Vec<_>>();
600 let client_fields = client_fields
601 .iter()
602 .map(|(f, _)| f.clone())
603 .collect::<Vec<_>>();
604
605 quote! {
606 #public struct #client_name<C: quic_rpc_utils::Connector<#service_name> = quic_rpc_utils::BoxedConnector<#service_name>> {
607 client: quic_rpc_utils::RpcClient<#service_name, C>,
608 #(#client_fields),*
609 }
610
611 impl<C: quic_rpc_utils::Connector<#service_name>> #client_name<C> {
612 pub fn new(client: &quic_rpc_utils::RpcClient<#service_name, C>) -> Self {
613 Self {
614 client: client.clone(),
615 #(#client_children),*
616 }
617 }
618
619 #(#client_methods)*
620 }
621 }
622 } else {
623 quote!()
624 };
625
626 let declared_types = func_items
627 .iter()
628 .map(
629 |(
630 _,
631 _,
632 name,
633 _,
634 arg_types,
635 ret,
636 client_stream_item,
637 server_stream_item,
638 server_stream_err_type,
639 )| {
640 let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
641 let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
642
643 let args = if arg_types.is_empty() {
644 quote!()
645 } else {
646 quote! {(#(#arg_types),*)}
647 };
648
649 let req_impls = if client_stream_item.is_some() && server_stream_item.is_some() {
650 quote! {
651 impl quic_rpc_utils::Msg<#service_name> for #req_name {
652 type Pattern = quic_rpc_utils::BidiStreaming;
653 }
654
655 impl quic_rpc_utils::BidiStreamingMsg<#service_name> for #req_name {
656 type Update = #client_stream_item;
657 type Response = #res_name;
658 }
659 }
660 } else if client_stream_item.is_some() {
661 quote! {
662 impl quic_rpc_utils::Msg<#service_name> for #req_name {
663 type Pattern = quic_rpc_utils::ClientStreaming;
664 }
665
666 impl quic_rpc_utils::ClientStreamingMsg<#service_name> for #req_name {
667 type Update = #client_stream_item;
668 type Response = #res_name;
669 }
670 }
671 } else if server_stream_item.is_some() {
672 quote! {
673 impl quic_rpc_utils::Msg<#service_name> for #req_name {
674 type Pattern = quic_rpc_utils::ServerStreaming;
675 }
676
677 impl quic_rpc_utils::ServerStreamingMsg<#service_name> for #req_name {
678 type Response = #res_name;
679 }
680 }
681 } else {
682 quote! {
683 impl quic_rpc_utils::RpcMsg<#service_name> for #req_name {
684 type Response = #res_name;
685 }
686 }
687 };
688
689 let res_type = if ret.is_none() {
690 quote! {struct #res_name;}
691 } else if server_stream_item.is_some() {
692 if server_stream_err_type.is_some() {
693 quote! {struct #res_name (std::result::Result<#server_stream_item, #server_stream_err_type>);}
694 } else {
695 quote! {struct #res_name (#server_stream_item);}
696 }
697 } else {
698 quote! {struct #res_name (#ret);}
699 };
700
701 quote! {
702 #[derive(Debug, serde::Serialize, serde::Deserialize)]
703 struct #req_name #args;
704
705 #req_impls
706
707 #[derive(Debug, serde::Serialize, serde::Deserialize)]
708 #res_type
709 }
710 },
711 )
712 .collect::<Vec<_>>();
713
714 let children_debug = meta
715 .services
716 .iter()
717 .map(|(_, field)| quote!(let res = write!(f, "{:?}", self.#field)))
718 .collect::<Vec<_>>();
719
720 let output = quote! {
721 #server
722
723 #client
724
725 #(#declared_types)*
726
727 #[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
728 #public enum #request_name {
729 #(#request_enum_variants),*
730 }
731
732 #[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
733 #public enum #response_name {
734 #(#response_enum_variants),*
735 }
736
737 impl quic_rpc_utils::RpcMsg<#service_name> for #request_name {
738 type Response = #response_name;
739 }
740
741 impl quic_rpc_utils::Service for #service_name {
742 type Req = #request_name;
743 type Res = #response_name;
744 }
745
746 impl std::fmt::Debug for #service_name {
747 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
748 let res = write!(f, "{}(Request:{}, Response:{})\n", #service_name_str, std::mem::size_of::<#request_name>(), std::mem::size_of::<#response_name>());
749 #(#children_debug;)*
750 res
751 }
752 }
753 };
754 output.into()
756}