1use proc_macro::TokenStream;
17use quote::{format_ident, quote};
18use syn::{parse_macro_input, Expr, ExprLit, ItemFn, Lit, Meta};
19
20#[proc_macro_attribute]
46pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
47 let input_fn = parse_macro_input!(item as ItemFn);
48
49 let mut description = String::new();
51 let mut custom_name: Option<String> = None;
52
53 let attr_parser = syn::meta::parser(|meta| {
55 if meta.path.is_ident("description") {
56 let value: syn::LitStr = meta.value()?.parse()?;
57 description = value.value();
58 Ok(())
59 } else if meta.path.is_ident("name") {
60 let value: syn::LitStr = meta.value()?.parse()?;
61 custom_name = Some(value.value());
62 Ok(())
63 } else {
64 Err(meta.error("unsupported attribute"))
65 }
66 });
67
68 parse_macro_input!(attr with attr_parser);
69
70 let fn_name = &input_fn.sig.ident;
72 let fn_vis = &input_fn.vis;
73 let fn_block = &input_fn.block;
74 let fn_inputs = &input_fn.sig.inputs;
75 let fn_output = &input_fn.sig.output;
76 let fn_asyncness = &input_fn.sig.asyncness;
77
78 let tool_name = custom_name.unwrap_or_else(|| fn_name.to_string());
80
81 let description = if description.is_empty() {
83 let mut doc = String::new();
85 for attr in &input_fn.attrs {
86 if attr.path().is_ident("doc") {
87 if let Meta::NameValue(nv) = &attr.meta {
88 if let Expr::Lit(ExprLit {
89 lit: Lit::Str(lit), ..
90 }) = &nv.value
91 {
92 if !doc.is_empty() {
93 doc.push(' ');
94 }
95 doc.push_str(lit.value().trim());
96 }
97 }
98 }
99 }
100 if doc.is_empty() {
101 format!("Tool: {}", tool_name)
102 } else {
103 doc
104 }
105 } else {
106 description
107 };
108
109 let mut param_names: Vec<syn::Ident> = Vec::new();
111 let mut param_types: Vec<syn::Type> = Vec::new();
112 let mut param_name_strs: Vec<String> = Vec::new();
113 let mut param_json_types: Vec<String> = Vec::new();
114
115 for input in fn_inputs.iter() {
116 if let syn::FnArg::Typed(pat_type) = input {
117 if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
118 let name = pat_ident.ident.clone();
119 let name_str = name.to_string();
120 let ty = (*pat_type.ty).clone();
121
122 let json_type = rust_type_to_json_schema(&pat_type.ty);
124
125 param_names.push(name);
126 param_name_strs.push(name_str);
127 param_types.push(ty);
128 param_json_types.push(json_type);
129 }
130 }
131 }
132
133 let struct_name = format_ident!("{}Tool", to_pascal_case(&tool_name));
135
136 let expanded = quote! {
138 #fn_vis #fn_asyncness fn #fn_name(#fn_inputs) #fn_output #fn_block
140
141 #[derive(Debug, Clone)]
143 #fn_vis struct #struct_name;
144
145 impl #struct_name {
146 pub fn new() -> Self {
148 Self
149 }
150 }
151
152 impl Default for #struct_name {
153 fn default() -> Self {
154 Self::new()
155 }
156 }
157
158 #[async_trait::async_trait]
159 impl praisonai::Tool for #struct_name {
160 fn name(&self) -> &str {
161 #tool_name
162 }
163
164 fn description(&self) -> &str {
165 #description
166 }
167
168 fn parameters_schema(&self) -> serde_json::Value {
169 let mut properties = serde_json::Map::new();
170 let mut required = Vec::new();
171
172 #(
173 properties.insert(
174 #param_name_strs.to_string(),
175 serde_json::json!({ "type": #param_json_types })
176 );
177 required.push(serde_json::Value::String(#param_name_strs.to_string()));
178 )*
179
180 serde_json::json!({
181 "type": "object",
182 "properties": properties,
183 "required": required
184 })
185 }
186
187 async fn execute(&self, args: serde_json::Value) -> praisonai::Result<serde_json::Value> {
188 #(
189 let #param_names: #param_types = serde_json::from_value(
190 args.get(#param_name_strs)
191 .cloned()
192 .unwrap_or(serde_json::Value::Null)
193 ).map_err(|e| praisonai::Error::tool(format!("Failed to parse {}: {}", #param_name_strs, e)))?;
194 )*
195
196 let result = #fn_name(#(#param_names),*).await;
197 serde_json::to_value(result)
198 .map_err(|e| praisonai::Error::tool(format!("Failed to serialize result: {}", e)))
199 }
200 }
201 };
202
203 TokenStream::from(expanded)
204}
205
206fn rust_type_to_json_schema(ty: &syn::Type) -> String {
208 let type_str = quote!(#ty).to_string().replace(" ", "");
209
210 match type_str.as_str() {
211 "String" | "&str" | "str" => "string".to_string(),
212 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
213 | "usize" => "integer".to_string(),
214 "f32" | "f64" => "number".to_string(),
215 "bool" => "boolean".to_string(),
216 _ if type_str.starts_with("Vec<") => "array".to_string(),
217 _ if type_str.starts_with("Option<") => {
218 let inner = &type_str[7..type_str.len() - 1];
220 rust_type_str_to_json_schema(inner)
221 }
222 _ => "object".to_string(),
223 }
224}
225
226fn rust_type_str_to_json_schema(type_str: &str) -> String {
227 match type_str {
228 "String" | "&str" | "str" => "string".to_string(),
229 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
230 | "usize" => "integer".to_string(),
231 "f32" | "f64" => "number".to_string(),
232 "bool" => "boolean".to_string(),
233 _ if type_str.starts_with("Vec<") => "array".to_string(),
234 _ => "object".to_string(),
235 }
236}
237
238fn to_pascal_case(s: &str) -> String {
240 s.split('_')
241 .map(|word| {
242 let mut chars = word.chars();
243 match chars.next() {
244 None => String::new(),
245 Some(first) => first.to_uppercase().chain(chars).collect(),
246 }
247 })
248 .collect()
249}