aither_derive/
lib.rs

1//! # aither-derive
2//!
3//! Procedural macros for converting Rust functions into AI tools that can be called by language models.
4//!
5//! This crate provides the `#[tool]` attribute macro that automatically generates the necessary
6//! boilerplate code to make your async functions callable by AI models through the `aither` framework.
7//!
8//! ## Quick Start
9//!
10//! Transform any async function into an AI tool by adding the `#[tool]` attribute:
11//!
12//! ```rust
13//! use aither::Result;
14//! use aither_derive::tool;
15//!
16//! #[tool(description = "Get the current UTC time")]
17//! pub async fn get_time() -> Result<&'static str> {
18//!     Ok("2023-10-01T12:00:00Z")
19//! }
20//! ```
21//!
22//! ## Function Patterns
23//!
24//! ### No Parameters
25//!
26//! ```rust
27//! #[tool(description = "Check service health status")]
28//! pub async fn health_check() -> Result<String> {
29//!     Ok("Service is healthy".to_string())
30//! }
31//! ```
32//!
33//! ### Simple Parameters
34//!
35//! ```rust
36//! use serde::Serialize;
37//!
38//! #[derive(Debug, Serialize)]
39//! pub struct SearchResult {
40//!     title: String,
41//!     url: String,
42//! }
43//!
44//! #[tool(description = "Search the web for content")]
45//! pub async fn search(keywords: Vec<String>, limit: u32) -> Result<Vec<SearchResult>> {
46//!     // Your search implementation here
47//!     Ok(vec![])
48//! }
49//! ```
50//!
51//! ### Complex Parameters with Documentation
52//!
53//! ```rust
54//! use schemars::JsonSchema;
55//! use serde::Deserialize;
56//!
57//! #[derive(Debug, JsonSchema, Deserialize)]
58//! pub struct ImageArgs {
59//!     /// The text prompt for image generation
60//!     pub prompt: String,
61//!     /// Image width in pixels
62//!     #[serde(default = "default_width")]
63//!     pub width: u32,
64//!     /// Image height in pixels  
65//!     #[serde(default = "default_height")]
66//!     pub height: u32,
67//! }
68//!
69//! fn default_width() -> u32 { 512 }
70//! fn default_height() -> u32 { 512 }
71//!
72//! #[tool(description = "Generate an image from a text prompt")]
73//! pub async fn generate_image(args: ImageArgs) -> Result<String> {
74//!     // Your image generation logic here
75//!     Ok(format!("Generated image: {}", args.prompt))
76//! }
77//! ```
78//!
79//! ## Requirements
80//!
81//! - Functions must be `async`
82//! - Return type must be `Result<T>` where `T: serde::Serialize`
83//! - Parameters must implement `serde::Deserialize` and `schemars::JsonSchema`
84//! - No `self` parameters (static functions only)
85//! - No lifetime or generic parameters
86
87use convert_case::{Case, Casing};
88use proc_macro::TokenStream;
89use quote::{format_ident, quote};
90use syn::{
91    FnArg, Ident, ItemFn, LitStr, Token, Type, Visibility,
92    parse::{Parse, ParseStream},
93    parse_macro_input, parse_quote,
94};
95
96/// Arguments for the `#[tool]` attribute macro
97struct ToolArgs {
98    description: String,
99    rename: Option<String>,
100}
101
102impl Parse for ToolArgs {
103    /// Parse the arguments from the `#[tool(...)]` attribute.
104    ///
105    /// Supports:
106    /// - `description = "..."` (required): Tool description for the AI model
107    /// - `rename = "..."` (optional): Custom name for the tool (defaults to function name)
108    fn parse(input: ParseStream) -> syn::Result<Self> {
109        let mut description = None;
110        let mut rename = None;
111
112        while !input.is_empty() {
113            let ident: Ident = input.parse()?;
114            let _: Token![=] = input.parse()?;
115            let value: LitStr = input.parse()?;
116
117            match ident.to_string().as_str() {
118                "description" => description = Some(value.value()),
119                "rename" => rename = Some(value.value()),
120                _ => {
121                    return Err(syn::Error::new_spanned(
122                        ident,
123                        "unknown attribute. Supported: description, rename",
124                    ));
125                }
126            }
127
128            if input.peek(Token![,]) {
129                let _: Token![,] = input.parse()?;
130            }
131        }
132
133        let description = description
134            .ok_or_else(|| syn::Error::new(input.span(), "description attribute is required"))?;
135
136        Ok(Self {
137            description,
138            rename,
139        })
140    }
141}
142
143/// Converts an async function into an AI tool that can be called by language models.
144///
145/// This procedural macro generates the necessary boilerplate code to make your function
146/// callable through the `aither::llm::Tool` trait.
147///
148/// # Arguments
149///
150/// - `description` (required): A clear description of what the tool does. This helps the AI model
151///   decide when to use this tool.
152/// - `rename` (optional): A custom name for the tool. If not provided, uses the function name.
153///
154/// # Examples
155///
156/// ## Basic Usage (No Parameters)
157///
158/// ```rust
159/// use aither::Result;
160/// use aither_derive::tool;
161///
162/// #[tool(description = "Get the current system time")]
163/// pub async fn current_time() -> Result<String> {
164///     Ok(chrono::Utc::now().to_rfc3339())
165/// }
166/// ```
167///
168/// ## With Simple Parameters
169///
170/// ```rust
171/// #[tool(description = "Calculate the sum of two numbers")]
172/// pub async fn add(a: f64, b: f64) -> Result<f64> {
173///     Ok(a + b)
174/// }
175/// ```
176///
177/// ## With Complex Parameters
178///
179/// ```rust
180/// use schemars::JsonSchema;
181/// use serde::Deserialize;
182///
183/// #[derive(JsonSchema, Deserialize)]
184/// pub struct EmailRequest {
185///     /// Recipient email address
186///     pub to: String,
187///     /// Email subject line
188///     pub subject: String,
189///     /// Email body content
190///     pub body: String,
191/// }
192///
193/// #[tool(description = "Send an email to a recipient")]
194/// pub async fn send_email(request: EmailRequest) -> Result<String> {
195///     // Your email sending logic here
196///     Ok(format!("Email sent to {}", request.to))
197/// }
198/// ```
199///
200/// ## With Custom Name
201///
202/// ```rust
203/// #[tool(
204///     description = "Perform complex mathematical calculations",
205///     rename = "calculator"
206/// )]
207/// pub async fn complex_math_function(expression: String) -> Result<f64> {
208///     // Your calculation logic here
209///     Ok(42.0)
210/// }
211/// ```
212///
213/// # Generated Code
214///
215/// For a function named `search`, the macro generates:
216///
217/// 1. A `SearchArgs` struct (if the function has multiple parameters)
218/// 2. A `Search` struct that implements `aither::llm::Tool`
219/// 3. All necessary trait implementations for JSON schema generation and deserialization
220///
221/// # Requirements
222///
223/// - Function must be `async`
224/// - Return type must be `Result<T>` where `T` implements `serde::Serialize`
225/// - Parameters must implement `serde::Deserialize` and `schemars::JsonSchema`
226/// - No `self` parameters (only free functions are supported)
227/// - No lifetime parameters or generics
228///
229/// # Errors
230///
231/// The macro will produce compile-time errors if:
232/// - The function is not async
233/// - The function has `self` parameters
234/// - The function has more than the supported number of parameters
235/// - Required attributes are missing
236#[proc_macro_attribute]
237pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
238    let args = parse_macro_input!(args as ToolArgs);
239    let input_fn = parse_macro_input!(input as ItemFn);
240
241    match tool_impl(args, input_fn) {
242        Ok(tokens) => tokens.into(),
243        Err(err) => err.to_compile_error().into(),
244    }
245}
246
247/// Implementation details for the `#[tool]` macro.
248///
249/// This function performs the actual code generation, transforming the annotated async function
250/// into a struct that implements the `Tool` trait.
251fn tool_impl(args: ToolArgs, input_fn: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
252    let fn_name = &input_fn.sig.ident;
253    let tool_name = args.rename.unwrap_or_else(|| fn_name.to_string());
254    let description = args.description;
255    let fn_vis = &input_fn.vis;
256
257    let tool_struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
258
259    // Analyze function signature
260    let AnalyzedArgs {
261        args_type,
262        params,
263        stream,
264    } = analyze_function_args(fn_vis, &tool_struct_name, &input_fn.sig.inputs)?;
265
266    if input_fn.sig.asyncness.is_none() {
267        return Err(syn::Error::new_spanned(
268            input_fn.sig,
269            "Tool functions must be async",
270        ));
271    }
272
273    let call_expr = if params.is_empty() {
274        // No parameters, call the function directly
275        quote! { #fn_name().await }
276    } else {
277        // Call the function with extracted parameters
278        let args_tuple = quote! { #(#params),* };
279        quote! { #fn_name(#args_tuple).await }
280    };
281
282    let extractor = if params.len() <= 1 {
283        quote! {}
284    } else {
285        quote! { let Self::Arguments { #(#params),* } = args; }
286    };
287
288    let expanded = quote! {
289        #input_fn
290
291        #stream
292
293
294        #[derive(::core::default::Default,::core::fmt::Debug)]
295        #fn_vis struct #tool_struct_name;
296
297        impl ::aither::llm::Tool for #tool_struct_name {
298            const NAME: &'static str = #tool_name;
299            const DESCRIPTION: &'static str = #description;
300            type Arguments = #args_type;
301
302            async fn call(&mut self, args: Self::Arguments) -> ::aither::Result {
303                #extractor
304                let result: ::aither::Result<_> = #call_expr;
305                let result: ::aither::Result<String> = result.map(|value|{
306                    // Convert the result to a JSON string
307                    ::aither::llm::tool::json(&value)
308                });
309                result
310            }
311        }
312    };
313
314    Ok(expanded)
315}
316
317/// Container for analyzed function arguments and generated types.
318struct AnalyzedArgs {
319    /// The type used for the Tool's Arguments associated type
320    args_type: Type,
321    /// Parameter names extracted from the function signature  
322    params: Vec<Ident>,
323    /// Generated argument struct definition (if needed)
324    stream: proc_macro2::TokenStream,
325}
326
327/// Analyzes function parameters and generates appropriate argument types.
328///
329/// This function handles three cases:
330/// - No parameters: Uses unit type `()`
331/// - Single parameter: Uses the parameter type directly
332/// - Multiple parameters: Generates a new struct with all parameters as fields
333fn analyze_function_args(
334    fn_vis: &Visibility,
335    struct_name: &Ident,
336    inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
337) -> syn::Result<AnalyzedArgs> {
338    match inputs.len() {
339        0 => {
340            // No arguments - use unit type
341
342            Ok(AnalyzedArgs {
343                args_type: parse_quote! { () },
344                params: vec![],
345                stream: quote! {},
346            })
347        }
348        1 => {
349            // Single argument
350            if let FnArg::Typed(pat_type) = &inputs[0] {
351                Ok(AnalyzedArgs {
352                    args_type: (*pat_type.ty).clone(),
353                    params: vec![format_ident!("args")],
354                    stream: quote! {},
355                })
356            } else {
357                Err(syn::Error::new_spanned(
358                    &inputs[0],
359                    "self parameters are not supported in tool functions",
360                ))
361            }
362        }
363        _ => {
364            let mut attributes = Vec::new();
365
366            for arg in inputs {
367                if let FnArg::Typed(pat_type) = arg {
368                    let pat = &pat_type.pat;
369                    let ty = &pat_type.ty;
370                    attributes.push(quote! {
371                        #pat: #ty,
372                    });
373                } else {
374                    return Err(syn::Error::new_spanned(
375                        arg,
376                        "self parameters are not supported in tool functions",
377                    ));
378                }
379            }
380
381            let arg_struct_name = format_ident!("{}Args", struct_name);
382
383            let new_type_gen = quote! {
384                #[derive(::schemars::JsonSchema, ::serde::Deserialize,::core::fmt::Debug)]
385                #fn_vis struct #arg_struct_name {
386                    #(
387                        #attributes
388                    )*
389                }
390            };
391
392            Ok(AnalyzedArgs {
393                args_type: parse_quote! { #arg_struct_name },
394                params: inputs
395                    .iter()
396                    .map(|arg| {
397                        if let FnArg::Typed(pat_type) = arg {
398                            let pat = &pat_type.pat;
399                            format_ident!("{}", quote! {#pat}.to_string())
400                        } else {
401                            panic!("Expected typed argument")
402                        }
403                    })
404                    .collect(),
405                stream: new_type_gen,
406            })
407        }
408    }
409}