use darling::{FromMeta, ast::NestedMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{FnArg, ItemFn, Type, parse_macro_input};
#[derive(Debug, Clone, Copy, Default, FromMeta)]
pub enum TaskSupportAttr {
#[default]
Forbidden,
Optional,
Required,
}
#[derive(Debug, FromMeta)]
pub struct McpToolAttrs {
#[darling(default)]
pub name: Option<String>,
#[darling(default)]
pub title: Option<String>,
#[darling(default)]
pub description: Option<String>,
#[darling(default)]
pub output: Option<String>,
#[darling(default)]
pub visible: Option<String>,
#[darling(default)]
pub task_support: Option<TaskSupportAttr>,
#[darling(default)]
pub read_only: Option<bool>,
#[darling(default)]
pub destructive: Option<bool>,
#[darling(default)]
pub idempotent: Option<bool>,
#[darling(default)]
pub open_world: Option<bool>,
}
pub fn expand_mcp_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let attrs = match McpToolAttrs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input_fn = parse_macro_input!(item as ItemFn);
match generate_tool_impl(attrs, input_fn) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_tool_impl(attrs: McpToolAttrs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
let tool_name = attrs.name.clone().unwrap_or_else(|| fn_name_str.clone());
let info_fn_name = format_ident!("{}_tool_info", fn_name);
let handler_fn_name = format_ident!("{}_handler", fn_name);
let params_type = extract_parameters_type(&input_fn)?;
let ctx_type = extract_context_type(&input_fn)?;
let title_tokens = match &attrs.title {
Some(title) => quote! { Some(#title.to_string()) },
None => quote! { None },
};
let description_tokens = match &attrs.description {
Some(desc) => quote! { Some(#desc.to_string()) },
None => {
let doc = extract_doc_comment(&input_fn);
if let Some(doc) = doc {
quote! { Some(#doc.to_string()) }
} else {
quote! { None }
}
}
};
let output_schema_tokens = if let Some(output_type_str) = &attrs.output {
let output_type: Type = syn::parse_str(output_type_str).map_err(|e| {
syn::Error::new_spanned(&input_fn.sig.ident, format!("Invalid output type: {}", e))
})?;
quote! { Some(mcp_host::macros::schema_for::<#output_type>()) }
} else {
quote! { None }
};
let execution_tokens = match attrs.task_support {
Some(TaskSupportAttr::Required) => quote! {
Some(mcp_host::protocol::types::ToolExecution {
task_support: Some(mcp_host::protocol::types::TaskSupport::Required),
})
},
Some(TaskSupportAttr::Optional) => quote! {
Some(mcp_host::protocol::types::ToolExecution {
task_support: Some(mcp_host::protocol::types::TaskSupport::Optional),
})
},
Some(TaskSupportAttr::Forbidden) | None => quote! { None },
};
if attrs.read_only == Some(true) && attrs.destructive == Some(true) {
return Err(syn::Error::new_spanned(
&input_fn.sig.ident,
"Tool cannot be both read_only and destructive. \
A read-only tool by definition does not modify its environment.",
));
}
let annotations_tokens = {
let read_only = attrs.read_only;
let destructive = if read_only == Some(true) {
None
} else {
attrs.destructive
};
let idempotent = if read_only == Some(true) {
None
} else {
attrs.idempotent
};
let open_world = attrs.open_world;
let emit_read_only = read_only == Some(true); let emit_destructive = destructive == Some(false); let emit_idempotent = idempotent == Some(true); let emit_open_world = open_world == Some(false);
let has_any = emit_read_only || emit_destructive || emit_idempotent || emit_open_world;
if has_any {
let read_only_token = if emit_read_only {
quote! { Some(true) }
} else {
quote! { None }
};
let destructive_token = if emit_destructive {
quote! { Some(false) }
} else {
quote! { None }
};
let idempotent_token = if emit_idempotent {
quote! { Some(true) }
} else {
quote! { None }
};
let open_world_token = if emit_open_world {
quote! { Some(false) }
} else {
quote! { None }
};
quote! {
Some(mcp_host::protocol::types::ToolAnnotations {
title: None,
read_only_hint: #read_only_token,
destructive_hint: #destructive_token,
idempotent_hint: #idempotent_token,
open_world_hint: #open_world_token,
})
}
} else {
quote! { None }
}
};
let visibility_fn = if let Some(visible_expr) = &attrs.visible {
let vis_fn_name = format_ident!("{}_visibility", fn_name);
let expr: syn::Expr = syn::parse_str(visible_expr).map_err(|e| {
syn::Error::new_spanned(
&input_fn.sig.ident,
format!("Invalid visibility expression: {}", e),
)
})?;
Some(quote! {
pub fn #vis_fn_name(ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
#expr
}
})
} else {
None
};
let hints_comment = generate_hints_comment(&attrs);
let expanded = quote! {
#hints_comment
#input_fn
pub fn #info_fn_name() -> mcp_host::registry::tools::ToolInfo {
mcp_host::registry::tools::ToolInfo {
name: #tool_name.to_string(),
title: #title_tokens,
description: #description_tokens,
input_schema: mcp_host::macros::schema_for::<#params_type>(),
output_schema: #output_schema_tokens,
execution: #execution_tokens,
annotations: #annotations_tokens,
}
}
pub fn #handler_fn_name<'a>(
server: &'a Self,
ctx: mcp_host::server::visibility::ExecutionContext<'a>,
) -> mcp_host::registry::router::ToolHandlerFuture<'a> {
Box::pin(async move {
let params: mcp_host::macros::Parameters<#params_type> = serde_json::from_value(ctx.params.clone())
.map_err(|e| mcp_host::registry::tools::ToolError::InvalidArguments(e.to_string()))?;
let ctx: #ctx_type = mcp_host::macros::FromExecutionContext::from_execution_context(&ctx);
server.#fn_name(ctx, params).await
})
}
#visibility_fn
};
Ok(expanded)
}
fn generate_hints_comment(attrs: &McpToolAttrs) -> TokenStream2 {
let mut hints = Vec::new();
if let Some(true) = attrs.read_only {
hints.push("read-only");
}
if let Some(true) = attrs.destructive {
hints.push("destructive");
}
if let Some(true) = attrs.idempotent {
hints.push("idempotent");
}
if hints.is_empty() {
quote! {}
} else {
let hint_str = hints.join(", ");
quote! {
#[doc = ""]
#[doc = concat!("**Hints:** ", #hint_str)]
}
}
}
fn extract_parameters_type(input_fn: &ItemFn) -> syn::Result<Type> {
for arg in &input_fn.sig.inputs {
if let FnArg::Typed(pat_type) = arg
&& let Type::Path(type_path) = &*pat_type.ty
&& let Some(segment) = type_path.path.segments.last()
&& segment.ident == "Parameters"
&& let syn::PathArguments::AngleBracketed(args) = &segment.arguments
&& let Some(syn::GenericArgument::Type(inner_type)) = args.args.first()
{
return Ok(inner_type.clone());
}
}
Err(syn::Error::new_spanned(
&input_fn.sig.ident,
"Missing Parameters<T> argument. Use Parameters<()> for tools with no parameters.",
))
}
fn extract_context_type(input_fn: &ItemFn) -> syn::Result<Type> {
let mut typed_args = input_fn.sig.inputs.iter().filter_map(|arg| match arg {
FnArg::Typed(pat_type) => Some(pat_type),
_ => None,
});
let ctx_arg = typed_args.next().ok_or_else(|| {
syn::Error::new_spanned(
&input_fn.sig.ident,
"Missing context argument. Expected signature: fn(&self, ctx: CtxType, params: Parameters<T>)",
)
})?;
if is_parameters_type(&ctx_arg.ty) {
return Err(syn::Error::new_spanned(
&ctx_arg.ty,
"Missing context argument before Parameters<T>.",
));
}
Ok((*ctx_arg.ty).clone())
}
fn is_parameters_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "Parameters";
}
}
false
}
fn extract_doc_comment(input_fn: &ItemFn) -> Option<String> {
let mut docs = Vec::new();
for attr in &input_fn.attrs {
if attr.path().is_ident("doc")
&& let syn::Meta::NameValue(nv) = &attr.meta
&& let syn::Expr::Lit(expr_lit) = &nv.value
&& let syn::Lit::Str(lit_str) = &expr_lit.lit
{
docs.push(lit_str.value().trim().to_string());
}
}
if docs.is_empty() {
None
} else {
Some(docs.join(" ").trim().to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use syn::parse_quote;
#[test]
fn extract_parameters_type_errors_without_parameters() {
let input_fn: ItemFn = parse_quote! {
async fn no_params(&self, ctx: Ctx) -> ToolResult {
unimplemented!()
}
};
let err = extract_parameters_type(&input_fn).unwrap_err();
assert!(err.to_string().contains("Parameters"));
}
}