aither_derive/lib.rs
1//! # aither-derive
2//!
3//! Procedural macros for converting Rust functions into AI tools that can be called by language models.
4//!
5//! This crate provides the `#[tool]` attribute macro that automatically generates the necessary
6//! boilerplate code to make your async functions callable by AI models through the `aither` framework.
7//!
8//! ## Quick Start
9//!
10//! Transform any async function into an AI tool by adding the `#[tool]` attribute:
11//!
12//! ```rust
13//! use aither::Result;
14//! use aither_derive::tool;
15//!
16//! #[tool(description = "Get the current UTC time")]
17//! pub async fn get_time() -> Result<&'static str> {
18//! Ok("2023-10-01T12:00:00Z")
19//! }
20//! ```
21//!
22//! ## Function Patterns
23//!
24//! ### No Parameters
25//!
26//! ```rust
27//! #[tool(description = "Check service health status")]
28//! pub async fn health_check() -> Result<String> {
29//! Ok("Service is healthy".to_string())
30//! }
31//! ```
32//!
33//! ### Simple Parameters
34//!
35//! ```rust
36//! use serde::Serialize;
37//!
38//! #[derive(Debug, Serialize)]
39//! pub struct SearchResult {
40//! title: String,
41//! url: String,
42//! }
43//!
44//! #[tool(description = "Search the web for content")]
45//! pub async fn search(keywords: Vec<String>, limit: u32) -> Result<Vec<SearchResult>> {
46//! // Your search implementation here
47//! Ok(vec![])
48//! }
49//! ```
50//!
51//! ### Complex Parameters with Documentation
52//!
53//! ```rust
54//! use schemars::JsonSchema;
55//! use serde::Deserialize;
56//!
57//! #[derive(Debug, JsonSchema, Deserialize)]
58//! pub struct ImageArgs {
59//! /// The text prompt for image generation
60//! pub prompt: String,
61//! /// Image width in pixels
62//! #[serde(default = "default_width")]
63//! pub width: u32,
64//! /// Image height in pixels
65//! #[serde(default = "default_height")]
66//! pub height: u32,
67//! }
68//!
69//! fn default_width() -> u32 { 512 }
70//! fn default_height() -> u32 { 512 }
71//!
72//! #[tool(description = "Generate an image from a text prompt")]
73//! pub async fn generate_image(args: ImageArgs) -> Result<String> {
74//! // Your image generation logic here
75//! Ok(format!("Generated image: {}", args.prompt))
76//! }
77//! ```
78//!
79//! ## Requirements
80//!
81//! - Functions must be `async`
82//! - Return type must be `Result<T>` where `T: serde::Serialize`
83//! - Parameters must implement `serde::Deserialize` and `schemars::JsonSchema`
84//! - No `self` parameters (static functions only)
85//! - No lifetime or generic parameters
86
87use convert_case::{Case, Casing};
88use proc_macro::TokenStream;
89use quote::{format_ident, quote};
90use syn::{
91 FnArg, Ident, ItemFn, LitStr, Token, Type, Visibility,
92 parse::{Parse, ParseStream},
93 parse_macro_input, parse_quote,
94};
95
96/// Arguments for the `#[tool]` attribute macro
97struct ToolArgs {
98 description: String,
99 rename: Option<String>,
100}
101
102impl Parse for ToolArgs {
103 /// Parse the arguments from the `#[tool(...)]` attribute.
104 ///
105 /// Supports:
106 /// - `description = "..."` (required): Tool description for the AI model
107 /// - `rename = "..."` (optional): Custom name for the tool (defaults to function name)
108 fn parse(input: ParseStream) -> syn::Result<Self> {
109 let mut description = None;
110 let mut rename = None;
111
112 while !input.is_empty() {
113 let ident: Ident = input.parse()?;
114 let _: Token![=] = input.parse()?;
115 let value: LitStr = input.parse()?;
116
117 match ident.to_string().as_str() {
118 "description" => description = Some(value.value()),
119 "rename" => rename = Some(value.value()),
120 _ => {
121 return Err(syn::Error::new_spanned(
122 ident,
123 "unknown attribute. Supported: description, rename",
124 ));
125 }
126 }
127
128 if input.peek(Token![,]) {
129 let _: Token![,] = input.parse()?;
130 }
131 }
132
133 let description = description
134 .ok_or_else(|| syn::Error::new(input.span(), "description attribute is required"))?;
135
136 Ok(Self {
137 description,
138 rename,
139 })
140 }
141}
142
143/// Converts an async function into an AI tool that can be called by language models.
144///
145/// This procedural macro generates the necessary boilerplate code to make your function
146/// callable through the `aither::llm::Tool` trait.
147///
148/// # Arguments
149///
150/// - `description` (required): A clear description of what the tool does. This helps the AI model
151/// decide when to use this tool.
152/// - `rename` (optional): A custom name for the tool. If not provided, uses the function name.
153///
154/// # Examples
155///
156/// ## Basic Usage (No Parameters)
157///
158/// ```rust
159/// use aither::Result;
160/// use aither_derive::tool;
161///
162/// #[tool(description = "Get the current system time")]
163/// pub async fn current_time() -> Result<String> {
164/// Ok(chrono::Utc::now().to_rfc3339())
165/// }
166/// ```
167///
168/// ## With Simple Parameters
169///
170/// ```rust
171/// #[tool(description = "Calculate the sum of two numbers")]
172/// pub async fn add(a: f64, b: f64) -> Result<f64> {
173/// Ok(a + b)
174/// }
175/// ```
176///
177/// ## With Complex Parameters
178///
179/// ```rust
180/// use schemars::JsonSchema;
181/// use serde::Deserialize;
182///
183/// #[derive(JsonSchema, Deserialize)]
184/// pub struct EmailRequest {
185/// /// Recipient email address
186/// pub to: String,
187/// /// Email subject line
188/// pub subject: String,
189/// /// Email body content
190/// pub body: String,
191/// }
192///
193/// #[tool(description = "Send an email to a recipient")]
194/// pub async fn send_email(request: EmailRequest) -> Result<String> {
195/// // Your email sending logic here
196/// Ok(format!("Email sent to {}", request.to))
197/// }
198/// ```
199///
200/// ## With Custom Name
201///
202/// ```rust
203/// #[tool(
204/// description = "Perform complex mathematical calculations",
205/// rename = "calculator"
206/// )]
207/// pub async fn complex_math_function(expression: String) -> Result<f64> {
208/// // Your calculation logic here
209/// Ok(42.0)
210/// }
211/// ```
212///
213/// # Generated Code
214///
215/// For a function named `search`, the macro generates:
216///
217/// 1. A `SearchArgs` struct (if the function has multiple parameters)
218/// 2. A `Search` struct that implements `aither::llm::Tool`
219/// 3. All necessary trait implementations for JSON schema generation and deserialization
220///
221/// # Requirements
222///
223/// - Function must be `async`
224/// - Return type must be `Result<T>` where `T` implements `serde::Serialize`
225/// - Parameters must implement `serde::Deserialize` and `schemars::JsonSchema`
226/// - No `self` parameters (only free functions are supported)
227/// - No lifetime parameters or generics
228///
229/// # Errors
230///
231/// The macro will produce compile-time errors if:
232/// - The function is not async
233/// - The function has `self` parameters
234/// - The function has more than the supported number of parameters
235/// - Required attributes are missing
236#[proc_macro_attribute]
237pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream {
238 let args = parse_macro_input!(args as ToolArgs);
239 let input_fn = parse_macro_input!(input as ItemFn);
240
241 match tool_impl(args, input_fn) {
242 Ok(tokens) => tokens.into(),
243 Err(err) => err.to_compile_error().into(),
244 }
245}
246
247/// Implementation details for the `#[tool]` macro.
248///
249/// This function performs the actual code generation, transforming the annotated async function
250/// into a struct that implements the `Tool` trait.
251fn tool_impl(args: ToolArgs, input_fn: ItemFn) -> syn::Result<proc_macro2::TokenStream> {
252 let fn_name = &input_fn.sig.ident;
253 let tool_name = args.rename.unwrap_or_else(|| fn_name.to_string());
254 let description = args.description;
255 let fn_vis = &input_fn.vis;
256
257 let tool_struct_name = format_ident!("{}", fn_name.to_string().to_case(Case::Pascal));
258
259 // Analyze function signature
260 let AnalyzedArgs {
261 args_type,
262 params,
263 stream,
264 } = analyze_function_args(fn_vis, &tool_struct_name, &input_fn.sig.inputs)?;
265
266 if input_fn.sig.asyncness.is_none() {
267 return Err(syn::Error::new_spanned(
268 input_fn.sig,
269 "Tool functions must be async",
270 ));
271 }
272
273 let call_expr = if params.is_empty() {
274 // No parameters, call the function directly
275 quote! { #fn_name().await }
276 } else {
277 // Call the function with extracted parameters
278 let args_tuple = quote! { #(#params),* };
279 quote! { #fn_name(#args_tuple).await }
280 };
281
282 let extractor = if params.len() <= 1 {
283 quote! {}
284 } else {
285 quote! { let Self::Arguments { #(#params),* } = args; }
286 };
287
288 let expanded = quote! {
289 #input_fn
290
291 #stream
292
293
294 #[derive(::core::default::Default,::core::fmt::Debug)]
295 #fn_vis struct #tool_struct_name;
296
297 impl ::aither::llm::Tool for #tool_struct_name {
298 const NAME: &'static str = #tool_name;
299 const DESCRIPTION: &'static str = #description;
300 type Arguments = #args_type;
301
302 async fn call(&mut self, args: Self::Arguments) -> ::aither::Result {
303 #extractor
304 let result: ::aither::Result<_> = #call_expr;
305 let result: ::aither::Result<String> = result.map(|value|{
306 // Convert the result to a JSON string
307 ::aither::llm::tool::json(&value)
308 });
309 result
310 }
311 }
312 };
313
314 Ok(expanded)
315}
316
317/// Container for analyzed function arguments and generated types.
318struct AnalyzedArgs {
319 /// The type used for the Tool's Arguments associated type
320 args_type: Type,
321 /// Parameter names extracted from the function signature
322 params: Vec<Ident>,
323 /// Generated argument struct definition (if needed)
324 stream: proc_macro2::TokenStream,
325}
326
327/// Analyzes function parameters and generates appropriate argument types.
328///
329/// This function handles three cases:
330/// - No parameters: Uses unit type `()`
331/// - Single parameter: Uses the parameter type directly
332/// - Multiple parameters: Generates a new struct with all parameters as fields
333fn analyze_function_args(
334 fn_vis: &Visibility,
335 struct_name: &Ident,
336 inputs: &syn::punctuated::Punctuated<FnArg, syn::Token![,]>,
337) -> syn::Result<AnalyzedArgs> {
338 match inputs.len() {
339 0 => {
340 // No arguments - use unit type
341
342 Ok(AnalyzedArgs {
343 args_type: parse_quote! { () },
344 params: vec![],
345 stream: quote! {},
346 })
347 }
348 1 => {
349 // Single argument
350 if let FnArg::Typed(pat_type) = &inputs[0] {
351 Ok(AnalyzedArgs {
352 args_type: (*pat_type.ty).clone(),
353 params: vec![format_ident!("args")],
354 stream: quote! {},
355 })
356 } else {
357 Err(syn::Error::new_spanned(
358 &inputs[0],
359 "self parameters are not supported in tool functions",
360 ))
361 }
362 }
363 _ => {
364 let mut attributes = Vec::new();
365
366 for arg in inputs {
367 if let FnArg::Typed(pat_type) = arg {
368 let pat = &pat_type.pat;
369 let ty = &pat_type.ty;
370 attributes.push(quote! {
371 #pat: #ty,
372 });
373 } else {
374 return Err(syn::Error::new_spanned(
375 arg,
376 "self parameters are not supported in tool functions",
377 ));
378 }
379 }
380
381 let arg_struct_name = format_ident!("{}Args", struct_name);
382
383 let new_type_gen = quote! {
384 #[derive(::schemars::JsonSchema, ::serde::Deserialize,::core::fmt::Debug)]
385 #fn_vis struct #arg_struct_name {
386 #(
387 #attributes
388 )*
389 }
390 };
391
392 Ok(AnalyzedArgs {
393 args_type: parse_quote! { #arg_struct_name },
394 params: inputs
395 .iter()
396 .map(|arg| {
397 if let FnArg::Typed(pat_type) = arg {
398 let pat = &pat_type.pat;
399 format_ident!("{}", quote! {#pat}.to_string())
400 } else {
401 panic!("Expected typed argument")
402 }
403 })
404 .collect(),
405 stream: new_type_gen,
406 })
407 }
408 }
409}