Skip to main content

praisonai_derive/
lib.rs

1//! Procedural macros for PraisonAI
2//!
3//! This crate provides the `#[tool]` attribute macro for defining tools.
4//!
5//! # Example
6//!
7//! ```rust,ignore
8//! use praisonai::tool;
9//!
10//! #[tool(description = "Search the web for information")]
11//! async fn search_web(query: String) -> String {
12//!     format!("Results for: {}", query)
13//! }
14//! ```
15
16use proc_macro::TokenStream;
17use quote::{format_ident, quote};
18use syn::{parse_macro_input, Expr, ExprLit, ItemFn, Lit, Meta};
19
20/// The `#[tool]` attribute macro for defining tools.
21///
22/// This macro transforms a function into a tool that can be used by agents.
23///
24/// # Attributes
25///
26/// - `description`: A description of what the tool does (required for LLM understanding)
27/// - `name`: Override the tool name (defaults to function name)
28///
29/// # Example
30///
31/// ```rust,ignore
32/// use praisonai::tool;
33///
34/// #[tool(description = "Search the web")]
35/// async fn search(query: String) -> String {
36///     format!("Results for: {}", query)
37/// }
38///
39/// // With custom name
40/// #[tool(name = "web_search", description = "Search the internet")]
41/// async fn my_search_fn(query: String, max_results: u32) -> Vec<String> {
42///     vec![format!("Result for: {}", query)]
43/// }
44/// ```
45#[proc_macro_attribute]
46pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
47    let input_fn = parse_macro_input!(item as ItemFn);
48
49    // Extract attributes from the attr token stream
50    let mut description = String::new();
51    let mut custom_name: Option<String> = None;
52
53    // Parse attributes manually for syn 2.x
54    let attr_parser = syn::meta::parser(|meta| {
55        if meta.path.is_ident("description") {
56            let value: syn::LitStr = meta.value()?.parse()?;
57            description = value.value();
58            Ok(())
59        } else if meta.path.is_ident("name") {
60            let value: syn::LitStr = meta.value()?.parse()?;
61            custom_name = Some(value.value());
62            Ok(())
63        } else {
64            Err(meta.error("unsupported attribute"))
65        }
66    });
67
68    parse_macro_input!(attr with attr_parser);
69
70    // Get function details
71    let fn_name = &input_fn.sig.ident;
72    let fn_vis = &input_fn.vis;
73    let fn_block = &input_fn.block;
74    let fn_inputs = &input_fn.sig.inputs;
75    let fn_output = &input_fn.sig.output;
76    let fn_asyncness = &input_fn.sig.asyncness;
77
78    // Tool name (custom or function name)
79    let tool_name = custom_name.unwrap_or_else(|| fn_name.to_string());
80
81    // Use docstring as description if not provided
82    let description = if description.is_empty() {
83        // Try to extract from doc comments
84        let mut doc = String::new();
85        for attr in &input_fn.attrs {
86            if attr.path().is_ident("doc") {
87                if let Meta::NameValue(nv) = &attr.meta {
88                    if let Expr::Lit(ExprLit {
89                        lit: Lit::Str(lit), ..
90                    }) = &nv.value
91                    {
92                        if !doc.is_empty() {
93                            doc.push(' ');
94                        }
95                        doc.push_str(lit.value().trim());
96                    }
97                }
98            }
99        }
100        if doc.is_empty() {
101            format!("Tool: {}", tool_name)
102        } else {
103            doc
104        }
105    } else {
106        description
107    };
108
109    // Generate parameter schema
110    let mut param_names: Vec<syn::Ident> = Vec::new();
111    let mut param_types: Vec<syn::Type> = Vec::new();
112    let mut param_name_strs: Vec<String> = Vec::new();
113    let mut param_json_types: Vec<String> = Vec::new();
114
115    for input in fn_inputs.iter() {
116        if let syn::FnArg::Typed(pat_type) = input {
117            if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
118                let name = pat_ident.ident.clone();
119                let name_str = name.to_string();
120                let ty = (*pat_type.ty).clone();
121
122                // Map Rust types to JSON Schema types
123                let json_type = rust_type_to_json_schema(&pat_type.ty);
124
125                param_names.push(name);
126                param_name_strs.push(name_str);
127                param_types.push(ty);
128                param_json_types.push(json_type);
129            }
130        }
131    }
132
133    // Generate the struct name for the tool
134    let struct_name = format_ident!("{}Tool", to_pascal_case(&tool_name));
135
136    // Generate the implementation
137    let expanded = quote! {
138        // Keep the original function
139        #fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output #fn_block
140
141        /// Auto-generated tool struct for #fn_name
142        #[derive(Debug, Clone)]
143        #fn_vis struct #struct_name;
144
145        impl #struct_name {
146            /// Create a new instance of this tool
147            pub fn new() -> Self {
148                Self
149            }
150        }
151
152        impl Default for #struct_name {
153            fn default() -> Self {
154                Self::new()
155            }
156        }
157
158        #[async_trait::async_trait]
159        impl praisonai::Tool for #struct_name {
160            fn name(&self) -> &str {
161                #tool_name
162            }
163
164            fn description(&self) -> &str {
165                #description
166            }
167
168            fn parameters_schema(&self) -> serde_json::Value {
169                let mut properties = serde_json::Map::new();
170                let mut required = Vec::new();
171
172                #(
173                    properties.insert(
174                        #param_name_strs.to_string(),
175                        serde_json::json!({ "type": #param_json_types })
176                    );
177                    required.push(serde_json::Value::String(#param_name_strs.to_string()));
178                )*
179
180                serde_json::json!({
181                    "type": "object",
182                    "properties": properties,
183                    "required": required
184                })
185            }
186
187            async fn execute(&self, args: serde_json::Value) -> praisonai::Result<serde_json::Value> {
188                #(
189                    let #param_names: #param_types = serde_json::from_value(
190                        args.get(#param_name_strs)
191                            .cloned()
192                            .unwrap_or(serde_json::Value::Null)
193                    ).map_err(|e| praisonai::Error::tool(format!("Failed to parse {}: {}", #param_name_strs, e)))?;
194                )*
195
196                let result = #fn_name(#(#param_names),*).await;
197                serde_json::to_value(result)
198                    .map_err(|e| praisonai::Error::tool(format!("Failed to serialize result: {}", e)))
199            }
200        }
201    };
202
203    TokenStream::from(expanded)
204}
205
206/// Convert a Rust type to JSON Schema type string
207fn rust_type_to_json_schema(ty: &syn::Type) -> String {
208    let type_str = quote!(#ty).to_string().replace(" ", "");
209
210    match type_str.as_str() {
211        "String" | "&str" | "str" => "string".to_string(),
212        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
213        | "usize" => "integer".to_string(),
214        "f32" | "f64" => "number".to_string(),
215        "bool" => "boolean".to_string(),
216        _ if type_str.starts_with("Vec<") => "array".to_string(),
217        _ if type_str.starts_with("Option<") => {
218            // Extract inner type
219            let inner = &type_str[7..type_str.len() - 1];
220            rust_type_str_to_json_schema(inner)
221        }
222        _ => "object".to_string(),
223    }
224}
225
226fn rust_type_str_to_json_schema(type_str: &str) -> String {
227    match type_str {
228        "String" | "&str" | "str" => "string".to_string(),
229        "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
230        | "usize" => "integer".to_string(),
231        "f32" | "f64" => "number".to_string(),
232        "bool" => "boolean".to_string(),
233        _ if type_str.starts_with("Vec<") => "array".to_string(),
234        _ => "object".to_string(),
235    }
236}
237
238/// Convert snake_case to PascalCase
239fn to_pascal_case(s: &str) -> String {
240    s.split('_')
241        .map(|word| {
242            let mut chars = word.chars();
243            match chars.next() {
244                None => String::new(),
245                Some(first) => first.to_uppercase().chain(chars).collect(),
246            }
247        })
248        .collect()
249}