1use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 Attribute, FnArg, ImplItem, ItemImpl, Lit, Meta, Pat, ReturnType, Type, parse_macro_input,
10};
11
12#[proc_macro_attribute]
31pub fn tool_registry(_attr: TokenStream, item: TokenStream) -> TokenStream {
32 let mut impl_block = parse_macro_input!(item as ItemImpl);
33 let self_ty = &impl_block.self_ty;
34
35 let mut generated_items = Vec::new();
36
37 for item in &mut impl_block.items {
38 if let ImplItem::Fn(method) = item {
39 let mut is_tool = false;
41
42 method.attrs.retain(|attr| {
44 if attr.path().is_ident("tool") {
45 is_tool = true;
46 false } else {
48 true
49 }
50 });
51
52 if is_tool {
53 let tool_impl = generate_tool_impl(self_ty, method);
54 generated_items.push(tool_impl);
55 }
56 }
57 }
58
59 let expanded = quote! {
60 #impl_block
61
62 #(#generated_items)*
63 };
64
65 TokenStream::from(expanded)
66}
67
68fn extract_doc_comment(attrs: &[Attribute]) -> String {
70 let mut lines = Vec::new();
71
72 for attr in attrs {
73 if attr.path().is_ident("doc") {
74 if let Meta::NameValue(meta) = &attr.meta {
75 if let syn::Expr::Lit(expr_lit) = &meta.value {
76 if let Lit::Str(lit_str) = &expr_lit.lit {
77 let line = lit_str.value();
78 let trimmed = line.strip_prefix(' ').unwrap_or(&line);
80 lines.push(trimmed.to_string());
81 }
82 }
83 }
84 }
85 }
86
87 lines.join("\n")
88}
89
90fn extract_description_attr(attrs: &[syn::Attribute]) -> Option<String> {
92 for attr in attrs {
93 if attr.path().is_ident("description") {
94 if let Meta::NameValue(meta) = &attr.meta {
95 if let syn::Expr::Lit(expr_lit) = &meta.value {
96 if let Lit::Str(lit_str) = &expr_lit.lit {
97 return Some(lit_str.value());
98 }
99 }
100 }
101 }
102 }
103 None
104}
105
106fn generate_tool_impl(self_ty: &Type, method: &syn::ImplItemFn) -> proc_macro2::TokenStream {
108 let sig = &method.sig;
109 let method_name = &sig.ident;
110 let tool_name = method_name.to_string();
111
112 let pascal_name = to_pascal_case(&method_name.to_string());
114 let tool_struct_name = format_ident!("Tool{}", pascal_name);
115 let args_struct_name = format_ident!("{}Args", pascal_name);
116 let definition_name = format_ident!("{}_definition", method_name);
117
118 let description = extract_doc_comment(&method.attrs);
120 let description = if description.is_empty() {
121 format!("Tool: {}", tool_name)
122 } else {
123 description
124 };
125
126 let args: Vec<_> = sig
128 .inputs
129 .iter()
130 .filter_map(|arg| {
131 if let FnArg::Typed(pat_type) = arg {
132 Some(pat_type)
133 } else {
134 None }
136 })
137 .collect();
138
139 let arg_fields: Vec<_> = args
141 .iter()
142 .map(|pat_type| {
143 let pat = &pat_type.pat;
144 let ty = &pat_type.ty;
145 let desc = extract_description_attr(&pat_type.attrs);
146
147 let field_name = if let Pat::Ident(pat_ident) = pat.as_ref() {
149 &pat_ident.ident
150 } else {
151 panic!("Only simple identifiers are supported for tool arguments");
152 };
153
154 if let Some(desc_str) = desc {
156 quote! {
157 #[schemars(description = #desc_str)]
158 pub #field_name: #ty
159 }
160 } else {
161 quote! {
162 pub #field_name: #ty
163 }
164 }
165 })
166 .collect();
167
168 let arg_names: Vec<_> = args
170 .iter()
171 .map(|pat_type| {
172 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
173 let ident = &pat_ident.ident;
174 quote! { args.#ident }
175 } else {
176 panic!("Only simple identifiers are supported");
177 }
178 })
179 .collect();
180
181 let is_async = sig.asyncness.is_some();
183
184 let awaiter = if is_async {
186 quote! { .await }
187 } else {
188 quote! {}
189 };
190
191 let result_handling = if is_result_type(&sig.output) {
193 quote! {
194 match result {
195 Ok(val) => Ok(format!("{:?}", val)),
196 Err(e) => Err(::llm_worker::tool::ToolError::ExecutionFailed(format!("{}", e))),
197 }
198 }
199 } else {
200 quote! {
201 Ok(format!("{:?}", result))
202 }
203 };
204
205 let args_struct_def = if arg_fields.is_empty() {
207 quote! {
208 #[derive(serde::Deserialize, schemars::JsonSchema)]
209 struct #args_struct_name {}
210 }
211 } else {
212 quote! {
213 #[derive(serde::Deserialize, schemars::JsonSchema)]
214 struct #args_struct_name {
215 #(#arg_fields),*
216 }
217 }
218 };
219
220 let execute_body = if args.is_empty() {
222 quote! {
223 let _: #args_struct_name = serde_json::from_str(input_json)
225 .unwrap_or(#args_struct_name {});
226
227 let result = self.ctx.#method_name()#awaiter;
228 #result_handling
229 }
230 } else {
231 quote! {
232 let args: #args_struct_name = serde_json::from_str(input_json)
233 .map_err(|e| ::llm_worker::tool::ToolError::InvalidArgument(e.to_string()))?;
234
235 let result = self.ctx.#method_name(#(#arg_names),*)#awaiter;
236 #result_handling
237 }
238 };
239
240 quote! {
241 #args_struct_def
242
243 #[derive(Clone)]
244 pub struct #tool_struct_name {
245 ctx: #self_ty,
246 }
247
248 #[async_trait::async_trait]
249 impl ::llm_worker::tool::Tool for #tool_struct_name {
250 async fn execute(&self, input_json: &str) -> Result<String, ::llm_worker::tool::ToolError> {
251 #execute_body
252 }
253 }
254
255 impl #self_ty {
256 pub fn #definition_name(&self) -> ::llm_worker::tool::ToolDefinition {
258 let ctx = self.clone();
259 ::std::sync::Arc::new(move || {
260 let schema = schemars::schema_for!(#args_struct_name);
261 let meta = ::llm_worker::tool::ToolMeta::new(#tool_name)
262 .description(#description)
263 .input_schema(serde_json::to_value(schema).unwrap_or(serde_json::json!({})));
264 let tool: ::std::sync::Arc<dyn ::llm_worker::tool::Tool> =
265 ::std::sync::Arc::new(#tool_struct_name { ctx: ctx.clone() });
266 (meta, tool)
267 })
268 }
269 }
270 }
271}
272
273fn is_result_type(return_type: &ReturnType) -> bool {
275 match return_type {
276 ReturnType::Default => false,
277 ReturnType::Type(_, ty) => {
278 if let Type::Path(type_path) = ty.as_ref() {
280 if let Some(segment) = type_path.path.segments.last() {
281 return segment.ident == "Result";
282 }
283 }
284 false
285 }
286 }
287}
288
289fn to_pascal_case(s: &str) -> String {
291 s.split('_')
292 .map(|part| {
293 let mut chars = part.chars();
294 match chars.next() {
295 None => String::new(),
296 Some(first) => first.to_uppercase().chain(chars).collect(),
297 }
298 })
299 .collect()
300}
301
302#[proc_macro_attribute]
304pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
305 item
306}
307
308#[proc_macro_attribute]
319pub fn description(_attr: TokenStream, item: TokenStream) -> TokenStream {
320 item
321}