use proc_macro::TokenStream;
use quote::quote;
use syn::{ImplItem, ItemImpl, parse_macro_input};
#[proc_macro_attribute]
pub fn instrumented_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let impl_block = parse_macro_input!(item as ItemImpl);
#[cfg(kani)]
{
return TokenStream::from(quote! { #impl_block });
}
#[cfg(not(kani))]
{
let mut impl_block = impl_block;
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
if matches!(method.vis, syn::Visibility::Public(_)) {
let method_name = method.sig.ident.to_string();
let has_generics = !method.sig.generics.params.is_empty();
let instrument_attr = if is_constructor(&method_name) {
if has_generics {
let param_names: Vec<_> = method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg
&& let syn::Pat::Ident(ident) = &*pat_type.pat
{
return Some(ident.ident.clone());
}
None
})
.collect();
quote! {
#[tracing::instrument(skip(#(#param_names),*), err)]
}
} else {
quote! {
#[tracing::instrument(err)]
}
}
} else if is_accessor(&method_name) {
quote! {
#[tracing::instrument(level = "trace", ret)]
}
} else {
quote! {
#[tracing::instrument(skip(self))]
}
};
let attr: syn::Attribute = syn::parse_quote! { #instrument_attr };
method.attrs.insert(0, attr);
}
}
}
TokenStream::from(quote! { #impl_block })
}
}
fn is_constructor(name: &str) -> bool {
name == "new" || name.starts_with("from_") || name.starts_with("try_") || name == "default"
}
fn is_accessor(name: &str) -> bool {
name == "get"
|| name == "into_inner"
|| name.starts_with("as_")
|| name.starts_with("to_")
|| name.starts_with("get_")
}
#[proc_macro_attribute]
pub fn elicit_tools(attr: TokenStream, item: TokenStream) -> TokenStream {
let impl_block = parse_macro_input!(item as ItemImpl);
let types_input = attr.to_string();
let types: Vec<&str> = types_input
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if types.is_empty() {
return syn::Error::new_spanned(
&impl_block,
"elicit_tools requires at least one type: #[elicit_tools(Type1, Type2)]",
)
.to_compile_error()
.into();
}
let mut new_impl = impl_block.clone();
for ty_str in types {
let ty: syn::Type = match syn::parse_str(ty_str) {
Ok(t) => t,
Err(e) => {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse type '{}': {}", ty_str, e),
)
.to_compile_error()
.into();
}
};
let method_name = to_snake_case(ty_str);
let method_ident = syn::Ident::new(
&format!("elicit_{}", method_name),
proc_macro2::Span::call_site(),
);
let tool_description = format!("Elicit {} via MCP", ty_str);
let method: syn::ImplItemFn = syn::parse_quote! {
#[doc = concat!("Elicit `", #ty_str, "` via MCP.")]
#[tool(description = #tool_description)]
pub async fn #method_ident(
peer: ::rmcp::service::Peer<::rmcp::service::RoleServer>,
) -> ::std::result::Result<::rmcp::handler::server::wrapper::Json<#ty>, ::rmcp::ErrorData> {
#ty::elicit_checked(peer)
.await
.map(::rmcp::handler::server::wrapper::Json)
.map_err(|e| ::rmcp::ErrorData::internal_error(e.to_string(), None))
}
};
new_impl.items.push(syn::ImplItem::Fn(method));
}
TokenStream::from(quote! { #new_impl })
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
let mut prev_was_lowercase = false;
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 && prev_was_lowercase {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
prev_was_lowercase = false;
} else {
result.push(ch);
prev_was_lowercase = ch.is_lowercase();
}
}
result
}