use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, Attribute, Data, DeriveInput, Expr, ExprLit, Fields, Lit};
struct RpcAttr {
status: String,
code: u32,
}
impl Default for RpcAttr {
fn default() -> Self {
Self {
status: "Error".to_string(),
code: 500,
}
}
}
fn parse_rpc_attr(attrs: &[Attribute]) -> RpcAttr {
let mut rpc_attr = RpcAttr::default();
for attr in attrs {
if !attr.path().is_ident("rpc") {
continue;
}
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("status") {
let value: Expr = meta.value()?.parse()?;
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = value
{
rpc_attr.status = lit.value();
}
} else if meta.path.is_ident("code") {
let value: Expr = meta.value()?.parse()?;
if let Expr::Lit(ExprLit {
lit: Lit::Int(lit), ..
}) = value
{
rpc_attr.code = lit.base10_parse().unwrap_or(500);
}
}
Ok(())
});
}
rpc_attr
}
fn status_to_tokens(status: &str) -> TokenStream2 {
match status {
"Ok" => quote! { synapse_proto::RpcStatus::Ok },
"Error" => quote! { synapse_proto::RpcStatus::Error },
"Timeout" => quote! { synapse_proto::RpcStatus::Timeout },
"InterfaceNotFound" => quote! { synapse_proto::RpcStatus::InterfaceNotFound },
"MethodNotFound" => quote! { synapse_proto::RpcStatus::MethodNotFound },
"Unavailable" => quote! { synapse_proto::RpcStatus::Unavailable },
"InvalidRequest" => quote! { synapse_proto::RpcStatus::InvalidRequest },
"UnsupportedVersion" => quote! { synapse_proto::RpcStatus::UnsupportedVersion },
"ServiceDraining" => quote! { synapse_proto::RpcStatus::ServiceDraining },
"ServiceFrozen" => quote! { synapse_proto::RpcStatus::ServiceFrozen },
"MessageTooLarge" => quote! { synapse_proto::RpcStatus::MessageTooLarge },
_ => quote! { synapse_proto::RpcStatus::Error },
}
}
pub fn derive_rpc_error_inner(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
let name = &input.ident;
let variants = match &input.data {
Data::Enum(data) => &data.variants,
_ => {
return syn::Error::new_spanned(&input, "RpcError can only be derived for enums")
.to_compile_error()
.into();
}
};
let match_arms: Vec<TokenStream2> = variants
.iter()
.map(|variant| {
let variant_name = &variant.ident;
let rpc_attr = parse_rpc_attr(&variant.attrs);
let status = status_to_tokens(&rpc_attr.status);
let code = rpc_attr.code;
let pattern = match &variant.fields {
Fields::Unit => quote! { #name::#variant_name },
Fields::Unnamed(_) => quote! { #name::#variant_name(..) },
Fields::Named(_) => quote! { #name::#variant_name { .. } },
};
quote! {
#pattern => synapse_rpc::ServiceError::new(#status, #code, err.to_string()),
}
})
.collect();
let expanded = quote! {
impl From<#name> for synapse_rpc::ServiceError {
fn from(err: #name) -> synapse_rpc::ServiceError {
use std::string::ToString;
match &err {
#(#match_arms)*
}
}
}
impl synapse_rpc::IntoServiceError for #name {
fn into_service_error(self) -> synapse_rpc::ServiceError {
self.into()
}
}
};
expanded.into()
}