neuron_tool_macros/
lib.rs1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
6
7#[proc_macro_attribute]
23pub fn neuron_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
24 let args = syn::parse_macro_input!(attr as AgentToolArgs);
25 let func = syn::parse_macro_input!(item as ItemFn);
26
27 match expand_neuron_tool(args, func) {
28 Ok(tokens) => tokens.into(),
29 Err(err) => err.to_compile_error().into(),
30 }
31}
32
33struct AgentToolArgs {
35 name: String,
36 description: String,
37}
38
39impl syn::parse::Parse for AgentToolArgs {
40 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41 let mut name = None;
42 let mut description = None;
43
44 while !input.is_empty() {
45 let ident: syn::Ident = input.parse()?;
46 let _: syn::Token![=] = input.parse()?;
47 let value: syn::LitStr = input.parse()?;
48
49 match ident.to_string().as_str() {
50 "name" => name = Some(value.value()),
51 "description" => description = Some(value.value()),
52 other => {
53 return Err(syn::Error::new(
54 ident.span(),
55 format!("unknown attribute: {other}"),
56 ));
57 }
58 }
59
60 if !input.is_empty() {
61 let _: syn::Token![,] = input.parse()?;
62 }
63 }
64
65 Ok(AgentToolArgs {
66 name: name.ok_or_else(|| input.error("missing `name` attribute"))?,
67 description: description
68 .ok_or_else(|| input.error("missing `description` attribute"))?,
69 })
70 }
71}
72
73fn to_pascal_case(s: &str) -> String {
74 s.split('_')
75 .map(|word| {
76 let mut chars = word.chars();
77 match chars.next() {
78 None => String::new(),
79 Some(c) => c.to_uppercase().to_string() + &chars.collect::<String>(),
80 }
81 })
82 .collect()
83}
84
85fn expand_neuron_tool(args: AgentToolArgs, func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
86 let func_name = &func.sig.ident;
87 let vis = &func.vis;
88 let pascal = to_pascal_case(&func_name.to_string());
89 let tool_struct = format_ident!("{}Tool", pascal);
90 let args_struct = format_ident!("{}Args", pascal);
91
92 let tool_name = &args.name;
93 let tool_description = &args.description;
94
95 let params: Vec<_> = func.sig.inputs.iter().collect();
97 if params.is_empty() {
98 return Err(syn::Error::new_spanned(
99 &func.sig,
100 "function must have at least a ctx parameter",
101 ));
102 }
103
104 let tool_params = ¶ms[..params.len() - 1]; let mut field_names = Vec::new();
108 let mut field_types = Vec::new();
109 let mut field_docs = Vec::new();
110
111 for param in tool_params {
112 match param {
113 FnArg::Typed(pat_type) => {
114 let name = match pat_type.pat.as_ref() {
115 Pat::Ident(ident) => &ident.ident,
116 _ => {
117 return Err(syn::Error::new_spanned(
118 pat_type,
119 "expected identifier pattern",
120 ));
121 }
122 };
123 let ty = &pat_type.ty;
124
125 let docs: Vec<_> = pat_type
127 .attrs
128 .iter()
129 .filter(|a| a.path().is_ident("doc"))
130 .cloned()
131 .collect();
132
133 field_names.push(name.clone());
134 field_types.push(ty.clone());
135 field_docs.push(docs);
136 }
137 FnArg::Receiver(_) => {
138 return Err(syn::Error::new_spanned(
139 param,
140 "self parameter not supported",
141 ));
142 }
143 }
144 }
145
146 let (output_type, error_type) = match &func.sig.output {
148 ReturnType::Type(_, ty) => extract_result_types(ty)?,
149 ReturnType::Default => {
150 return Err(syn::Error::new_spanned(
151 &func.sig,
152 "function must return Result<Output, Error>",
153 ));
154 }
155 };
156
157 let body = &func.block;
159
160 let field_defs: Vec<_> = field_names
162 .iter()
163 .zip(field_types.iter())
164 .zip(field_docs.iter())
165 .map(|((name, ty), docs)| {
166 quote! {
167 #(#docs)*
168 pub #name: #ty
169 }
170 })
171 .collect();
172
173 let destructure_fields: Vec<_> = field_names.iter().map(|name| quote! { #name }).collect();
175
176 Ok(quote! {
177 #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
179 #vis struct #args_struct {
180 #(#field_defs,)*
181 }
182
183 #vis struct #tool_struct;
185
186 impl neuron_types::Tool for #tool_struct {
187 const NAME: &'static str = #tool_name;
188 type Args = #args_struct;
189 type Output = #output_type;
190 type Error = #error_type;
191
192 fn definition(&self) -> neuron_types::ToolDefinition {
193 neuron_types::ToolDefinition {
194 name: Self::NAME.into(),
195 title: None,
196 description: #tool_description.into(),
197 input_schema: serde_json::to_value(
198 schemars::schema_for!(#args_struct)
199 ).unwrap(),
200 output_schema: None,
201 annotations: None,
202 cache_control: None,
203 }
204 }
205
206 async fn call(
207 &self,
208 args: Self::Args,
209 ctx: &neuron_types::ToolContext,
210 ) -> Result<Self::Output, Self::Error> {
211 let #args_struct { #(#destructure_fields,)* } = args;
212 let _ = &ctx;
214 #body
215 }
216 }
217 })
218}
219
220fn extract_result_types(ty: &Type) -> syn::Result<(Box<Type>, Box<Type>)> {
221 if let Type::Path(type_path) = ty {
222 let last_segment = type_path
223 .path
224 .segments
225 .last()
226 .ok_or_else(|| syn::Error::new_spanned(ty, "expected Result type"))?;
227
228 if last_segment.ident != "Result" {
229 return Err(syn::Error::new_spanned(
230 ty,
231 "return type must be Result<Output, Error>",
232 ));
233 }
234
235 if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
236 let mut types = args.args.iter().filter_map(|arg| {
237 if let syn::GenericArgument::Type(t) = arg {
238 Some(t.clone())
239 } else {
240 None
241 }
242 });
243
244 let output = types
245 .next()
246 .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Output type"))?;
247 let error = types
248 .next()
249 .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Error type"))?;
250
251 return Ok((Box::new(output), Box::new(error)));
252 }
253 }
254
255 Err(syn::Error::new_spanned(
256 ty,
257 "return type must be Result<Output, Error>",
258 ))
259}