Skip to main content

neuron_tool_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
6
7/// Derive a Tool implementation from an async function.
8///
9/// # Example
10///
11/// ```ignore
12/// #[neuron_tool(name = "calculate", description = "Evaluate a math expression")]
13/// async fn calculate(
14///     /// A mathematical expression like "2 + 2"
15///     expression: String,
16///     _ctx: &ToolContext,
17/// ) -> Result<CalculateOutput, CalculateError> {
18///     let result = eval(&expression);
19///     Ok(CalculateOutput { result })
20/// }
21/// ```
22#[proc_macro_attribute]
23pub fn neuron_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
24    let args = syn::parse_macro_input!(attr as AgentToolArgs);
25    let func = syn::parse_macro_input!(item as ItemFn);
26
27    match expand_neuron_tool(args, func) {
28        Ok(tokens) => tokens.into(),
29        Err(err) => err.to_compile_error().into(),
30    }
31}
32
33// Parse the attribute args: name = "...", description = "..."
34struct AgentToolArgs {
35    name: String,
36    description: String,
37}
38
39impl syn::parse::Parse for AgentToolArgs {
40    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41        let mut name = None;
42        let mut description = None;
43
44        while !input.is_empty() {
45            let ident: syn::Ident = input.parse()?;
46            let _: syn::Token![=] = input.parse()?;
47            let value: syn::LitStr = input.parse()?;
48
49            match ident.to_string().as_str() {
50                "name" => name = Some(value.value()),
51                "description" => description = Some(value.value()),
52                other => {
53                    return Err(syn::Error::new(
54                        ident.span(),
55                        format!("unknown attribute: {other}"),
56                    ));
57                }
58            }
59
60            if !input.is_empty() {
61                let _: syn::Token![,] = input.parse()?;
62            }
63        }
64
65        Ok(AgentToolArgs {
66            name: name.ok_or_else(|| input.error("missing `name` attribute"))?,
67            description: description
68                .ok_or_else(|| input.error("missing `description` attribute"))?,
69        })
70    }
71}
72
73fn to_pascal_case(s: &str) -> String {
74    s.split('_')
75        .map(|word| {
76            let mut chars = word.chars();
77            match chars.next() {
78                None => String::new(),
79                Some(c) => c.to_uppercase().to_string() + &chars.collect::<String>(),
80            }
81        })
82        .collect()
83}
84
85fn expand_neuron_tool(args: AgentToolArgs, func: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
86    let func_name = &func.sig.ident;
87    let vis = &func.vis;
88    let pascal = to_pascal_case(&func_name.to_string());
89    let tool_struct = format_ident!("{}Tool", pascal);
90    let args_struct = format_ident!("{}Args", pascal);
91
92    let tool_name = &args.name;
93    let tool_description = &args.description;
94
95    // Extract parameters (skip last one which is &ToolContext)
96    let params: Vec<_> = func.sig.inputs.iter().collect();
97    if params.is_empty() {
98        return Err(syn::Error::new_spanned(
99            &func.sig,
100            "function must have at least a ctx parameter",
101        ));
102    }
103
104    let tool_params = &params[..params.len() - 1]; // All except last (ctx)
105
106    // Build Args struct fields
107    let mut field_names = Vec::new();
108    let mut field_types = Vec::new();
109    let mut field_docs = Vec::new();
110
111    for param in tool_params {
112        match param {
113            FnArg::Typed(pat_type) => {
114                let name = match pat_type.pat.as_ref() {
115                    Pat::Ident(ident) => &ident.ident,
116                    _ => {
117                        return Err(syn::Error::new_spanned(
118                            pat_type,
119                            "expected identifier pattern",
120                        ));
121                    }
122                };
123                let ty = &pat_type.ty;
124
125                // Extract doc comments from attributes
126                let docs: Vec<_> = pat_type
127                    .attrs
128                    .iter()
129                    .filter(|a| a.path().is_ident("doc"))
130                    .cloned()
131                    .collect();
132
133                field_names.push(name.clone());
134                field_types.push(ty.clone());
135                field_docs.push(docs);
136            }
137            FnArg::Receiver(_) => {
138                return Err(syn::Error::new_spanned(
139                    param,
140                    "self parameter not supported",
141                ));
142            }
143        }
144    }
145
146    // Extract return type: Result<Output, Error>
147    let (output_type, error_type) = match &func.sig.output {
148        ReturnType::Type(_, ty) => extract_result_types(ty)?,
149        ReturnType::Default => {
150            return Err(syn::Error::new_spanned(
151                &func.sig,
152                "function must return Result<Output, Error>",
153            ));
154        }
155    };
156
157    // Get the function body
158    let body = &func.block;
159
160    // Build the field definitions with doc comments
161    let field_defs: Vec<_> = field_names
162        .iter()
163        .zip(field_types.iter())
164        .zip(field_docs.iter())
165        .map(|((name, ty), docs)| {
166            quote! {
167                #(#docs)*
168                pub #name: #ty
169            }
170        })
171        .collect();
172
173    // Build the destructuring pattern
174    let destructure_fields: Vec<_> = field_names.iter().map(|name| quote! { #name }).collect();
175
176    Ok(quote! {
177        /// Auto-generated args struct for the tool.
178        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
179        #vis struct #args_struct {
180            #(#field_defs,)*
181        }
182
183        /// Auto-generated tool struct.
184        #vis struct #tool_struct;
185
186        impl neuron_types::Tool for #tool_struct {
187            const NAME: &'static str = #tool_name;
188            type Args = #args_struct;
189            type Output = #output_type;
190            type Error = #error_type;
191
192            fn definition(&self) -> neuron_types::ToolDefinition {
193                neuron_types::ToolDefinition {
194                    name: Self::NAME.into(),
195                    title: None,
196                    description: #tool_description.into(),
197                    input_schema: serde_json::to_value(
198                        schemars::schema_for!(#args_struct)
199                    ).unwrap(),
200                    output_schema: None,
201                    annotations: None,
202                    cache_control: None,
203                }
204            }
205
206            async fn call(
207                &self,
208                args: Self::Args,
209                ctx: &neuron_types::ToolContext,
210            ) -> Result<Self::Output, Self::Error> {
211                let #args_struct { #(#destructure_fields,)* } = args;
212                // Suppress unused variable warning for ctx when it's used as _ctx
213                let _ = &ctx;
214                #body
215            }
216        }
217    })
218}
219
220fn extract_result_types(ty: &Type) -> syn::Result<(Box<Type>, Box<Type>)> {
221    if let Type::Path(type_path) = ty {
222        let last_segment = type_path
223            .path
224            .segments
225            .last()
226            .ok_or_else(|| syn::Error::new_spanned(ty, "expected Result type"))?;
227
228        if last_segment.ident != "Result" {
229            return Err(syn::Error::new_spanned(
230                ty,
231                "return type must be Result<Output, Error>",
232            ));
233        }
234
235        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
236            let mut types = args.args.iter().filter_map(|arg| {
237                if let syn::GenericArgument::Type(t) = arg {
238                    Some(t.clone())
239                } else {
240                    None
241                }
242            });
243
244            let output = types
245                .next()
246                .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Output type"))?;
247            let error = types
248                .next()
249                .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Error type"))?;
250
251            return Ok((Box::new(output), Box::new(error)));
252        }
253    }
254
255    Err(syn::Error::new_spanned(
256        ty,
257        "return type must be Result<Output, Error>",
258    ))
259}