use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, Error, FnArg, ImplItem, ImplItemFn, ItemImpl, Result, parse2};
use crate::attrs::{PromptAttrs, ResourceAttrs, ServerAttrs, ToolAttrs};
use crate::codegen::{ToolMethod, ToolParam, extract_param, is_result_type};
#[derive(Debug)]
struct ResourceMethod {
name: syn::Ident,
uri_pattern: String,
resource_name: String,
description: String,
mime_type: String,
is_async: bool,
returns_result: bool,
}
#[derive(Debug)]
struct PromptMethod {
name: syn::Ident,
prompt_name: String,
description: String,
params: Vec<PromptParam>,
is_async: bool,
returns_result: bool,
}
#[derive(Debug)]
struct PromptParam {
name: syn::Ident,
ty: syn::Type,
doc: Option<String>,
is_optional: bool,
}
pub fn expand_mcp_server(attr: TokenStream, item: TokenStream) -> Result<TokenStream> {
let attrs =
ServerAttrs::parse(attr).map_err(|e| Error::new(proc_macro2::Span::call_site(), e))?;
let mut impl_block: ItemImpl = parse2(item)?;
let tool_methods = extract_tool_methods(&mut impl_block)?;
let resource_methods = extract_resource_methods(&mut impl_block)?;
let prompt_methods = extract_prompt_methods(&mut impl_block)?;
let self_ty = &impl_block.self_ty;
let server_handler_impl = generate_server_handler(
&attrs,
self_ty,
!tool_methods.is_empty(),
!resource_methods.is_empty(),
!prompt_methods.is_empty(),
);
let tool_handler_impl = if tool_methods.is_empty() {
quote!()
} else {
generate_tool_handler(&tool_methods, self_ty)
};
let resource_handler_impl = if resource_methods.is_empty() {
quote!()
} else {
generate_resource_handler(&resource_methods, self_ty)
};
let prompt_handler_impl = if prompt_methods.is_empty() {
quote!()
} else {
generate_prompt_handler(&prompt_methods, self_ty)
};
let convenience_methods = generate_convenience_methods(
self_ty,
!tool_methods.is_empty(),
!resource_methods.is_empty(),
!prompt_methods.is_empty(),
);
if attrs.debug_expand {
eprintln!("=== Generated code for {} ===", quote!(#self_ty));
eprintln!("{server_handler_impl}");
eprintln!("{tool_handler_impl}");
eprintln!("{resource_handler_impl}");
eprintln!("{prompt_handler_impl}");
eprintln!("=== End generated code ===");
}
Ok(quote! {
#impl_block
#server_handler_impl
#tool_handler_impl
#resource_handler_impl
#prompt_handler_impl
#convenience_methods
})
}
fn extract_tool_methods(impl_block: &mut ItemImpl) -> Result<Vec<ToolMethod>> {
let mut tools = Vec::new();
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
if let Some((idx, tool_attrs)) = find_tool_attr(&method.attrs)? {
method.attrs.remove(idx);
let tool = extract_tool_info(method, tool_attrs)?;
tools.push(tool);
}
}
}
Ok(tools)
}
fn find_tool_attr(attrs: &[Attribute]) -> Result<Option<(usize, ToolAttrs)>> {
for (idx, attr) in attrs.iter().enumerate() {
if attr.path().is_ident("tool") {
let tokens = match &attr.meta {
syn::Meta::List(list) => list.tokens.clone(),
syn::Meta::Path(_) => {
return Err(Error::new_spanned(
attr,
"missing tool attributes\n\
help: add description, e.g., #[tool(description = \"...\")]",
));
}
syn::Meta::NameValue(_) => {
return Err(Error::new_spanned(
attr,
"invalid #[tool] syntax\n\
help: use #[tool(description = \"...\")]",
));
}
};
let tool_attrs = ToolAttrs::parse(tokens)
.map_err(|e| Error::new(attr.bracket_token.span.join(), e))?;
return Ok(Some((idx, tool_attrs)));
}
}
Ok(None)
}
#[allow(clippy::unnecessary_wraps)] fn extract_tool_info(method: &ImplItemFn, attrs: ToolAttrs) -> Result<ToolMethod> {
let name = method.sig.ident.clone();
let tool_name = attrs.name.unwrap_or_else(|| name.to_string());
let params: Vec<ToolParam> = method
.sig
.inputs
.iter()
.filter_map(|arg| match arg {
FnArg::Receiver(_) => None,
FnArg::Typed(_) => extract_param(arg),
})
.collect();
let is_async = method.sig.asyncness.is_some();
let returns_result = is_result_type(&method.sig.output);
Ok(ToolMethod {
name,
tool_name,
description: attrs.description,
destructive: attrs.destructive,
idempotent: attrs.idempotent,
read_only: attrs.read_only,
params,
is_async,
returns_result,
})
}
fn extract_resource_methods(impl_block: &mut ItemImpl) -> Result<Vec<ResourceMethod>> {
let mut resources = Vec::new();
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
if let Some((idx, resource_attrs)) = find_resource_attr(&method.attrs)? {
method.attrs.remove(idx);
let resource = extract_resource_info(method, resource_attrs)?;
resources.push(resource);
}
}
}
Ok(resources)
}
fn find_resource_attr(attrs: &[Attribute]) -> Result<Option<(usize, ResourceAttrs)>> {
for (idx, attr) in attrs.iter().enumerate() {
if attr.path().is_ident("resource") {
let tokens = match &attr.meta {
syn::Meta::List(list) => list.tokens.clone(),
syn::Meta::Path(_) => {
return Err(Error::new_spanned(
attr,
"missing resource attributes\n\
help: add uri_pattern, e.g., #[resource(uri_pattern = \"myserver://data/{id}\")]",
));
}
syn::Meta::NameValue(_) => {
return Err(Error::new_spanned(
attr,
"invalid #[resource] syntax\n\
help: use #[resource(uri_pattern = \"...\")]",
));
}
};
let resource_attrs = ResourceAttrs::parse(tokens)
.map_err(|e| Error::new(attr.bracket_token.span.join(), e))?;
return Ok(Some((idx, resource_attrs)));
}
}
Ok(None)
}
#[allow(clippy::unnecessary_wraps)] fn extract_resource_info(method: &ImplItemFn, attrs: ResourceAttrs) -> Result<ResourceMethod> {
let name = method.sig.ident.clone();
let resource_name = attrs.name.unwrap_or_else(|| name.to_string());
let is_async = method.sig.asyncness.is_some();
let returns_result = is_result_type(&method.sig.output);
Ok(ResourceMethod {
name,
uri_pattern: attrs.uri_pattern,
resource_name,
description: attrs.description.unwrap_or_default(),
mime_type: attrs.mime_type.unwrap_or_else(|| "text/plain".to_string()),
is_async,
returns_result,
})
}
fn extract_prompt_methods(impl_block: &mut ItemImpl) -> Result<Vec<PromptMethod>> {
let mut prompts = Vec::new();
for item in &mut impl_block.items {
if let ImplItem::Fn(method) = item {
if let Some((idx, prompt_attrs)) = find_prompt_attr(&method.attrs)? {
method.attrs.remove(idx);
let prompt = extract_prompt_info(method, prompt_attrs)?;
prompts.push(prompt);
}
}
}
Ok(prompts)
}
fn find_prompt_attr(attrs: &[Attribute]) -> Result<Option<(usize, PromptAttrs)>> {
for (idx, attr) in attrs.iter().enumerate() {
if attr.path().is_ident("prompt") {
let tokens = match &attr.meta {
syn::Meta::List(list) => list.tokens.clone(),
syn::Meta::Path(_) => {
return Err(Error::new_spanned(
attr,
"missing prompt attributes\n\
help: add description, e.g., #[prompt(description = \"...\")]",
));
}
syn::Meta::NameValue(_) => {
return Err(Error::new_spanned(
attr,
"invalid #[prompt] syntax\n\
help: use #[prompt(description = \"...\")]",
));
}
};
let prompt_attrs = PromptAttrs::parse(tokens)
.map_err(|e| Error::new(attr.bracket_token.span.join(), e))?;
return Ok(Some((idx, prompt_attrs)));
}
}
Ok(None)
}
#[allow(clippy::unnecessary_wraps)] fn extract_prompt_info(method: &ImplItemFn, attrs: PromptAttrs) -> Result<PromptMethod> {
let name = method.sig.ident.clone();
let prompt_name = attrs.name.unwrap_or_else(|| name.to_string());
let params: Vec<PromptParam> = method
.sig
.inputs
.iter()
.filter_map(extract_prompt_param)
.collect();
let is_async = method.sig.asyncness.is_some();
let returns_result = is_result_type(&method.sig.output);
Ok(PromptMethod {
name,
prompt_name,
description: attrs.description,
params,
is_async,
returns_result,
})
}
fn extract_prompt_param(arg: &FnArg) -> Option<PromptParam> {
match arg {
FnArg::Typed(syn::PatType { pat, ty, attrs, .. }) => {
let name = match pat.as_ref() {
syn::Pat::Ident(syn::PatIdent { ident, .. }) => ident.clone(),
_ => return None,
};
let doc = attrs
.iter()
.filter_map(|attr| {
if attr.path().is_ident("doc") {
if let syn::Meta::NameValue(nv) = &attr.meta {
if let syn::Expr::Lit(lit) = &nv.value {
if let syn::Lit::Str(s) = &lit.lit {
return Some(s.value().trim().to_string());
}
}
}
}
None
})
.collect::<Vec<_>>()
.join(" ");
let doc = if doc.is_empty() { None } else { Some(doc) };
let is_optional = is_option_type(ty);
Some(PromptParam {
name,
ty: (**ty).clone(),
doc,
is_optional,
})
}
FnArg::Receiver(_) => None,
}
}
fn is_option_type(ty: &syn::Type) -> bool {
if let syn::Type::Path(path) = ty {
if let Some(segment) = path.path.segments.last() {
return segment.ident == "Option";
}
}
false
}
fn extract_option_inner_type(ty: &syn::Type) -> syn::Type {
if let syn::Type::Path(path) = ty {
if let Some(segment) = path.path.segments.last() {
if segment.ident == "Option" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
return inner.clone();
}
}
}
}
}
ty.clone()
}
fn generate_server_handler(
attrs: &ServerAttrs,
self_ty: &syn::Type,
has_tools: bool,
has_resources: bool,
has_prompts: bool,
) -> TokenStream {
let name = &attrs.name;
let version = &attrs.version;
let instructions = attrs
.instructions
.as_ref()
.map_or_else(|| quote!(None), |s| quote!(Some(#s.to_string())));
let mut capability_chain = vec![quote!(::mcpkit::capability::ServerCapabilities::new())];
if has_tools {
capability_chain.push(quote!(.with_tools()));
}
if has_resources {
capability_chain.push(quote!(.with_resources()));
}
if has_prompts {
capability_chain.push(quote!(.with_prompts()));
}
let capabilities = if capability_chain.len() == 1 {
quote!(::mcpkit::capability::ServerCapabilities::new())
} else {
let mut result = capability_chain[0].clone();
for cap in &capability_chain[1..] {
result = quote!(#result #cap);
}
result
};
quote! {
impl ::mcpkit::ServerHandler for #self_ty {
fn server_info(&self) -> ::mcpkit::capability::ServerInfo {
::mcpkit::capability::ServerInfo::new(#name, #version)
}
fn capabilities(&self) -> ::mcpkit::capability::ServerCapabilities {
#capabilities
}
fn instructions(&self) -> Option<String> {
#instructions
}
}
}
}
fn generate_tool_handler(tools: &[ToolMethod], self_ty: &syn::Type) -> TokenStream {
let tool_defs: Vec<_> = tools
.iter()
.map(|tool| {
let name = &tool.tool_name;
let description = &tool.description;
let input_schema = tool.generate_input_schema();
let destructive = tool.destructive;
let idempotent = tool.idempotent;
let read_only = tool.read_only;
quote! {
::mcpkit::types::Tool {
name: #name.to_string(),
description: Some(#description.to_string()),
input_schema: #input_schema,
annotations: Some(::mcpkit::types::ToolAnnotations {
title: None,
read_only_hint: Some(#read_only),
destructive_hint: Some(#destructive),
idempotent_hint: Some(#idempotent),
open_world_hint: None,
}),
}
}
})
.collect();
let dispatch_arms: Vec<_> = tools
.iter()
.map(super::codegen::ToolMethod::generate_call_dispatch)
.collect();
let tool_names: Vec<_> = tools.iter().map(|t| t.tool_name.as_str()).collect();
let _available_tools = tool_names.join(", ");
quote! {
impl ::mcpkit::ToolHandler for #self_ty {
fn list_tools(
&self,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<Vec<::mcpkit::types::Tool>, ::mcpkit::error::McpError>> + Send {
async move {
Ok(vec![
#(#tool_defs),*
])
}
}
fn call_tool(
&self,
name: &str,
args: ::serde_json::Value,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<::mcpkit::types::ToolOutput, ::mcpkit::error::McpError>> + Send {
let args_clone = args.clone();
async move {
let args = match args_clone.as_object() {
Some(obj) => obj.clone(),
None => ::serde_json::Map::new(),
};
match name {
#(#dispatch_arms)*
_ => Err(::mcpkit::error::McpError::method_not_found_with_suggestions(
name,
vec![#(#tool_names.to_string()),*],
))
}
}
}
}
}
}
fn generate_resource_handler(resources: &[ResourceMethod], self_ty: &syn::Type) -> TokenStream {
let resource_defs: Vec<_> = resources
.iter()
.filter_map(|resource| {
let uri = &resource.uri_pattern;
let name = &resource.resource_name;
let description = &resource.description;
let mime_type = &resource.mime_type;
if uri.contains('{') {
None
} else {
Some(quote! {
::mcpkit::types::Resource {
uri: #uri.to_string(),
name: #name.to_string(),
description: if #description.is_empty() { None } else { Some(#description.to_string()) },
mime_type: Some(#mime_type.to_string()),
size: None,
annotations: None,
},
})
}
})
.collect();
let template_defs: Vec<_> = resources
.iter()
.filter_map(|resource| {
let uri = &resource.uri_pattern;
let name = &resource.resource_name;
let description = &resource.description;
let mime_type = &resource.mime_type;
if uri.contains('{') {
Some(quote! {
::mcpkit::types::ResourceTemplate {
uri_template: #uri.to_string(),
name: #name.to_string(),
description: if #description.is_empty() { None } else { Some(#description.to_string()) },
mime_type: Some(#mime_type.to_string()),
annotations: None,
},
})
} else {
None
}
})
.collect();
let dispatch_arms: Vec<_> = resources
.iter()
.map(|resource| {
let method_name = &resource.name;
let uri_pattern = &resource.uri_pattern;
let call = if resource.is_async {
quote!(self.#method_name(uri).await)
} else {
quote!(self.#method_name(uri))
};
if uri_pattern.contains('{') {
let pattern_prefix = uri_pattern.split('{').next().unwrap_or("");
if resource.returns_result {
quote! {
if uri.starts_with(#pattern_prefix) {
let result = #call?;
return Ok(vec![result]);
}
}
} else {
quote! {
if uri.starts_with(#pattern_prefix) {
let result = #call;
return Ok(vec![result]);
}
}
}
} else {
if resource.returns_result {
quote! {
if uri == #uri_pattern {
let result = #call?;
return Ok(vec![result]);
}
}
} else {
quote! {
if uri == #uri_pattern {
let result = #call;
return Ok(vec![result]);
}
}
}
}
})
.collect();
let _uri_patterns: Vec<_> = resources.iter().map(|r| r.uri_pattern.as_str()).collect();
quote! {
impl ::mcpkit::ResourceHandler for #self_ty {
fn list_resources(
&self,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<Vec<::mcpkit::types::Resource>, ::mcpkit::error::McpError>> + Send {
async move {
Ok(vec![
#(#resource_defs)*
])
}
}
fn list_resource_templates(
&self,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<Vec<::mcpkit::types::ResourceTemplate>, ::mcpkit::error::McpError>> + Send {
async move {
Ok(vec![
#(#template_defs)*
])
}
}
fn read_resource(
&self,
uri: &str,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<Vec<::mcpkit::types::ResourceContents>, ::mcpkit::error::McpError>> + Send {
let uri_owned = uri.to_string();
async move {
let uri: &str = &uri_owned;
#(#dispatch_arms)*
Err(::mcpkit::error::McpError::resource_not_found(uri))
}
}
}
}
}
fn generate_prompt_handler(prompts: &[PromptMethod], self_ty: &syn::Type) -> TokenStream {
let prompt_defs: Vec<_> = prompts
.iter()
.map(|prompt| {
let name = &prompt.prompt_name;
let description = &prompt.description;
let arguments: Vec<_> = prompt
.params
.iter()
.map(|param| {
let param_name = param.name.to_string();
let param_desc = param.doc.as_deref().unwrap_or("");
let required = !param.is_optional;
quote! {
::mcpkit::types::PromptArgument {
name: #param_name.to_string(),
description: if #param_desc.is_empty() { None } else { Some(#param_desc.to_string()) },
required: Some(#required),
}
}
})
.collect();
let arguments_expr = if arguments.is_empty() {
quote!(None)
} else {
quote!(Some(vec![#(#arguments),*]))
};
quote! {
::mcpkit::types::Prompt {
name: #name.to_string(),
description: if #description.is_empty() { None } else { Some(#description.to_string()) },
arguments: #arguments_expr,
}
}
})
.collect();
let dispatch_arms: Vec<_> = prompts
.iter()
.map(|prompt| {
let method_name = &prompt.name;
let prompt_name = &prompt.prompt_name;
let param_extractions: Vec<_> = prompt
.params
.iter()
.map(|param| {
let name = ¶m.name;
let name_str = name.to_string();
let ty = ¶m.ty;
if param.is_optional {
let inner_ty = extract_option_inner_type(ty);
quote! {
let #name: #ty = match arguments
.as_ref()
.and_then(|args| args.get(#name_str))
{
Some(v) => ::serde_json::from_value::<#inner_ty>(v.clone()).ok(),
None => None,
};
}
} else {
quote! {
let #name: #ty = {
let value = match arguments
.as_ref()
.and_then(|args| args.get(#name_str))
{
Some(v) => v.clone(),
None => return Err(::mcpkit::error::McpError::invalid_params(
#prompt_name,
format!("missing required argument: {}", #name_str),
)),
};
match ::serde_json::from_value::<#ty>(value) {
Ok(v) => v,
Err(e) => return Err(::mcpkit::error::McpError::invalid_params(
#prompt_name,
format!("invalid argument '{}': {}", #name_str, e),
)),
}
};
}
}
})
.collect();
let param_names: Vec<_> = prompt.params.iter().map(|p| &p.name).collect();
let call = if prompt.is_async {
quote!(self.#method_name(#(#param_names),*).await)
} else {
quote!(self.#method_name(#(#param_names),*))
};
let call_with_conversion = if prompt.returns_result {
quote!(#call)
} else {
quote!(Ok(#call))
};
quote! {
#prompt_name => {
#(#param_extractions)*
#call_with_conversion
}
}
})
.collect();
let prompt_names: Vec<_> = prompts.iter().map(|p| p.prompt_name.as_str()).collect();
quote! {
impl ::mcpkit::PromptHandler for #self_ty {
fn list_prompts(
&self,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<Vec<::mcpkit::types::Prompt>, ::mcpkit::error::McpError>> + Send {
async move {
Ok(vec![
#(#prompt_defs),*
])
}
}
fn get_prompt(
&self,
name: &str,
arguments: Option<::serde_json::Map<String, ::serde_json::Value>>,
_ctx: &::mcpkit::Context,
) -> impl std::future::Future<Output = Result<::mcpkit::types::GetPromptResult, ::mcpkit::error::McpError>> + Send {
let name = name.to_string();
async move {
match name.as_str() {
#(#dispatch_arms)*
_ => Err(::mcpkit::error::McpError::method_not_found_with_suggestions(
&name,
vec![#(#prompt_names.to_string()),*],
))
}
}
}
}
}
}
fn generate_convenience_methods(
self_ty: &syn::Type,
has_tools: bool,
has_resources: bool,
has_prompts: bool,
) -> TokenStream {
let arc_self = quote!(::std::sync::Arc<Self>);
let tools_ty = if has_tools {
quote!(::mcpkit::server::Registered<#arc_self>)
} else {
quote!(::mcpkit::server::NotRegistered)
};
let resources_ty = if has_resources {
quote!(::mcpkit::server::Registered<#arc_self>)
} else {
quote!(::mcpkit::server::NotRegistered)
};
let prompts_ty = if has_prompts {
quote!(::mcpkit::server::Registered<#arc_self>)
} else {
quote!(::mcpkit::server::NotRegistered)
};
let tasks_ty = quote!(::mcpkit::server::NotRegistered);
let handler_count = [has_tools, has_resources, has_prompts]
.iter()
.filter(|&&x| x)
.count();
let builder_body = if handler_count == 0 {
quote! {
let handler = ::std::sync::Arc::new(self);
::mcpkit::ServerBuilder::new(handler).build()
}
} else {
let mut method_chain = quote!(::mcpkit::ServerBuilder::new(::std::sync::Arc::clone(
&handler
)));
if has_tools {
method_chain = quote!(#method_chain.with_tools(::std::sync::Arc::clone(&handler)));
}
if has_resources {
method_chain = quote!(#method_chain.with_resources(::std::sync::Arc::clone(&handler)));
}
if has_prompts {
method_chain = quote!(#method_chain.with_prompts(::std::sync::Arc::clone(&handler)));
}
quote! {
let handler = ::std::sync::Arc::new(self);
#method_chain.build()
}
};
quote! {
impl #self_ty {
#[must_use]
pub fn into_server(self) -> ::mcpkit::server::Server<#arc_self, #tools_ty, #resources_ty, #prompts_ty, #tasks_ty>
where
Self: Send + Sync + 'static,
{
#builder_body
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quote::quote;
#[test]
fn test_parse_server_attrs() -> std::result::Result<(), Box<dyn std::error::Error>> {
let tokens = quote!(name = "test", version = "1.0.0");
let attrs = ServerAttrs::parse(tokens)?;
assert_eq!(attrs.name, "test");
assert_eq!(attrs.version, "1.0.0");
Ok(())
}
}