1use convert_case::{Case, Casing};
11use proc_macro::TokenStream;
12use quote::{format_ident, quote};
13use syn::{
14 FnArg, GenericArgument, Ident, ItemFn, LitStr, Pat, PatType, PathArguments, Token, Type,
15 parse::ParseStream, parse_macro_input,
16};
17
18#[proc_macro_attribute]
92pub fn llm_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
93 let func = parse_macro_input!(item as ItemFn);
94 let tool_attr = if attr.is_empty() {
95 None
96 } else {
97 match syn::parse::<ToolAttr>(attr) {
98 Ok(parsed) => Some(parsed),
99 Err(err) => return err.to_compile_error().into(),
100 }
101 };
102 match tool_impl(&func, tool_attr.as_ref()) {
103 Ok(tokens) => tokens.into(),
104 Err(err) => err.to_compile_error().into(),
105 }
106}
107
108struct ToolAttr {
118 description_inline: Option<LitStr>,
120 template_path: Option<LitStr>,
122 #[cfg(feature = "prompt-templates")]
125 inline_params: Vec<(Ident, LitStr)>,
126 #[cfg(feature = "prompt-templates")]
128 context_fn: Option<syn::Path>,
129}
130
131impl syn::parse::Parse for ToolAttr {
132 fn parse(input: ParseStream) -> syn::Result<Self> {
133 let mut description_inline = None;
134 let mut template_path = None;
135 #[cfg(feature = "prompt-templates")]
136 let mut inline_params = Vec::new();
137 #[cfg(feature = "prompt-templates")]
138 let mut context_fn = None;
139 #[cfg(not(feature = "prompt-templates"))]
141 let mut has_inline_params = false;
142 #[cfg(not(feature = "prompt-templates"))]
143 let mut has_context_fn = false;
144
145 while !input.is_empty() {
146 let ident: Ident = input.parse()?;
147 if ident == "description" {
148 let _: Token![=] = input.parse()?;
149 description_inline = Some(input.parse::<LitStr>()?);
150 } else if ident == "template" {
151 let _: Token![=] = input.parse()?;
152 template_path = Some(input.parse::<LitStr>()?);
153 } else if ident == "params" {
154 let content;
155 syn::parenthesized!(content in input);
156 while !content.is_empty() {
157 let key: Ident = content.parse()?;
158 let _: Token![=] = content.parse()?;
159 let value: LitStr = content.parse()?;
160 #[cfg(feature = "prompt-templates")]
161 inline_params.push((key, value));
162 #[cfg(not(feature = "prompt-templates"))]
164 {
165 drop(key);
166 drop(value);
167 }
168 if !content.is_empty() {
169 let _: Token![,] = content.parse()?;
170 }
171 }
172 #[cfg(not(feature = "prompt-templates"))]
173 {
174 has_inline_params = true;
175 }
176 } else if ident == "context" {
177 let _: Token![=] = input.parse()?;
178 #[cfg(feature = "prompt-templates")]
179 {
180 context_fn = Some(input.parse::<syn::Path>()?);
181 }
182 #[cfg(not(feature = "prompt-templates"))]
183 {
184 let _path: syn::Path = input.parse()?;
185 has_context_fn = true;
186 }
187 } else {
188 return Err(syn::Error::new(
189 ident.span(),
190 "expected `description`, `template`, `params`, or `context`",
191 ));
192 }
193
194 if !input.is_empty() {
195 let _: Token![,] = input.parse()?;
196 }
197 }
198
199 #[cfg(feature = "prompt-templates")]
200 let (has_inline_params, has_context_fn) = (!inline_params.is_empty(), context_fn.is_some());
201
202 validate_tool_attr(
203 description_inline.as_ref(),
204 template_path.as_ref(),
205 has_inline_params,
206 has_context_fn,
207 )?;
208
209 Ok(Self {
210 description_inline,
211 template_path,
212 #[cfg(feature = "prompt-templates")]
213 inline_params,
214 #[cfg(feature = "prompt-templates")]
215 context_fn,
216 })
217 }
218}
219
220fn validate_tool_attr(
223 description_inline: Option<&LitStr>,
224 template_path: Option<&LitStr>,
225 has_inline_params: bool,
226 has_context_fn: bool,
227) -> syn::Result<()> {
228 if description_inline.is_some() && template_path.is_some() {
230 return Err(syn::Error::new(
231 proc_macro2::Span::call_site(),
232 "`description` and `template` are mutually exclusive",
233 ));
234 }
235
236 if template_path.is_none() && has_inline_params {
238 return Err(syn::Error::new(
239 proc_macro2::Span::call_site(),
240 "`params(...)` requires `template = \"...\"`",
241 ));
242 }
243 if template_path.is_none() && has_context_fn {
244 return Err(syn::Error::new(
245 proc_macro2::Span::call_site(),
246 "`context = ...` requires `template = \"...\"`",
247 ));
248 }
249
250 if has_inline_params && has_context_fn {
252 return Err(syn::Error::new(
253 proc_macro2::Span::call_site(),
254 "`params(...)` and `context = ...` are mutually exclusive; \
255 use `params` for compile-time values or `context` for runtime values",
256 ));
257 }
258
259 if description_inline.is_none() && template_path.is_none() {
261 return Err(syn::Error::new(
262 proc_macro2::Span::call_site(),
263 "expected `description = \"...\"` or `template = \"...\"`",
264 ));
265 }
266
267 Ok(())
268}
269
270struct ParamInfo {
274 name: syn::Ident,
275 ty: Box<syn::Type>,
276 doc_attrs: Vec<syn::Attribute>,
277 is_context: bool,
278}
279
280enum ReturnInfo {
282 ResultType {
284 ok_type: Box<syn::Type>,
285 err_type: Box<syn::Type>,
286 },
287 BareType,
289}
290
291fn tool_impl(func: &ItemFn, attr: Option<&ToolAttr>) -> syn::Result<proc_macro2::TokenStream> {
292 let crate_path = quote! { ::llm_tool };
293 let fn_name = &func.sig.ident;
294 let tool_name_str = fn_name.to_string();
295 let struct_name = format_ident!("{}", tool_name_str.to_case(Case::Pascal));
296 let params_name = format_ident!("{}Params", struct_name);
297
298 let DescriptionInfo {
300 static_description,
301 helper_tokens,
302 description_method,
303 dep_tracking,
304 } = resolve_description(func, attr)?;
305
306 let all_params = extract_params(func)?;
308 let ctx_param = all_params.iter().find(|p| p.is_context);
309 let params: Vec<&ParamInfo> = all_params.iter().filter(|p| !p.is_context).collect();
310
311 for param in ¶ms {
313 if param.doc_attrs.is_empty() {
314 return Err(syn::Error::new_spanned(
315 ¶m.name,
316 format!(
317 "#[llm_tool] parameter `{}` must have a doc comment \
318 (used as the parameter description in the JSON schema)",
319 param.name
320 ),
321 ));
322 }
323 }
324
325 let return_info = parse_return_type(func)?;
327
328 let param_names: Vec<_> = params.iter().map(|p| &p.name).collect();
329 let param_descriptions: Vec<String> = params
330 .iter()
331 .map(|p| extract_doc_string(&p.doc_attrs))
332 .collect();
333
334 let (param_struct_types, borrow_bindings) = build_param_types_and_borrows(¶ms);
335 let serde_defaults = build_serde_defaults(¶ms);
336 let body_tokens = build_body_tokens(func, &return_info, &crate_path);
337
338 let vis = &func.vis;
339
340 let params_doc = format!("Auto-generated parameters for the [`{struct_name}`] tool.");
341 let struct_doc = format!(
342 "Auto-generated tool struct. See the `#[llm_tool]`-annotated function `{fn_name}` for the implementation."
343 );
344
345 let ctx_binding = if let Some(cp) = ctx_param {
348 let ctx_name = &cp.name;
349 quote! { let #ctx_name = _ctx; }
350 } else {
351 quote! {}
352 };
353
354 Ok(quote! {
355 #dep_tracking
356 #helper_tokens
357
358 #[doc = #params_doc]
359 #[derive(::serde::Deserialize, ::schemars::JsonSchema)]
360 #vis struct #params_name {
361 #(
362 #[schemars(description = #param_descriptions)]
363 #serde_defaults
364 pub #param_names: #param_struct_types,
365 )*
366 }
367
368 #[doc = #struct_doc]
369 #vis struct #struct_name;
370
371 impl #crate_path::RustTool for #struct_name {
372 type Params = #params_name;
373 const NAME: &'static str = #tool_name_str;
374 const DESCRIPTION: &'static str = #static_description;
375
376 #description_method
377
378 async fn call(&self, params: Self::Params, _ctx: &#crate_path::ToolContext) -> ::std::result::Result<#crate_path::ToolOutput, #crate_path::ToolError> {
379 use #crate_path::__private::SerializeFallback as _;
382 let #params_name { #( #param_names, )* } = params;
385 #( #borrow_bindings )*
387 #ctx_binding
388 #body_tokens
389 }
390 }
391 })
392}
393
394struct DescriptionInfo {
398 static_description: String,
400 helper_tokens: proc_macro2::TokenStream,
402 description_method: Option<proc_macro2::TokenStream>,
404 dep_tracking: proc_macro2::TokenStream,
406}
407
408fn resolve_description(func: &ItemFn, attr: Option<&ToolAttr>) -> syn::Result<DescriptionInfo> {
410 match attr {
411 None => {
413 let desc = extract_doc_string(&func.attrs);
414 if desc.is_empty() {
415 return Err(syn::Error::new_spanned(
416 &func.sig.ident,
417 "#[llm_tool] functions must have a doc comment \
418 (used as the tool description), or use \
419 #[llm_tool(description = \"...\")]",
420 ));
421 }
422 Ok(DescriptionInfo {
423 static_description: desc,
424 helper_tokens: quote! {},
425 description_method: None,
426 dep_tracking: quote! {},
427 })
428 }
429 Some(ToolAttr {
431 description_inline: Some(desc),
432 ..
433 }) => Ok(DescriptionInfo {
434 static_description: desc.value(),
435 helper_tokens: quote! {},
436 description_method: None,
437 dep_tracking: quote! {},
438 }),
439 Some(
441 tool_attr @ ToolAttr {
442 template_path: Some(_),
443 ..
444 },
445 ) => resolve_template_description(tool_attr),
446 _ => Err(syn::Error::new(
448 proc_macro2::Span::call_site(),
449 "expected `description = \"...\"` or `template = \"...\"`",
450 )),
451 }
452}
453
454fn resolve_template_description(attr: &ToolAttr) -> syn::Result<DescriptionInfo> {
463 #[cfg(not(feature = "prompt-templates"))]
464 {
465 let span = attr
466 .template_path
467 .as_ref()
468 .map_or(proc_macro2::Span::call_site(), LitStr::span);
469 Err(syn::Error::new(
470 span,
471 "the `prompt-templates` feature must be enabled to use \
472 `#[llm_tool(template = \"...\")]`. \
473 Add `features = [\"prompt-templates\"]` to your llm-tool dependency.",
474 ))
475 }
476
477 #[cfg(feature = "prompt-templates")]
478 resolve_template_description_impl(attr)
479}
480
481#[cfg(feature = "prompt-templates")]
488fn resolve_template_description_impl(attr: &ToolAttr) -> syn::Result<DescriptionInfo> {
489 let template_lit = attr
490 .template_path
491 .as_ref()
492 .expect("template_path validated");
493 let rel_path = template_lit.value();
494 let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".to_string());
495 let full_path = std::path::Path::new(&manifest_dir).join(&rel_path);
496
497 let source = std::fs::read_to_string(&full_path).map_err(|e| {
498 syn::Error::new(
499 template_lit.span(),
500 format!("failed to read template '{}': {e}", full_path.display()),
501 )
502 })?;
503
504 let (fm, body) = prompt_templates::parse_frontmatter(&source).map_err(|e| {
505 syn::Error::new(
506 template_lit.span(),
507 format!("template '{rel_path}' error: {e}"),
508 )
509 })?;
510
511 let body_str = body.trim().to_string();
512 let path_str = full_path.to_string_lossy().to_string();
513
514 let dep_tracking = quote! {
517 const _: &str = include_str!(#path_str);
518 };
519
520 let has_params = !attr.inline_params.is_empty();
521 let has_context = attr.context_fn.is_some();
522 let has_declarations = !fm.declarations.is_empty();
523
524 if !has_declarations && !has_params && !has_context {
525 Ok(DescriptionInfo {
527 static_description: body_str,
528 helper_tokens: quote! {},
529 description_method: None,
530 dep_tracking,
531 })
532 } else if has_params {
533 resolve_template_with_params(
535 attr,
536 &fm,
537 &source,
538 &rel_path,
539 template_lit.span(),
540 dep_tracking,
541 )
542 } else if has_context {
543 resolve_context_description(
545 attr,
546 &rel_path,
547 template_lit,
548 &body_str,
549 &path_str,
550 has_declarations,
551 dep_tracking,
552 )
553 } else {
554 let declared: Vec<&str> = fm.declarations.iter().map(|d| d.name.as_str()).collect();
556 Err(syn::Error::new(
557 template_lit.span(),
558 format!(
559 "template '{rel_path}' declares parameters ({}) but neither \
560 `params(...)` nor `context = ...` was provided",
561 declared.join(", ")
562 ),
563 ))
564 }
565}
566
567#[cfg(feature = "prompt-templates")]
573fn resolve_context_description(
574 attr: &ToolAttr,
575 rel_path: &str,
576 template_lit: &LitStr,
577 body_str: &str,
578 path_str: &str,
579 has_declarations: bool,
580 dep_tracking: proc_macro2::TokenStream,
581) -> syn::Result<DescriptionInfo> {
582 let context_fn = attr.context_fn.as_ref().ok_or_else(|| {
583 syn::Error::new(
584 template_lit.span(),
585 "internal error: resolve_context_description called without context_fn",
586 )
587 })?;
588
589 if !has_declarations {
590 return Err(syn::Error::new(
591 template_lit.span(),
592 format!(
593 "template '{rel_path}' has no declared parameters, \
594 so `context = ...` is unnecessary. Remove `context` \
595 or add params to the template."
596 ),
597 ));
598 }
599
600 let description_method = quote! {
603 fn description(&self) -> ::std::borrow::Cow<'static, str> {
604 static TEMPLATE: ::std::sync::LazyLock<::prompt_templates::Template> =
605 ::std::sync::LazyLock::new(|| {
606 ::prompt_templates::Template::from_source(
607 include_str!(#path_str)
608 ).expect("Valid template (verified at compile time)")
609 });
610 let ctx = #context_fn(self);
611 let rendered = TEMPLATE.render(&ctx)
612 .expect("Failed to render tool description template");
613 ::std::borrow::Cow::Owned(rendered)
614 }
615 };
616
617 Ok(DescriptionInfo {
618 static_description: body_str.to_owned(),
619 helper_tokens: quote! {},
620 description_method: Some(description_method),
621 dep_tracking,
622 })
623}
624
625#[cfg(feature = "prompt-templates")]
632fn resolve_template_with_params(
633 attr: &ToolAttr,
634 fm: &prompt_templates::Frontmatter,
635 source: &str,
636 rel_path: &str,
637 span: proc_macro2::Span,
638 dep_tracking: proc_macro2::TokenStream,
639) -> syn::Result<DescriptionInfo> {
640 let declared_names: std::collections::HashSet<&str> =
641 fm.declarations.iter().map(|d| d.name.as_str()).collect();
642 let provided_names: std::collections::HashSet<String> = attr
643 .inline_params
644 .iter()
645 .map(|(k, _)| k.to_string())
646 .collect();
647
648 let missing: Vec<&str> = declared_names
650 .iter()
651 .filter(|n| !provided_names.contains(**n))
652 .copied()
653 .collect();
654 if !missing.is_empty() {
655 return Err(syn::Error::new(
656 span,
657 format!(
658 "template '{rel_path}' declares parameters not provided in `params(...)`: {}",
659 missing.join(", ")
660 ),
661 ));
662 }
663
664 for (key, _) in &attr.inline_params {
666 let key_str = key.to_string();
667 if !declared_names.contains(key_str.as_str()) {
668 return Err(syn::Error::new(
669 key.span(),
670 format!(
671 "param `{key_str}` is not declared in template '{rel_path}'. \
672 Declared params: {}",
673 declared_names.into_iter().collect::<Vec<_>>().join(", ")
674 ),
675 ));
676 }
677 }
678
679 let template = prompt_templates::Template::from_source(source)
681 .map_err(|e| syn::Error::new(span, format!("template '{rel_path}' parse error: {e}")))?;
682
683 let mut ctx = prompt_templates::Context::new();
684 for (key, value) in &attr.inline_params {
685 ctx.set(key.to_string(), value.value());
686 }
687
688 let rendered = template
689 .render(&ctx)
690 .map_err(|e| syn::Error::new(span, format!("template '{rel_path}' render error: {e}")))?;
691
692 Ok(DescriptionInfo {
693 static_description: rendered,
694 helper_tokens: quote! {},
695 description_method: None,
696 dep_tracking,
697 })
698}
699
700fn build_param_types_and_borrows(
702 params: &[&ParamInfo],
703) -> (Vec<proc_macro2::TokenStream>, Vec<proc_macro2::TokenStream>) {
704 params
705 .iter()
706 .map(|p| {
707 if is_str_ref(&p.ty) {
708 let name = &p.name;
710 (quote! { String }, quote! { let #name: &str = &#name; })
711 } else {
712 let ty = &p.ty;
713 (quote! { #ty }, quote! {})
714 }
715 })
716 .unzip()
717}
718
719fn build_serde_defaults(params: &[&ParamInfo]) -> Vec<proc_macro2::TokenStream> {
721 params
722 .iter()
723 .map(|p| {
724 if is_option_type(&p.ty) {
725 quote! { #[serde(default)] }
726 } else {
727 quote! {}
728 }
729 })
730 .collect()
731}
732
733fn build_body_tokens(
740 func: &ItemFn,
741 return_info: &ReturnInfo,
742 crate_path: &proc_macro2::TokenStream,
743) -> proc_macro2::TokenStream {
744 let is_async = func.sig.asyncness.is_some();
745 let body_stmts = &func.block.stmts;
746
747 match return_info {
748 ReturnInfo::ResultType { ok_type, err_type } => {
749 let inner = if is_async {
750 quote! {
751 let __r: ::std::result::Result<#ok_type, #err_type> = async move {
752 #( #body_stmts )*
753 }.await;
754 }
755 } else {
756 quote! {
757 let __r: ::std::result::Result<#ok_type, #err_type> = (|| { #( #body_stmts )* })();
758 }
759 };
760 quote! {
761 #inner
762 match __r {
763 ::std::result::Result::Ok(__v) => #crate_path::__private::Wrap(__v).__convert(),
764 ::std::result::Result::Err(__e) => ::std::result::Result::Err(::std::convert::Into::into(__e)),
765 }
766 }
767 }
768 ReturnInfo::BareType => {
769 let inner = if is_async {
770 quote! {
771 let __v = async move { #( #body_stmts )* }.await;
772 }
773 } else {
774 quote! {
775 let __v = (|| { #( #body_stmts )* })();
776 }
777 };
778 quote! {
779 #inner
780 #crate_path::__private::Wrap(__v).__convert()
781 }
782 }
783 }
784}
785
786fn is_option_type(ty: &syn::Type) -> bool {
788 let Type::Path(type_path) = ty else {
789 return false;
790 };
791 let Some(last_seg) = type_path.path.segments.last() else {
792 return false;
793 };
794 if last_seg.ident != "Option" {
795 return false;
796 }
797 matches!(&last_seg.arguments, PathArguments::AngleBracketed(args)
798 if args.args.len() == 1
799 && matches!(args.args.first(), Some(GenericArgument::Type(_))))
800}
801
802fn is_tool_context_type(ty: &syn::Type) -> bool {
805 let inner = match ty {
806 Type::Reference(r) => r.elem.as_ref(),
807 other => other,
808 };
809 let Type::Path(type_path) = inner else {
810 return false;
811 };
812 type_path
813 .path
814 .segments
815 .last()
816 .is_some_and(|seg| seg.ident == "ToolContext")
817}
818
819fn is_str_ref(ty: &syn::Type) -> bool {
821 let Type::Reference(ref_type) = ty else {
822 return false;
823 };
824 if ref_type.mutability.is_some() {
825 return false;
826 }
827 let Type::Path(type_path) = ref_type.elem.as_ref() else {
828 return false;
829 };
830 type_path
831 .path
832 .segments
833 .last()
834 .is_some_and(|seg| seg.ident == "str" && seg.arguments.is_none())
835}
836
837fn is_explicit_context_attr(attr: &syn::Attribute) -> syn::Result<bool> {
838 if !attr.path().is_ident("llm_tool") {
839 return Ok(false);
840 }
841 let mut is_context = false;
842 attr.parse_nested_meta(|meta| {
843 if meta.path.is_ident("context") {
844 is_context = true;
845 Ok(())
846 } else {
847 Err(meta.error("unsupported llm_tool attribute"))
848 }
849 })?;
850 Ok(is_context)
851}
852
853fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
854 let mut params = Vec::new();
855 for arg in &func.sig.inputs {
856 match arg {
857 FnArg::Receiver(r) => {
858 return Err(syn::Error::new_spanned(
859 r,
860 "#[llm_tool] functions must be free functions (no `self`)",
861 ));
862 }
863 FnArg::Typed(PatType { pat, ty, attrs, .. }) => {
864 let name = match pat.as_ref() {
865 Pat::Ident(ident) => ident.ident.clone(),
866 other => {
867 return Err(syn::Error::new_spanned(
868 other,
869 "#[llm_tool] parameters must be simple identifiers",
870 ));
871 }
872 };
873
874 let mut has_context_attr = false;
875 for a in attrs {
876 has_context_attr |= is_explicit_context_attr(a)?;
877 }
878 let is_tool_context = is_tool_context_type(ty);
879 let is_context = has_context_attr || is_tool_context;
880
881 if is_tool_context && !matches!(ty.as_ref(), syn::Type::Reference(_)) {
882 return Err(syn::Error::new_spanned(
883 ty,
884 "ToolContext parameter must be a reference type (e.g., `&ToolContext` or `&'a ToolContext`)",
885 ));
886 }
887
888 let doc_attrs: Vec<syn::Attribute> = attrs
889 .iter()
890 .filter(|a| a.path().is_ident("doc"))
891 .cloned()
892 .collect();
893 params.push(ParamInfo {
894 name,
895 ty: ty.clone(),
896 doc_attrs,
897 is_context,
898 });
899 }
900 }
901 }
902 Ok(params)
903}
904
905fn extract_doc_string(attrs: &[syn::Attribute]) -> String {
906 let lines: Vec<String> = attrs
907 .iter()
908 .filter_map(|attr| {
909 if !attr.path().is_ident("doc") {
910 return None;
911 }
912 if let syn::Meta::NameValue(nv) = &attr.meta
913 && let syn::Expr::Lit(lit) = &nv.value
914 && let syn::Lit::Str(s) = &lit.lit
915 {
916 return Some(s.value());
917 }
918 None
919 })
920 .collect();
921 lines
922 .iter()
923 .map(|l| l.trim())
924 .collect::<Vec<_>>()
925 .join("\n")
926 .trim()
927 .to_string()
928}
929
930fn parse_return_type(func: &ItemFn) -> syn::Result<ReturnInfo> {
932 let syn::ReturnType::Type(_, ty) = &func.sig.output else {
933 return Err(syn::Error::new_spanned(
934 &func.sig,
935 "#[llm_tool] functions must have an explicit return type",
936 ));
937 };
938
939 if let Some(result_types) = try_extract_result_types(ty) {
941 return Ok(result_types);
942 }
943
944 Ok(ReturnInfo::BareType)
946}
947
948fn try_extract_result_types(ty: &syn::Type) -> Option<ReturnInfo> {
951 let Type::Path(type_path) = ty else {
952 return None;
953 };
954
955 let last_seg = type_path.path.segments.last()?;
956
957 if last_seg.ident != "Result" {
958 return None;
959 }
960
961 let PathArguments::AngleBracketed(args) = &last_seg.arguments else {
962 return None;
963 };
964
965 if args.args.len() != 2 {
966 return None;
967 }
968
969 let GenericArgument::Type(ok_type) = &args.args[0] else {
970 return None;
971 };
972
973 let GenericArgument::Type(err_type) = &args.args[1] else {
974 return None;
975 };
976
977 Some(ReturnInfo::ResultType {
978 ok_type: Box::new(ok_type.clone()),
979 err_type: Box::new(err_type.clone()),
980 })
981}