use darling::{ast::NestedMeta, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Attribute, ImplItem, ItemImpl};
#[derive(Debug, Default, FromMeta)]
pub struct McpToolRouterAttrs {
#[darling(default)]
pub router: Option<String>,
}
pub fn expand_mcp_tool_router(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let attrs = match McpToolRouterAttrs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input_impl = parse_macro_input!(item as ItemImpl);
match generate_router_impl(attrs, input_impl) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_router_impl(
attrs: McpToolRouterAttrs,
input_impl: ItemImpl,
) -> syn::Result<TokenStream2> {
let router_fn_name = format_ident!(
"{}",
attrs.router.unwrap_or_else(|| "tool_router".to_string())
);
let self_ty = &input_impl.self_ty;
let (impl_generics, _ty_generics, where_clause) = input_impl.generics.split_for_impl();
let mut tool_methods = Vec::new();
for item in &input_impl.items {
if let ImplItem::Fn(method) = item
&& has_mcp_tool_attr(&method.attrs)
{
let method_name = &method.sig.ident;
let info_fn = format_ident!("{}_tool_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs);
tool_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
}
let route_adds: Vec<TokenStream2> = tool_methods
.iter()
.map(|(info_fn, handler_fn, visibility_fn, has_visibility)| {
if *has_visibility {
quote! {
.with_tool(
Self::#info_fn(),
Self::#handler_fn,
Some(Self::#visibility_fn)
)
}
} else {
quote! {
.with_tool(
Self::#info_fn(),
Self::#handler_fn,
None
)
}
}
})
.collect();
let expanded = quote! {
#input_impl
impl #impl_generics #self_ty #where_clause {
pub fn #router_fn_name() -> mcp_host::registry::router::McpToolRouter<Self> {
mcp_host::registry::router::McpToolRouter::new()
#(#route_adds)*
}
}
};
Ok(expanded)
}
fn has_mcp_tool_attr(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path().is_ident("mcp_tool"))
}
fn has_visibility_attr(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| {
if !attr.path().is_ident("mcp_tool") {
return false;
}
let meta_list = match attr.meta.require_list() {
Ok(meta) => meta,
Err(_) => return false,
};
let nested = match NestedMeta::parse_meta_list(meta_list.tokens.clone()) {
Ok(list) => list,
Err(_) => return false,
};
nested.iter().any(|meta| match meta {
NestedMeta::Meta(syn::Meta::Path(path)) => path.is_ident("visible"),
NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
name_value.path.is_ident("visible")
}
NestedMeta::Meta(syn::Meta::List(list)) => list.path.is_ident("visible"),
NestedMeta::Lit(_) => false,
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn visibility_attr_requires_visible_key() {
let attrs: Vec<Attribute> =
vec![parse_quote!(#[mcp_tool(description = "visible by default")])];
assert!(!has_visibility_attr(&attrs));
let attrs: Vec<Attribute> = vec![parse_quote!(#[mcp_tool(visible = "ctx.is_admin()")])];
assert!(has_visibility_attr(&attrs));
}
#[test]
fn generated_router_preserves_generics_and_where_clause() {
let input_impl: ItemImpl = parse_quote! {
impl<T> MyServer<T>
where
T: Send,
{
#[mcp_tool]
async fn tool(&self, ctx: Ctx, params: Parameters<()>) -> ToolResult {
unimplemented!()
}
}
};
let tokens = generate_router_impl(McpToolRouterAttrs::default(), input_impl).unwrap();
let file: syn::File = syn::parse2(tokens).unwrap();
let impls: Vec<&syn::ItemImpl> = file
.items
.iter()
.filter_map(|item| match item {
syn::Item::Impl(impl_item) => Some(impl_item),
_ => None,
})
.collect();
assert_eq!(impls.len(), 2);
let generated_impl = impls[1];
assert_eq!(generated_impl.generics.params.len(), 1);
assert!(generated_impl.generics.where_clause.is_some());
}
}