mcp-host-macros 0.1.0

Procedural macros for mcp-host crate
Documentation
//! Procedural macros for mcp-host
//!
//! Provides derive macros for MCP tools and resources.

use proc_macro::TokenStream;
use quote::quote;
use syn::{Data, DeriveInput, Field, Fields, Type};

/// Derive macro for MCP tools
#[proc_macro_derive(McpTool, attributes(mcp))]
pub fn derive_mcp_tool(input: TokenStream) -> TokenStream {
    let input: DeriveInput = syn::parse2(input.into()).expect("Failed to parse input");

    let name = &input.ident;

    // Generate schema from fields
    let description_token = get_description(&input.attrs);
    let schema = match &input.data {
        Data::Struct(data_struct) => match &data_struct.fields {
            Fields::Named(fields) => generate_schema_fields(fields.named.iter().collect()),
            Fields::Unnamed(_) => panic!("McTool does not support unnamed fields"),
            Fields::Unit => panic!("McTool does not support unit structs"),
        },
        Data::Enum(_) => panic!("McTool does not support enums"),
        Data::Union(_) => panic!("McTool does not support unions"),
    };

    let expanded = quote! {
        #[async_trait::async_trait]
        impl Tool for #name {
            fn name(&self) -> &str {
                stringify!(#name)
            }

            fn description(&self) -> Option<&str> {
                #description_token
            }

            fn input_schema(&self) -> serde_json::Value {
                serde_json::json!({
                    "type": "object",
                    "properties": #schema,
                    "required": []
                })
            }

            fn execution(&self) -> Option<mcp_host::protocol::types::ToolExecution> {
                None
            }

            fn is_visible(&self, _ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
                true
            }

            async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
                // Users should implement their logic directly in the execute method
                // using the struct fields
                unimplemented!("User must implement execute for {}", stringify!(#name))
            }
        }
    };

    TokenStream::from(expanded)
}

/// Derive macro for MCP resources
#[proc_macro_derive(McpResource, attributes(mcp))]
pub fn derive_mcp_resource(input: TokenStream) -> TokenStream {
    let input: DeriveInput = syn::parse2(input.into()).expect("Failed to parse input");

    let name = &input.ident;
    let description_token = get_description(&input.attrs);

    let (_impl_methods, uri, mime_type) = match &input.data {
        Data::Struct(data_struct) => match &data_struct.fields {
            Fields::Named(fields) => {
                let result = generate_resource_impl(fields.named.iter().collect());
                (result.0, result.1, result.2)
            }
            Fields::Unnamed(_) => {
                panic!("McResource does not support unnamed fields")
            }
            Fields::Unit => {
                panic!("McResource does not support unit structs")
            }
        },
        Data::Enum(_) => {
            panic!("McResource does not support enums")
        }
        Data::Union(_) => {
            panic!("McResource does not support unions")
        }
    };

    let expanded = quote! {
        #[async_trait::async_trait]
        impl Resource for #name {
            fn uri(&self) -> &str {
                #uri
            }

            fn name(&self) -> &str {
                stringify!(#name)
            }

            fn description(&self) -> Option<&str> {
                #description_token
            }

            fn mime_type(&self) -> Option<&str> {
                #mime_type
            }

            fn is_visible(&self, _ctx: &mcp_host::server::visibility::VisibilityContext) -> bool {
                true
            }
        }
    };

    TokenStream::from(expanded)
}

/// Generate schema fields from struct fields
fn generate_schema_fields(fields: Vec<&Field>) -> proc_macro2::TokenStream {
    let mut properties = Vec::new();

    for field in fields {
        let field_name = &field.ident.as_ref().expect("Field should have identifier");
        let field_name_str = field_name.to_string();
        let field_type = &field.ty;

        let schema_prop = match parse_type_to_schema(field_type) {
            Ok(schema) => schema,
            Err(_) => {
                // default fallback
                quote! { { "type": "string" } }
            }
        };

        properties.push(quote! {
            #field_name_str: #schema_prop
        });
    }

    quote! {
        { #(#properties,)* }
    }
}

/// Generate implementation methods for tool execution
#[allow(dead_code)]
fn generate_impl_methods(fields: Vec<&Field>) -> (proc_macro2::TokenStream, Vec<String>) {
    let mut param_names = Vec::new();
    let mut impl_methods = quote! {};

    for field in fields {
        let field_name = &field.ident.as_ref().expect("Field should have identifier");
        let field_name_str = field_name.to_string();
        param_names.push(field_name_str.clone());

        // Generate parameter extraction
        let extract_code = match &field.ty {
            Type::Path(type_path) => {
                let type_segment = &type_path.path.segments[0];
                match type_segment.ident.to_string().as_str() {
                    "String" => quote! {
                        ctx.params.get(#field_name_str)
                            .and_then(|v| v.as_str())
                            .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
                            .to_string()
                    },
                    "i32" | "i64" | "u32" | "u64" => quote! {
                        ctx.params.get(#field_name_str)
                            .and_then(|v| v.as_i64())
                            .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
                            as i32
                    },
                    "f32" | "f64" => quote! {
                        ctx.params.get(#field_name_str)
                            .and_then(|v| v.as_f64())
                            .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
                            as f64
                    },
                    "bool" => quote! {
                        ctx.params.get(#field_name_str)
                            .and_then(|v| v.as_bool())
                            .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
                    },
                    _ => quote! {
                        ctx.params.get(#field_name_str)
                            .cloned()
                            .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
                    },
                }
            }
            _ => quote! {
                ctx.params.get(#field_name_str)
                    .cloned()
                    .ok_or_else(|| ToolError::InvalidArguments(format!("Missing or invalid parameter: {}", #field_name_str)))?
            },
        };

        impl_methods = quote! {
            #impl_methods
            let #field_name = #extract_code;
        };
    }

    // Add the execute method that calls the user's implementation
    impl_methods = quote! {
        #impl_methods

        async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
            self.run(ctx).await
        }
    };

    (impl_methods, param_names)
}

/// Parse Rust type to JSON schema
fn parse_type_to_schema(ty: &Type) -> Result<proc_macro2::TokenStream, String> {
    match ty {
        Type::Path(type_path) => {
            let type_segment = &type_path.path.segments[0];
            match type_segment.ident.to_string().as_str() {
                "String" => Ok(quote! { { "type": "string" } }),
                "i32" | "i64" | "u32" | "u64" => Ok(quote! { { "type": "integer" } }),
                "f32" | "f64" => Ok(quote! { { "type": "number" } }),
                "bool" => Ok(quote! { { "type": "boolean" } }),
                "Vec" => Ok(quote! { { "type": "array", "items": { "type": "string" } } }),
                "Option" => {
                    // Handle Option<T> by ignoring the optional nature for now
                    Ok(quote! { { "type": ["string", "null"] } })
                }
                _ => Err(format!("Unsupported type: {}", type_segment.ident)),
            }
        }
        _ => Err("Complex types not yet supported".to_string()),
    }
}

/// Get description from attributes
fn get_description(attrs: &Vec<syn::Attribute>) -> proc_macro2::TokenStream {
    for attr in attrs {
        if attr.path().is_ident("mcp") {
            // Parse mcp attributes
            // For now, just return None
            return quote! { None };
        }
    }
    quote! { None }
}

/// Generate implementation methods for resource
fn generate_resource_impl(
    fields: Vec<&Field>,
) -> (
    proc_macro2::TokenStream,
    proc_macro2::TokenStream,
    proc_macro2::TokenStream,
) {
    let mut uri_template = quote! { "default:///" };
    let mut mime_type = quote! { None };
    let mut impl_methods = quote! {};

    for field in fields {
        let field_name = &field.ident.as_ref().expect("Field should have identifier");
        let field_name_str = field_name.to_string();

        // Check for special URI template field
        if field_name_str == "uri" || field_name_str == "uri_template" {
            uri_template = quote! {
                self.#field_name.as_str()
            };
        }

        // Check for MIME type field
        if field_name_str == "mime_type" || field_name_str == "mime_type" {
            mime_type = quote! {
                self.#field_name.as_deref()
            };
        }
    }

    // Add the read method
    impl_methods = quote! {
        #impl_methods

        async fn read(&self, ctx: ExecutionContext<'_>) -> Result<Vec<ResourceContent>, ResourceError> {
            self.read_resource(ctx).await
        }
    };

    (impl_methods, uri_template, mime_type)
}