use darling::{FromMeta, ast::NestedMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Attribute, ImplItem, ItemImpl, parse_macro_input};
#[derive(Debug, Default, FromMeta)]
pub struct McpRouterAttrs {
}
type MethodInfo = (
proc_macro2::Ident,
proc_macro2::Ident,
proc_macro2::Ident,
bool,
);
pub fn expand_mcp_router(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let attrs = match McpRouterAttrs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input_impl = parse_macro_input!(item as ItemImpl);
match generate_router_impl(attrs, input_impl) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_router_impl(_attrs: McpRouterAttrs, input_impl: ItemImpl) -> syn::Result<TokenStream2> {
let self_ty = &input_impl.self_ty;
let (impl_generics, _ty_generics, where_clause) = input_impl.generics.split_for_impl();
let mut tool_methods: Vec<MethodInfo> = Vec::new();
let mut prompt_methods: Vec<MethodInfo> = Vec::new();
let mut resource_methods: Vec<MethodInfo> = Vec::new();
let mut template_methods: Vec<MethodInfo> = Vec::new();
for item in &input_impl.items {
if let ImplItem::Fn(method) = item {
let method_name = &method.sig.ident;
if has_attr(&method.attrs, "mcp_tool") {
let info_fn = format_ident!("{}_tool_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs, "mcp_tool");
tool_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
if has_attr(&method.attrs, "mcp_prompt") {
let info_fn = format_ident!("{}_prompt_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs, "mcp_prompt");
prompt_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
if has_attr(&method.attrs, "mcp_resource") {
let info_fn = format_ident!("{}_resource_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs, "mcp_resource");
resource_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
if has_attr(&method.attrs, "mcp_resource_template") {
let info_fn = format_ident!("{}_template_info", method_name);
let handler_fn = format_ident!("{}_handler", method_name);
let visibility_fn = format_ident!("{}_visibility", method_name);
let has_visibility = has_visibility_attr(&method.attrs, "mcp_resource_template");
template_methods.push((info_fn, handler_fn, visibility_fn, has_visibility));
}
}
}
let tool_route_adds = generate_route_adds(&tool_methods, "with_tool");
let prompt_route_adds = generate_route_adds(&prompt_methods, "with_prompt");
let resource_route_adds = generate_route_adds(&resource_methods, "with_resource");
let template_route_adds = generate_route_adds(&template_methods, "with_template");
let expanded = quote! {
#input_impl
impl #impl_generics #self_ty #where_clause {
pub fn router() -> mcp_host::registry::router::McpRouter<Self> {
mcp_host::registry::router::McpRouter::new(
mcp_host::registry::router::McpToolRouter::new()
#(#tool_route_adds)*,
mcp_host::registry::router::McpPromptRouter::new()
#(#prompt_route_adds)*,
mcp_host::registry::router::McpResourceRouter::new()
#(#resource_route_adds)*,
mcp_host::registry::router::McpResourceTemplateRouter::new()
#(#template_route_adds)*,
)
}
}
};
Ok(expanded)
}
fn generate_route_adds(methods: &[MethodInfo], builder_method: &str) -> Vec<TokenStream2> {
let builder_ident = format_ident!("{}", builder_method);
methods
.iter()
.map(|(info_fn, handler_fn, visibility_fn, has_visibility)| {
if *has_visibility {
quote! {
.#builder_ident(
Self::#info_fn(),
Self::#handler_fn,
Some(Self::#visibility_fn)
)
}
} else {
quote! {
.#builder_ident(
Self::#info_fn(),
Self::#handler_fn,
None
)
}
}
})
.collect()
}
fn has_attr(attrs: &[Attribute], name: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(name))
}
fn has_visibility_attr(attrs: &[Attribute], attr_name: &str) -> bool {
attrs.iter().any(|attr| {
if !attr.path().is_ident(attr_name) {
return false;
}
let meta_list = match attr.meta.require_list() {
Ok(meta) => meta,
Err(_) => return false,
};
let nested = match NestedMeta::parse_meta_list(meta_list.tokens.clone()) {
Ok(list) => list,
Err(_) => return false,
};
nested.iter().any(|meta| match meta {
NestedMeta::Meta(syn::Meta::Path(path)) => path.is_ident("visible"),
NestedMeta::Meta(syn::Meta::NameValue(name_value)) => {
name_value.path.is_ident("visible")
}
NestedMeta::Meta(syn::Meta::List(list)) => list.path.is_ident("visible"),
NestedMeta::Lit(_) => false,
})
})
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn test_has_attr() {
let attrs: Vec<Attribute> = vec![parse_quote!(#[mcp_tool(name = "test")])];
assert!(has_attr(&attrs, "mcp_tool"));
assert!(!has_attr(&attrs, "mcp_prompt"));
}
#[test]
fn test_visibility_attr_detection() {
let attrs: Vec<Attribute> =
vec![parse_quote!(#[mcp_tool(name = "test", visible = "ctx.is_admin()")])];
assert!(has_visibility_attr(&attrs, "mcp_tool"));
let attrs: Vec<Attribute> = vec![parse_quote!(#[mcp_tool(name = "test")])];
assert!(!has_visibility_attr(&attrs, "mcp_tool"));
}
#[test]
fn test_router_collects_all_types() {
let input_impl: ItemImpl = parse_quote! {
impl MyServer {
#[mcp_tool(name = "echo")]
async fn echo(&self) {}
#[mcp_prompt(name = "greeting")]
async fn greeting(&self) {}
#[mcp_resource(uri = "test:///", name = "test")]
async fn test_resource(&self) {}
#[mcp_resource_template(uri_template = "file:///{path}", name = "files")]
async fn files(&self) {}
}
};
let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
let output = tokens.to_string();
assert!(output.contains("pub fn router"));
assert!(output.contains("McpRouter"));
assert!(output.contains("McpToolRouter"));
assert!(output.contains("McpPromptRouter"));
assert!(output.contains("McpResourceRouter"));
assert!(output.contains("McpResourceTemplateRouter"));
}
#[test]
fn test_router_preserves_generics() {
let input_impl: ItemImpl = parse_quote! {
impl<T> MyServer<T>
where
T: Send,
{
#[mcp_tool(name = "test")]
async fn test(&self) {}
}
};
let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
let file: syn::File = syn::parse2(tokens).unwrap();
let impls: Vec<&syn::ItemImpl> = file
.items
.iter()
.filter_map(|item| match item {
syn::Item::Impl(impl_item) => Some(impl_item),
_ => None,
})
.collect();
assert_eq!(impls.len(), 2);
let generated_impl = impls[1];
assert_eq!(generated_impl.generics.params.len(), 1);
assert!(generated_impl.generics.where_clause.is_some());
}
#[test]
fn test_empty_impl_generates_no_routers() {
let input_impl: ItemImpl = parse_quote! {
impl MyServer {
fn regular_method(&self) {}
}
};
let tokens = generate_router_impl(McpRouterAttrs::default(), input_impl).unwrap();
let output = tokens.to_string();
assert!(!output.contains("tool_router"));
assert!(!output.contains("prompt_router"));
assert!(!output.contains("resource_router"));
}
}