use darling::{ast::NestedMeta, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Attribute, ImplItem, ItemImpl};
#[derive(Debug, Default, FromMeta)]
pub struct McpToolRouterAttrs {
#[darling(default)]
pub router: Option<String>,
}
pub fn expand_mcp_tool_router(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let attrs = match McpToolRouterAttrs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input_impl = parse_macro_input!(item as ItemImpl);
match generate_router_impl(attrs, input_impl) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_router_impl(
attrs: McpToolRouterAttrs,
input_impl: ItemImpl,
) -> syn::Result<TokenStream2> {
let router_fn_name = format_ident!(
"{}",
attrs.router.unwrap_or_else(|| "tool_router".to_string())
);
let self_ty = &input_impl.self_ty;
let mut tool_methods = Vec::new();
for item in &input_impl.items {
if let ImplItem::Fn(method) = item
&& has_mcp_tool_attr(&method.attrs)
{
let method_name = &method.sig.ident;
let info_fn = format_ident!("{}_tool_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs);
tool_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
}
let route_adds: Vec<TokenStream2> = tool_methods
.iter()
.map(|(info_fn, handler_fn, visibility_fn, has_visibility)| {
if *has_visibility {
quote! {
.with_tool(
Self::#info_fn(),
Self::#handler_fn,
Some(Self::#visibility_fn)
)
}
} else {
quote! {
.with_tool(
Self::#info_fn(),
Self::#handler_fn,
None
)
}
}
})
.collect();
let expanded = quote! {
#input_impl
impl #self_ty {
pub fn #router_fn_name() -> mcp_host::registry::router::McpToolRouter<Self> {
mcp_host::registry::router::McpToolRouter::new()
#(#route_adds)*
}
}
};
Ok(expanded)
}
fn has_mcp_tool_attr(attrs: &[Attribute]) -> bool {
attrs.iter().any(|attr| attr.path().is_ident("mcp_tool"))
}
fn has_visibility_attr(attrs: &[Attribute]) -> bool {
for attr in attrs {
if attr.path().is_ident("mcp_tool") {
if let Ok(meta) = attr.meta.require_list() {
let tokens = meta.tokens.to_string();
if tokens.contains("visible") {
return true;
}
}
}
}
false
}