1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::Lit;
4use syn::{parse_macro_input, Expr, ExprLit, FnArg, ItemFn, Pat, PatType, Type};
5
6#[proc_macro_attribute]
16pub fn tool_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
17 let input_fn = parse_macro_input!(item as ItemFn);
19 let fn_name = &input_fn.sig.ident;
20 let fn_name_str = fn_name.to_string();
21
22 let attrs = parse_macro_input!(attr with syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated);
24
25 let mut tool_name = fn_name_str.clone();
27 let mut tool_description = format!("Tool function {}", fn_name_str);
28
29 for attr in attrs.iter() {
30 if let syn::Meta::NameValue(name_value) = attr {
31 if name_value.path.is_ident("name") {
32 if let Expr::Lit(ExprLit {
33 lit: Lit::Str(lit_str),
34 ..
35 }) = &name_value.value
36 {
37 tool_name = lit_str.value();
38 }
39 } else if name_value.path.is_ident("description") {
40 if let Expr::Lit(ExprLit {
41 lit: Lit::Str(lit_str),
42 ..
43 }) = &name_value.value
44 {
45 tool_description = lit_str.value();
46 }
47 }
48 }
49 }
50
51 let params = extract_params(&input_fn);
53
54 let tool_fn_name = format_ident!("{}_tool", fn_name);
56
57 let param_extractions = params.iter().map(|(name, type_name)| {
59 let param_name = format_ident!("{}", name);
60 match type_name.as_str() {
61 "i32" => quote! {
62 let #param_name = params[#name].as_i64()
63 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
64 as i32;
65 },
66 "i64" => quote! {
67 let #param_name = params[#name].as_i64()
68 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
69 },
70 "u32" | "u64" => quote! {
71 let #param_name = params[#name].as_u64()
72 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
73 as u32;
74 },
75 "f32" | "f64" => quote! {
76 let #param_name = params[#name].as_f64()
77 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
78 as f64;
79 },
80 "String" => quote! {
81 let #param_name = params[#name].as_str()
82 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?
83 .to_string();
84 },
85 "&str" => quote! {
86 let #param_name = params[#name].as_str()
87 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
88 },
89 "bool" => quote! {
90 let #param_name = params[#name].as_bool()
91 .ok_or_else(|| AgentError::InvalidInput(format!("Missing or invalid parameter: {}", #name)))?;
92 },
93 _ => quote! {
94 let #param_name = serde_json::from_value::<#param_name>(params[#name].clone())
95 .map_err(|e| AgentError::InvalidInput(format!("Invalid parameter {}: {}", #name, e)))?;
96 },
97 }
98 });
99
100 let param_names = params.iter().map(|(name, _)| format_ident!("{}", name));
102
103 let schema_properties = params.iter().map(|(name, type_name)| {
105 let type_str = match type_name.as_str() {
106 "i32" | "i64" | "u32" | "u64" | "f32" | "f64" => "number",
107 "String" | "&str" => "string",
108 "bool" => "boolean",
109 _ => "object",
110 };
111
112 quote! {
113 let mut property = serde_json::Map::new();
114 property.insert("type".to_string(), serde_json::Value::String(#type_str.to_string()));
115 properties.insert(#name.to_string(), serde_json::Value::Object(property));
116 required.push(serde_json::Value::String(#name.to_string()));
117 }
118 });
119
120 let expanded = quote! {
122 #input_fn
124
125 pub fn #tool_fn_name() -> ::adk::tool::FunctionTool {
127 use adk::error::AgentError;
128 use adk::tool::ToolResult;
129
130 ::adk::tool::FunctionTool::new(
131 #tool_name,
132 #tool_description,
133 generate_parameter_schema(),
135 Box::new(|context, params_str| {
136 let params: serde_json::Value = serde_json::from_str(params_str)
138 .map_err(|e| AgentError::InvalidInput(e.to_string()))?;
139
140 #(#param_extractions)*
142
143 let result = #fn_name(context, #(#param_names),*);
145
146 Ok(ToolResult {
147 tool_name: #tool_name.to_string(),
148 output: result,
149 })
150 })
151 )
152 }
153
154 fn generate_parameter_schema() -> serde_json::Value {
156 let mut properties = serde_json::Map::new();
157 let mut required = Vec::new();
158
159 #(#schema_properties)*
160
161 let mut schema = serde_json::Map::new();
162 schema.insert("type".to_string(), serde_json::Value::String("object".to_string()));
163 schema.insert("properties".to_string(), serde_json::Value::Object(properties));
164 schema.insert("required".to_string(), serde_json::Value::Array(required));
165
166 serde_json::Value::Object(schema)
167 }
168 };
169
170 expanded.into()
171}
172
173fn extract_params(input_fn: &ItemFn) -> Vec<(String, String)> {
175 let mut params = Vec::new();
176
177 for arg in &input_fn.sig.inputs {
178 if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
179 if let Pat::Ident(pat_ident) = &**pat {
180 let param_name = pat_ident.ident.to_string();
181 let param_type = get_type_name(ty);
182
183 if param_name != "context" && !param_type.contains("RunContext") {
185 params.push((param_name, param_type));
186 }
187 }
188 }
189 }
190
191 params
192}
193
194fn get_type_name(ty: &Box<Type>) -> String {
196 match ty.as_ref() {
197 Type::Path(type_path) => {
198 if let Some(segment) = type_path.path.segments.last() {
199 segment.ident.to_string()
200 } else {
201 "unknown".to_string()
202 }
203 }
204 Type::Reference(type_ref) => {
205 if let Type::Path(type_path) = type_ref.elem.as_ref() {
206 if let Some(segment) = type_path.path.segments.last() {
207 segment.ident.to_string()
208 } else {
209 "unknown".to_string()
210 }
211 } else {
212 "unknown".to_string()
213 }
214 }
215 _ => "unknown".to_string(),
216 }
217}