1#![forbid(unsafe_code)]
2
3use heck::ToUpperCamelCase;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9 parse_macro_input, punctuated::Punctuated, Expr, FnArg, ItemFn, Lit, Meta, MetaNameValue, Token,
10};
11
12#[proc_macro_attribute]
14pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
15 let args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
16 let function = parse_macro_input!(input as ItemFn);
17
18 expand_tool(args.into_iter().collect(), function)
19 .unwrap_or_else(|error| error.to_compile_error())
20 .into()
21}
22
23fn expand_tool(args: Vec<Meta>, function: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
24 let mut name_literal = None;
25 let mut description_literal = None;
26
27 for arg in args {
28 let Meta::NameValue(MetaNameValue { path, value, .. }) = arg else {
29 return Err(syn::Error::new_spanned(arg, "expected name = \"...\""));
30 };
31
32 let Expr::Lit(expr_lit) = value else {
33 return Err(syn::Error::new_spanned(value, "expected string literal"));
34 };
35 let Lit::Str(lit) = expr_lit.lit else {
36 return Err(syn::Error::new_spanned(expr_lit, "expected string literal"));
37 };
38
39 if path.is_ident("name") {
40 name_literal = Some(lit);
41 } else if path.is_ident("description") {
42 description_literal = Some(lit);
43 } else {
44 return Err(syn::Error::new_spanned(path, "unsupported tool attribute"));
45 }
46 }
47
48 let tool_name = name_literal
49 .ok_or_else(|| syn::Error::new_spanned(&function.sig.ident, "missing tool name"))?;
50 let tool_description = description_literal
51 .ok_or_else(|| syn::Error::new_spanned(&function.sig.ident, "missing tool description"))?;
52
53 if function.sig.asyncness.is_none() {
54 return Err(syn::Error::new_spanned(
55 &function.sig.fn_token,
56 "#[tool] requires an async function",
57 ));
58 }
59
60 let function_name = &function.sig.ident;
61 let visibility = &function.vis;
62 let tool_struct_name = format_ident!("{}Tool", function_name.to_string().to_upper_camel_case());
63
64 let inputs = function.sig.inputs.iter().collect::<Vec<_>>();
65 if inputs.is_empty() || inputs.len() > 2 {
66 return Err(syn::Error::new_spanned(
67 &function.sig.inputs,
68 "#[tool] expects one input argument and an optional context argument",
69 ));
70 }
71
72 let input_ty = match inputs[0] {
73 FnArg::Typed(argument) => &argument.ty,
74 FnArg::Receiver(receiver) => {
75 return Err(syn::Error::new_spanned(
76 receiver,
77 "tool functions cannot take self",
78 ))
79 }
80 };
81
82 let call_expr = if inputs.len() == 2 {
83 quote! {
84 let ctx = ::agentrs_tools::ToolContext::default();
85 #function_name(parsed_input, &ctx).await?
86 }
87 } else {
88 quote! {
89 #function_name(parsed_input).await?
90 }
91 };
92
93 Ok(quote! {
94 #function
95
96 #[derive(Debug, Clone, Default)]
97 #visibility struct #tool_struct_name;
98
99 impl #tool_struct_name {
100 #visibility fn new() -> Self {
102 Self
103 }
104 }
105
106 #[::agentrs_tools::__private::async_trait::async_trait]
107 impl ::agentrs_core::Tool for #tool_struct_name {
108 fn name(&self) -> &str {
109 #tool_name
110 }
111
112 fn description(&self) -> &str {
113 #tool_description
114 }
115
116 fn schema(&self) -> ::agentrs_tools::__private::serde_json::Value {
117 ::agentrs_tools::__private::serde_json::to_value(
118 &::agentrs_tools::__private::schemars::schema_for!(#input_ty)
119 )
120 .expect("tool schema should serialize")
121 }
122
123 async fn call(
124 &self,
125 input: ::agentrs_tools::__private::serde_json::Value,
126 ) -> ::agentrs_core::Result<::agentrs_core::ToolOutput> {
127 let parsed_input: #input_ty = ::agentrs_tools::__private::serde_json::from_value(input)?;
128 let output = #call_expr;
129 Ok(::agentrs_tools::IntoToolOutput::into_tool_output(output))
130 }
131 }
132 })
133}