use proc_macro::TokenStream;
use quote::quote;
use syn::{
FnArg, Ident, ItemTrait, Pat, PatType, Result, ReturnType, TraitItem,
parse::{Parse, ParseStream},
parse_macro_input,
};
struct EndpointAsyncArgs {
client_name: Ident,
}
impl Parse for EndpointAsyncArgs {
fn parse(input: ParseStream) -> Result<Self> {
let client_name = input.parse()?;
Ok(EndpointAsyncArgs { client_name })
}
}
pub fn endpoint_async(args: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(args as EndpointAsyncArgs);
let input_trait = parse_macro_input!(input as ItemTrait);
let client_name = args.client_name;
let has_async_trait = input_trait
.attrs
.iter()
.any(|attr| attr.path.segments.iter().any(|segment| segment.ident == "async_trait"));
let client_struct = generate_client_struct(&client_name);
let client_impl = generate_client_impl(&client_name);
let trait_impl = generate_trait_impl(&client_name, &input_trait, has_async_trait);
let expanded = quote! {
#input_trait
#client_struct
#client_impl
#trait_impl
};
TokenStream::from(expanded)
}
fn get_result_type_from_future(ty: &syn::Type) -> Option<&syn::Type> {
if let syn::Type::ImplTrait(type_impl) = ty {
for bound in &type_impl.bounds {
if let syn::TypeParamBound::Trait(trait_bound) = bound {
if let Some(segment) = trait_bound.path.segments.last() {
if segment.ident == "Future" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
for arg in &args.args {
if let syn::GenericArgument::Binding(binding) = arg {
if binding.ident == "Output" {
return Some(&binding.ty);
}
}
}
}
}
}
}
}
}
None
}
fn check_return_type(return_type: &syn::ReturnType, returns_impl_future: bool) -> bool {
if returns_impl_future {
if let syn::ReturnType::Type(_, ty) = return_type {
if let Some(output_type) = get_result_type_from_future(ty) {
return check_result_type(output_type);
}
}
false
} else {
if let syn::ReturnType::Type(_, ty) = return_type {
return check_result_type(ty);
}
false
}
}
fn check_result_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Result" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if args.args.len() == 2 {
if let syn::GenericArgument::Type(syn::Type::Path(error_type)) =
&args.args[1]
{
if let Some(error_segment) = error_type.path.segments.last() {
return error_segment.ident == "RpcError";
}
}
}
}
}
}
}
false
}
fn generate_client_struct(client_name: &Ident) -> proc_macro2::TokenStream {
quote! {
#[derive(Clone)]
pub struct #client_name<C>
where
C: occams_rpc::client::ClientCaller,
C: Clone,
C: Sync,
C: 'static,
C::Facts: occams_rpc::client::ClientFacts<Task = occams_rpc::client::task::APIClientReq>,
{
pub endpoint: occams_rpc::client::AsyncEndpoint<C>,
}
}
}
fn generate_client_impl(client_name: &Ident) -> proc_macro2::TokenStream {
let new_method = quote! {
pub fn new(caller: C) -> Self {
Self {
endpoint: occams_rpc::client::AsyncEndpoint::new(caller),
}
}
};
quote! {
impl<C> #client_name<C>
where
C: occams_rpc::client::ClientCaller,
C: Clone + Sync + 'static,
C::Facts: occams_rpc::client::ClientFacts<Task = occams_rpc::client::task::APIClientReq>,
{
#new_method
}
}
}
fn generate_trait_impl(
client_name: &Ident, input_trait: &ItemTrait, has_async_trait: bool,
) -> proc_macro2::TokenStream {
let trait_name = &input_trait.ident;
let trait_items = &input_trait.items;
let mut impl_methods = Vec::new();
for item in trait_items {
if let TraitItem::Method(method) = item {
let method_sig = &method.sig;
let method_name = &method_sig.ident;
let (arg_name, arg_type) = if method_sig.inputs.len() == 1 {
panic!(
"Method `{}` has no parameters. Endpoint methods must have exactly one parameter besides &self.",
method_name
);
} else if method_sig.inputs.len() == 2 {
if let FnArg::Typed(PatType { pat, ty, .. }) = &method_sig.inputs[1] {
if let Pat::Ident(pat_ident) = pat.as_ref() {
(Some(pat_ident.ident.clone()), Some(ty.as_ref()))
} else {
(None, None)
}
} else {
(None, None)
}
} else {
panic!(
"Method `{}` methods must have exactly one parameter besides &self.",
method_name
);
};
let service_method = format!("{}.{}", trait_name, method_name);
let return_type = &method_sig.output;
let returns_impl_future = if let ReturnType::Type(_, ty) = return_type {
get_result_type_from_future(ty).is_some()
} else {
false
};
let is_async_method = method_sig.asyncness.is_some() || returns_impl_future;
if !check_return_type(return_type, returns_impl_future) {
panic!(
"Method `{}` has invalid return type. Endpoint methods must return `Result<_, RpcError<_>>`.",
method_name
);
}
let method_impl = if let Some(arg_type) = arg_type {
let arg_name = arg_name.unwrap();
if is_async_method {
if returns_impl_future {
quote! {
fn #method_name(&self, #arg_name: #arg_type) #return_type {
async move {
self.endpoint.call(#service_method, &#arg_name).await
}
}
}
} else {
quote! {
async fn #method_name(&self, #arg_name: #arg_type) #return_type {
self.endpoint.call(#service_method, &#arg_name).await
}
}
}
} else {
quote! {
async fn #method_name(&self, #arg_name: #arg_type) #return_type {
self.endpoint.call(#service_method, &#arg_name).await
}
}
}
} else {
if is_async_method {
if returns_impl_future {
quote! {
fn #method_name(&self) #return_type {
async move {
self.endpoint.call(#service_method, &()).await
}
}
}
} else {
quote! {
async fn #method_name(&self) #return_type {
self.endpoint.call(#service_method, &()).await
}
}
}
} else {
quote! {
async fn #method_name(&self) #return_type {
self.endpoint.call(#service_method, &()).await
}
}
}
};
impl_methods.push(method_impl);
}
}
if has_async_trait {
quote! {
#[::async_trait::async_trait]
impl<C> #trait_name for #client_name<C>
where
C: occams_rpc::client::ClientCaller,
C: Clone,
C: Sync,
C: 'static,
C::Facts: occams_rpc::client::ClientFacts<Task = occams_rpc::client::task::APIClientReq>,
{
#(#impl_methods)*
}
}
} else {
quote! {
impl<C> #trait_name for #client_name<C>
where
C: occams_rpc::client::ClientCaller,
C: Clone,
C: Sync,
C: 'static,
C::Facts: occams_rpc::client::ClientFacts<Task = occams_rpc::client::task::APIClientReq>,
{
#(#impl_methods)*
}
}
}
}
#[doc(hidden)]
#[allow(dead_code)]
fn test_too_many_arguments_compile_fail() {}
#[doc(hidden)]
#[allow(dead_code)]
fn test_invalid_return_type_compile_fail() {}
#[doc(hidden)]
#[allow(dead_code)]
fn test_sync_method_compile_fail() {}
#[doc(hidden)]
#[allow(dead_code)]
fn test_zero_param_compile_fail() {}