use proc_macro2::TokenStream;
#[cfg(feature = "weak_default")]
use quote::format_ident;
use quote::quote;
use syn::{parse_quote, Error, ItemTrait, TraitItem};
#[cfg(feature = "weak_default")]
use syn::{
punctuated::Punctuated, visit_mut::VisitMut, Block, Expr, ExprPath, Ident, Path, PathSegment,
Signature,
};
#[cfg(feature = "weak_default")]
use std::collections::HashMap;
use crate::args::DefInterfaceArgs;
use crate::errors::generic_not_allowed_error;
#[cfg(not(feature = "weak_default"))]
use crate::errors::weak_default_required_error;
use crate::naming::{
alias_guard_name, extern_fn_mod_name, extern_fn_name, extract_caller_args, namespace_guard_name,
};
use crate::validator::validate_fn_signature;
#[cfg(feature = "weak_default")]
fn rewrite_self_in_default_body(
default_body: &Block,
trait_name: &Ident,
namespace: Option<&str>,
method_signatures: &HashMap<String, Signature>,
) -> TokenStream {
struct SelfRefRewriter<'a> {
trait_name: &'a Ident,
namespace: Option<&'a str>,
method_signatures: &'a HashMap<String, Signature>,
generated_proxies: HashMap<String, TokenStream>,
}
impl SelfRefRewriter<'_> {
fn is_self_method_path(expr: &Expr) -> Option<Ident> {
if let Expr::Path(path_expr) = expr {
let path = &path_expr.path;
if path.segments.len() == 2 {
let first_seg = &path.segments[0];
let second_seg = &path.segments[1];
if first_seg.ident == "Self" {
return Some(second_seg.ident.clone());
}
}
}
None
}
fn proxy_name(method_name: &Ident) -> Ident {
format_ident!("__self_proxy_{}", method_name)
}
fn ensure_proxy_fn(&mut self, method_name: Ident) -> Option<Ident> {
let method_key = method_name.to_string();
if self.generated_proxies.contains_key(&method_key) {
return Some(Self::proxy_name(&method_name));
}
let sig = self.method_signatures.get(&method_key)?;
let mod_name = extern_fn_mod_name(self.trait_name);
let extern_fn = extern_fn_name(self.namespace, &self.trait_name, &method_name);
let proxy_name = Self::proxy_name(&method_name);
let caller_args = extract_caller_args(sig).ok()?;
let mut proxy_sig = sig.clone();
proxy_sig.ident = proxy_name.clone();
let proxy_fn = quote! {
#[allow(non_snake_case)]
#proxy_sig {
unsafe { #mod_name :: #extern_fn ( #caller_args ) }
}
};
self.generated_proxies.insert(method_key, proxy_fn);
Some(proxy_name)
}
fn replace_with_proxy(&mut self, expr: &mut Expr, method_name: Ident) {
if let Some(proxy_name) = self.ensure_proxy_fn(method_name) {
*expr = Expr::Path(ExprPath {
attrs: vec![],
qself: None,
path: Path {
leading_colon: None,
segments: {
let mut segs = Punctuated::new();
segs.push(PathSegment::from(proxy_name));
segs
},
},
});
}
}
}
impl VisitMut for SelfRefRewriter<'_> {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
syn::visit_mut::visit_expr_mut(self, expr);
if let Some(method_name) = Self::is_self_method_path(expr) {
self.replace_with_proxy(expr, method_name);
}
}
}
let mut body = default_body.clone();
let mut rewriter = SelfRefRewriter {
trait_name,
namespace,
method_signatures,
generated_proxies: HashMap::new(),
};
rewriter.visit_block_mut(&mut body);
let proxy_fns: Vec<_> = rewriter.generated_proxies.into_values().collect();
if proxy_fns.is_empty() {
quote! { #body }
} else {
let stmts = &body.stmts;
quote! {
{
#(#proxy_fns)*
#(#stmts)*
}
}
}
}
pub fn def_interface(
mut ast: ItemTrait,
macro_arg: DefInterfaceArgs,
) -> Result<TokenStream, Error> {
let trait_name = &ast.ident;
let vis = &ast.vis;
if !ast.generics.params.is_empty() {
return Err(generic_not_allowed_error(&ast.generics));
}
let mod_name = extern_fn_mod_name(trait_name);
#[cfg(feature = "weak_default")]
let mut method_signatures: HashMap<String, Signature> = HashMap::new();
#[cfg(feature = "weak_default")]
for item in &ast.items {
if let TraitItem::Fn(method) = item {
let sig = &method.sig;
method_signatures.insert(sig.ident.to_string(), sig.clone());
}
}
let mut extern_fn_list = vec![];
let mut callers: Vec<TokenStream> = vec![];
for item in &mut ast.items {
if let TraitItem::Fn(method) = item {
let sig = &method.sig;
let fn_name = &sig.ident;
validate_fn_signature(sig)?;
let extern_fn_name =
extern_fn_name(macro_arg.namespace.as_deref(), trait_name, fn_name);
let mut extern_fn_sig = sig.clone();
extern_fn_sig.ident = extern_fn_name.clone();
extern_fn_list.push(quote! {
pub #extern_fn_sig;
});
#[cfg(not(feature = "weak_default"))]
if method.default.is_some() {
return Err(weak_default_required_error(method));
}
#[cfg(feature = "weak_default")]
if let Some(default_body) = &mut method.default {
let default_body_cleaned = rewrite_self_in_default_body(
default_body,
trait_name,
macro_arg.namespace.as_deref(),
&method_signatures,
);
let weak_default_impl = quote! {
#[allow(non_snake_case)]
#[linkage = "weak"]
#[no_mangle]
extern "Rust" #extern_fn_sig #default_body_cleaned
};
let caller_args = extract_caller_args(sig)?;
*default_body = syn::parse2(quote! {{
#weak_default_impl
#extern_fn_name ( #caller_args )
}})?;
}
if macro_arg.gen_caller {
let attrs = &method.attrs;
let caller_fn_sig = sig.clone();
let caller_args = extract_caller_args(sig)?;
callers.push(quote! {
#(#attrs)*
#[inline]
#vis #caller_fn_sig {
unsafe { #mod_name :: #extern_fn_name ( #caller_args ) }
}
})
}
}
}
let alias_guard_name = alias_guard_name(trait_name);
let alias_guard = parse_quote!(
#[allow(non_upper_case_globals)]
#[doc(hidden)]
const #alias_guard_name: () = ();
);
ast.items.push(alias_guard);
if let Some(ns) = ¯o_arg.namespace {
let ns_guard_name = namespace_guard_name(ns);
let ns_guard = parse_quote!(
#[allow(non_upper_case_globals)]
#[doc(hidden)]
const #ns_guard_name: ();
);
ast.items.push(ns_guard);
}
Ok(quote! {
#ast
#[doc(hidden)]
#[allow(non_snake_case)]
#vis mod #mod_name {
use super::*;
extern "Rust" {
#(#extern_fn_list)*
}
}
#(#callers)*
})
}