use {
proc_macro::TokenStream,
quote::quote,
syn::{
Expr, ExprLit, ItemImpl, ItemStruct, Lit, Meta, Token, Type,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
},
};
#[derive(Default)]
struct ServiceConfig {
name: Option<String>,
port: Option<u16>,
host: Option<String>,
compression: Option<String>,
max_connections: Option<usize>,
}
impl Parse for ServiceConfig {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut config = ServiceConfig::default();
if input.is_empty() {
return Ok(config);
}
let args = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
for arg in args {
if let Meta::NameValue(nv) = arg {
let ident = nv
.path
.get_ident()
.ok_or_else(|| syn::Error::new_spanned(&nv.path, "Expected identifier"))?;
match ident.to_string().as_str() {
| "name" => {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str), ..
}) = &nv.value
{
config.name = Some(lit_str.value());
}
}
| "port" => {
if let Expr::Lit(ExprLit {
lit: Lit::Int(lit_int), ..
}) = &nv.value
{
config.port = Some(lit_int.base10_parse()?);
}
}
| "host" => {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str), ..
}) = &nv.value
{
config.host = Some(lit_str.value());
}
}
| "compression" => {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit_str), ..
}) = &nv.value
{
config.compression = Some(lit_str.value());
}
}
| "max_connections" => {
if let Expr::Lit(ExprLit {
lit: Lit::Int(lit_int), ..
}) = &nv.value
{
config.max_connections = Some(lit_int.base10_parse()?);
}
}
| _ => {
return Err(syn::Error::new_spanned(
&nv.path,
format!("Unknown configuration option: {ident}"),
));
}
}
}
}
Ok(config)
}
}
#[proc_macro_attribute]
pub fn zus_service(args: TokenStream, input: TokenStream) -> TokenStream {
let config = parse_macro_input!(args as ServiceConfig);
let input_impl = parse_macro_input!(input as ItemImpl);
let struct_name = match &*input_impl.self_ty {
| Type::Path(type_path) => &type_path.path.segments.last().unwrap().ident,
| _ => {
return syn::Error::new_spanned(&input_impl.self_ty, "Expected a simple type path")
.to_compile_error()
.into();
}
};
let mut method_arms = Vec::new();
let mut impl_items = Vec::new();
for item in input_impl.items {
if let syn::ImplItem::Fn(mut method) = item {
let mut method_name = None;
method.attrs.retain(|attr| {
if attr.path().is_ident("method") {
if let Ok(Lit::Str(lit_str)) = attr.parse_args::<Lit>() {
method_name = Some(lit_str.value());
}
false } else {
true }
});
if let Some(rpc_method_name) = method_name {
let fn_name = &method.sig.ident;
method_arms.push(quote! {
#rpc_method_name => {
self.#fn_name(params, context).await
}
});
}
impl_items.push(syn::ImplItem::Fn(method));
} else {
impl_items.push(item);
}
}
let service_name = config.name.unwrap_or_else(|| struct_name.to_string());
let self_ty = &input_impl.self_ty;
let generics = &input_impl.generics;
let where_clause = &input_impl.generics.where_clause;
let config_impl = if config.port.is_some()
|| config.host.is_some()
|| config.compression.is_some()
|| config.max_connections.is_some()
{
let port = config.port.unwrap_or(9527);
let host = config.host.unwrap_or_else(|| "0.0.0.0".to_string());
let compression = config.compression.unwrap_or_else(|| "quicklz".to_string());
let max_connections = config.max_connections.unwrap_or(1000);
quote! {
impl #struct_name {
pub const DEFAULT_PORT: u16 = #port;
pub const DEFAULT_HOST: &'static str = #host;
pub const DEFAULT_COMPRESSION: &'static str = #compression;
pub const DEFAULT_MAX_CONNECTIONS: usize = #max_connections;
pub fn create_server() -> zus_rpc_server::ZusServerManager {
zus_rpc_server::ZusServerManager::new(
Self::DEFAULT_HOST.to_string(),
Self::DEFAULT_PORT
)
.with_max_connections(Self::DEFAULT_MAX_CONNECTIONS)
}
}
}
} else {
quote! {}
};
let expanded = quote! {
impl #generics #self_ty #where_clause {
#(#impl_items)*
}
#config_impl
#[async_trait::async_trait]
impl zus_rpc_server::Service for #struct_name {
fn service_name(&self) -> &str {
#service_name
}
async fn do_work(
&self,
method: &str,
params: bytes::Bytes,
context: zus_rpc_server::RequestContext,
) -> zus_common::Result<bytes::Bytes> {
match method {
#(#method_arms)*
_ => Err(zus_common::ZusError::MethodNotFound(
format!("Unknown method: {}", method)
))
}
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_attribute]
pub fn method(_args: TokenStream, input: TokenStream) -> TokenStream {
input
}
#[proc_macro_derive(ZusService, attributes(service_name))]
pub fn derive_zus_service(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as ItemStruct);
let name = &input.ident;
let service_name = input
.attrs
.iter()
.find(|attr| attr.path().is_ident("service_name"))
.and_then(|attr| {
if let Meta::NameValue(ref meta) = attr.meta
&& let Expr::Lit(ExprLit {
lit: Lit::Str(ref lit_str),
..
}) = meta.value
{
return Some(lit_str.value());
}
None
})
.unwrap_or_else(|| name.to_string());
let expanded = quote! {
impl #name {
pub const SERVICE_NAME: &'static str = #service_name;
pub const IS_ZUS_SERVICE: bool = true;
}
};
TokenStream::from(expanded)
}
#[proc_macro]
pub fn zus_server(_input: TokenStream) -> TokenStream {
let expanded = quote! {
compile_error!("zus_server! macro not yet fully implemented. Use ZusServerManager directly.");
};
TokenStream::from(expanded)
}