1use convert_case::{Case, Casing};
8use proc_macro::TokenStream;
9use quote::{format_ident, quote};
10use syn::{FnArg, GenericArgument, ItemFn, Pat, PatType, PathArguments, Type, parse_macro_input};
11
12#[proc_macro_attribute]
58pub fn llm_tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
59 let func = parse_macro_input!(item as ItemFn);
60 match tool_impl(&func) {
61 Ok(tokens) => tokens.into(),
62 Err(err) => err.to_compile_error().into(),
63 }
64}
65
66struct ParamInfo {
70 name: syn::Ident,
71 ty: Box<syn::Type>,
72 doc_attrs: Vec<syn::Attribute>,
73 is_context: bool,
74}
75
76enum ReturnInfo {
78 ResultType {
80 ok_type: Box<syn::Type>,
81 err_type: Box<syn::Type>,
82 },
83 BareType,
85}
86
87fn tool_impl(func: &ItemFn) -> syn::Result<proc_macro2::TokenStream> {
88 let crate_path = quote! { ::llm_tool };
89 let fn_name = &func.sig.ident;
90 let tool_name_str = fn_name.to_string();
91 let struct_name = format_ident!("{}", tool_name_str.to_case(Case::Pascal));
92 let params_name = format_ident!("{}Params", struct_name);
93
94 let description = extract_doc_string(&func.attrs);
96 if description.is_empty() {
97 return Err(syn::Error::new_spanned(
98 fn_name,
99 "#[llm_tool] functions must have a doc comment (used as the tool description)",
100 ));
101 }
102
103 let all_params = extract_params(func)?;
105 let ctx_param = all_params.iter().find(|p| p.is_context);
106 let params: Vec<&ParamInfo> = all_params.iter().filter(|p| !p.is_context).collect();
107
108 for param in ¶ms {
110 if param.doc_attrs.is_empty() {
111 return Err(syn::Error::new_spanned(
112 ¶m.name,
113 format!(
114 "#[llm_tool] parameter `{}` must have a doc comment \
115 (used as the parameter description in the JSON schema)",
116 param.name
117 ),
118 ));
119 }
120 }
121
122 let return_info = parse_return_type(func)?;
124
125 let param_names: Vec<_> = params.iter().map(|p| &p.name).collect();
126 let param_descriptions: Vec<String> = params
127 .iter()
128 .map(|p| extract_doc_string(&p.doc_attrs))
129 .collect();
130
131 let (param_struct_types, borrow_bindings) = build_param_types_and_borrows(¶ms);
132 let serde_defaults = build_serde_defaults(¶ms);
133 let body_tokens = build_body_tokens(func, &return_info, &crate_path);
134
135 let vis = &func.vis;
136
137 let params_doc = format!("Auto-generated parameters for the [`{struct_name}`] tool.");
138 let struct_doc = format!(
139 "Auto-generated tool struct. See the `#[llm_tool]`-annotated function `{fn_name}` for the implementation."
140 );
141
142 let ctx_binding = if let Some(cp) = ctx_param {
145 let ctx_name = &cp.name;
146 quote! { let #ctx_name = _ctx; }
147 } else {
148 quote! {}
149 };
150
151 Ok(quote! {
152 #[doc = #params_doc]
153 #[derive(::serde::Deserialize, ::schemars::JsonSchema)]
154 #vis struct #params_name {
155 #(
156 #[schemars(description = #param_descriptions)]
157 #serde_defaults
158 pub #param_names: #param_struct_types,
159 )*
160 }
161
162 #[doc = #struct_doc]
163 #vis struct #struct_name;
164
165 impl #crate_path::RustTool for #struct_name {
166 type Params = #params_name;
167 const NAME: &'static str = #tool_name_str;
168 const DESCRIPTION: &'static str = #description;
169
170 async fn call(&self, params: Self::Params, _ctx: &#crate_path::ToolContext) -> ::std::result::Result<#crate_path::ToolOutput, #crate_path::ToolError> {
171 use #crate_path::__private::SerializeFallback as _;
174 let #params_name { #( #param_names, )* } = params;
177 #( #borrow_bindings )*
179 #ctx_binding
180 #body_tokens
181 }
182 }
183 })
184}
185
186fn build_param_types_and_borrows(
188 params: &[&ParamInfo],
189) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
190 params
191 .iter()
192 .map(|p| {
193 if is_str_ref(&p.ty) {
194 let name = &p.name;
196 (quote! { String }, quote! { let #name: &str = &#name; })
197 } else {
198 let ty = &p.ty;
199 (quote! { #ty }, quote! {})
200 }
201 })
202 .unzip()
203}
204
205fn build_serde_defaults(params: &[&ParamInfo]) -> Vec<proc_macro2::TokenStream> {
207 params
208 .iter()
209 .map(|p| {
210 if is_option_type(&p.ty) {
211 quote! { #[serde(default)] }
212 } else {
213 quote! {}
214 }
215 })
216 .collect()
217}
218
219fn build_body_tokens(
226 func: &ItemFn,
227 return_info: &ReturnInfo,
228 crate_path: &proc_macro2::TokenStream,
229) -> proc_macro2::TokenStream {
230 let is_async = func.sig.asyncness.is_some();
231 let body_stmts = &func.block.stmts;
232
233 match return_info {
234 ReturnInfo::ResultType { ok_type, err_type } => {
235 let inner = if is_async {
236 quote! {
237 let __r: ::std::result::Result<#ok_type, #err_type> = async move {
238 #( #body_stmts )*
239 }.await;
240 }
241 } else {
242 quote! {
243 let __r: ::std::result::Result<#ok_type, #err_type> = (|| { #( #body_stmts )* })();
244 }
245 };
246 quote! {
247 #inner
248 match __r {
249 ::std::result::Result::Ok(__v) => #crate_path::__private::Wrap(__v).__convert(),
250 ::std::result::Result::Err(__e) => ::std::result::Result::Err(::std::convert::Into::into(__e)),
251 }
252 }
253 }
254 ReturnInfo::BareType => {
255 let inner = if is_async {
256 quote! {
257 let __v = async move { #( #body_stmts )* }.await;
258 }
259 } else {
260 quote! {
261 let __v = (|| { #( #body_stmts )* })();
262 }
263 };
264 quote! {
265 #inner
266 #crate_path::__private::Wrap(__v).__convert()
267 }
268 }
269 }
270}
271
272fn is_option_type(ty: &syn::Type) -> bool {
274 let Type::Path(type_path) = ty else {
275 return false;
276 };
277 let Some(last_seg) = type_path.path.segments.last() else {
278 return false;
279 };
280 if last_seg.ident != "Option" {
281 return false;
282 }
283 matches!(&last_seg.arguments, PathArguments::AngleBracketed(args)
284 if args.args.len() == 1
285 && matches!(args.args.first(), Some(GenericArgument::Type(_))))
286}
287
288fn is_tool_context_type(ty: &syn::Type) -> bool {
291 let inner = match ty {
292 Type::Reference(r) => r.elem.as_ref(),
293 other => other,
294 };
295 let Type::Path(type_path) = inner else {
296 return false;
297 };
298 type_path
299 .path
300 .segments
301 .last()
302 .is_some_and(|seg| seg.ident == "ToolContext")
303}
304
305fn is_str_ref(ty: &syn::Type) -> bool {
307 let Type::Reference(ref_type) = ty else {
308 return false;
309 };
310 if ref_type.mutability.is_some() {
311 return false;
312 }
313 let Type::Path(type_path) = ref_type.elem.as_ref() else {
314 return false;
315 };
316 type_path
317 .path
318 .segments
319 .last()
320 .is_some_and(|seg| seg.ident == "str" && seg.arguments.is_none())
321}
322
323fn is_explicit_context_attr(attr: &syn::Attribute) -> syn::Result<bool> {
324 if !attr.path().is_ident("llm_tool") {
325 return Ok(false);
326 }
327 let mut is_context = false;
328 attr.parse_nested_meta(|meta| {
329 if meta.path.is_ident("context") {
330 is_context = true;
331 Ok(())
332 } else {
333 Err(meta.error("unsupported llm_tool attribute"))
334 }
335 })?;
336 Ok(is_context)
337}
338
339fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
340 let mut params = Vec::new();
341 for arg in &func.sig.inputs {
342 match arg {
343 FnArg::Receiver(r) => {
344 return Err(syn::Error::new_spanned(
345 r,
346 "#[llm_tool] functions must be free functions (no `self`)",
347 ));
348 }
349 FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
350 let name = match pat.as_ref() {
351 Pat::Ident(ident) => ident.ident.clone(),
352 other => {
353 return Err(syn::Error::new_spanned(
354 other,
355 "#[llm_tool] parameters must be simple identifiers",
356 ));
357 }
358 };
359
360 let mut has_context_attr = false;
361 for a in attrs {
362 has_context_attr |= is_explicit_context_attr(a)?;
363 }
364 let is_tool_context = is_tool_context_type(ty);
365 let is_context = has_context_attr || is_tool_context;
366
367 if is_tool_context && !matches!(ty.as_ref(), syn::Type::Reference(_)) {
368 return Err(syn::Error::new_spanned(
369 ty,
370 "ToolContext parameter must be a reference type (e.g., `&ToolContext` or `&'a ToolContext`)",
371 ));
372 }
373
374 let doc_attrs: Vec<syn::Attribute> = attrs
375 .iter()
376 .filter(|a| a.path().is_ident("doc"))
377 .cloned()
378 .collect();
379 params.push(ParamInfo {
380 name,
381 ty: ty.clone(),
382 doc_attrs,
383 is_context,
384 });
385 }
386 }
387 }
388 Ok(params)
389}
390
391fn extract_doc_string(attrs: &[syn::Attribute]) -> String {
392 let lines: Vec<String> = attrs
393 .iter()
394 .filter_map(|attr| {
395 if !attr.path().is_ident("doc") {
396 return None;
397 }
398 if let syn::Meta::NameValue(nv) = &attr.meta
399 && let syn::Expr::Lit(lit) = &nv.value
400 && let syn::Lit::Str(s) = &lit.lit
401 {
402 return Some(s.value());
403 }
404 None
405 })
406 .collect();
407 lines
408 .iter()
409 .map(|l| l.trim())
410 .collect::<Vec<_>>()
411 .join("\n")
412 .trim()
413 .to_string()
414}
415
416fn parse_return_type(func: &ItemFn) -> syn::Result<ReturnInfo> {
418 let syn::ReturnType::Type(_, ty) = &func.sig.output else {
419 return Err(syn::Error::new_spanned(
420 &func.sig,
421 "#[llm_tool] functions must have an explicit return type",
422 ));
423 };
424
425 if let Some(result_types) = try_extract_result_types(ty) {
427 return Ok(result_types);
428 }
429
430 Ok(ReturnInfo::BareType)
432}
433
434fn try_extract_result_types(ty: &syn::Type) -> Option<ReturnInfo> {
437 let Type::Path(type_path) = ty else {
438 return None;
439 };
440
441 let last_seg = type_path.path.segments.last()?;
442
443 if last_seg.ident != "Result" {
444 return None;
445 }
446
447 let PathArguments::AngleBracketed(args) = &last_seg.arguments else {
448 return None;
449 };
450
451 if args.args.len() != 2 {
452 return None;
453 }
454
455 let GenericArgument::Type(ok_type) = &args.args[0] else {
456 return None;
457 };
458
459 let GenericArgument::Type(err_type) = &args.args[1] else {
460 return None;
461 };
462
463 Some(ReturnInfo::ResultType {
464 ok_type: Box::new(ok_type.clone()),
465 err_type: Box::new(err_type.clone()),
466 })
467}