Skip to main content

neuron_tool_macros/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type};
6
7/// Derive a Tool implementation from an async function.
8///
9/// # Example
10///
11/// ```ignore
12/// #[neuron_tool(name = "calculate", description = "Evaluate a math expression")]
13/// async fn calculate(
14///     /// A mathematical expression like "2 + 2"
15///     expression: String,
16///     _ctx: &ToolContext,
17/// ) -> Result<CalculateOutput, CalculateError> {
18///     let result = eval(&expression);
19///     Ok(CalculateOutput { result })
20/// }
21/// ```
22#[proc_macro_attribute]
23pub fn neuron_tool(attr: TokenStream, item: TokenStream) -> TokenStream {
24    let args = syn::parse_macro_input!(attr as AgentToolArgs);
25    let func = syn::parse_macro_input!(item as ItemFn);
26
27    match expand_neuron_tool(args, func) {
28        Ok(tokens) => tokens.into(),
29        Err(err) => err.to_compile_error().into(),
30    }
31}
32
33// Parse the attribute args: name = "...", description = "..."
34struct AgentToolArgs {
35    name: String,
36    description: String,
37}
38
39impl syn::parse::Parse for AgentToolArgs {
40    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
41        let mut name = None;
42        let mut description = None;
43
44        while !input.is_empty() {
45            let ident: syn::Ident = input.parse()?;
46            let _: syn::Token![=] = input.parse()?;
47            let value: syn::LitStr = input.parse()?;
48
49            match ident.to_string().as_str() {
50                "name" => name = Some(value.value()),
51                "description" => description = Some(value.value()),
52                other => {
53                    return Err(syn::Error::new(
54                        ident.span(),
55                        format!("unknown attribute: {other}"),
56                    ));
57                }
58            }
59
60            if !input.is_empty() {
61                let _: syn::Token![,] = input.parse()?;
62            }
63        }
64
65        Ok(AgentToolArgs {
66            name: name.ok_or_else(|| input.error("missing `name` attribute"))?,
67            description: description.ok_or_else(|| input.error("missing `description` attribute"))?,
68        })
69    }
70}
71
72fn to_pascal_case(s: &str) -> String {
73    s.split('_')
74        .map(|word| {
75            let mut chars = word.chars();
76            match chars.next() {
77                None => String::new(),
78                Some(c) => c.to_uppercase().to_string() + &chars.collect::<String>(),
79            }
80        })
81        .collect()
82}
83
84fn expand_neuron_tool(
85    args: AgentToolArgs,
86    func: ItemFn,
87) -> syn::Result<proc_macro2::TokenStream> {
88    let func_name = &func.sig.ident;
89    let vis = &func.vis;
90    let pascal = to_pascal_case(&func_name.to_string());
91    let tool_struct = format_ident!("{}Tool", pascal);
92    let args_struct = format_ident!("{}Args", pascal);
93
94    let tool_name = &args.name;
95    let tool_description = &args.description;
96
97    // Extract parameters (skip last one which is &ToolContext)
98    let params: Vec<_> = func.sig.inputs.iter().collect();
99    if params.is_empty() {
100        return Err(syn::Error::new_spanned(
101            &func.sig,
102            "function must have at least a ctx parameter",
103        ));
104    }
105
106    let tool_params = &params[..params.len() - 1]; // All except last (ctx)
107
108    // Build Args struct fields
109    let mut field_names = Vec::new();
110    let mut field_types = Vec::new();
111    let mut field_docs = Vec::new();
112
113    for param in tool_params {
114        match param {
115            FnArg::Typed(pat_type) => {
116                let name = match pat_type.pat.as_ref() {
117                    Pat::Ident(ident) => &ident.ident,
118                    _ => {
119                        return Err(syn::Error::new_spanned(
120                            pat_type,
121                            "expected identifier pattern",
122                        ));
123                    }
124                };
125                let ty = &pat_type.ty;
126
127                // Extract doc comments from attributes
128                let docs: Vec<_> = pat_type
129                    .attrs
130                    .iter()
131                    .filter(|a| a.path().is_ident("doc"))
132                    .cloned()
133                    .collect();
134
135                field_names.push(name.clone());
136                field_types.push(ty.clone());
137                field_docs.push(docs);
138            }
139            FnArg::Receiver(_) => {
140                return Err(syn::Error::new_spanned(
141                    param,
142                    "self parameter not supported",
143                ));
144            }
145        }
146    }
147
148    // Extract return type: Result<Output, Error>
149    let (output_type, error_type) = match &func.sig.output {
150        ReturnType::Type(_, ty) => extract_result_types(ty)?,
151        ReturnType::Default => {
152            return Err(syn::Error::new_spanned(
153                &func.sig,
154                "function must return Result<Output, Error>",
155            ));
156        }
157    };
158
159    // Get the function body
160    let body = &func.block;
161
162    // Build the field definitions with doc comments
163    let field_defs: Vec<_> = field_names
164        .iter()
165        .zip(field_types.iter())
166        .zip(field_docs.iter())
167        .map(|((name, ty), docs)| {
168            quote! {
169                #(#docs)*
170                pub #name: #ty
171            }
172        })
173        .collect();
174
175    // Build the destructuring pattern
176    let destructure_fields: Vec<_> = field_names.iter().map(|name| quote! { #name }).collect();
177
178    Ok(quote! {
179        /// Auto-generated args struct for the tool.
180        #[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
181        #vis struct #args_struct {
182            #(#field_defs,)*
183        }
184
185        /// Auto-generated tool struct.
186        #vis struct #tool_struct;
187
188        impl neuron_types::Tool for #tool_struct {
189            const NAME: &'static str = #tool_name;
190            type Args = #args_struct;
191            type Output = #output_type;
192            type Error = #error_type;
193
194            fn definition(&self) -> neuron_types::ToolDefinition {
195                neuron_types::ToolDefinition {
196                    name: Self::NAME.into(),
197                    title: None,
198                    description: #tool_description.into(),
199                    input_schema: serde_json::to_value(
200                        schemars::schema_for!(#args_struct)
201                    ).unwrap(),
202                    output_schema: None,
203                    annotations: None,
204                    cache_control: None,
205                }
206            }
207
208            async fn call(
209                &self,
210                args: Self::Args,
211                ctx: &neuron_types::ToolContext,
212            ) -> Result<Self::Output, Self::Error> {
213                let #args_struct { #(#destructure_fields,)* } = args;
214                // Suppress unused variable warning for ctx when it's used as _ctx
215                let _ = &ctx;
216                #body
217            }
218        }
219    })
220}
221
222fn extract_result_types(ty: &Type) -> syn::Result<(Box<Type>, Box<Type>)> {
223    if let Type::Path(type_path) = ty {
224        let last_segment = type_path
225            .path
226            .segments
227            .last()
228            .ok_or_else(|| syn::Error::new_spanned(ty, "expected Result type"))?;
229
230        if last_segment.ident != "Result" {
231            return Err(syn::Error::new_spanned(
232                ty,
233                "return type must be Result<Output, Error>",
234            ));
235        }
236
237        if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
238            let mut types = args.args.iter().filter_map(|arg| {
239                if let syn::GenericArgument::Type(t) = arg {
240                    Some(t.clone())
241                } else {
242                    None
243                }
244            });
245
246            let output = types
247                .next()
248                .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Output type"))?;
249            let error = types
250                .next()
251                .ok_or_else(|| syn::Error::new_spanned(ty, "Result must have Error type"))?;
252
253            return Ok((Box::new(output), Box::new(error)));
254        }
255    }
256
257    Err(syn::Error::new_spanned(
258        ty,
259        "return type must be Result<Output, Error>",
260    ))
261}