1use convert_case::{Case, Casing};
11use proc_macro::TokenStream;
12use quote::{format_ident, quote};
13use syn::{
14 FnArg, GenericArgument, Ident, ItemFn, LitStr, Pat, PatType, PathArguments, Token, Type,
15 parse::ParseStream, parse_macro_input,
16};
17
18#[proc_macro_attribute]
92pub fn llm_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
93 let func = parse_macro_input!(item as ItemFn);
94 let tool_attr = if attr.is_empty() {
95 None
96 } else {
97 match syn::parse::<ToolAttr>(attr) {
98 Ok(parsed) => Some(parsed),
99 Err(err) => return err.to_compile_error().into(),
100 }
101 };
102 match tool_impl(&func, tool_attr.as_ref()) {
103 Ok(tokens) => tokens.into(),
104 Err(err) => err.to_compile_error().into(),
105 }
106}
107
108struct ToolAttr {
118 description_inline: Option<LitStr>,
120 template_path: Option<LitStr>,
122 #[cfg(feature = "prompt-templates")]
125 inline_params: Vec<(Ident, LitStr)>,
126 #[cfg(feature = "prompt-templates")]
128 context_fn: Option<syn::Path>,
129}
130
131impl syn::parse::Parse for ToolAttr {
132 fn parse(input: ParseStream) -> syn::Result<Self> {
133 let mut description_inline = None;
134 let mut template_path = None;
135 #[cfg(feature = "prompt-templates")]
136 let mut inline_params = Vec::new();
137 #[cfg(feature = "prompt-templates")]
138 let mut context_fn = None;
139 #[cfg(not(feature = "prompt-templates"))]
141 let mut has_inline_params = false;
142 #[cfg(not(feature = "prompt-templates"))]
143 let mut has_context_fn = false;
144
145 while !input.is_empty() {
146 let ident: Ident = input.parse()?;
147 if ident == "description" {
148 let _: Token![=] = input.parse()?;
149 description_inline = Some(input.parse::<LitStr>()?);
150 } else if ident == "template" {
151 let _: Token![=] = input.parse()?;
152 template_path = Some(input.parse::<LitStr>()?);
153 } else if ident == "params" {
154 let content;
155 syn::parenthesized!(content in input);
156 while !content.is_empty() {
157 let key: Ident = content.parse()?;
158 let _: Token![=] = content.parse()?;
159 let value: LitStr = content.parse()?;
160 #[cfg(feature = "prompt-templates")]
161 inline_params.push((key, value));
162 #[cfg(not(feature = "prompt-templates"))]
164 {
165 let _ = (key, value);
166 }
167 if !content.is_empty() {
168 let _: Token![,] = content.parse()?;
169 }
170 }
171 #[cfg(not(feature = "prompt-templates"))]
172 {
173 has_inline_params = true;
174 }
175 } else if ident == "context" {
176 let _: Token![=] = input.parse()?;
177 #[cfg(feature = "prompt-templates")]
178 {
179 context_fn = Some(input.parse::<syn::Path>()?);
180 }
181 #[cfg(not(feature = "prompt-templates"))]
182 {
183 let _path: syn::Path = input.parse()?;
184 has_context_fn = true;
185 }
186 } else {
187 return Err(syn::Error::new(
188 ident.span(),
189 "expected `description`, `template`, `params`, or `context`",
190 ));
191 }
192
193 if !input.is_empty() {
194 let _: Token![,] = input.parse()?;
195 }
196 }
197
198 #[cfg(feature = "prompt-templates")]
199 let (has_inline_params, has_context_fn) = (!inline_params.is_empty(), context_fn.is_some());
200
201 validate_tool_attr(
202 description_inline.as_ref(),
203 template_path.as_ref(),
204 has_inline_params,
205 has_context_fn,
206 )?;
207
208 Ok(Self {
209 description_inline,
210 template_path,
211 #[cfg(feature = "prompt-templates")]
212 inline_params,
213 #[cfg(feature = "prompt-templates")]
214 context_fn,
215 })
216 }
217}
218
219fn validate_tool_attr(
222 description_inline: Option<&LitStr>,
223 template_path: Option<&LitStr>,
224 has_inline_params: bool,
225 has_context_fn: bool,
226) -> syn::Result<()> {
227 if description_inline.is_some() && template_path.is_some() {
229 return Err(syn::Error::new(
230 proc_macro2::Span::call_site(),
231 "`description` and `template` are mutually exclusive",
232 ));
233 }
234
235 if template_path.is_none() && has_inline_params {
237 return Err(syn::Error::new(
238 proc_macro2::Span::call_site(),
239 "`params(...)` requires `template = \"...\"`",
240 ));
241 }
242 if template_path.is_none() && has_context_fn {
243 return Err(syn::Error::new(
244 proc_macro2::Span::call_site(),
245 "`context = ...` requires `template = \"...\"`",
246 ));
247 }
248
249 if has_inline_params && has_context_fn {
251 return Err(syn::Error::new(
252 proc_macro2::Span::call_site(),
253 "`params(...)` and `context = ...` are mutually exclusive; \
254 use `params` for compile-time values or `context` for runtime values",
255 ));
256 }
257
258 if description_inline.is_none() && template_path.is_none() {
260 return Err(syn::Error::new(
261 proc_macro2::Span::call_site(),
262 "expected `description = \"...\"` or `template = \"...\"`",
263 ));
264 }
265
266 Ok(())
267}
268
269struct ParamInfo {
273 name: syn::Ident,
274 ty: Box<syn::Type>,
275 doc_attrs: Vec<syn::Attribute>,
276 is_context: bool,
277}
278
279enum ReturnInfo {
281 ResultType {
283 ok_type: Box<syn::Type>,
284 err_type: Box<syn::Type>,
285 },
286 BareType,
288}
289
290fn tool_impl(func: &ItemFn, attr: Option<&ToolAttr>) -> syn::Result<proc_macro2::TokenStream> {
291 let crate_path = quote! { ::llm_tool };
292 let fn_name = &func.sig.ident;
293 let tool_name_str = fn_name.to_string();
294 let struct_name = format_ident!("{}", tool_name_str.to_case(Case::Pascal));
295 let params_name = format_ident!("{}Params", struct_name);
296
297 let DescriptionInfo {
299 static_description,
300 helper_tokens,
301 description_method,
302 dep_tracking,
303 } = resolve_description(func, attr)?;
304
305 let all_params = extract_params(func)?;
307 let ctx_param = all_params.iter().find(|p| p.is_context);
308 let params: Vec<&ParamInfo> = all_params.iter().filter(|p| !p.is_context).collect();
309
310 for param in ¶ms {
312 if param.doc_attrs.is_empty() {
313 return Err(syn::Error::new_spanned(
314 ¶m.name,
315 format!(
316 "#[llm_tool] parameter `{}` must have a doc comment \
317 (used as the parameter description in the JSON schema)",
318 param.name
319 ),
320 ));
321 }
322 }
323
324 let return_info = parse_return_type(func)?;
326
327 let param_names: Vec<_> = params.iter().map(|p| &p.name).collect();
328 let param_descriptions: Vec<String> = params
329 .iter()
330 .map(|p| extract_doc_string(&p.doc_attrs))
331 .collect();
332
333 let (param_struct_types, borrow_bindings) = build_param_types_and_borrows(¶ms);
334 let serde_defaults = build_serde_defaults(¶ms);
335 let body_tokens = build_body_tokens(func, &return_info, &crate_path);
336
337 let vis = &func.vis;
338
339 let params_doc = format!("Auto-generated parameters for the [`{struct_name}`] tool.");
340 let struct_doc = format!(
341 "Auto-generated tool struct. See the `#[llm_tool]`-annotated function `{fn_name}` for the implementation."
342 );
343
344 let ctx_binding = if let Some(cp) = ctx_param {
347 let ctx_name = &cp.name;
348 quote! { let #ctx_name = _ctx; }
349 } else {
350 quote! {}
351 };
352
353 Ok(quote! {
354 #dep_tracking
355 #helper_tokens
356
357 #[doc = #params_doc]
358 #[derive(::serde::Deserialize, ::schemars::JsonSchema)]
359 #vis struct #params_name {
360 #(
361 #[schemars(description = #param_descriptions)]
362 #serde_defaults
363 pub #param_names: #param_struct_types,
364 )*
365 }
366
367 #[doc = #struct_doc]
368 #vis struct #struct_name;
369
370 impl #crate_path::RustTool for #struct_name {
371 type Params = #params_name;
372 const NAME: &'static str = #tool_name_str;
373 const DESCRIPTION: &'static str = #static_description;
374
375 #description_method
376
377 async fn call(&self, params: Self::Params, _ctx: &#crate_path::ToolContext) -> ::std::result::Result<#crate_path::ToolOutput, #crate_path::ToolError> {
378 use #crate_path::__private::SerializeFallback as _;
381 let #params_name { #( #param_names, )* } = params;
384 #( #borrow_bindings )*
386 #ctx_binding
387 #body_tokens
388 }
389 }
390 })
391}
392
393struct DescriptionInfo {
397 static_description: String,
399 helper_tokens: proc_macro2::TokenStream,
401 description_method: Option<proc_macro2::TokenStream>,
403 dep_tracking: proc_macro2::TokenStream,
405}
406
407fn resolve_description(func: &ItemFn, attr: Option<&ToolAttr>) -> syn::Result<DescriptionInfo> {
409 match attr {
410 None => {
412 let desc = extract_doc_string(&func.attrs);
413 if desc.is_empty() {
414 return Err(syn::Error::new_spanned(
415 &func.sig.ident,
416 "#[llm_tool] functions must have a doc comment \
417 (used as the tool description), or use \
418 #[llm_tool(description = \"...\")]",
419 ));
420 }
421 Ok(DescriptionInfo {
422 static_description: desc,
423 helper_tokens: quote! {},
424 description_method: None,
425 dep_tracking: quote! {},
426 })
427 }
428 Some(ToolAttr {
430 description_inline: Some(desc),
431 ..
432 }) => Ok(DescriptionInfo {
433 static_description: desc.value(),
434 helper_tokens: quote! {},
435 description_method: None,
436 dep_tracking: quote! {},
437 }),
438 Some(
440 tool_attr @ ToolAttr {
441 template_path: Some(_),
442 ..
443 },
444 ) => resolve_template_description(tool_attr),
445 _ => Err(syn::Error::new(
447 proc_macro2::Span::call_site(),
448 "expected `description = \"...\"` or `template = \"...\"`",
449 )),
450 }
451}
452
453fn resolve_template_description(attr: &ToolAttr) -> syn::Result<DescriptionInfo> {
462 #[cfg(not(feature = "prompt-templates"))]
463 {
464 let span = attr
465 .template_path
466 .as_ref()
467 .map_or(proc_macro2::Span::call_site(), LitStr::span);
468 Err(syn::Error::new(
469 span,
470 "the `prompt-templates` feature must be enabled to use \
471 `#[llm_tool(template = \"...\")]`. \
472 Add `features = [\"prompt-templates\"]` to your llm-tool dependency.",
473 ))
474 }
475
476 #[cfg(feature = "prompt-templates")]
477 resolve_template_description_impl(attr)
478}
479
480#[cfg(feature = "prompt-templates")]
487fn resolve_template_description_impl(attr: &ToolAttr) -> syn::Result<DescriptionInfo> {
488 let template_lit = attr
489 .template_path
490 .as_ref()
491 .expect("template_path validated");
492 let rel_path = template_lit.value();
493 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
494 let full_path = std::path::Path::new(&manifest_dir).join(&rel_path);
495
496 let source = std::fs::read_to_string(&full_path).map_err(|e| {
497 syn::Error::new(
498 template_lit.span(),
499 format!("failed to read template '{}': {e}", full_path.display()),
500 )
501 })?;
502
503 let (fm, body) = prompt_templates::parse_frontmatter(&source).map_err(|e| {
504 syn::Error::new(
505 template_lit.span(),
506 format!("template '{rel_path}' error: {e}"),
507 )
508 })?;
509
510 let body_str = body.trim().to_string();
511 let path_str = full_path.to_string_lossy().to_string();
512
513 let dep_tracking = quote! {
516 const _: &str = include_str!(#path_str);
517 };
518
519 let has_params = !attr.inline_params.is_empty();
520 let has_context = attr.context_fn.is_some();
521 let has_declarations = !fm.declarations.is_empty();
522
523 if !has_declarations && !has_params && !has_context {
524 Ok(DescriptionInfo {
526 static_description: body_str,
527 helper_tokens: quote! {},
528 description_method: None,
529 dep_tracking,
530 })
531 } else if has_params {
532 resolve_template_with_params(
534 attr,
535 &fm,
536 &source,
537 &rel_path,
538 template_lit.span(),
539 dep_tracking,
540 )
541 } else if has_context {
542 resolve_context_description(
544 attr,
545 &rel_path,
546 template_lit,
547 &body_str,
548 &path_str,
549 has_declarations,
550 dep_tracking,
551 )
552 } else {
553 let declared: Vec<&str> = fm.declarations.iter().map(|d| d.name.as_str()).collect();
555 Err(syn::Error::new(
556 template_lit.span(),
557 format!(
558 "template '{rel_path}' declares parameters ({}) but neither \
559 `params(...)` nor `context = ...` was provided",
560 declared.join(", ")
561 ),
562 ))
563 }
564}
565
566#[cfg(feature = "prompt-templates")]
572fn resolve_context_description(
573 attr: &ToolAttr,
574 rel_path: &str,
575 template_lit: &LitStr,
576 body_str: &str,
577 path_str: &str,
578 has_declarations: bool,
579 dep_tracking: proc_macro2::TokenStream,
580) -> syn::Result<DescriptionInfo> {
581 let context_fn = attr.context_fn.as_ref().unwrap();
582
583 if !has_declarations {
584 return Err(syn::Error::new(
585 template_lit.span(),
586 format!(
587 "template '{rel_path}' has no declared parameters, \
588 so `context = ...` is unnecessary. Remove `context` \
589 or add params to the template."
590 ),
591 ));
592 }
593
594 let description_method = quote! {
597 fn description(&self) -> ::std::borrow::Cow<'static, str> {
598 static TEMPLATE: ::std::sync::LazyLock<::prompt_templates::Template> =
599 ::std::sync::LazyLock::new(|| {
600 ::prompt_templates::Template::from_source(
601 include_str!(#path_str)
602 ).expect("Valid template (verified at compile time)")
603 });
604 let ctx = #context_fn(self);
605 let rendered = TEMPLATE.render(&ctx)
606 .expect("Failed to render tool description template");
607 ::std::borrow::Cow::Owned(rendered)
608 }
609 };
610
611 Ok(DescriptionInfo {
612 static_description: body_str.to_owned(),
613 helper_tokens: quote! {},
614 description_method: Some(description_method),
615 dep_tracking,
616 })
617}
618
619#[cfg(feature = "prompt-templates")]
626fn resolve_template_with_params(
627 attr: &ToolAttr,
628 fm: &prompt_templates::Frontmatter,
629 source: &str,
630 rel_path: &str,
631 span: proc_macro2::Span,
632 dep_tracking: proc_macro2::TokenStream,
633) -> syn::Result<DescriptionInfo> {
634 let declared_names: std::collections::HashSet<&str> =
635 fm.declarations.iter().map(|d| d.name.as_str()).collect();
636 let provided_names: std::collections::HashSet<String> = attr
637 .inline_params
638 .iter()
639 .map(|(k, _)| k.to_string())
640 .collect();
641
642 let missing: Vec<&str> = declared_names
644 .iter()
645 .filter(|n| !provided_names.contains(**n))
646 .copied()
647 .collect();
648 if !missing.is_empty() {
649 return Err(syn::Error::new(
650 span,
651 format!(
652 "template '{rel_path}' declares parameters not provided in `params(...)`: {}",
653 missing.join(", ")
654 ),
655 ));
656 }
657
658 for (key, _) in &attr.inline_params {
660 let key_str = key.to_string();
661 if !declared_names.contains(key_str.as_str()) {
662 return Err(syn::Error::new(
663 key.span(),
664 format!(
665 "param `{key_str}` is not declared in template '{rel_path}'. \
666 Declared params: {}",
667 declared_names.into_iter().collect::<Vec<_>>().join(", ")
668 ),
669 ));
670 }
671 }
672
673 let template = prompt_templates::Template::from_source(source)
675 .map_err(|e| syn::Error::new(span, format!("template '{rel_path}' parse error: {e}")))?;
676
677 let mut ctx = prompt_templates::Context::new();
678 for (key, value) in &attr.inline_params {
679 ctx.set(key.to_string(), value.value());
680 }
681
682 let rendered = template
683 .render(&ctx)
684 .map_err(|e| syn::Error::new(span, format!("template '{rel_path}' render error: {e}")))?;
685
686 Ok(DescriptionInfo {
687 static_description: rendered,
688 helper_tokens: quote! {},
689 description_method: None,
690 dep_tracking,
691 })
692}
693
694fn build_param_types_and_borrows(
696 params: &[&ParamInfo],
697) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
698 params
699 .iter()
700 .map(|p| {
701 if is_str_ref(&p.ty) {
702 let name = &p.name;
704 (quote! { String }, quote! { let #name: &str = &#name; })
705 } else {
706 let ty = &p.ty;
707 (quote! { #ty }, quote! {})
708 }
709 })
710 .unzip()
711}
712
713fn build_serde_defaults(params: &[&ParamInfo]) -> Vec<proc_macro2::TokenStream> {
715 params
716 .iter()
717 .map(|p| {
718 if is_option_type(&p.ty) {
719 quote! { #[serde(default)] }
720 } else {
721 quote! {}
722 }
723 })
724 .collect()
725}
726
727fn build_body_tokens(
734 func: &ItemFn,
735 return_info: &ReturnInfo,
736 crate_path: &proc_macro2::TokenStream,
737) -> proc_macro2::TokenStream {
738 let is_async = func.sig.asyncness.is_some();
739 let body_stmts = &func.block.stmts;
740
741 match return_info {
742 ReturnInfo::ResultType { ok_type, err_type } => {
743 let inner = if is_async {
744 quote! {
745 let __r: ::std::result::Result<#ok_type, #err_type> = async move {
746 #( #body_stmts )*
747 }.await;
748 }
749 } else {
750 quote! {
751 let __r: ::std::result::Result<#ok_type, #err_type> = (|| { #( #body_stmts )* })();
752 }
753 };
754 quote! {
755 #inner
756 match __r {
757 ::std::result::Result::Ok(__v) => #crate_path::__private::Wrap(__v).__convert(),
758 ::std::result::Result::Err(__e) => ::std::result::Result::Err(::std::convert::Into::into(__e)),
759 }
760 }
761 }
762 ReturnInfo::BareType => {
763 let inner = if is_async {
764 quote! {
765 let __v = async move { #( #body_stmts )* }.await;
766 }
767 } else {
768 quote! {
769 let __v = (|| { #( #body_stmts )* })();
770 }
771 };
772 quote! {
773 #inner
774 #crate_path::__private::Wrap(__v).__convert()
775 }
776 }
777 }
778}
779
780fn is_option_type(ty: &syn::Type) -> bool {
782 let Type::Path(type_path) = ty else {
783 return false;
784 };
785 let Some(last_seg) = type_path.path.segments.last() else {
786 return false;
787 };
788 if last_seg.ident != "Option" {
789 return false;
790 }
791 matches!(&last_seg.arguments, PathArguments::AngleBracketed(args)
792 if args.args.len() == 1
793 && matches!(args.args.first(), Some(GenericArgument::Type(_))))
794}
795
796fn is_tool_context_type(ty: &syn::Type) -> bool {
799 let inner = match ty {
800 Type::Reference(r) => r.elem.as_ref(),
801 other => other,
802 };
803 let Type::Path(type_path) = inner else {
804 return false;
805 };
806 type_path
807 .path
808 .segments
809 .last()
810 .is_some_and(|seg| seg.ident == "ToolContext")
811}
812
813fn is_str_ref(ty: &syn::Type) -> bool {
815 let Type::Reference(ref_type) = ty else {
816 return false;
817 };
818 if ref_type.mutability.is_some() {
819 return false;
820 }
821 let Type::Path(type_path) = ref_type.elem.as_ref() else {
822 return false;
823 };
824 type_path
825 .path
826 .segments
827 .last()
828 .is_some_and(|seg| seg.ident == "str" && seg.arguments.is_none())
829}
830
831fn is_explicit_context_attr(attr: &syn::Attribute) -> syn::Result<bool> {
832 if !attr.path().is_ident("llm_tool") {
833 return Ok(false);
834 }
835 let mut is_context = false;
836 attr.parse_nested_meta(|meta| {
837 if meta.path.is_ident("context") {
838 is_context = true;
839 Ok(())
840 } else {
841 Err(meta.error("unsupported llm_tool attribute"))
842 }
843 })?;
844 Ok(is_context)
845}
846
847fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
848 let mut params = Vec::new();
849 for arg in &func.sig.inputs {
850 match arg {
851 FnArg::Receiver(r) => {
852 return Err(syn::Error::new_spanned(
853 r,
854 "#[llm_tool] functions must be free functions (no `self`)",
855 ));
856 }
857 FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
858 let name = match pat.as_ref() {
859 Pat::Ident(ident) => ident.ident.clone(),
860 other => {
861 return Err(syn::Error::new_spanned(
862 other,
863 "#[llm_tool] parameters must be simple identifiers",
864 ));
865 }
866 };
867
868 let mut has_context_attr = false;
869 for a in attrs {
870 has_context_attr |= is_explicit_context_attr(a)?;
871 }
872 let is_tool_context = is_tool_context_type(ty);
873 let is_context = has_context_attr || is_tool_context;
874
875 if is_tool_context && !matches!(ty.as_ref(), syn::Type::Reference(_)) {
876 return Err(syn::Error::new_spanned(
877 ty,
878 "ToolContext parameter must be a reference type (e.g., `&ToolContext` or `&'a ToolContext`)",
879 ));
880 }
881
882 let doc_attrs: Vec<syn::Attribute> = attrs
883 .iter()
884 .filter(|a| a.path().is_ident("doc"))
885 .cloned()
886 .collect();
887 params.push(ParamInfo {
888 name,
889 ty: ty.clone(),
890 doc_attrs,
891 is_context,
892 });
893 }
894 }
895 }
896 Ok(params)
897}
898
899fn extract_doc_string(attrs: &[syn::Attribute]) -> String {
900 let lines: Vec<String> = attrs
901 .iter()
902 .filter_map(|attr| {
903 if !attr.path().is_ident("doc") {
904 return None;
905 }
906 if let syn::Meta::NameValue(nv) = &attr.meta
907 && let syn::Expr::Lit(lit) = &nv.value
908 && let syn::Lit::Str(s) = &lit.lit
909 {
910 return Some(s.value());
911 }
912 None
913 })
914 .collect();
915 lines
916 .iter()
917 .map(|l| l.trim())
918 .collect::<Vec<_>>()
919 .join("\n")
920 .trim()
921 .to_string()
922}
923
924fn parse_return_type(func: &ItemFn) -> syn::Result<ReturnInfo> {
926 let syn::ReturnType::Type(_, ty) = &func.sig.output else {
927 return Err(syn::Error::new_spanned(
928 &func.sig,
929 "#[llm_tool] functions must have an explicit return type",
930 ));
931 };
932
933 if let Some(result_types) = try_extract_result_types(ty) {
935 return Ok(result_types);
936 }
937
938 Ok(ReturnInfo::BareType)
940}
941
942fn try_extract_result_types(ty: &syn::Type) -> Option<ReturnInfo> {
945 let Type::Path(type_path) = ty else {
946 return None;
947 };
948
949 let last_seg = type_path.path.segments.last()?;
950
951 if last_seg.ident != "Result" {
952 return None;
953 }
954
955 let PathArguments::AngleBracketed(args) = &last_seg.arguments else {
956 return None;
957 };
958
959 if args.args.len() != 2 {
960 return None;
961 }
962
963 let GenericArgument::Type(ok_type) = &args.args[0] else {
964 return None;
965 };
966
967 let GenericArgument::Type(err_type) = &args.args[1] else {
968 return None;
969 };
970
971 Some(ReturnInfo::ResultType {
972 ok_type: Box::new(ok_type.clone()),
973 err_type: Box::new(err_type.clone()),
974 })
975}