use proc_macro2::TokenStream;
use quote::quote;
use std::collections::HashMap;
use syn::{Error, FnArg, ItemTrait, Lit, Pat, TraitItem, parse::Parser, parse2};
mod chain_extension;
mod chain_method;
mod method_impl;
mod types;
mod utils;
use chain_extension::generate_chain_extension_method;
use chain_method::generate_chain_method;
use method_impl::generate_method_impl;
use types::MethodAttrs;
use utils::{ParamAttrs, build_combined_where_clause, extract_param_attrs};
pub(crate) fn proxy(attr: TokenStream, input: TokenStream) -> TokenStream {
match proxy_impl(attr, input) {
Ok(tokens) => tokens,
Err(err) => err.to_compile_error(),
}
}
fn proxy_impl(attr: TokenStream, input: TokenStream) -> Result<TokenStream, Error> {
let mut trait_def = parse2::<ItemTrait>(input)?;
let (interface_name, crate_path, chain_name) = parse_proxy_attributes(&attr, &trait_def)?;
validate_trait(&trait_def)?;
let mut methods = Vec::new();
let mut chain_method_traits = Vec::new();
let mut chain_method_impls = Vec::new();
let mut chain_extension_methods = Vec::new();
let mut chain_extension_impls = Vec::new();
for item in &mut trait_def.items {
if let TraitItem::Fn(method) = item {
let method_attrs = MethodAttrs::extract(&mut method.attrs)?;
let mut param_attrs_map: HashMap<String, ParamAttrs> = HashMap::new();
for arg in method.sig.inputs.iter_mut().skip(1) {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
let param_name = pat_ident.ident.to_string();
let attrs = extract_param_attrs(&mut pat_type.attrs)?;
param_attrs_map.insert(param_name, attrs);
}
}
}
let (extension_method, extension_impl) = generate_chain_extension_method(
method,
&interface_name,
&trait_def.generics,
&method_attrs,
&crate_path,
¶m_attrs_map,
)?;
if !extension_method.is_empty() {
chain_extension_methods.push(extension_method);
}
if !extension_impl.is_empty() {
chain_extension_impls.push(extension_impl);
}
let method_impl = generate_method_impl(
method,
&interface_name,
&trait_def.generics,
&method_attrs,
&crate_path,
¶m_attrs_map,
)?;
methods.push(method_impl);
let (chain_trait, chain_impl) = generate_chain_method(
method,
&interface_name,
&trait_def.generics,
&method_attrs,
&crate_path,
¶m_attrs_map,
)?;
if !chain_trait.is_empty() {
chain_method_traits.push(chain_trait);
}
if !chain_impl.is_empty() {
chain_method_impls.push(chain_impl);
}
}
}
let trait_output = build_trait_output(&mut trait_def, &chain_method_traits, &crate_path)?;
let impl_output = build_impl_output(
&trait_def.ident,
&trait_def.generics,
&trait_def.generics.where_clause,
&methods,
&chain_method_impls,
&crate_path,
);
let chain_extension_trait_output = build_chain_extension_trait(
&trait_def.ident,
&chain_extension_methods,
&chain_extension_impls,
&crate_path,
chain_name,
);
Ok(quote! {
#trait_output
#impl_output
#chain_extension_trait_output
})
}
fn parse_proxy_attributes(
attr: &TokenStream,
trait_def: &ItemTrait,
) -> Result<(String, TokenStream, Option<syn::Ident>), Error> {
if attr.is_empty() {
return Err(Error::new_spanned(
trait_def,
"proxy macro requires interface name, e.g. #[proxy(\"org.example.Interface\")] \
or #[proxy(interface = \"org.example.Interface\")]",
));
}
if let Ok(Lit::Str(lit_str)) = parse2::<Lit>(attr.clone()) {
return Ok((lit_str.value(), quote! { ::zlink }, None));
}
let mut interface_name = None;
let mut crate_path = None;
let mut chain_name = None;
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("interface") {
let value: syn::LitStr = meta.value()?.parse()?;
interface_name = Some(value.value());
} else if meta.path.is_ident("crate") {
let value: syn::LitStr = meta.value()?.parse()?;
let path_str = value.value();
crate_path = Some(syn::parse_str(&path_str)?);
} else if meta.path.is_ident("chain_name") {
let value: syn::LitStr = meta.value()?.parse()?;
chain_name = Some(syn::Ident::new(&value.value(), value.span()));
} else {
return Err(meta.error("unsupported attribute"));
}
Ok(())
});
parser.parse2(attr.clone())?;
let interface_name = interface_name.ok_or_else(|| {
Error::new_spanned(
trait_def,
"proxy macro requires 'interface' parameter, \
e.g. #[proxy(interface = \"org.example.Interface\")]",
)
})?;
let crate_path = crate_path.unwrap_or_else(|| quote! { ::zlink });
Ok((interface_name, crate_path, chain_name))
}
fn validate_trait(trait_def: &ItemTrait) -> Result<(), Error> {
if !trait_def.items.is_empty()
&& trait_def
.items
.iter()
.any(|item| !matches!(item, TraitItem::Fn(_)))
{
return Err(Error::new_spanned(
trait_def,
"proxy macro only supports traits with method definitions",
));
}
Ok(())
}
fn build_trait_output(
trait_def: &mut ItemTrait,
chain_method_traits: &[TokenStream],
crate_path: &TokenStream,
) -> Result<TokenStream, Error> {
trait_def.items.push(syn::parse2(quote! {
type Socket: #crate_path::connection::socket::Socket;
})?);
for chain_trait in chain_method_traits {
trait_def.items.push(syn::parse2(chain_trait.clone())?);
}
Ok(quote! {
#[allow(async_fn_in_trait)]
#trait_def
})
}
fn build_impl_output(
trait_name: &syn::Ident,
generics: &syn::Generics,
where_clause: &Option<syn::WhereClause>,
methods: &[TokenStream],
chain_method_impls: &[TokenStream],
crate_path: &TokenStream,
) -> TokenStream {
let mut impl_generics = generics.clone();
impl_generics.params.push(syn::parse_quote!(S));
let mut trait_generics_no_bounds = generics.clone();
for param in &mut trait_generics_no_bounds.params {
if let syn::GenericParam::Type(type_param) = param {
type_param.bounds.clear();
}
}
let combined_where_clause = Some(build_combined_where_clause(
where_clause.clone(),
syn::parse_quote!(S: #crate_path::connection::socket::Socket),
generics,
));
quote! {
impl #impl_generics #trait_name #trait_generics_no_bounds for #crate_path::Connection<S>
#combined_where_clause
{
type Socket = S;
#(#methods)*
#(#chain_method_impls)*
}
}
}
fn build_chain_extension_trait(
trait_name: &syn::Ident,
chain_extension_methods: &[TokenStream],
chain_extension_impls: &[TokenStream],
crate_path: &TokenStream,
custom_chain_name: Option<syn::Ident>,
) -> TokenStream {
if chain_extension_methods.is_empty() && custom_chain_name.is_none() {
return quote! {};
}
let chain_trait_name = custom_chain_name
.unwrap_or_else(|| syn::Ident::new(&format!("{trait_name}Chain"), trait_name.span()));
quote! {
pub trait #chain_trait_name<'c, S>
where
S: #crate_path::connection::socket::Socket,
{
#(#chain_extension_methods)*
}
impl<'c, S> #chain_trait_name<'c, S>
for #crate_path::connection::chain::Chain<'c, S>
where
S: #crate_path::connection::socket::Socket,
{
#(#chain_extension_impls)*
}
}
}