use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::{parse_macro_input, spanned::Spanned, ItemTrait, ReturnType, TraitItem, Type};
#[proc_macro_attribute]
pub fn nanorpc_derive(_: TokenStream, input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemTrait);
let input_again = input.clone();
let protocol_name = input.ident;
if !protocol_name.to_string().ends_with("Protocol") {
panic!("trait must end with the word \"Protocol\"")
}
let server_struct_name = syn::Ident::new(
&format!(
"{}Service",
protocol_name.to_string().trim_end_matches("Protocol")
),
protocol_name.span(),
);
let client_struct_name = syn::Ident::new(
&format!(
"{}Client",
protocol_name.to_string().trim_end_matches("Protocol")
),
protocol_name.span(),
);
let error_struct_name = syn::Ident::new(
&format!(
"{}Error",
protocol_name.to_string().trim_end_matches("Protocol")
),
protocol_name.span(),
);
let mut server_match = quote! {};
let mut client_body = quote! {};
for item in input.items {
match item {
TraitItem::Method(inner) => {
let method_name = inner.sig.ident.clone();
let mut offset = 0;
let method_call = inner
.sig
.inputs
.iter()
.enumerate()
.map(|(idx, arg)| match arg {
syn::FnArg::Receiver(_) => {
offset += 1;
quote! {&self.0}
}
syn::FnArg::Typed(_) => {
let index = idx - offset;
quote! {if let ::std::option::Option::Some(::std::result::Result::Ok(v)) = __nrpc_args.get(#index).map(|v|::serde_json::from_value(v.clone())) {v} else {
return Some(
::std::result::Result::Err(nanorpc::ServerError{
code: 1,
message: format!("deserialization of argument {} failed", #index),
details: ::serde_json::Value::Null
})
)
}}
}
})
.reduce(|a, b| quote! {#a,#b})
.unwrap();
let method_name_str = method_name.to_string();
let is_fallible = inner
.sig
.output
.to_token_stream()
.to_string()
.contains("Result");
if is_fallible {
server_match = quote! {
#server_match
#method_name_str => {
let raw = #protocol_name::#method_name(#method_call).await;
let ok_mapped = raw.map(|o| ::serde_json::to_value(o).expect("serialization failed"));
let err_mapped = ok_mapped.map_err(|e| nanorpc::ServerError{
code: 1,
message: e.to_string(),
details: ::serde_json::to_value(e).expect("serialization failed")
});
::std::option::Option::Some(err_mapped)
}
};
} else {
server_match = quote! {
#server_match
#method_name_str => {
::std::option::Option::Some(::std::result::Result::Ok(::serde_json::to_value(#protocol_name::#method_name(#method_call).await).expect("serialization failed")))
}
};
}
let mut client_signature = inner.sig.clone();
let original_output = match &client_signature.output {
ReturnType::Default => quote! {()},
ReturnType::Type(_, t) => t.to_token_stream(),
};
client_signature.output = ReturnType::Type(
syn::Token! [->](client_signature.span()),
Box::new(Type::Verbatim(
quote! {::std::result::Result<#original_output, #error_struct_name<__nrpc_T::Error>>},
)),
);
let vec_build = client_signature
.inputs
.iter()
.filter_map(|arg| match arg {
syn::FnArg::Receiver(_) => None,
syn::FnArg::Typed(t) => match t.pat.as_ref() {
syn::Pat::Ident(varname) => {
Some(quote! {__vb.push(::serde_json::to_value(&#varname).unwrap())})
}
v => panic!("wild {:?}", v.to_token_stream()),
},
})
.fold(
quote! {
let mut __vb: ::std::vec::Vec<::serde_json::Value> = ::std::vec::Vec::with_capacity(8);
},
|a, b| quote! {#a; #b},
);
let method_name = client_signature.ident.to_string();
let return_handler = if is_fallible {
quote! {
match jsval {
Ok(jsval) => {
let retval = ::serde_json::from_value(jsval).map_err(#error_struct_name::FailedDecode)?;
Ok(Ok(retval))
}
Err(serverr) => {
Ok(Err(::serde_json::from_value(serverr.details).map_err(#error_struct_name::FailedDecode)?))
}
}
}
} else {
quote! {
match jsval {
Ok(jsval) => {
let retval: #original_output = ::serde_json::from_value(jsval).map_err(#error_struct_name::FailedDecode)?;
Ok(retval)
}
Err(serverr) => {
Err(#error_struct_name::ServerFail)
}
}
}
};
client_body = quote! {
#client_body
pub #client_signature {
#vec_build;
let result = nanorpc::RpcTransport::call(&self.0, #method_name, &__vb).await.map_err(#error_struct_name::Transport)?;
match result {
None => Err(#error_struct_name::NotFound),
Some(jsval) => {
#return_handler
}
}
}
}
}
_ => {
panic!("does not support things other than methods in the trait definition")
}
}
}
let client_type_comment = format!("Automatically generated client type that communicates to servers implementing the [{protocol_name}] protocol. The easiest way to use this is by using the `From<RpcTransport>` implementation. \n\nSee the [{protocol_name}] trait for further documentation on the functionality of the methods..");
let client_impl = quote! {
#[doc=#client_type_comment]
pub struct #client_struct_name<T: nanorpc::RpcTransport = nanorpc::DynRpcTransport>(pub T);
impl<T: nanorpc::RpcTransport> ::std::convert::From<T> for #client_struct_name
where
T::Error: Into<::anyhow::Error> {
fn from(transport: T) -> Self {
Self(nanorpc::DynRpcTransport::new(transport))
}
}
impl <__nrpc_T: nanorpc::RpcTransport + Send + Sync + 'static> #client_struct_name<__nrpc_T> {
#client_body
}
};
let error_type_comment = format!("Automatically generated error type that {client_struct_name} instances return from its methods");
let server_type_comment = format!("Automatically generated struct that wraps any 'business logic' struct implementing [{protocol_name}], and returns a JSON-RPC server implementing [nanorpc::RpcService]. See the [{protocol_name}] trait for further documentation.");
let assembled = quote! {
#input_again
#[doc=#server_type_comment]
pub struct #server_struct_name<T: #protocol_name>(pub T);
#[::async_trait::async_trait]
impl <__nrpc_T: #protocol_name + ::std::marker::Sync + ::std::marker::Send + 'static> nanorpc::RpcService for #server_struct_name<__nrpc_T> {
async fn respond(&self, __nrpc_method: &str, __nrpc_args: Vec<::serde_json::Value>) -> Option<Result<::serde_json::Value, nanorpc::ServerError>> {
match __nrpc_method {
#server_match
_ => {None}
}
}
}
#[derive(::thiserror::Error, Debug)]
#[doc=#error_type_comment]
pub enum #error_struct_name<T> {
#[error("verb not found")]
NotFound,
#[error("unexpected server error on an infallible verb")]
ServerFail,
#[error("failed to decode JSON response: {0:?}")]
FailedDecode(::serde_json::Error),
#[error("transport-level error: {0:?}")]
Transport(T)
}
#client_impl
};
assembled.into()
}