mcp-host-macros 0.1.26

Procedural macros for mcp-host crate
Documentation
//! #[mcp_router] unified attribute macro implementation
//!
//! Transforms an impl block to generate all router methods for tools, prompts,
//! resources, and resource templates marked with their respective attributes.

use darling::{FromMeta, ast::NestedMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Attribute, ImplItem, ItemImpl, parse_macro_input};

/// Attributes for #[mcp_router]
#[derive(Debug, Default, FromMeta)]
pub struct McpRouterAttrs {
    // Currently no attributes needed - generates router() function
}

/// Method info tuple: (info_fn, handler_fn, visibility_fn, has_visibility)
type MethodInfo = (
    proc_macro2::Ident,
    proc_macro2::Ident,
    proc_macro2::Ident,
    bool,
);

/// Parse and transform an impl block marked with #[mcp_router]
pub fn expand_mcp_router(attr: TokenStream, item: TokenStream) -> TokenStream {
    let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
        Ok(v) => v,
        Err(e) => return TokenStream::from(e.to_compile_error()),
    };

    let attrs = match McpRouterAttrs::from_list(&attr_args) {
        Ok(v) => v,
        Err(e) => return TokenStream::from(e.write_errors()),
    };

    let input_impl = parse_macro_input!(item as ItemImpl);

    match generate_router_impl(attrs, input_impl) {
        Ok(tokens) => tokens.into(),
        Err(e) => e.to_compile_error().into(),
    }
}

fn generate_router_impl(_attrs: McpRouterAttrs, input_impl: ItemImpl) -> syn::Result<TokenStream2> {
    let self_ty = &input_impl.self_ty;
    let (impl_generics, _ty_generics, where_clause) = input_impl.generics.split_for_impl();

    // Collect methods by type
    let mut tool_methods: Vec<MethodInfo> = Vec::new();
    let mut prompt_methods: Vec<MethodInfo> = Vec::new();
    let mut resource_methods: Vec<MethodInfo> = Vec::new();
    let mut template_methods: Vec<MethodInfo> = Vec::new();

    for item in &input_impl.items {
        if let ImplItem::Fn(method) = item {
            let method_name = &method.sig.ident;

            if has_attr(&method.attrs, "mcp_tool") {
                let info_fn = format_ident!("{}_tool_info", method_name);
                let handler_fn = format_ident!("{}_handler", method_name);
                let visibility_fn = format_ident!("{}_visibility", method_name);
                let has_visibility = has_visibility_attr(&method.attrs, "mcp_tool");
                tool_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
            }

            if has_attr(&method.attrs, "mcp_prompt") {
                let info_fn = format_ident!("{}_prompt_info", method_name);
                let handler_fn = format_ident!("{}_handler", method_name);
                let visibility_fn = format_ident!("{}_visibility", method_name);
                let has_visibility = has_visibility_attr(&method.attrs, "mcp_prompt");
                prompt_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
            }

            if has_attr(&method.attrs, "mcp_resource") {
                let info_fn = format_ident!("{}_resource_info", method_name);
                let handler_fn = format_ident!("{}_handler", method_name);
                let visibility_fn = format_ident!("{}_visibility", method_name);
                let has_visibility = has_visibility_attr(&method.attrs, "mcp_resource");
                resource_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
            }

            if has_attr(&method.attrs, "mcp_resource_template") {
                let info_fn = format_ident!("{}_template_info", method_name);
                let handler_fn = format_ident!("{}_handler", method_name);
                let visibility_fn = format_ident!("{}_visibility", method_name);
                let has_visibility = has_visibility_attr(&method.attrs, "mcp_resource_template");
                template_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
            }
        }
    }

    // Generate route additions for each type
    let tool_route_adds = generate_route_adds(&tool_methods, "with_tool");
    let prompt_route_adds = generate_route_adds(&prompt_methods, "with_prompt");
    let resource_route_adds = generate_route_adds(&resource_methods, "with_resource");
    let template_route_adds = generate_route_adds(&template_methods, "with_template");

    // Generate the complete output with a single router() function
    let expanded = quote! {
        #input_impl

        impl #impl_generics #self_ty #where_clause {
            /// Generated router collecting all MCP handlers
            pub fn router() -> mcp_host::registry::router::McpRouter<Self> {
                mcp_host::registry::router::McpRouter::new(
                    mcp_host::registry::router::McpToolRouter::new()
                        #(#tool_route_adds)*,
                    mcp_host::registry::router::McpPromptRouter::new()
                        #(#prompt_route_adds)*,
                    mcp_host::registry::router::McpResourceRouter::new()
                        #(#resource_route_adds)*,
                    mcp_host::registry::router::McpResourceTemplateRouter::new()
                        #(#template_route_adds)*,
                )
            }
        }
    };

    Ok(expanded)
}

