1use darling::{ast::NestedMeta, FromMeta};
37use proc_macro::TokenStream;
38use proc_macro2::TokenStream as TokenStream2;
39use quote::{format_ident, quote};
40use syn::{parse_macro_input, Expr, FnArg, ItemFn, Lit, Meta, Pat, PatType, Type};
41
42#[derive(Debug, FromMeta)]
44struct ToolArgs {
45 description: String,
47 #[darling(default)]
49 name: Option<String>,
50}
51
52#[derive(Debug, Default)]
54struct ParamArgs {
55 description: Option<String>,
56 default: Option<Expr>,
57}
58
59impl ParamArgs {
60 fn from_attrs(attrs: &[syn::Attribute]) -> Self {
61 let mut args = ParamArgs::default();
62
63 for attr in attrs {
64 if attr.path().is_ident("description") {
65 if let Meta::NameValue(nv) = &attr.meta {
66 if let Expr::Lit(expr_lit) = &nv.value {
67 if let Lit::Str(lit_str) = &expr_lit.lit {
68 args.description = Some(lit_str.value());
69 }
70 }
71 }
72 } else if attr.path().is_ident("default") {
73 if let Meta::NameValue(nv) = &attr.meta {
74 args.default = Some(nv.value.clone());
75 }
76 }
77 }
78
79 args
80 }
81}
82
83#[proc_macro_attribute]
108pub fn tool(attr: TokenStream, item: TokenStream) -> TokenStream {
109 let attr_args = match NestedMeta::parse_meta_list(attr.into()) {
110 Ok(v) => v,
111 Err(e) => return TokenStream::from(e.into_compile_error()),
112 };
113
114 let tool_args = match ToolArgs::from_list(&attr_args) {
115 Ok(v) => v,
116 Err(e) => return TokenStream::from(e.write_errors()),
117 };
118
119 let input_fn = parse_macro_input!(item as ItemFn);
120
121 match generate_tool_impl(tool_args, input_fn) {
122 Ok(tokens) => tokens.into(),
123 Err(e) => e.into_compile_error().into(),
124 }
125}
126
127fn generate_tool_impl(args: ToolArgs, input_fn: ItemFn) -> syn::Result<TokenStream2> {
128 let fn_name = &input_fn.sig.ident;
129 let fn_vis = &input_fn.vis;
130 let is_async = input_fn.sig.asyncness.is_some();
131
132 let tool_name = args.name.unwrap_or_else(|| fn_name.to_string());
133 let description = &args.description;
134
135 let tool_fn_name = format_ident!("{}_tool", fn_name);
137 let callback_fn_name = format_ident!("{}_callback", fn_name);
138 let combined_fn_name = format_ident!("{}_tool_with_callback", fn_name);
139 let args_struct_name = format_ident!("__{}Args", fn_name);
140
141 let mut stripped_fn = input_fn.clone();
143 for arg in &mut stripped_fn.sig.inputs {
144 if let FnArg::Typed(pat_type) = arg {
145 pat_type.attrs.retain(|attr| {
147 !attr.path().is_ident("description") && !attr.path().is_ident("default")
148 });
149 }
150 }
151
152 let mut param_names = Vec::new();
154 let mut param_types = Vec::new();
155 let mut param_descriptions = Vec::new();
156 let mut param_defaults = Vec::new();
157 let mut required_params = Vec::new();
158
159 for arg in &input_fn.sig.inputs {
160 if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
161 if let Pat::Ident(pat_ident) = pat.as_ref() {
162 let param_name = &pat_ident.ident;
163 let param_args = ParamArgs::from_attrs(attrs);
164
165 param_names.push(param_name.clone());
166 param_types.push(ty.as_ref().clone());
167 param_descriptions.push(param_args.description);
168 param_defaults.push(param_args.default);
169
170 let is_optional = is_option_type(ty);
172 if !is_optional && param_defaults.last().unwrap().is_none() {
173 required_params.push(param_name.to_string());
174 }
175 }
176 }
177 }
178
179 let args_struct_fields: Vec<TokenStream2> = param_names
181 .iter()
182 .zip(param_types.iter())
183 .zip(param_defaults.iter())
184 .map(|((name, ty), default)| {
185 if default.is_some() {
186 let default_fn_name_str = format!("__default_{}", name);
187 quote! {
188 #[serde(default = #default_fn_name_str)]
189 pub #name: #ty
190 }
191 } else {
192 quote! {
193 pub #name: #ty
194 }
195 }
196 })
197 .collect();
198
199 let default_fns: Vec<TokenStream2> = param_names
201 .iter()
202 .zip(param_types.iter())
203 .zip(param_defaults.iter())
204 .filter_map(|((name, ty), default)| {
205 default.as_ref().map(|default_expr| {
206 let default_fn_name = format_ident!("__default_{}", name);
207 let value_expr = if is_option_type(ty) {
209 quote! { Some(#default_expr.into()) }
210 } else {
211 quote! { #default_expr }
212 };
213 quote! {
214 fn #default_fn_name() -> #ty {
215 #value_expr
216 }
217 }
218 })
219 })
220 .collect();
221
222 let property_schemas: Vec<TokenStream2> = param_names
224 .iter()
225 .zip(param_types.iter())
226 .zip(param_descriptions.iter())
227 .map(|((name, ty), desc)| {
228 let name_str = name.to_string();
229 let schema_type = extract_option_inner_type(ty).unwrap_or(ty);
231 let desc_insert = if let Some(d) = desc {
232 quote! {
233 if let Some(obj) = prop_schema.as_object_mut() {
234 obj.insert("description".to_string(), serde_json::json!(#d));
235 }
236 }
237 } else {
238 quote! {}
239 };
240 quote! {
241 {
242 let schema = schemars::schema_for!(#schema_type);
243 let mut prop_schema = serde_json::to_value(&schema).unwrap_or(serde_json::json!({}));
244 #desc_insert
245 properties.insert(#name_str.to_string(), prop_schema);
246 }
247 }
248 })
249 .collect();
250
251 let required_array: Vec<TokenStream2> = required_params
253 .iter()
254 .map(|name| quote! { #name.to_string() })
255 .collect();
256
257 let call_args: Vec<TokenStream2> = param_names
259 .iter()
260 .map(|name| quote! { args.#name })
261 .collect();
262
263 let output = if is_async {
265 quote! {
267 #stripped_fn
269
270 #(#default_fns)*
272
273 #[derive(serde::Deserialize)]
275 #[allow(non_camel_case_types)]
276 struct #args_struct_name {
277 #(#args_struct_fields),*
278 }
279
280 #fn_vis fn #tool_fn_name() -> hanzo::Tool {
282 let mut properties = std::collections::HashMap::<String, serde_json::Value>::new();
283
284 #(#property_schemas)*
285
286 let required: Vec<String> = vec![#(#required_array),*];
287
288 let parameters: std::collections::HashMap<String, serde_json::Value> = serde_json::from_value(
289 serde_json::json!({
290 "type": "object",
291 "properties": properties,
292 "required": required,
293 })
294 ).expect("Failed to create tool parameters");
295
296 hanzo::Tool {
297 tp: hanzo::ToolType::Function,
298 function: hanzo::Function {
299 description: Some(#description.to_string()),
300 name: #tool_name.to_string(),
301 parameters: Some(parameters),
302 strict: None,
303 },
304 }
305 }
306
307 #fn_vis fn #callback_fn_name() -> std::sync::Arc<hanzo::AsyncToolCallback> {
309 std::sync::Arc::new(|called: hanzo::CalledFunction| {
310 Box::pin(async move {
311 let args: #args_struct_name = serde_json::from_str(&called.arguments)
312 .map_err(|e| anyhow::anyhow!("Failed to parse tool arguments: {}", e))?;
313
314 let result = #fn_name(#(#call_args),*).await?;
315
316 serde_json::to_string(&result)
317 .map_err(|e| anyhow::anyhow!("Failed to serialize tool result: {}", e))
318 })
319 })
320 }
321
322 #fn_vis fn #combined_fn_name() -> (hanzo::Tool, hanzo::ToolCallbackType) {
324 (#tool_fn_name(), hanzo::ToolCallbackType::Async(#callback_fn_name()))
325 }
326 }
327 } else {
328 quote! {
330 #stripped_fn
332
333 #(#default_fns)*
335
336 #[derive(serde::Deserialize)]
338 #[allow(non_camel_case_types)]
339 struct #args_struct_name {
340 #(#args_struct_fields),*
341 }
342
343 #fn_vis fn #tool_fn_name() -> hanzo::Tool {
345 let mut properties = std::collections::HashMap::<String, serde_json::Value>::new();
346
347 #(#property_schemas)*
348
349 let required: Vec<String> = vec![#(#required_array),*];
350
351 let parameters: std::collections::HashMap<String, serde_json::Value> = serde_json::from_value(
352 serde_json::json!({
353 "type": "object",
354 "properties": properties,
355 "required": required,
356 })
357 ).expect("Failed to create tool parameters");
358
359 hanzo::Tool {
360 tp: hanzo::ToolType::Function,
361 function: hanzo::Function {
362 description: Some(#description.to_string()),
363 name: #tool_name.to_string(),
364 parameters: Some(parameters),
365 strict: None,
366 },
367 }
368 }
369
370 #fn_vis fn #callback_fn_name() -> std::sync::Arc<hanzo::ToolCallback> {
372 std::sync::Arc::new(|called: &hanzo::CalledFunction, _ctx: &hanzo::ToolCallContext| {
373 let args: #args_struct_name = serde_json::from_str(&called.arguments)
374 .map_err(|e| anyhow::anyhow!("Failed to parse tool arguments: {}", e))?;
375
376 let result = #fn_name(#(#call_args),*)?;
377
378 serde_json::to_string(&result)
379 .map_err(|e| anyhow::anyhow!("Failed to serialize tool result: {}", e))
380 })
381 }
382
383 #fn_vis fn #combined_fn_name() -> (hanzo::Tool, hanzo::ToolCallbackType) {
385 (#tool_fn_name(), hanzo::ToolCallbackType::Sync(#callback_fn_name()))
386 }
387 }
388 };
389
390 Ok(output)
391}
392
393fn is_option_type(ty: &Type) -> bool {
395 if let Type::Path(type_path) = ty {
396 if let Some(segment) = type_path.path.segments.last() {
397 return segment.ident == "Option";
398 }
399 }
400 false
401}
402
403fn extract_option_inner_type(ty: &Type) -> Option<&Type> {
405 if let Type::Path(type_path) = ty {
406 if let Some(segment) = type_path.path.segments.last() {
407 if segment.ident == "Option" {
408 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
409 if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
410 return Some(inner);
411 }
412 }
413 }
414 }
415 }
416 None
417}