agent_chain_macros/
lib.rs1use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{FnArg, ItemFn, Pat, ReturnType, Type, parse_macro_input};
6
7#[proc_macro_attribute]
26pub fn tool(_attr: TokenStream, item: TokenStream) -> TokenStream {
27 let input = parse_macro_input!(item as ItemFn);
28
29 let fn_name = &input.sig.ident;
30 let fn_name_str = fn_name.to_string();
31 let mod_name = format_ident!("{}", fn_name);
32 let struct_name = format_ident!("{}Tool", to_pascal_case(&fn_name_str));
33
34 let fn_body = &input.block;
35 let fn_vis = &input.vis;
36 let _fn_asyncness = &input.sig.asyncness;
37 let fn_return_type = &input.sig.output;
38
39 let params: Vec<_> = input
41 .sig
42 .inputs
43 .iter()
44 .filter_map(|arg| {
45 if let FnArg::Typed(pat_type) = arg
46 && let Pat::Ident(pat_ident) = pat_type.pat.as_ref()
47 {
48 let param_name = &pat_ident.ident;
49 let param_type = &pat_type.ty;
50 return Some((param_name.clone(), param_type.clone()));
51 }
52 None
53 })
54 .collect();
55
56 let param_names: Vec<_> = params.iter().map(|(name, _)| name.clone()).collect();
57 let param_types: Vec<_> = params.iter().map(|(_, ty)| ty.clone()).collect();
58 let param_names_str: Vec<_> = params.iter().map(|(name, _)| name.to_string()).collect();
59
60 let schema_properties: Vec<_> = params
62 .iter()
63 .map(|(name, ty)| {
64 let name_str = name.to_string();
65 let type_str = get_json_type(ty);
66 quote! {
67 (#name_str.to_string(), serde_json::json!({ "type": #type_str }))
68 }
69 })
70 .collect();
71
72 let actual_return_type = match fn_return_type {
74 ReturnType::Default => quote! { () },
75 ReturnType::Type(_, ty) => quote! { #ty },
76 };
77
78 let expanded = quote! {
79 #fn_vis mod #mod_name {
80 use super::*;
81 use std::collections::HashMap;
82 use serde_json;
83
84 pub struct #struct_name;
86
87 impl #struct_name {
88 pub fn new() -> Self {
90 Self
91 }
92 }
93
94 impl Default for #struct_name {
95 fn default() -> Self {
96 Self::new()
97 }
98 }
99
100 #[agent_chain::async_trait]
101 impl agent_chain::tools::Tool for #struct_name {
102 fn name(&self) -> &str {
103 #fn_name_str
104 }
105
106 fn description(&self) -> &str {
107 concat!("Tool: ", #fn_name_str)
108 }
109
110 fn parameters_schema(&self) -> serde_json::Value {
111 let properties: HashMap<String, serde_json::Value> = [
112 #(#schema_properties),*
113 ].into_iter().collect();
114
115 let required: Vec<String> = vec![
116 #(#param_names_str.to_string()),*
117 ];
118
119 serde_json::json!({
120 "type": "object",
121 "properties": properties,
122 "required": required
123 })
124 }
125
126 async fn invoke(&self, tool_call: agent_chain::messages::ToolCall) -> agent_chain::messages::BaseMessage {
127 let args = tool_call.args();
128
129 #(
130 let #param_names: #param_types = serde_json::from_value(
131 args.get(#param_names_str).cloned().unwrap_or(serde_json::Value::Null)
132 ).expect(&format!("Failed to parse parameter '{}'", #param_names_str));
133 )*
134
135 let result: #actual_return_type = { #fn_body };
136
137 let result_str = serde_json::to_string(&result).unwrap_or_else(|_| format!("{:?}", result));
138
139 agent_chain::messages::ToolMessage::new(result_str, tool_call.id()).into()
140 }
141 }
142
143 pub fn tool() -> #struct_name {
145 #struct_name::new()
146 }
147 }
148 };
149
150 TokenStream::from(expanded)
151}
152
153fn to_pascal_case(s: &str) -> String {
155 s.split('_')
156 .map(|word| {
157 let mut chars = word.chars();
158 match chars.next() {
159 None => String::new(),
160 Some(first) => first.to_uppercase().chain(chars).collect(),
161 }
162 })
163 .collect()
164}
165
166fn get_json_type(ty: &Type) -> &'static str {
168 let type_str = quote!(#ty).to_string();
169 match type_str.as_str() {
170 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128"
171 | "usize" => "integer",
172 "f32" | "f64" => "number",
173 "bool" => "boolean",
174 "String" | "& str" | "& 'static str" => "string",
175 _ => "string", }
177}