use darling::{FromMeta, ast::NestedMeta};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{ItemFn, Type, parse_macro_input};
#[derive(Debug, Clone, FromMeta)]
pub struct PromptArgumentAttr {
pub name: String,
#[darling(default)]
pub description: Option<String>,
#[darling(default)]
pub required: Option<bool>,
}
#[derive(Debug, FromMeta)]
pub struct McpPromptAttrs {
#[darling(default)]
pub name: Option<String>,
#[darling(default)]
pub description: Option<String>,
#[darling(default)]
pub visible: Option<String>,
#[darling(default, multiple)]
pub argument: Vec<PromptArgumentAttr>,
}
pub fn expand_mcp_prompt(attr: TokenStream, item: TokenStream) -> TokenStream {
let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.to_compile_error()),
};
let attrs = match McpPromptAttrs::from_list(&attr_args) {
Ok(v) => v,
Err(e) => return TokenStream::from(e.write_errors()),
};
let input_fn = parse_macro_input!(item as ItemFn);
match generate_prompt_impl(attrs, input_fn) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_prompt_impl(attrs: McpPromptAttrs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
let prompt_name = attrs.name.clone().unwrap_or_else(|| fn_name_str.clone());
let info_fn_name = format_ident!("{}_prompt_info", fn_name);
let handler_fn_name = format_ident!("{}_handler", fn_name);
let ctx_type = extract_context_type(&input_fn)?;
let description_tokens = match &attrs.description {
Some(desc) => quote! { Some(#desc.to_string()) },
None => {
let doc = extract_doc_comment(&input_fn);
if let Some(doc) = doc {
quote! { Some(#doc.to_string()) }
} else {
quote! { None }
}
}
};
let arguments_tokens = if attrs.argument.is_empty() {
quote! { None }
} else {
let args = attrs.argument.iter().map(|arg| {
let name = &arg.name;
let desc = match &arg.description {
Some(d) => quote! { Some(#d.to_string()) },
None => quote! { None },
};
let req = match arg.required {
Some(r) => quote! { Some(#r) },
None => quote! { None },
};
quote! {
mcp_host::registry::prompts::PromptArgument {
name: #name.to_string(),
description: #desc,
required: #req,
}
}
});
quote! { Some(vec![#(#args),*]) }
};
let visibility_fn = if let Some(visible_expr) = &attrs.visible {
let vis_fn_name = format_ident!("{}_visibility", fn_name);
let expr: syn::Expr = syn::parse_str(visible_expr).map_err(|e| {
syn::Error::new_spanned(
&input_fn.sig.ident,
format!("Invalid visibility expression: {}", e),
)
})?;
Some(quote! {
pub fn #vis_fn_name(ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
#expr
}
})
} else {
None
};
let expanded = quote! {
#input_fn
pub fn #info_fn_name() -> mcp_host::registry::prompts::PromptInfo {
mcp_host::registry::prompts::PromptInfo {
name: #prompt_name.to_string(),
description: #description_tokens,
arguments: #arguments_tokens,
}
}
pub fn #handler_fn_name<'a>(
server: &'a Self,
ctx: mcp_host::server::visibility::ExecutionContext<'a>,
) -> mcp_host::registry::router::PromptHandlerFuture<'a> {
Box::pin(async move {
let args = ctx.params.clone();
let ctx: #ctx_type = mcp_host::macros::FromExecutionContext::from_execution_context(&ctx);
server.#fn_name(ctx, args).await
})
}
#visibility_fn
};
Ok(expanded)
}
fn extract_context_type(input_fn: &ItemFn) -> syn::Result<Type> {
let mut typed_args = input_fn.sig.inputs.iter().filter_map(|arg| match arg {
syn::FnArg::Typed(pat_type) => Some(pat_type),
_ => None,
});
let ctx_arg = typed_args.next().ok_or_else(|| {
syn::Error::new_spanned(
&input_fn.sig.ident,
"Missing context argument. Expected signature: fn(&self, ctx: CtxType, args: Value)",
)
})?;
Ok((*ctx_arg.ty).clone())
}
fn extract_doc_comment(input_fn: &ItemFn) -> Option<String> {
let mut docs = Vec::new();
for attr in &input_fn.attrs {
if attr.path().is_ident("doc")
&& let syn::Meta::NameValue(nv) = &attr.meta
&& let syn::Expr::Lit(expr_lit) = &nv.value
&& let syn::Lit::Str(lit_str) = &expr_lit.lit
{
docs.push(lit_str.value().trim().to_string());
}
}
if docs.is_empty() {
None
} else {
Some(docs.join(" ").trim().to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prompt_argument_attr() {
let arg = PromptArgumentAttr {
name: "test".to_string(),
description: Some("A test argument".to_string()),
required: Some(true),
};
assert_eq!(arg.name, "test");
assert!(arg.required.unwrap());
}
}