use proc_macro::TokenStream;
use heck::{ToPascalCase, ToSnakeCase};
use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
use quote::{quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse2, parse_macro_input,
punctuated::Punctuated,
token::Comma,
AngleBracketedGenericArguments, AssocType, ExprAssign, FnArg, GenericArgument, ImplItem,
ItemImpl, Path, PathArguments, PathSegment, ReturnType, TraitBound, Type, TypeImplTrait,
TypeParamBound, TypePath,
};
struct Meta {
server: bool,
client: bool,
public: TokenStream2,
services: Vec<(TokenStream2, TokenStream2)>,
}
impl Parse for Meta {
fn parse(input: ParseStream) -> syn::Result<Self> {
let items = Punctuated::<ExprAssign, Comma>::parse_terminated(input).unwrap();
let mut server = false;
let mut client = false;
let mut public = quote!();
let services = items
.iter()
.filter_map(|i| {
let j = i.left.to_token_stream();
let k = i.right.to_token_stream();
if j.to_string() == "server" {
server = k.to_string() == "true";
None
} else if j.to_string() == "client" {
client = k.to_string() == "true";
None
} else if j.to_string() == "public" {
public = if k.to_string() == "true" {
quote! {pub}
} else {
quote! {pub(#k)}
};
None
} else {
Some((j, k))
}
})
.collect();
Ok(Meta {
server,
client,
public,
services,
})
}
}
fn unwrap_stream_item_type(ty: &Type) -> Option<(Type, Option<Type>)> {
match ty {
Type::ImplTrait(TypeImplTrait { bounds, .. }) => match bounds.first() {
Some(TypeParamBound::Trait(TraitBound { path, .. })) => match path.segments.last() {
Some(PathSegment {
arguments: PathArguments::AngleBracketed(path),
..
}) => match path.args.first() {
Some(GenericArgument::AssocType(AssocType { ty, .. })) => {
Some((ty.clone(), None))
}
_ => None,
},
_ => None,
},
_ => panic!("Only support impl Stream."),
},
Type::Path(TypePath {
path: Path { segments, .. },
..
}) => match segments.last() {
Some(PathSegment {
ident,
arguments:
PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }),
..
}) if ident == "Result" => match args.first() {
Some(GenericArgument::Type(ty)) => {
unwrap_stream_item_type(ty).map_or(None, |(ty, _)| match args.last() {
Some(GenericArgument::Type(err_type)) => Some((ty, Some(err_type.clone()))),
_ => Some((ty, None)),
})
}
_ => None,
},
_ => None,
},
_ => None,
}
}
#[proc_macro_attribute]
pub fn service(attrs: TokenStream, input: TokenStream) -> TokenStream {
let meta: Meta = parse2(Into::<TokenStream2>::into(attrs)).unwrap();
let item = parse_macro_input!(input as ItemImpl);
let service_name = item.self_ty.as_ref().clone();
let service_name_str = service_name.to_token_stream().to_string();
let public = meta.public;
let (request_name, response_name) = {
let name = service_name.to_token_stream().to_string();
(
Ident::new(&(name.clone() + "Request"), Span::call_site()),
Ident::new(&(name + "Response"), Span::call_site()),
)
};
let items = item
.items
.iter()
.filter_map(|i| match i {
ImplItem::Fn(f) => Some((f.sig.clone(), f.attrs.clone())),
_ => None,
})
.collect::<Vec<_>>();
let func_items = items
.iter()
.map(|(func, attrs)| {
if func.asyncness.is_none() {
panic!("Function `{}` must be asyncable.", func.ident);
}
let self_ = func.inputs.iter().find(|i| match i {
FnArg::Receiver(_) => true,
FnArg::Typed(_) => false,
});
if self_.is_none() {
panic!("Function `{}` must contain `self` argument.", func.ident);
}
let mut client_stream_item: Option<Type> = None;
let args = func
.inputs
.iter()
.filter_map(|i| match i {
FnArg::Receiver(..) => None,
FnArg::Typed(ty) => match unwrap_stream_item_type(ty.ty.as_ref()) {
None => Some((ty.pat.as_ref().clone(), ty.ty.as_ref().clone())),
Some((ty, _)) => {
client_stream_item.replace(ty);
None
}
},
})
.collect::<Vec<_>>();
let arg_names = args.iter().map(|i| i.0.clone()).collect::<Vec<_>>();
let arg_types = args.iter().map(|i| i.1.clone()).collect::<Vec<_>>();
let (server_stream_item, server_stream_err_type, ret) = match func.output {
ReturnType::Default => (None, None, None),
ReturnType::Type(_, ref ty) => unwrap_stream_item_type(ty.as_ref())
.map_or((None, None, Some(ty.as_ref().clone())), |(t, e)| {
(Some(t), e, Some(ty.as_ref().clone()))
}),
};
(
attrs,
func.ident.clone(),
Ident::new(&func.ident.to_string().to_pascal_case(), Span::call_site()),
arg_names,
arg_types,
ret,
client_stream_item,
server_stream_item,
server_stream_err_type,
)
})
.collect::<Vec<_>>();
let mut request_enum_variants = func_items
.iter()
.map(|(_, _, name, _, _, _, _, _, _)| {
let name2 = Ident::new(&(name.to_string() + "Request"), Span::call_site());
quote! {#name(#name2)}
})
.collect::<Vec<_>>();
request_enum_variants.extend(
func_items
.iter()
.filter_map(|(_, _, name, _, _, _, client_stream_item, _, _)| {
if client_stream_item.is_some() {
let name2 = Ident::new(&(name.to_string() + "Put"), Span::call_site());
return Some(quote! {#name2(#client_stream_item)});
}
None
})
.collect::<Vec<_>>(),
);
request_enum_variants.extend(
meta.services
.iter()
.map(|(subname, _)| {
let name = Ident::new(&(subname.to_string() + "Request"), Span::call_site());
quote! {#subname(#name)}
})
.collect::<Vec<_>>(),
);
let mut response_enum_variants = func_items
.iter()
.map(|(_, _, name, _, _, _, _, _, _)| {
let name2 = Ident::new(&(name.to_string() + "Response"), Span::call_site());
quote! {#name(#name2)}
})
.collect::<Vec<_>>();
response_enum_variants.extend(
meta.services
.iter()
.map(|(subname, _)| {
let name = Ident::new(&(subname.to_string() + "Response"), Span::call_site());
quote! {#name(#name)}
})
.collect::<Vec<_>>(),
);
let server = if meta.server {
let child_request_patterns = meta
.services
.iter()
.map(|(subname, field)| {
let handler = if field.to_string() == "None" {
quote!{quic_rpc_utils::GetServiceHandler::<#subname>::get_handler(self)}
} else {
quote!{self.#field.clone()}
};
quote! {
#request_name::#subname(req) => #handler.handle_rpc_request(req, chan.map().boxed(), rt).await?
}
})
.collect::<Vec<_>>();
let request_match_patterns = func_items
.iter()
.map(|(_, origin_name, name, arg_names, _, ret, client_stream_item, server_stream_item, server_stream_err_type)| {
let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
let args = if arg_names.is_empty() {
quote!()
} else {
quote!{#(#arg_names),*}
};
let parse_args = if arg_names.is_empty() {
quote!{
let #req_name = req;
}
} else {
quote!{
let #req_name (#(ref #arg_names),*) = req;
let (#args) = (#(#arg_names.to_owned()),*);
}
};
if client_stream_item.is_some() && server_stream_item.is_some() {
let call_stream = if server_stream_err_type.is_some() {
quote!{
let stream = match self_.#origin_name(#args, rx2.into_stream()).await {
Ok(stream) => stream,
Err(e) => {
let _ = tx.send_async(#res_name(Err(e))).await;
return;
}
};
quic_rpc_utils::pin!(stream);
while let Some(i) = stream.next().await {
let _ = tx.send_async(#res_name(Ok(i))).await;
}
}
} else {
quote!{
let stream = self_.#origin_name(#args, rx2.into_stream()).await;
quic_rpc_utils::pin!(stream);
while let Some(i) = stream.next().await {
let _ = tx.send_async(#res_name(i)).await;
}
}
};
quote! {
#request_name::#name(req) => {
#parse_args
let (tx, rx) = quic_rpc_utils::flume_bounded(2);
let (tx2, rx2) = quic_rpc_utils::flume_bounded(2);
let self_ = self.clone();
let task = rt.spawn(async move {
#call_stream
});
let (tx3, rx3) = quic_rpc_utils::oneshot_channel();
match chan.bidi_streaming(req, self, |self_, req, updates| {
let _ = tx3.send(rt.spawn(async move {
quic_rpc_utils::pin!(updates);
while let Some(item) = updates.next().await {
let _ = tx2.send_async(item).await;
}
})).map_err(|e| e.abort());
rx.into_stream()
}).await {
Err(e) => {
rx3.await.map_err(|e2| quic_rpc_utils::Error::msg(format!("{}: {}", e2, e)))?.abort();
Err(e)
}
ok => ok,
}?
}
}
} else if client_stream_item.is_some() {
let call_stream = if ret.is_some() {
quote!{
#res_name(self_.#origin_name(#args, updates).await)
}
} else {
quote!{
self_.#origin_name(#args, updates).await;
#res_name
}
};
quote! {
#request_name::#name(req) => chan.client_streaming(req, self, |self_, req, updates| async move {
#parse_args
#call_stream
}).await?
}
} else if server_stream_item.is_some() {
let call_stream = if server_stream_err_type.is_some() {
quote!{
let stream = match self_.#origin_name(#args).await {
Ok(stream) => stream,
Err(e) => {
let _ = tx.send_async(#res_name(Err(e))).await;
return;
}
};
quic_rpc_utils::pin!(stream);
while let Some(i) = stream.next().await {
let _ = tx.send_async(#res_name(Ok(i))).await;
}
}
} else {
quote!{
let stream = self_.#origin_name(#args).await;
quic_rpc_utils::pin!(stream);
while let Some(i) = stream.next().await {
let _ = tx.send_async(#res_name(i)).await;
}
}
};
quote! {
#request_name::#name(req) => {
#parse_args
let (tx, rx) = quic_rpc_utils::flume_bounded(2);
let self_ = self.clone();
rt.spawn(async move {
#call_stream
});
chan.server_streaming(req, self, move |_, _| rx.into_stream()).await?
}
}
} else {
let call = if ret.is_some() {
quote! {
#res_name(self_.#origin_name(#args).await)
}
} else {
quote! {
self_.#origin_name(#args).await;
#res_name
}
};
quote! {
#request_name::#name(req) => chan.rpc(req, self, |self_, req| async move {
#parse_args
#call
}).await?
}
}
})
.collect::<Vec<_>>();
let handler_match =
if child_request_patterns.is_empty() && request_match_patterns.is_empty() {
quote!()
} else {
quote! {
match req {
#(#child_request_patterns,)*
#(#request_match_patterns,)*
_ => return Err(quic_rpc_utils::Error::msg("Response error."))
}
}
};
quote! {
#item
impl<C: quic_rpc_utils::ChannelTypes<#service_name>> quic_rpc_utils::ServiceHandler<#service_name, C> for #service_name {
#[track_caller]
async fn handle_rpc_request(
self: std::sync::Arc<Self>,
req: #request_name,
chan: quic_rpc_utils::RpcChannel<#service_name, C>,
rt: &'static quic_rpc_utils::Runtime
) -> quic_rpc_utils::Result<()> {
#handler_match
Ok(())
}
}
}
} else {
quote!()
};
let client = if meta.client {
let client_name = Ident::new(
&(service_name.to_token_stream().to_string() + "Client"),
Span::call_site(),
);
let client_methods = func_items
.iter()
.map(|(attrs, origin_name, name, arg_names, arg_types, ret, client_stream_item, server_stream_item, server_stream_err_type)| {
let args2 = arg_names
.iter()
.enumerate()
.map(|(i, j)| {
let ty = arg_types[i].clone();
quote! {#j: #ty}
})
.collect::<Vec<_>>();
let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
let request = if arg_types.is_empty() {
quote! {#req_name}
} else {
quote! {#req_name(#(#arg_names),*)}
};
if client_stream_item.is_some() && server_stream_item.is_some() {
let server_stream_item = if server_stream_err_type.is_some() {
quote!{Result<#server_stream_item, #server_stream_err_type>}
} else {
quote!{#server_stream_item}
};
quote! {
#(#attrs)*
#[track_caller]
pub async fn #origin_name(
&self,
#(#args2),*
) -> Result<(
quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, ()>,
quic_rpc_utils::ServerStreamingResponse<#server_stream_item>
), quic_rpc_utils::BidiStreamingError<C>> {
let (sink, res) = self.client.bidi(#request).await?;
let res = quic_rpc_utils::ServerStreamingResponse::new(res.map(|i| match i {
Ok(#res_name(i)) => Ok(i),
Ok(_) => Err(quic_rpc_utils::Error::msg("Response error.")),
Err(e) => Err(quic_rpc_utils::Error::msg(format!("Response error. ({})", e)))
}));
Ok((quic_rpc_utils::ClientStreamingResponse::new(sink, async {
Ok(())
}), res))
}
}
} else if client_stream_item.is_some() {
quote! {
#(#attrs)*
#[track_caller]
pub async fn #origin_name(
&self,
#(#args2),*
) -> Result<
quic_rpc_utils::ClientStreamingResponse<#client_stream_item, #service_name, C, #ret>,
quic_rpc_utils::ClientStreamingError<C>
> {
let (sink, res) = self.client.client_streaming(#request).await?;
Ok(quic_rpc_utils::ClientStreamingResponse::new(sink, async move {
Ok(res.await?.0)
}))
}
}
} else if server_stream_item.is_some() {
let server_stream_item = if server_stream_err_type.is_some() {
quote!{Result<#server_stream_item, #server_stream_err_type>}
} else {
quote!{#server_stream_item}
};
quote! {
#(#attrs)*
#[track_caller]
pub async fn #origin_name(
&self,
#(#args2),*
) -> Result<quic_rpc_utils::ServerStreamingResponse<#server_stream_item>, quic_rpc_utils::ServerStreamingError<C>> {
let stream = self.client
.server_streaming(#request)
.await?
.map(|i| match i {
Ok(#res_name(i)) => Ok(i),
Ok(_) => Err(quic_rpc_utils::Error::msg("Response error.")),
Err(e) => Err(quic_rpc_utils::Error::msg(format!("Response error. ({})", e)))
});
Ok(quic_rpc_utils::ServerStreamingResponse::new(stream))
}
}
} else {
let (ret, response) = if ret.is_some() {
(quote!{#ret}, quote!{Ok(self.client.rpc(#request).await?.0)})
} else {
(quote!{()}, quote!{
self.client.rpc(#request).await?;
Ok(())
})
};
quote! {
#(#attrs)*
#[track_caller]
pub async fn #origin_name(&self, #(#args2),*) -> Result<#ret, quic_rpc_utils::RpcError<C>> {
#response
}
}
}
})
.collect::<Vec<_>>();
let client_fields = meta
.services
.iter()
.map(|(subname, _)| {
let name = Ident::new(
subname
.to_string()
.trim_end_matches("Service")
.to_snake_case()
.as_str(),
Span::call_site(),
);
let name2 = Ident::new(&(subname.to_string() + "Client"), Span::call_site());
let field = quote! {pub #name: #name2};
(
field,
quote! {#name: #name2::new(&client.clone().map().boxed())},
)
})
.collect::<Vec<_>>();
let client_children = client_fields
.iter()
.map(|(_, ch)| ch.clone())
.collect::<Vec<_>>();
let client_fields = client_fields
.iter()
.map(|(f, _)| f.clone())
.collect::<Vec<_>>();
quote! {
#public struct #client_name<C: quic_rpc_utils::Connector<#service_name> = quic_rpc_utils::BoxedConnector<#service_name>> {
client: quic_rpc_utils::RpcClient<#service_name, C>,
#(#client_fields),*
}
impl<C: quic_rpc_utils::Connector<#service_name>> #client_name<C> {
pub fn new(client: &quic_rpc_utils::RpcClient<#service_name, C>) -> Self {
Self {
client: client.clone(),
#(#client_children),*
}
}
#(#client_methods)*
}
}
} else {
quote!()
};
let declared_types = func_items
.iter()
.map(
|(_, _, name, _, arg_types, ret, client_stream_item, server_stream_item,server_stream_err_type)| {
let req_name = Ident::new(&(name.to_string() + "Request"), Span::call_site());
let res_name = Ident::new(&(name.to_string() + "Response"), Span::call_site());
let args = if arg_types.is_empty() {
quote!()
} else {
quote! {(#(#arg_types),*)}
};
let req_impls = if client_stream_item.is_some() && server_stream_item.is_some() {
quote! {
impl quic_rpc_utils::Msg<#service_name> for #req_name {
type Pattern = quic_rpc_utils::BidiStreaming;
}
impl quic_rpc_utils::BidiStreamingMsg<#service_name> for #req_name {
type Update = #client_stream_item;
type Response = #res_name;
}
}
} else if client_stream_item.is_some() {
quote! {
impl quic_rpc_utils::Msg<#service_name> for #req_name {
type Pattern = quic_rpc_utils::ClientStreaming;
}
impl quic_rpc_utils::ClientStreamingMsg<#service_name> for #req_name {
type Update = #client_stream_item;
type Response = #res_name;
}
}
} else if server_stream_item.is_some() {
quote! {
impl quic_rpc_utils::Msg<#service_name> for #req_name {
type Pattern = quic_rpc_utils::ServerStreaming;
}
impl quic_rpc_utils::ServerStreamingMsg<#service_name> for #req_name {
type Response = #res_name;
}
}
} else {
quote! {
impl quic_rpc_utils::RpcMsg<#service_name> for #req_name {
type Response = #res_name;
}
}
};
let res_type = if ret.is_none() {
quote! {struct #res_name;}
} else if server_stream_item.is_some() {
if server_stream_err_type.is_some() {
quote! {struct #res_name (Result<#server_stream_item, #server_stream_err_type>);}
} else {
quote! {struct #res_name (#server_stream_item);}
}
} else {
quote! {struct #res_name (#ret);}
};
quote! {
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct #req_name #args;
#req_impls
#[derive(Debug, serde::Serialize, serde::Deserialize)]
#res_type
}
},
)
.collect::<Vec<_>>();
let children_debug = meta
.services
.iter()
.map(|(_, field)| quote!(let res = write!(f, "{:?}", self.#field)))
.collect::<Vec<_>>();
let output = quote! {
#server
#client
#(#declared_types)*
#[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
#public enum #request_name {
#(#request_enum_variants),*
}
#[derive(Debug, serde::Serialize, serde::Deserialize, derive_more::From, derive_more::TryInto)]
#public enum #response_name {
#(#response_enum_variants),*
}
impl quic_rpc_utils::RpcMsg<#service_name> for #request_name {
type Response = #response_name;
}
impl quic_rpc_utils::Service for #service_name {
type Req = #request_name;
type Res = #response_name;
}
impl std::fmt::Debug for #service_name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let res = write!(f, "{}(Request:{}, Response:{})\n", #service_name_str, std::mem::size_of::<#request_name>(), std::mem::size_of::<#response_name>());
#(#children_debug;)*
res
}
}
};
output.into()
}