mod trait_reflection;
use proc_macro::TokenStream;
use quote::quote;
use syn::{ImplItem, ItemImpl, parse_macro_input};
#[proc_macro_attribute]
pub fn instrumented_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let impl_block = parse_macro_input!(item as ItemImpl);
#[cfg(kani)]
{
return TokenStream::from(quote! { #impl_block });
}
#[cfg(not(kani))]
{
let mut impl_block = impl_block;
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
if matches!(method.vis, syn::Visibility::Public(_)) {
let method_name = method.sig.ident.to_string();
let has_generics = !method.sig.generics.params.is_empty();
let instrument_attr = if is_constructor(&method_name) {
if has_generics {
let param_names: Vec<_> = method
.sig
.inputs
.iter()
.filter_map(|arg| {
if let syn::FnArg::Typed(pat_type) = arg
&& let syn::Pat::Ident(ident) = &*pat_type.pat
{
return Some(ident.ident.clone());
}
None
})
.collect();
quote! {
#[tracing::instrument(skip(#(#param_names),*), err)]
}
} else {
quote! {
#[tracing::instrument(err)]
}
}
} else if is_accessor(&method_name) {
quote! {
#[tracing::instrument(level = "trace", ret)]
}
} else {
quote! {
#[tracing::instrument(skip(self))]
}
};
let attr: syn::Attribute = syn::parse_quote! { #instrument_attr };
method.attrs.insert(0, attr);
}
}
}
TokenStream::from(quote! { #impl_block })
}
}
fn is_constructor(name: &str) -> bool {
name == "new" || name.starts_with("from_") || name.starts_with("try_") || name == "default"
}
fn is_accessor(name: &str) -> bool {
name == "get"
|| name == "into_inner"
|| name.starts_with("as_")
|| name.starts_with("to_")
|| name.starts_with("get_")
}
#[proc_macro_attribute]
pub fn elicit_tools(attr: TokenStream, item: TokenStream) -> TokenStream {
let impl_block = parse_macro_input!(item as ItemImpl);
let types_input = attr.to_string();
let types: Vec<&str> = types_input
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if types.is_empty() {
return syn::Error::new_spanned(
&impl_block,
"elicit_tools requires at least one type: #[elicit_tools(Type1, Type2)]",
)
.to_compile_error()
.into();
}
let mut new_impl = impl_block.clone();
for ty_str in types {
let ty: syn::Type = match syn::parse_str(ty_str) {
Ok(t) => t,
Err(e) => {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse type '{}': {}", ty_str, e),
)
.to_compile_error()
.into();
}
};
let method_name = to_snake_case(ty_str);
let method_ident = syn::Ident::new(
&format!("elicit_{}", method_name),
proc_macro2::Span::call_site(),
);
let tool_description = format!("Elicit {} via MCP", ty_str);
let method: syn::ImplItemFn = syn::parse_quote! {
#[doc = concat!("Elicit `", #ty_str, "` via MCP.")]
#[tool(description = #tool_description)]
pub async fn #method_ident(
peer: ::rmcp::service::Peer<::rmcp::service::RoleServer>,
) -> ::std::result::Result<
::rmcp::handler::server::wrapper::Json<::elicitation::ElicitToolOutput<#ty>>,
::rmcp::ErrorData
> {
let value = #ty::elicit_checked(peer).await
.map_err(|e| ::rmcp::ErrorData::internal_error(e.to_string(), None))?;
Ok(::rmcp::handler::server::wrapper::Json(::elicitation::ElicitToolOutput::new(value)))
}
};
new_impl.items.push(syn::ImplItem::Fn(method));
}
TokenStream::from(quote! { #new_impl })
}
fn to_snake_case(s: &str) -> String {
let mut result = String::new();
let mut prev_was_lowercase = false;
for (i, ch) in s.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 && prev_was_lowercase {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
prev_was_lowercase = false;
} else {
result.push(ch);
prev_was_lowercase = ch.is_lowercase();
}
}
result
}
#[proc_macro_attribute]
pub fn elicit_trait_tools_router(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_str = attr.to_string();
let mut parts = Vec::new();
let mut current = String::new();
let mut bracket_depth = 0;
for ch in attr_str.chars() {
match ch {
'[' => {
bracket_depth += 1;
current.push(ch);
}
']' => {
bracket_depth -= 1;
current.push(ch);
}
',' if bracket_depth == 0 => {
parts.push(current.trim().to_string());
current.clear();
}
_ => current.push(ch),
}
}
if !current.is_empty() {
parts.push(current.trim().to_string());
}
if parts.len() != 3 {
return syn::Error::new(
proc_macro2::Span::call_site(),
"elicit_trait_tools_router requires three arguments: #[elicit_trait_tools_router(TraitName, field_name, [method1, method2])]",
)
.to_compile_error()
.into();
}
let _trait_name = &parts[0];
let field_name = &parts[1];
let methods_str = &parts[2];
let methods_str = methods_str.trim_start_matches('[').trim_end_matches(']');
let methods: Vec<&str> = methods_str
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
if methods.is_empty() {
return syn::Error::new(
proc_macro2::Span::call_site(),
"elicit_trait_tools_router requires at least one method in the list",
)
.to_compile_error()
.into();
}
let mut impl_block = parse_macro_input!(item as ItemImpl);
for method_name in methods {
let pascal_case = to_pascal_case(method_name);
let params_type = format!("{}Params", pascal_case);
let result_type = format!("{}Result", pascal_case);
let params_ty: syn::Type = match syn::parse_str(¶ms_type) {
Ok(t) => t,
Err(e) => {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse params type '{}': {}", params_type, e),
)
.to_compile_error()
.into();
}
};
let result_ty: syn::Type = match syn::parse_str(&result_type) {
Ok(t) => t,
Err(e) => {
return syn::Error::new(
proc_macro2::Span::call_site(),
format!("Failed to parse result type '{}': {}", result_type, e),
)
.to_compile_error()
.into();
}
};
let method_ident = syn::Ident::new(method_name, proc_macro2::Span::call_site());
let field_ident = syn::Ident::new(field_name, proc_macro2::Span::call_site());
let tool_description = format!("{} operation", method_name.replace('_', " "));
let method: syn::ImplItemFn = syn::parse_quote! {
#[doc = concat!("`", #method_name, "` operation via trait method delegation.")]
#[::rmcp::tool(description = #tool_description)]
pub async fn #method_ident(
&self,
params: ::rmcp::handler::server::wrapper::Parameters<#params_ty>,
) -> ::std::result::Result<
::rmcp::handler::server::wrapper::Json<#result_ty>,
::rmcp::ErrorData
> {
self.#field_ident.#method_ident(params).await
}
};
impl_block.items.push(syn::ImplItem::Fn(method));
}
TokenStream::from(quote! { #impl_block })
}
fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|word| {
let mut chars = word.chars();
match chars.next() {
None => String::new(),
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
#[proc_macro_attribute]
pub fn reflect_trait(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr2 = proc_macro2::TokenStream::from(attr);
let item2 = proc_macro2::TokenStream::from(item);
match trait_reflection::expand(attr2, item2) {
Ok(ts) => ts.into(),
Err(e) => e.to_compile_error().into(),
}
}