llm_worker_macros/
lib.rs

1//! llm-worker-macros - Tool生成用手続きマクロ
2//!
3//! `#[tool_registry]` と `#[tool]` マクロを提供し、
4//! ユーザー定義のメソッドから `Tool` トレイト実装を自動生成する。
5
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
10};
11
12/// `impl` ブロックに付与し、内部の `#[tool]` 属性がついたメソッドからツールを生成するマクロ。
13///
14/// # Example
15/// ```ignore
16/// #[tool_registry]
17/// impl MyApp {
18///     /// ユーザー情報を取得する
19///     /// 指定されたIDのユーザーをDBから検索します。
20///     #[tool]
21///     async fn get_user(&self, user_id: String) -> Result<User, Error> { ... }
22/// }
23/// ```
24///
25/// これにより以下が生成されます:
26/// - `GetUserArgs` 構造体(引数用)
27/// - `Tool_get_user` 構造体(Toolラッパー)
28/// - `impl Tool for Tool_get_user`
29/// - `impl MyApp { fn get_user_tool(&self) -> Tool_get_user }`
30#[proc_macro_attribute]
31pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
32    let mut impl_block = parse_macro_input!(item as ItemImpl);
33    let self_ty = &impl_block.self_ty;
34
35    let mut generated_items = Vec::new();
36
37    for item in &mut impl_block.items {
38        if let ImplItem::Fn(method) = item {
39            // #[tool] 属性を探す
40            let mut is_tool = false;
41
42            // 属性を走査してtoolがあるか確認し、削除する
43            method.attrs.retain(|attr| {
44                if attr.path().is_ident("tool") {
45                    is_tool = true;
46                    false // 属性を削除
47                } else {
48                    true
49                }
50            });
51
52            if is_tool {
53                let tool_impl = generate_tool_impl(self_ty, method);
54                generated_items.push(tool_impl);
55            }
56        }
57    }
58
59    let expanded = quote! {
60        #impl_block
61
62        #(#generated_items)*
63    };
64
65    TokenStream::from(expanded)
66}
67
68/// ドキュメントコメントから説明文を抽出
69fn extract_doc_comment(attrs: &[Attribute]) -> String {
70    let mut lines = Vec::new();
71
72    for attr in attrs {
73        if attr.path().is_ident("doc") {
74            if let Meta::NameValue(meta) = &attr.meta {
75                if let syn::Expr::Lit(expr_lit) = &meta.value {
76                    if let Lit::Str(lit_str) = &expr_lit.lit {
77                        let line = lit_str.value();
78                        // 先頭の空白を1つだけ除去(/// の後のスペース)
79                        let trimmed = line.strip_prefix(' ').unwrap_or(&line);
80                        lines.push(trimmed.to_string());
81                    }
82                }
83            }
84        }
85    }
86
87    lines.join("\n")
88}
89
90/// #[description = "..."] 属性から説明を抽出
91fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
92    for attr in attrs {
93        if attr.path().is_ident("description") {
94            if let Meta::NameValue(meta) = &attr.meta {
95                if let syn::Expr::Lit(expr_lit) = &meta.value {
96                    if let Lit::Str(lit_str) = &expr_lit.lit {
97                        return Some(lit_str.value());
98                    }
99                }
100            }
101        }
102    }
103    None
104}
105
106/// メソッドからTool実装を生成
107fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
108    let sig = &method.sig;
109    let method_name = &sig.ident;
110    let tool_name = method_name.to_string();
111
112    // 構造体名を生成(PascalCase変換)
113    let pascal_name = to_pascal_case(&method_name.to_string());
114    let tool_struct_name = format_ident!("Tool{}", pascal_name);
115    let args_struct_name = format_ident!("{}Args", pascal_name);
116    let definition_name = format_ident!("{}_definition", method_name);
117
118    // ドキュメントコメントから説明を取得
119    let description = extract_doc_comment(&method.attrs);
120    let description = if description.is_empty() {
121        format!("Tool: {}", tool_name)
122    } else {
123        description
124    };
125
126    // 引数を解析(selfを除く)
127    let args: Vec<_> = sig
128        .inputs
129        .iter()
130        .filter_map(|arg| {
131            if let FnArg::Typed(pat_type) = arg {
132                Some(pat_type)
133            } else {
134                None // selfを除外
135            }
136        })
137        .collect();
138
139    // 引数構造体のフィールドを生成
140    let arg_fields: Vec<_> = args
141        .iter()
142        .map(|pat_type| {
143            let pat = &pat_type.pat;
144            let ty = &pat_type.ty;
145            let desc = extract_description_attr(&pat_type.attrs);
146
147            // パターンから識別子を抽出
148            let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
149                &pat_ident.ident
150            } else {
151                panic!("Only simple identifiers are supported for tool arguments");
152            };
153
154            // #[description] があればschemarsのdocに変換
155            if let Some(desc_str) = desc {
156                quote! {
157                    #[schemars(description = #desc_str)]
158                    pub #field_name: #ty
159                }
160            } else {
161                quote! {
162                    pub #field_name: #ty
163                }
164            }
165        })
166        .collect();
167
168    // execute内で引数を展開するコード
169    let arg_names: Vec<_> = args
170        .iter()
171        .map(|pat_type| {
172            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
173                let ident = &pat_ident.ident;
174                quote! { args.#ident }
175            } else {
176                panic!("Only simple identifiers are supported");
177            }
178        })
179        .collect();
180
181    // メソッドが非同期かどうか
182    let is_async = sig.asyncness.is_some();
183
184    // 戻り値の型を解析してResult判定
185    let awaiter = if is_async {
186        quote! { .await }
187    } else {
188        quote! {}
189    };
190
191    // 戻り値がResultかどうかを判定
192    let result_handling = if is_result_type(&sig.output) {
193        quote! {
194            match result {
195                Ok(val) => Ok(format!("{:?}", val)),
196                Err(e) => Err(::llm_worker::tool::ToolError::ExecutionFailed(format!("{}", e))),
197            }
198        }
199    } else {
200        quote! {
201            Ok(format!("{:?}", result))
202        }
203    };
204
205    // 引数がない場合は空のArgs構造体を作成
206    let args_struct_def = if arg_fields.is_empty() {
207        quote! {
208            #[derive(serde::Deserialize, schemars::JsonSchema)]
209            struct #args_struct_name {}
210        }
211    } else {
212        quote! {
213            #[derive(serde::Deserialize, schemars::JsonSchema)]
214            struct #args_struct_name {
215                #(#arg_fields),*
216            }
217        }
218    };
219
220    // 引数がない場合のexecute処理
221    let execute_body = if args.is_empty() {
222        quote! {
223            // 引数なしでも空のJSONオブジェクトを許容
224            let _: #args_struct_name = serde_json::from_str(input_json)
225                .unwrap_or(#args_struct_name {});
226
227            let result = self.ctx.#method_name()#awaiter;
228            #result_handling
229        }
230    } else {
231        quote! {
232            let args: #args_struct_name = serde_json::from_str(input_json)
233                .map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;
234
235            let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
236            #result_handling
237        }
238    };
239
240    quote! {
241        #args_struct_def
242
243        #[derive(Clone)]
244        pub struct #tool_struct_name {
245            ctx: #self_ty,
246        }
247
248        #[async_trait::async_trait]
249        impl ::llm_worker::tool::Tool for #tool_struct_name {
250            async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
251                #execute_body
252            }
253        }
254
255        impl #self_ty {
256            /// ToolDefinition を取得(Worker への登録用)
257            pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
258                let ctx = self.clone();
259                ::std::sync::Arc::new(move || {
260                    let schema = schemars::schema_for!(#args_struct_name);
261                    let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
262                        .description(#description)
263                        .input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
264                    let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
265                        ::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
266                    (meta, tool)
267                })
268            }
269        }
270    }
271}
272
273/// 戻り値の型がResultかどうかを判定
274fn is_result_type(return_type: &ReturnType) -> bool {
275    match return_type {
276        ReturnType::Default => false,
277        ReturnType::Type(_, ty) => {
278            // Type::Pathの場合、最後のセグメントが"Result"かチェック
279            if let Type::Path(type_path) = ty.as_ref() {
280                if let Some(segment) = type_path.path.segments.last() {
281                    return segment.ident == "Result";
282                }
283            }
284            false
285        }
286    }
287}
288
289/// snake_case を PascalCase に変換
290fn to_pascal_case(s: &str) -> String {
291    s.split('_')
292        .map(|part| {
293            let mut chars = part.chars();
294            match chars.next() {
295                None => String::new(),
296                Some(first) => first.to_uppercase().chain(chars).collect(),
297            }
298        })
299        .collect()
300}
301
302/// マーカー属性。`tool_registry` によって処理されるため、ここでは何もしない。
303#[proc_macro_attribute]
304pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
305    item
306}
307
308/// 引数属性用のマーカー。パース時に`tool_registry`で解釈される。
309///
310/// # Example
311/// ```ignore
312/// #[tool]
313/// async fn get_user(
314///     &self,
315///     #[description = "取得したいユーザーのID"] user_id: String
316/// ) -> Result<User, Error> { ... }
317/// ```
318#[proc_macro_attribute]
319pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
320    item
321}