mcp-host-macros 0.1.3

Procedural macros for mcp-host crate
Documentation
//! #[mcp_tool_router] attribute macro implementation
//!
//! Transforms an impl block to generate a tool_router() method that collects
//! all tools marked with #[mcp_tool].

use darling::{ast::NestedMeta, FromMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Attribute, ImplItem, ItemImpl};

/// Attributes for #[mcp_tool_router]
#[derive(Debug, Default, FromMeta)]
pub struct McpToolRouterAttrs {
    /// Name of the generated router function (default: "tool_router")
    #[darling(default)]
    pub router: Option<String>,
}

/// Parse and transform an impl block marked with #[mcp_tool_router]
pub fn expand_mcp_tool_router(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
        Ok(v) => v,
        Err(e) => return TokenStream::from(e.to_compile_error()),
    };

    let attrs = match McpToolRouterAttrs::from_list(&attr_args) {
        Ok(v) => v,
        Err(e) => return TokenStream::from(e.write_errors()),
    };

    let input_impl = parse_macro_input!(item as ItemImpl);

    match generate_router_impl(attrs, input_impl) {
        Ok(tokens) => tokens.into(),
        Err(e) => e.to_compile_error().into(),
    }
}

fn generate_router_impl(
    attrs: McpToolRouterAttrs,
    input_impl: ItemImpl,
) -> syn::Result<TokenStream2> {
    let router_fn_name = format_ident!(
        "{}",
        attrs.router.unwrap_or_else(|| "tool_router".to_string())
    );

    let self_ty = &input_impl.self_ty;

    // Collect all methods that have #[mcp_tool] attribute
    let mut tool_methods = Vec::new();

    for item in &input_impl.items {
        if let ImplItem::Fn(method) = item
            && has_mcp_tool_attr(&method.attrs)
        {
            let method_name = &method.sig.ident;
            let info_fn = format_ident!("{}_tool_info", method_name);
            let handler_fn = format_ident!("{}_handler", method_name);
            let visibility_fn = format_ident!("{}_visibility", method_name);

            // Check if this tool has a visibility attribute
            let has_visibility = has_visibility_attr(&method.attrs);

            tool_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
        }
    }

    // Generate the router function body
    let route_adds: Vec<TokenStream2> = tool_methods
        .iter()
        .map(|(info_fn, handler_fn, visibility_fn, has_visibility)| {
            if *has_visibility {
                quote! {
                    .with_tool(
                        Self::#info_fn(),
                        Self::#handler_fn,
                        Some(Self::#visibility_fn)
                    )
                }
            } else {
                quote! {
                    .with_tool(
                        Self::#info_fn(),
                        Self::#handler_fn,
                        None
                    )
                }
            }
        })
        .collect();

    // Generate the complete output
    let expanded = quote! {
        #input_impl

        impl #self_ty {
            /// Generated tool router collecting all #[mcp_tool] methods
            pub fn #router_fn_name() -> mcp_host::registry::router::McpToolRouter<Self> {
                mcp_host::registry::router::McpToolRouter::new()
                    #(#route_adds)*
            }
        }
    };

    Ok(expanded)
}

/// Check if a method has the #[mcp_tool] attribute
fn has_mcp_tool_attr(attrs: &[Attribute]) -> bool {
    attrs.iter().any(|attr| attr.path().is_ident("mcp_tool"))
}

/// Check if a method has a visibility attribute in its #[mcp_tool(...)]
fn has_visibility_attr(attrs: &[Attribute]) -> bool {
    for attr in attrs {
        if attr.path().is_ident("mcp_tool") {
            // Try to parse the attribute arguments
            if let Ok(meta) = attr.meta.require_list() {
                let tokens = meta.tokens.to_string();
                if tokens.contains("visible") {
                    return true;
                }
            }
        }
    }
    false
}