/// Generate route addition calls for a list of methods
fn generate_route_adds(methods: &[MethodInfo], builder_method: &str) -> Vec<TokenStream2> {
    let builder_ident = format_ident!("{}", builder_method);
    methods
        .iter()
        .map(|(info_fn, handler_fn, visibility_fn, has_visibility)| {
            if *has_visibility {
                quote! {
                    .#builder_ident(
                        Self::#info_fn(),
                        Self::#handler_fn,
                        Some(Self::#visibility_fn)
                    )
                }
            } else {
                quote! {
                    .#builder_ident(
                        Self::#info_fn(),
                        Self::#handler_fn,
                        None
                    )
                }
            }
        })
        .collect()
}

/// Check if a method has a specific attribute
fn has_attr(attrs: &[Attribute], name: &str) -> bool {
    attrs.iter().any(|attr| attr.path().is_ident(name))
}

/// Check if a method has a visibility attribute in its attribute macro
fn has_visibility_attr(attrs: &[Attribute], attr_name: &str) -> bool {
    attrs.iter().any(|attr| {
        if !attr.path().is_ident(attr_name) {
            return false;
        }
        let meta_list = match attr.meta.require_list() {
            Ok(meta) => meta,
            Err(_) => return false,
        };
        let nested = match NestedMeta::parse_meta_list(meta_list.tokens.clone()) {
            Ok(list) => list,
            Err(_) => return false,
        };
        nested.iter().any(|meta| match meta {
            NestedMeta::Meta(syn::Meta::Path(path)) => path.is_ident("visible"),
            NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
                name_value.path.is_ident("visible")
            }
            NestedMeta::Meta(syn::Meta::List(list)) => list.path.is_ident("visible"),
            NestedMeta::Lit(_) => false,
        })
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use syn::parse_quote;

    #[test]
    fn test_has_attr() {
        let attrs: Vec<Attribute> = vec![parse_quote!(#[mcp_tool(name = "test")])];
        assert!(has_attr(&attrs, "mcp_tool"));
        assert!(!has_attr(&attrs, "mcp_prompt"));
    }

    #[test]
    fn test_visibility_attr_detection() {
        let attrs: Vec<Attribute> =
            vec![parse_quote!(#[mcp_tool(name = "test", visible = "ctx.is_admin()")])];
        assert!(has_visibility_attr(&attrs, "mcp_tool"));

        let attrs: Vec<Attribute> = vec![parse_quote!(#[mcp_tool(name = "test")])];
        assert!(!has_visibility_attr(&attrs, "mcp_tool"));
    }

    #[test]
    fn test_router_collects_all_types() {
        let input_impl: ItemImpl = parse_quote! {
            impl MyServer {
                #[mcp_tool(name = "echo")]
                async fn echo(&self) {}

                #[mcp_prompt(name = "greeting")]
                async fn greeting(&self) {}

                #[mcp_resource(uri = "test:///", name = "test")]
                async fn test_resource(&self) {}

                #[mcp_resource_template(uri_template = "file:///{path}", name = "files")]
                async fn files(&self) {}
            }
        };

        let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
        let output = tokens.to_string();

        // Verify single router() function is generated with all sub-routers
        assert!(output.contains("pub fn router"));
        assert!(output.contains("McpRouter"));
        assert!(output.contains("McpToolRouter"));
        assert!(output.contains("McpPromptRouter"));
        assert!(output.contains("McpResourceRouter"));
        assert!(output.contains("McpResourceTemplateRouter"));
    }

    #[test]
    fn test_router_preserves_generics() {
        let input_impl: ItemImpl = parse_quote! {
            impl<T> MyServer<T>
            where
                T: Send,
            {
                #[mcp_tool(name = "test")]
                async fn test(&self) {}
            }
        };

        let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
        let file: syn::File = syn::parse2(tokens).unwrap();
        let impls: Vec<&syn::ItemImpl> = file
            .items
            .iter()
            .filter_map(|item| match item {
                syn::Item::Impl(impl_item) => Some(impl_item),
                _ => None,
            })
            .collect();

        assert_eq!(impls.len(), 2);
        let generated_impl = impls[1];
        assert_eq!(generated_impl.generics.params.len(), 1);
        assert!(generated_impl.generics.where_clause.is_some());
    }

    #[test]
    fn test_empty_impl_generates_no_routers() {
        let input_impl: ItemImpl = parse_quote! {
            impl MyServer {
                fn regular_method(&self) {}
            }
        };

        let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
        let output = tokens.to_string();

        // Should not contain any router functions
        assert!(!output.contains("tool_router"));
        assert!(!output.contains("prompt_router"));
        assert!(!output.contains("resource_router"));
    }
}