fastmcp_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::{Parse, ParseStream};
4use syn::punctuated::Punctuated;
5use syn::token::Comma;
6use syn::{ItemFn, Lit, Meta, parse_macro_input};
7
8/// 解析宏的属性参数
9struct MacroArgs {
10    /// 工具名称
11    name: Option<String>,
12    /// 工具描述
13    description: Option<String>,
14    /// 工具分类
15    categories: Vec<String>,
16    /// 是否支持流式响应
17    streaming: Option<bool>,
18    /// 工具版本
19    version: Option<String>,
20    /// 工具作者
21    author: Option<String>,
22    /// 文档URL
23    docs: Option<String>,
24}
25
26impl Parse for MacroArgs {
27    fn parse(input: ParseStream) -> syn::Result<Self> {
28        let args = Punctuated::<Meta, Comma>::parse_terminated(input)?;
29        let mut name = None;
30        let mut description = None;
31        let mut categories = Vec::new();
32        let mut streaming = None;
33        let mut version = None;
34        let mut author = None;
35        let mut docs = None;
36
37        for arg in args {
38            if let Meta::NameValue(nv) = arg {
39                let key = nv.path.get_ident().unwrap().to_string();
40                if let Lit::Str(lit) = nv.lit {
41                    let value = lit.value();
42                    match key.as_str() {
43                        "name" => name = Some(value),
44                        "description" => description = Some(value),
45                        "version" => version = Some(value),
46                        "author" => author = Some(value),
47                        "docs" => docs = Some(value),
48                        "categories" => {
49                            for cat in value.split(',') {
50                                categories.push(cat.trim().to_string());
51                            }
52                        }
53                        _ => {}
54                    }
55                } else if let Lit::Bool(b) = nv.lit {
56                    if key == "streaming" {
57                        streaming = Some(b.value);
58                    }
59                }
60            }
61        }
62
63        Ok(MacroArgs {
64            name,
65            description,
66            categories,
67            streaming,
68            version,
69            author,
70            docs,
71        })
72    }
73}
74
75/// MCP工具宏
76///
77/// 这个宏可以标记一个函数作为MCP工具,自动生成所需的Tool trait实现。
78///
79/// # 参数
80/// - `name` - 工具名称,默认为函数名
81/// - `description` - 工具描述
82/// - `categories` - 工具分类,用逗号分隔
83/// - `streaming` - 是否支持流式响应,默认为false
84/// - `version` - 工具版本
85/// - `author` - 工具作者
86/// - `docs` - 文档URL
87///
88/// # 示例
89/// ```
90/// #[mcp_tool(
91///     name = "calculator",
92///     description = "A simple calculator",
93///     categories = "math,utility"
94/// )]
95/// async fn add(a: i32, b: i32) -> i32 {
96///     a + b
97/// }
98/// ```
99#[proc_macro_attribute]
100pub fn mcp_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
101    // 解析函数
102    let input_fn = parse_macro_input!(item as ItemFn);
103    let fn_name = &input_fn.sig.ident;
104
105    // 解析宏参数
106    let args = syn::parse_macro_input!(attr as MacroArgs);
107
108    // 使用函数名作为默认工具名
109    let tool_name = args.name.unwrap_or_else(|| fn_name.to_string());
110
111    // 构建工具描述,如果未提供则使用默认值
112    let tool_description = args
113        .description
114        .unwrap_or_else(|| format!("MCP tool: {tool_name}"));
115
116    // 创建工具结构体名称
117    let struct_name = format_ident!("{}Tool", fn_name);
118
119    // 生成分类列表
120    let categories = &args.categories;
121    let categories_expr = if categories.is_empty() {
122        quote! { vec![] }
123    } else {
124        let category_strings = categories.iter().map(|c| c.as_str());
125        quote! { vec![#(#category_strings.to_string()),*] }
126    };
127
128    // 流式响应标志
129    let streaming = args.streaming.unwrap_or(false);
130
131    // 版本、作者和文档URL
132    let version_expr = if let Some(v) = args.version {
133        quote! { Some(#v.to_string()) }
134    } else {
135        quote! { None }
136    };
137
138    let author_expr = if let Some(a) = args.author {
139        quote! { Some(#a.to_string()) }
140    } else {
141        quote! { None }
142    };
143
144    let docs_expr = if let Some(d) = args.docs {
145        quote! { Some(#d.to_string()) }
146    } else {
147        quote! { None }
148    };
149
150    // 生成工具结构体和实现
151    let output = quote! {
152        // 保留原始函数
153        #input_fn
154
155        // 创建工具结构体
156        #[derive(Debug)]
157        struct #struct_name;
158
159        #[async_trait::async_trait]
160        impl ::fastmcp::Tool for #struct_name {
161            fn name(&self) -> &str {
162                #tool_name
163            }
164
165            fn description(&self) -> &str {
166                #tool_description
167            }
168
169            fn parameters(&self) -> serde_json::Value {
170                // TODO: 从函数参数自动生成参数模式
171                // 当前在MVP阶段,返回空对象
172                serde_json::json!({
173                    "type": "object",
174                    "properties": {}
175                })
176            }
177
178            fn streaming(&self) -> bool {
179                #streaming
180            }
181
182            fn categories(&self) -> Vec<String> {
183                #categories_expr
184            }
185
186            fn version(&self) -> Option<String> {
187                #version_expr
188            }
189
190            fn author(&self) -> Option<String> {
191                #author_expr
192            }
193
194            fn documentation_url(&self) -> Option<String> {
195                #docs_expr
196            }
197
198            async fn execute(&self, params: serde_json::Value, context: std::sync::Arc<::fastmcp::ToolContext>) -> ::fastmcp::Result<serde_json::Value> {
199                // 调用原始函数
200                // TODO: 在完整实现中应该解析参数并正确调用函数
201                // 当前在MVP阶段,简单转发
202                let result = #fn_name().await;
203
204                // 将结果转换为JSON
205                // TODO: 在完整实现中应该正确序列化结果
206                Ok(serde_json::json!({
207                    "result": format!("{:?}", result)
208                }))
209            }
210        }
211    };
212
213    output.into()
214}