Skip to main content

agentrs_macros/
lib.rs

1#![forbid(unsafe_code)]
2
3//! Procedural macros for defining `agentrs` tools ergonomically.
4
5use heck::ToUpperCamelCase;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::{
9    parse_macro_input, punctuated::Punctuated, Expr, FnArg, ItemFn, Lit, Meta, MetaNameValue, Token,
10};
11
12/// Attribute macro that turns an async function into a zero-cost tool wrapper.
13#[proc_macro_attribute]
14pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
15    let args = parse_macro_input!(args with Punctuated::<Meta, Token![,]>::parse_terminated);
16    let function = parse_macro_input!(input as ItemFn);
17
18    expand_tool(args.into_iter().collect(), function)
19        .unwrap_or_else(|error| error.to_compile_error())
20        .into()
21}
22
23fn expand_tool(args: Vec<Meta>, function: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
24    let mut name_literal = None;
25    let mut description_literal = None;
26
27    for arg in args {
28        let Meta::NameValue(MetaNameValue { path, value, .. }) = arg else {
29            return Err(syn::Error::new_spanned(arg, "expected name = \"...\""));
30        };
31
32        let Expr::Lit(expr_lit) = value else {
33            return Err(syn::Error::new_spanned(value, "expected string literal"));
34        };
35        let Lit::Str(lit) = expr_lit.lit else {
36            return Err(syn::Error::new_spanned(expr_lit, "expected string literal"));
37        };
38
39        if path.is_ident("name") {
40            name_literal = Some(lit);
41        } else if path.is_ident("description") {
42            description_literal = Some(lit);
43        } else {
44            return Err(syn::Error::new_spanned(path, "unsupported tool attribute"));
45        }
46    }
47
48    let tool_name = name_literal
49        .ok_or_else(|| syn::Error::new_spanned(&function.sig.ident, "missing tool name"))?;
50    let tool_description = description_literal
51        .ok_or_else(|| syn::Error::new_spanned(&function.sig.ident, "missing tool description"))?;
52
53    if function.sig.asyncness.is_none() {
54        return Err(syn::Error::new_spanned(
55            &function.sig.fn_token,
56            "#[tool] requires an async function",
57        ));
58    }
59
60    let function_name = &function.sig.ident;
61    let visibility = &function.vis;
62    let tool_struct_name = format_ident!("{}Tool", function_name.to_string().to_upper_camel_case());
63
64    let inputs = function.sig.inputs.iter().collect::<Vec<_>>();
65    if inputs.is_empty() || inputs.len() > 2 {
66        return Err(syn::Error::new_spanned(
67            &function.sig.inputs,
68            "#[tool] expects one input argument and an optional context argument",
69        ));
70    }
71
72    let input_ty = match inputs[0] {
73        FnArg::Typed(argument) => &argument.ty,
74        FnArg::Receiver(receiver) => {
75            return Err(syn::Error::new_spanned(
76                receiver,
77                "tool functions cannot take self",
78            ))
79        }
80    };
81
82    let call_expr = if inputs.len() == 2 {
83        quote! {
84            let ctx = ::agentrs_tools::ToolContext::default();
85            #function_name(parsed_input, &ctx).await?
86        }
87    } else {
88        quote! {
89            #function_name(parsed_input).await?
90        }
91    };
92
93    Ok(quote! {
94        #function
95
96        #[derive(Debug, Clone, Default)]
97        #visibility struct #tool_struct_name;
98
99        impl #tool_struct_name {
100            /// Creates the generated tool wrapper.
101            #visibility fn new() -> Self {
102                Self
103            }
104        }
105
106        #[::agentrs_tools::__private::async_trait::async_trait]
107        impl ::agentrs_core::Tool for #tool_struct_name {
108            fn name(&self) -> &str {
109                #tool_name
110            }
111
112            fn description(&self) -> &str {
113                #tool_description
114            }
115
116            fn schema(&self) -> ::agentrs_tools::__private::serde_json::Value {
117                ::agentrs_tools::__private::serde_json::to_value(
118                    &::agentrs_tools::__private::schemars::schema_for!(#input_ty)
119                )
120                .expect("tool schema should serialize")
121            }
122
123            async fn call(
124                &self,
125                input: ::agentrs_tools::__private::serde_json::Value,
126            ) -> ::agentrs_core::Result<::agentrs_core::ToolOutput> {
127                let parsed_input: #input_ty = ::agentrs_tools::__private::serde_json::from_value(input)?;
128                let output = #call_expr;
129                Ok(::agentrs_tools::IntoToolOutput::into_tool_output(output))
130            }
131        }
132    })
133}