extern crate proc_macro;
use inflector::Inflector as _;
use proc_macro::TokenStream;
use quote::{quote, quote_spanned};
use std::collections::HashSet;
use syn::spanned::Spanned as _;
mod api_def;
#[proc_macro]
pub fn rpc_api(input_token_stream: TokenStream) -> TokenStream {
let defs: api_def::ApiDefinitions = match syn::parse(input_token_stream) {
Ok(d) => d,
Err(err) => return err.to_compile_error().into(),
};
let mut out = Vec::with_capacity(defs.apis.len());
for api in defs.apis {
match build_api(api) {
Ok(a) => out.push(a),
Err(err) => return err.to_compile_error().into(),
};
}
TokenStream::from(quote! {
#(#out)*
})
}
fn build_api(api: api_def::ApiDefinition) -> Result<proc_macro2::TokenStream, syn::Error> {
let enum_name = &api.name;
let mut tweaked_generics = api.generics.clone();
tweaked_generics.params.insert(
0,
From::from(syn::LifetimeDef::new(
syn::parse_str::<syn::Lifetime>("'a").unwrap(),
)),
);
tweaked_generics
.params
.push(From::from(syn::TypeParam::from(
syn::parse_str::<syn::Ident>("R").unwrap(),
)));
tweaked_generics
.params
.push(From::from(syn::TypeParam::from(
syn::parse_str::<syn::Ident>("I").unwrap(),
)));
let (impl_generics, ty_generics, where_clause) = tweaked_generics.split_for_impl();
let generics = api
.generics
.params
.iter()
.filter_map(|gp| {
if let syn::GenericParam::Type(tp) = gp {
Some(tp.ident.clone())
} else {
None
}
})
.collect::<HashSet<_>>();
let visibility = &api.visibility;
let mut variants = Vec::new();
let mut tmp_variants = Vec::new();
for function in &api.definitions {
let function_is_notification = function.is_void_ret_type();
let variant_name = snake_case_to_camel_case(&function.signature.ident);
let ret = match &function.signature.output {
syn::ReturnType::Default => quote! {()},
syn::ReturnType::Type(_, ty) => quote_spanned!(ty.span()=> #ty),
};
let mut params_list = Vec::new();
for input in function.signature.inputs.iter() {
let (ty, pat_span, param_variant_name) = match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new(
input.span(),
"Having `self` is not allowed in RPC queries definitions",
));
}
syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => {
(ty, pat.span(), param_variant_name(&pat)?)
}
};
params_list.push(quote_spanned!(pat_span=> #param_variant_name: #ty));
}
if !function_is_notification {
if params_list.is_empty() {
tmp_variants.push(quote_spanned!(function.signature.ident.span()=> #variant_name));
} else {
tmp_variants.push(quote_spanned!(function.signature.ident.span()=>
#variant_name {
#(#params_list,)*
}
));
}
}
if function_is_notification {
variants.push(quote_spanned!(function.signature.ident.span()=>
#variant_name {
#(#params_list,)*
}
));
} else {
variants.push(quote_spanned!(function.signature.ident.span()=>
#variant_name {
respond: jsonrpsee::raw::server::TypedResponder<'a, R, I, #ret>,
#(#params_list,)*
}
));
}
}
let next_request = {
let mut notifications_blocks = Vec::new();
let mut function_blocks = Vec::new();
let mut tmp_to_rq = Vec::new();
struct GenericParams {
generics: HashSet<syn::Ident>,
types: HashSet<syn::Ident>,
}
impl<'ast> syn::visit::Visit<'ast> for GenericParams {
fn visit_ident(&mut self, ident: &'ast syn::Ident) {
if self.generics.contains(ident) {
self.types.insert(ident.clone());
}
}
}
let mut generic_params = GenericParams {
generics,
types: HashSet::new(),
};
for function in &api.definitions {
let function_is_notification = function.is_void_ret_type();
let variant_name = snake_case_to_camel_case(&function.signature.ident);
let rpc_method_name = function
.attributes
.method
.clone()
.unwrap_or_else(|| function.signature.ident.to_string());
let mut params_builders = Vec::new();
let mut params_names_list = Vec::new();
for input in function.signature.inputs.iter() {
let (ty, param_variant_name, rpc_param_name) = match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new(
input.span(),
"Having `self` is not allowed in RPC queries definitions",
));
}
syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => {
(ty, param_variant_name(&pat)?, rpc_param_name(&pat, &attrs)?)
}
};
syn::visit::visit_type(&mut generic_params, &ty);
params_names_list
.push(quote_spanned!(function.signature.span()=> #param_variant_name));
if !function_is_notification {
params_builders.push(quote_spanned!(function.signature.span()=>
let #param_variant_name: #ty = {
match request.params().get(#rpc_param_name) {
Ok(v) => v,
Err(_) => {
request.respond(Err(jsonrpsee::common::Error::invalid_params(#rpc_param_name))).await;
continue;
}
}
};
));
} else {
params_builders.push(quote_spanned!(function.signature.span()=>
let #param_variant_name: #ty = {
match request.params().get(#rpc_param_name) {
Ok(v) => v,
Err(_) => {
continue;
}
}
};
));
}
}
if function_is_notification {
notifications_blocks.push(quote_spanned!(function.signature.span()=>
if method == #rpc_method_name {
let request = n;
#(#params_builders)*
return Ok(#enum_name::#variant_name { #(#params_names_list),* });
}
));
} else {
function_blocks.push(quote_spanned!(function.signature.span()=>
if request_outcome.is_none() && method == #rpc_method_name {
let request = server.request_by_id(&request_id).unwrap();
#(#params_builders)*
request_outcome = Some(Tmp::#variant_name { #(#params_names_list),* });
}
));
tmp_to_rq.push(quote_spanned!(function.signature.span()=>
Some(Tmp::#variant_name { #(#params_names_list),* }) => {
let request = server.request_by_id(&request_id).unwrap();
let respond = jsonrpsee::raw::server::TypedResponder::from(request);
return Ok(#enum_name::#variant_name { respond #(, #params_names_list)* });
},
));
}
}
let params_tys = generic_params.types.iter();
let tmp_generics = if generic_params.types.is_empty() {
quote!()
} else {
quote_spanned!(api.name.span()=>
<#(#params_tys,)*>
)
};
let on_request = quote_spanned!(api.name.span()=> {
#[allow(unused)]
enum Tmp #tmp_generics {
#(#tmp_variants,)*
}
let request_id = r.id();
let method = r.method().to_owned();
let mut request_outcome: Option<Tmp #tmp_generics> = None;
#(#function_blocks)*
match request_outcome {
#(#tmp_to_rq)*
None => server.request_by_id(&request_id).unwrap().respond(Err(jsonrpsee::common::Error::method_not_found())).await,
}
});
let on_notification = quote_spanned!(api.name.span()=> {
let method = n.method().to_owned();
#(#notifications_blocks)*
});
let params_tys = generic_params.types.iter();
quote_spanned!(api.name.span()=>
#visibility async fn next_request(server: &'a mut jsonrpsee::raw::RawServer<R, I>) -> core::result::Result<#enum_name #ty_generics, std::io::Error>
where
R: jsonrpsee::transport::TransportServer<RequestId = I>,
I: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync
#(, #params_tys: jsonrpsee::common::DeserializeOwned)*
{
loop {
match server.next_event().await {
jsonrpsee::raw::RawServerEvent::Notification(n) => #on_notification,
jsonrpsee::raw::RawServerEvent::SubscriptionsClosed(_) => unimplemented!(),
jsonrpsee::raw::RawServerEvent::SubscriptionsReady(_) => unimplemented!(),
jsonrpsee::raw::RawServerEvent::Request(r) => #on_request,
}
}
}
)
};
let client_impl_block = build_client_impl(&api)?;
let debug_variants = build_debug_variants(&api)?;
Ok(quote_spanned!(api.name.span()=>
#visibility enum #enum_name #tweaked_generics {
#(#variants),*
}
impl #impl_generics #enum_name #ty_generics #where_clause {
#next_request
}
#client_impl_block
impl #impl_generics std::fmt::Debug for #enum_name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
#(#debug_variants,)*
}
}
}
))
}
fn build_client_impl(api: &api_def::ApiDefinition) -> Result<proc_macro2::TokenStream, syn::Error> {
let enum_name = &api.name;
let (impl_generics_org, _, where_clause_org) = api.generics.split_for_impl();
let lifetimes_org = api.generics.lifetimes();
let type_params_org = api.generics.type_params();
let const_params_org = api.generics.const_params();
let client_functions = build_client_functions(&api)?;
Ok(quote_spanned!(api.name.span()=>
impl #impl_generics_org #enum_name<'static #(, #lifetimes_org)* #(, #type_params_org)* #(, #const_params_org)*, (), ()>
#where_clause_org
{
#(#client_functions)*
}
))
}
fn build_client_functions(
api: &api_def::ApiDefinition,
) -> Result<Vec<proc_macro2::TokenStream>, syn::Error> {
let visibility = &api.visibility;
let mut client_functions = Vec::new();
for function in &api.definitions {
let f_name = &function.signature.ident;
let ret_ty = match function.signature.output {
syn::ReturnType::Default => quote!(()),
syn::ReturnType::Type(_, ref ty) => quote_spanned!(ty.span()=> #ty),
};
let rpc_method_name = function
.attributes
.method
.clone()
.unwrap_or_else(|| function.signature.ident.to_string());
let mut params_list = Vec::new();
let mut params_to_json = Vec::new();
let mut params_to_array = Vec::new();
let mut params_tys = Vec::new();
for (param_index, input) in function.signature.inputs.iter().enumerate() {
let (ty, pat_span, rpc_param_name) = match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new(
input.span(),
"Having `self` is not allowed in RPC queries definitions",
));
}
syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => {
(ty, pat.span(), rpc_param_name(&pat, &attrs)?)
}
};
let generated_param_name = syn::Ident::new(
&format!("param{}", param_index),
proc_macro2::Span::call_site(),
);
params_tys.push(ty);
params_list.push(quote_spanned!(pat_span=> #generated_param_name: impl Into<#ty>));
params_to_json.push(quote_spanned!(pat_span=>
map.insert(
#rpc_param_name.to_string(),
jsonrpsee::common::to_value(#generated_param_name.into()).unwrap()
);
));
params_to_array.push(quote_spanned!(pat_span=>
jsonrpsee::common::to_value(#generated_param_name.into()).unwrap()
));
}
let params_building = if params_list.is_empty() {
quote! {jsonrpsee::common::Params::None}
} else if function.attributes.positional_params {
quote_spanned!(function.signature.span()=>
jsonrpsee::common::Params::Array(vec![
#(#params_to_array),*
])
)
} else {
let params_list_len = params_list.len();
quote_spanned!(function.signature.span()=>
jsonrpsee::common::Params::Map({
let mut map = jsonrpsee::common::JsonMap::with_capacity(#params_list_len);
#(#params_to_json)*
map
})
)
};
let is_notification = function.is_void_ret_type();
let function_body = if is_notification {
quote_spanned!(function.signature.span()=>
client.send_notification(#rpc_method_name, #params_building).await
.map_err(jsonrpsee::raw::client::RawClientError::Inner)?;
Ok(())
)
} else {
quote_spanned!(function.signature.span()=>
let rq_id = client.start_request(#rpc_method_name, #params_building).await
.map_err(jsonrpsee::raw::client::RawClientError::Inner)?;
let data = client.request_by_id(rq_id).unwrap().await?;
Ok(jsonrpsee::common::from_value(data).unwrap())
)
};
client_functions.push(quote_spanned!(function.signature.span()=>
#visibility async fn #f_name<C: jsonrpsee::transport::TransportClient>(client: &mut jsonrpsee::raw::RawClient<C> #(, #params_list)*)
-> core::result::Result<#ret_ty, jsonrpsee::raw::client::RawClientError<<C as jsonrpsee::transport::TransportClient>::Error>>
where
#ret_ty: jsonrpsee::common::DeserializeOwned
#(, #params_tys: jsonrpsee::common::Serialize)*
{
#function_body
}
));
}
Ok(client_functions)
}
fn build_debug_variants(
api: &api_def::ApiDefinition,
) -> Result<Vec<proc_macro2::TokenStream>, syn::Error> {
let enum_name = &api.name;
let mut debug_variants = Vec::new();
for function in &api.definitions {
let variant_name = snake_case_to_camel_case(&function.signature.ident);
debug_variants.push(quote_spanned!(function.signature.ident.span()=>
#enum_name::#variant_name { .. } => {
f.debug_struct(stringify!(#enum_name)).finish()
}
));
}
Ok(debug_variants)
}
fn snake_case_to_camel_case(snake_case: &syn::Ident) -> syn::Ident {
syn::Ident::new(&snake_case.to_string().to_pascal_case(), snake_case.span())
}
fn param_variant_name(pat: &syn::Pat) -> syn::parse::Result<&syn::Ident> {
match pat {
syn::Pat::Ident(ident) => Ok(&ident.ident),
_ => unimplemented!(),
}
}
fn rpc_param_name(pat: &syn::Pat, attrs: &[syn::Attribute]) -> syn::parse::Result<String> {
match pat {
syn::Pat::Ident(ident) => Ok(ident.ident.to_string()),
_ => unimplemented!(),
}
}