use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Expr, Field, Fields, Lit, Meta, Type};
#[proc_macro_derive(McpTool, attributes(mcp))]
pub fn derive_mcp_tool(input: TokenStream) -> TokenStream {
let input: DeriveInput = syn::parse2(input.into()).expect("Failed to parse input");
let struct_name = &input.ident;
let name_token = get_name(&input.attrs, struct_name);
let description_token = get_description(&input.attrs);
let schema = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => generate_schema_fields(fields.named.iter().collect()),
Fields::Unnamed(_) => panic!("McpTool does not support unnamed fields"),
Fields::Unit => quote! { {} },
},
Data::Enum(_) => panic!("McpTool does not support enums"),
Data::Union(_) => panic!("McpTool does not support unions"),
};
let expanded = quote! {
#[async_trait::async_trait]
impl Tool for #struct_name {
fn name(&self) -> &str {
#name_token
}
fn description(&self) -> Option<&str> {
#description_token
}
fn input_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": #schema,
"required": []
})
}
fn execution(&self) -> Option<mcp_host::protocol::types::ToolExecution> {
None
}
fn is_visible(&self, _ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
true
}
async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
self.run(ctx).await
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(McpResource, attributes(mcp))]
pub fn derive_mcp_resource(input: TokenStream) -> TokenStream {
let input: DeriveInput = syn::parse2(input.into()).expect("Failed to parse input");
let struct_name = &input.ident;
let name_token = get_name(&input.attrs, struct_name);
let description_token = get_description(&input.attrs);
let (uri, mime_type) = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => extract_resource_fields(fields.named.iter().collect()),
Fields::Unnamed(_) => panic!("McpResource does not support unnamed fields"),
Fields::Unit => (quote! { "default:///" }, quote! { None }),
},
Data::Enum(_) => panic!("McpResource does not support enums"),
Data::Union(_) => panic!("McpResource does not support unions"),
};
let expanded = quote! {
#[async_trait::async_trait]
impl Resource for #struct_name {
fn uri(&self) -> &str {
#uri
}
fn name(&self) -> &str {
#name_token
}
fn description(&self) -> Option<&str> {
#description_token
}
fn mime_type(&self) -> Option<&str> {
#mime_type
}
fn is_visible(&self, _ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
true
}
async fn read(&self, ctx: ExecutionContext<'_>) -> Result<Vec<mcp_host::content::resource::ResourceContent>, mcp_host::registry::resources::ResourceError> {
self.read_resource(ctx).await
}
}
};
TokenStream::from(expanded)
}
#[proc_macro_derive(McpPrompt, attributes(mcp))]
pub fn derive_mcp_prompt(input: TokenStream) -> TokenStream {
let input: DeriveInput = syn::parse2(input.into()).expect("Failed to parse input");
let struct_name = &input.ident;
let name_token = get_name(&input.attrs, struct_name);
let description_token = get_description(&input.attrs);
let arguments = match &input.data {
Data::Struct(data_struct) => match &data_struct.fields {
Fields::Named(fields) => generate_prompt_arguments(fields.named.iter().collect()),
Fields::Unnamed(_) => panic!("McpPrompt does not support unnamed fields"),
Fields::Unit => quote! { None },
},
Data::Enum(_) => panic!("McpPrompt does not support enums"),
Data::Union(_) => panic!("McpPrompt does not support unions"),
};
let expanded = quote! {
#[async_trait::async_trait]
impl Prompt for #struct_name {
fn name(&self) -> &str {
#name_token
}
fn description(&self) -> Option<&str> {
#description_token
}
fn arguments(&self) -> Option<Vec<mcp_host::registry::prompts::PromptArgument>> {
#arguments
}
fn is_visible(&self, _ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
true
}
async fn get(&self, ctx: ExecutionContext<'_>) -> Result<mcp_host::registry::prompts::GetPromptResult, mcp_host::registry::prompts::PromptError> {
self.get_prompt(ctx).await
}
}
};
TokenStream::from(expanded)
}
fn has_skip_attr(field: &Field) -> bool {
for attr in &field.attrs {
if attr.path().is_ident("mcp") {
if let Ok(ident) = attr.parse_args::<syn::Ident>() {
if ident == "skip" {
return true;
}
}
}
}
false
}
fn get_name(attrs: &[syn::Attribute], default: &syn::Ident) -> proc_macro2::TokenStream {
for attr in attrs {
if attr.path().is_ident("mcp") {
if let Ok(Meta::NameValue(nv)) = attr.parse_args::<Meta>() {
if nv.path.is_ident("name") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let name = lit_str.value();
return quote! { #name };
}
}
}
}
if let Ok(list) = attr.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
) {
for meta in list {
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("name") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let name = lit_str.value();
return quote! { #name };
}
}
}
}
}
}
}
}
let default_name = default.to_string();
quote! { #default_name }
}
fn get_description(attrs: &[syn::Attribute]) -> proc_macro2::TokenStream {
for attr in attrs {
if attr.path().is_ident("mcp") {
if let Ok(Meta::NameValue(nv)) = attr.parse_args::<Meta>() {
if nv.path.is_ident("description") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let desc = lit_str.value();
return quote! { Some(#desc) };
}
}
}
}
if let Ok(list) = attr.parse_args_with(
syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated,
) {
for meta in list {
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("description") {
if let Expr::Lit(expr_lit) = &nv.value {
if let Lit::Str(lit_str) = &expr_lit.lit {
let desc = lit_str.value();
return quote! { Some(#desc) };
}
}
}
}
}
}
}
}
quote! { None }
}
fn generate_schema_fields(fields: Vec<&Field>) -> proc_macro2::TokenStream {
let mut properties = Vec::new();
for field in fields {
if has_skip_attr(field) {
continue;
}
let field_name = &field.ident.as_ref().expect("Field should have identifier");
let field_name_str = field_name.to_string();
let field_type = &field.ty;
let schema_prop = match parse_type_to_schema(field_type) {
Ok(schema) => schema,
Err(_) => {
continue;
}
};
properties.push(quote! {
#field_name_str: #schema_prop
});
}
quote! {
{ #(#properties,)* }
}
}
fn generate_prompt_arguments(fields: Vec<&Field>) -> proc_macro2::TokenStream {
let mut args = Vec::new();
for field in fields {
if has_skip_attr(field) {
continue;
}
let field_name = &field.ident.as_ref().expect("Field should have identifier");
let field_name_str = field_name.to_string();
let field_type = &field.ty;
if is_complex_type(field_type) {
continue;
}
let is_required = !is_option_type(field_type);
args.push(quote! {
mcp_host::registry::prompts::PromptArgument {
name: #field_name_str.to_string(),
description: None,
required: Some(#is_required),
}
});
}
if args.is_empty() {
quote! { None }
} else {
quote! { Some(vec![#(#args),*]) }
}
}
fn is_option_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first() {
return segment.ident == "Option";
}
}
false
}
fn is_complex_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.first() {
let name = segment.ident.to_string();
return matches!(
name.as_str(),
"Arc" | "Mutex" | "RwLock" | "Rc" | "RefCell" | "Box"
);
}
}
false
}
fn extract_resource_fields(
fields: Vec<&Field>,
) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
let mut uri = quote! { "default:///" };
let mut mime_type = quote! { None };
for field in fields {
if has_skip_attr(field) {
continue;
}
let field_name = &field.ident.as_ref().expect("Field should have identifier");
let field_name_str = field_name.to_string();
if field_name_str == "uri" || field_name_str == "uri_template" {
uri = quote! { self.#field_name.as_str() };
}
if field_name_str == "mime_type" {
mime_type = quote! { self.#field_name.as_deref() };
}
}
(uri, mime_type)
}
fn parse_type_to_schema(ty: &Type) -> Result<proc_macro2::TokenStream, String> {
match ty {
Type::Path(type_path) => {
if type_path.path.segments.is_empty() {
return Err("Empty type path".to_string());
}
let type_segment = &type_path.path.segments[0];
match type_segment.ident.to_string().as_str() {
"String" => Ok(quote! { { "type": "string" } }),
"i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64"
| "u128" | "usize" => Ok(quote! { { "type": "integer" } }),
"f32" | "f64" => Ok(quote! { { "type": "number" } }),
"bool" => Ok(quote! { { "type": "boolean" } }),
"Vec" => Ok(quote! { { "type": "array", "items": { "type": "string" } } }),
"Option" => Ok(quote! { { "type": ["string", "null"] } }),
"Arc" | "Mutex" | "RwLock" | "Rc" | "RefCell" | "Box" => {
Err("Internal type, skip".to_string())
}
_ => Err(format!("Unsupported type: {}", type_segment.ident)),
}
}
_ => Err("Complex types not yet supported".to_string()),
}
}
#[cfg(test)]
mod tests {
}