neuron_tool_macros/
lib.rs1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
6
7#[proc_macro_attribute]
23pub fn neuron_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
24 let args = syn::parse_macro_input!(attr as AgentToolArgs);
25 let func = syn::parse_macro_input!(item as ItemFn);
26
27 match expand_neuron_tool(args, func) {
28 Ok(tokens) => tokens.into(),
29 Err(err) => err.to_compile_error().into(),
30 }
31}
32
33struct AgentToolArgs {
35 name: String,
36 description: String,
37}
38
39impl syn::parse::Parse for AgentToolArgs {
40 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41 let mut name = None;
42 let mut description = None;
43
44 while !input.is_empty() {
45 let ident: syn::Ident = input.parse()?;
46 let _: syn::Token![=] = input.parse()?;
47 let value: syn::LitStr = input.parse()?;
48
49 match ident.to_string().as_str() {
50 "name" => name = Some(value.value()),
51 "description" => description = Some(value.value()),
52 other => {
53 return Err(syn::Error::new(
54 ident.span(),
55 format!("unknown attribute: {other}"),
56 ));
57 }
58 }
59
60 if !input.is_empty() {
61 let _: syn::Token![,] = input.parse()?;
62 }
63 }
64
65 Ok(AgentToolArgs {
66 name: name.ok_or_else(|| input.error("missing `name` attribute"))?,
67 description: description.ok_or_else(|| input.error("missing `description` attribute"))?,
68 })
69 }
70}
71
72fn to_pascal_case(s: &str) -> String {
73 s.split('_')
74 .map(|word| {
75 let mut chars = word.chars();
76 match chars.next() {
77 None => String::new(),
78 Some(c) => c.to_uppercase().to_string() + &chars.collect::<String>(),
79 }
80 })
81 .collect()
82}
83
84fn expand_neuron_tool(
85 args: AgentToolArgs,
86 func: ItemFn,
87) -> syn::Result<proc_macro2::TokenStream> {
88 let func_name = &func.sig.ident;
89 let vis = &func.vis;
90 let pascal = to_pascal_case(&func_name.to_string());
91 let tool_struct = format_ident!("{}Tool", pascal);
92 let args_struct = format_ident!("{}Args", pascal);
93
94 let tool_name = &args.name;
95 let tool_description = &args.description;
96
97 let params: Vec<_> = func.sig.inputs.iter().collect();
99 if params.is_empty() {
100 return Err(syn::Error::new_spanned(
101 &func.sig,
102 "function must have at least a ctx parameter",
103 ));
104 }
105
106 let tool_params = ¶ms[..params.len() - 1]; let mut field_names = Vec::new();
110 let mut field_types = Vec::new();
111 let mut field_docs = Vec::new();
112
113 for param in tool_params {
114 match param {
115 FnArg::Typed(pat_type) => {
116 let name = match pat_type.pat.as_ref() {
117 Pat::Ident(ident) => &ident.ident,
118 _ => {
119 return Err(syn::Error::new_spanned(
120 pat_type,
121 "expected identifier pattern",
122 ));
123 }
124 };
125 let ty = &pat_type.ty;
126
127 let docs: Vec<_> = pat_type
129 .attrs
130 .iter()
131 .filter(|a| a.path().is_ident("doc"))
132 .cloned()
133 .collect();
134
135 field_names.push(name.clone());
136 field_types.push(ty.clone());
137 field_docs.push(docs);
138 }
139 FnArg::Receiver(_) => {
140 return Err(syn::Error::new_spanned(
141 param,
142 "self parameter not supported",
143 ));
144 }
145 }
146 }
147
148 let (output_type, error_type) = match &func.sig.output {
150 ReturnType::Type(_, ty) => extract_result_types(ty)?,
151 ReturnType::Default => {
152 return Err(syn::Error::new_spanned(
153 &func.sig,
154 "function must return Result<Output, Error>",
155 ));
156 }
157 };
158
159 let body = &func.block;
161
162 let field_defs: Vec<_> = field_names
164 .iter()
165 .zip(field_types.iter())
166 .zip(field_docs.iter())
167 .map(|((name, ty), docs)| {
168 quote! {
169 #(#docs)*
170 pub #name: #ty
171 }
172 })
173 .collect();
174
175 let destructure_fields: Vec<_> = field_names.iter().map(|name| quote! { #name }).collect();
177
178 Ok(quote! {
179 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
181 #vis struct #args_struct {
182 #(#field_defs,)*
183 }
184
185 #vis struct #tool_struct;
187
188 impl neuron_types::Tool for #tool_struct {
189 const NAME: &'static str = #tool_name;
190 type Args = #args_struct;
191 type Output = #output_type;
192 type Error = #error_type;
193
194 fn definition(&self) -> neuron_types::ToolDefinition {
195 neuron_types::ToolDefinition {
196 name: Self::NAME.into(),
197 title: None,
198 description: #tool_description.into(),
199 input_schema: serde_json::to_value(
200 schemars::schema_for!(#args_struct)
201 ).unwrap(),
202 output_schema: None,
203 annotations: None,
204 cache_control: None,
205 }
206 }
207
208 async fn call(
209 &self,
210 args: Self::Args,
211 ctx: &neuron_types::ToolContext,
212 ) -> Result<Self::Output, Self::Error> {
213 let #args_struct { #(#destructure_fields,)* } = args;
214 let _ = &ctx;
216 #body
217 }
218 }
219 })
220}
221
222fn extract_result_types(ty: &Type) -> syn::Result<(Box<Type>, Box<Type>)> {
223 if let Type::Path(type_path) = ty {
224 let last_segment = type_path
225 .path
226 .segments
227 .last()
228 .ok_or_else(|| syn::Error::new_spanned(ty, "expected Result type"))?;
229
230 if last_segment.ident != "Result" {
231 return Err(syn::Error::new_spanned(
232 ty,
233 "return type must be Result<Output, Error>",
234 ));
235 }
236
237 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
238 let mut types = args.args.iter().filter_map(|arg| {
239 if let syn::GenericArgument::Type(t) = arg {
240 Some(t.clone())
241 } else {
242 None
243 }
244 });
245
246 let output = types
247 .next()
248 .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Output type"))?;
249 let error = types
250 .next()
251 .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Error type"))?;
252
253 return Ok((Box::new(output), Box::new(error)));
254 }
255 }
256
257 Err(syn::Error::new_spanned(
258 ty,
259 "return type must be Result<Output, Error>",
260 ))
261}