use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{quote, ToTokens};
use syn::{
parse_macro_input,
spanned::Spanned,
visit::Visit,
Expr,
ExprMethodCall,
FnArg,
ImplItem,
ImplItemFn,
ItemImpl,
Pat,
Type,
TypeImplTrait,
TypeParamBound,
};
enum NetworkType {
Sink,
Stream,
MultipartyInterface(String), }
fn determine_network_info(network_types: &[NetworkType]) -> TokenStream2 {
if let Some(net_param) = network_types.iter().find_map(|t| match t {
NetworkType::MultipartyInterface(name) => Some(name.clone()),
_ => None,
}) {
let net_ident = syn::Ident::new(&net_param, proc_macro2::Span::call_site());
return quote! {
format!("network interface {:?} as peer {:?}",
network::context::Peer::context_id(&#net_ident),
network::context::Peer::local_peer_id(&#net_ident))
};
}
let has_sink = network_types.iter().any(|t| matches!(t, NetworkType::Sink));
let has_stream = network_types
.iter()
.any(|t| matches!(t, NetworkType::Stream));
match (has_sink, has_stream) {
(true, true) => quote! { "sink/stream channels".to_string() },
(true, false) => quote! { "a sink channel".to_string() },
(false, true) => quote! { "a stream channel".to_string() },
(false, false) => quote! { "local execution".to_string() },
}
}
fn bound_contains_trait(bound: &TypeParamBound, trait_name: &str) -> bool {
matches!(bound, TypeParamBound::Trait(tb) if tb.path.segments.iter().any(|s| s.ident == trait_name))
}
fn extract_network_types(func: &ImplItemFn) -> Vec<NetworkType> {
func.sig
.inputs
.iter()
.filter_map(|arg| {
let FnArg::Typed(pat_type) = arg else {
return None;
};
let Pat::Ident(pat_ident) = &*pat_type.pat else {
return None;
};
let Type::ImplTrait(TypeImplTrait { bounds, .. }) = &*pat_type.ty else {
return None;
};
let param_name = pat_ident.ident.to_string();
Some(bounds.iter().filter_map(move |bound| {
if bound_contains_trait(bound, "MultipartyInterface") {
Some(NetworkType::MultipartyInterface(param_name.clone()))
} else if bound_contains_trait(bound, "IoSink") {
Some(NetworkType::Sink)
} else if bound_contains_trait(bound, "IoStream") {
Some(NetworkType::Stream)
} else {
None
}
}))
})
.flatten()
.collect()
}
fn has_ref_receiver(func: &ImplItemFn) -> bool {
func.sig
.inputs
.iter()
.any(|arg| matches!(arg, FnArg::Receiver(r) if r.reference.is_some()))
}
fn generate_log_statement(
struct_name: &syn::Ident,
trait_name: Option<&str>,
func: &ImplItemFn,
) -> TokenStream2 {
let func_name = func.sig.ident.to_string();
let network_info = determine_network_info(&extract_network_types(func));
let struct_name_str = struct_name.to_string();
let func_display = trait_name.map_or(func_name.clone(), |t| format!("{t}.{func_name}"));
let self_ref = if has_ref_receiver(func) {
quote! { self }
} else {
quote! { &self }
};
quote! {
log::debug!("<{} - {}> {} with {:?} over {}",
#struct_name_str, PROTOCOL_INFO.name(), #func_display,
crate::types::party::Party::session_id(#self_ref), #network_info);
}
}
fn inject_logging(
mut func: ImplItemFn,
struct_name: &syn::Ident,
trait_name: Option<&str>,
) -> ImplItemFn {
let log_stmt = generate_log_statement(struct_name, trait_name, &func);
let stmts = &func.block.stmts;
func.block = syn::parse_quote! {{ #log_stmt #(#stmts)* }};
func
}
fn extract_trait_name(impl_block: &ItemImpl) -> Option<String> {
impl_block.trait_.as_ref().map(|(_, path, _)| {
path.segments.last().map_or_else(
|| path.to_token_stream().to_string(),
|s| s.ident.to_string(),
)
})
}
struct RefreshVisitor(bool);
impl<'ast> Visit<'ast> for RefreshVisitor {
fn visit_expr_method_call(&mut self, node: &'ast ExprMethodCall) {
if node.method == "refresh" {
if let Expr::Path(p) = &*node.receiver {
if p.path.is_ident("self") {
self.0 = true;
}
}
}
syn::visit::visit_expr_method_call(self, node);
}
}
fn has_ref_self_fn(impl_block: &ItemImpl) -> bool {
impl_block.items.iter().any(|item| {
matches!(item, ImplItem::Fn(f) if f.sig.inputs.iter().any(
|a| matches!(a, FnArg::Receiver(r) if r.reference.is_some())))
})
}
fn contains_refresh_call(impl_block: &ItemImpl) -> bool {
let mut v = RefreshVisitor(false);
impl_block.items.iter().for_each(|item| {
if let ImplItem::Fn(f) = item {
v.visit_impl_item_fn(f);
}
});
v.0
}
pub fn protocol_trait_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let _ = attr;
let mut impl_block = parse_macro_input!(item as ItemImpl);
if has_ref_self_fn(&impl_block) && !contains_refresh_call(&impl_block) {
return syn::Error::new(
impl_block.span(),
"Protocol impl block must contain at least one `self.refresh()` call",
)
.to_compile_error()
.into();
}
let trait_name = extract_trait_name(&impl_block);
let struct_name = match &*impl_block.self_ty {
Type::Path(tp) => tp
.path
.segments
.last()
.map(|s| s.ident.clone())
.unwrap_or_else(|| syn::Ident::new("Unknown", proc_macro2::Span::call_site())),
_ => syn::Ident::new("Unknown", proc_macro2::Span::call_site()),
};
impl_block.items = impl_block
.items
.into_iter()
.map(|item| match item {
ImplItem::Fn(f) => ImplItem::Fn(inject_logging(f, &struct_name, trait_name.as_deref())),
other => other,
})
.collect();
quote! { #impl_block }.into()
}