1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{
10 parse::{Parse, ParseStream},
11 parse_macro_input, Attribute, Expr, FnArg, Ident, ItemFn, Lit, Meta, Pat, PatType, Token, Type,
12};
13
14struct ToolFnArgs {
15 crate_path: Option<syn::Path>,
16 name_override: Option<String>,
17 description_override: Option<String>,
18}
19
20impl Parse for ToolFnArgs {
21 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
22 let mut crate_path = None;
23 let mut name_override = None;
24 let mut description_override = None;
25
26 while !input.is_empty() {
27 let key: Ident = input.parse()?;
28 input.parse::<Token![=]>()?;
29
30 match key.to_string().as_str() {
31 "crate_path" => {
32 let lit: Lit = input.parse()?;
33 if let Lit::Str(s) = lit {
34 crate_path = Some(s.parse()?);
35 }
36 }
37 "name" => {
38 let lit: Lit = input.parse()?;
39 if let Lit::Str(s) = lit {
40 name_override = Some(s.value());
41 }
42 }
43 "description" => {
44 let lit: Lit = input.parse()?;
45 if let Lit::Str(s) = lit {
46 description_override = Some(s.value());
47 }
48 }
49 other => {
50 return Err(syn::Error::new(key.span(), format!("unknown attribute `{other}`")));
51 }
52 }
53
54 if !input.is_empty() {
55 input.parse::<Token![,]>()?;
56 }
57 }
58
59 Ok(Self {
60 crate_path,
61 name_override,
62 description_override,
63 })
64 }
65}
66
67struct ParamInfo {
68 name: String,
69 ty: Type,
70 doc: Option<String>,
71 optional: bool,
72 inner_ty: Option<Type>,
73}
74
75#[proc_macro_attribute]
116pub fn tool_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
117 let args = parse_macro_input!(attr as ToolFnArgs);
118 let func = parse_macro_input!(item as ItemFn);
119
120 match expand_tool_fn(args, func) {
121 Ok(tokens) => tokens.into(),
122 Err(e) => e.to_compile_error().into(),
123 }
124}
125
126fn expand_tool_fn(args: ToolFnArgs, func: ItemFn) -> syn::Result<TokenStream2> {
127 let crate_path = args
128 .crate_path
129 .map(|p| quote!(#p))
130 .unwrap_or_else(|| quote!(::daimon));
131
132 let fn_name = &func.sig.ident;
133 let struct_name = format_ident!("{}", to_pascal_case(&fn_name.to_string()));
134 let tool_name_str = args.name_override.unwrap_or_else(|| fn_name.to_string());
135
136 let description = args
137 .description_override
138 .unwrap_or_else(|| extract_doc_comments(&func.attrs));
139
140 if func.sig.asyncness.is_none() {
141 return Err(syn::Error::new_spanned(
142 func.sig.fn_token,
143 "tool_fn requires an async function",
144 ));
145 }
146
147 let params = extract_params(&func)?;
148 let schema_tokens = generate_schema(¶ms, &crate_path);
149 let extraction_tokens = generate_extraction(¶ms, &crate_path);
150 let body = &func.block;
151
152 Ok(quote! {
153 pub struct #struct_name;
155
156 impl #crate_path::tool::Tool for #struct_name {
157 fn name(&self) -> &str {
158 #tool_name_str
159 }
160
161 fn description(&self) -> &str {
162 #description
163 }
164
165 fn parameters_schema(&self) -> ::serde_json::Value {
166 #schema_tokens
167 }
168
169 async fn execute(
170 &self,
171 __daimon_input: &::serde_json::Value,
172 ) -> #crate_path::Result<#crate_path::tool::ToolOutput> {
173 #extraction_tokens
174 #body
175 }
176 }
177 })
178}
179
180fn extract_doc_comments(attrs: &[Attribute]) -> String {
181 let mut lines = Vec::new();
182 for attr in attrs {
183 if attr.path().is_ident("doc") {
184 if let Meta::NameValue(nv) = &attr.meta {
185 if let Expr::Lit(lit) = &nv.value {
186 if let Lit::Str(s) = &lit.lit {
187 lines.push(s.value().trim().to_string());
188 }
189 }
190 }
191 }
192 }
193 lines.join(" ").trim().to_string()
194}
195
196fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
197 let mut params = Vec::new();
198
199 for arg in &func.sig.inputs {
200 if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
201 let name = match pat.as_ref() {
202 Pat::Ident(ident) => ident.ident.to_string(),
203 _ => {
204 return Err(syn::Error::new_spanned(pat, "expected a simple identifier"));
205 }
206 };
207
208 let doc = extract_doc_comments(attrs);
209 let doc = if doc.is_empty() { None } else { Some(doc) };
210
211 let (optional, inner_ty) = unwrap_option(ty);
212
213 params.push(ParamInfo {
214 name,
215 ty: *ty.clone(),
216 doc,
217 optional,
218 inner_ty,
219 });
220 }
221 }
222
223 Ok(params)
224}
225
226fn unwrap_option(ty: &Type) -> (bool, Option<Type>) {
227 if let Type::Path(tp) = ty {
228 if let Some(seg) = tp.path.segments.last() {
229 if seg.ident == "Option" {
230 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
231 if let Some(syn::GenericArgument::Type(inner)) = ab.args.first() {
232 return (true, Some(inner.clone()));
233 }
234 }
235 }
236 }
237 }
238 (false, None)
239}
240
241fn type_to_json_schema(ty: &Type) -> TokenStream2 {
242 if let Type::Path(tp) = ty {
243 if let Some(seg) = tp.path.segments.last() {
244 let name = seg.ident.to_string();
245 match name.as_str() {
246 "String" | "str" => return quote!(::serde_json::json!({"type": "string"})),
247 "bool" => return quote!(::serde_json::json!({"type": "boolean"})),
248 "f32" | "f64" => return quote!(::serde_json::json!({"type": "number"})),
249 "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
250 | "u64" | "u128" | "usize" => {
251 return quote!(::serde_json::json!({"type": "integer"}));
252 }
253 "Vec" => {
254 if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
255 if let Some(syn::GenericArgument::Type(inner)) = ab.args.first() {
256 let inner_schema = type_to_json_schema(inner);
257 return quote!(::serde_json::json!({"type": "array", "items": #inner_schema}));
258 }
259 }
260 return quote!(::serde_json::json!({"type": "array"}));
261 }
262 "Value" => return quote!(::serde_json::json!({})),
263 _ => {}
264 }
265 }
266 }
267 quote!(::serde_json::json!({}))
268}
269
270fn generate_schema(params: &[ParamInfo], _crate_path: &TokenStream2) -> TokenStream2 {
271 let mut prop_entries = Vec::new();
272 let mut required_names = Vec::new();
273
274 for param in params {
275 let name = ¶m.name;
276 let effective_ty = param.inner_ty.as_ref().unwrap_or(¶m.ty);
277 let schema = type_to_json_schema(effective_ty);
278
279 if let Some(doc) = ¶m.doc {
280 prop_entries.push(quote! {
281 let mut __prop = #schema;
282 if let Some(obj) = __prop.as_object_mut() {
283 obj.insert("description".to_string(), ::serde_json::Value::String(#doc.to_string()));
284 }
285 __props.insert(#name.to_string(), __prop);
286 });
287 } else {
288 prop_entries.push(quote! {
289 __props.insert(#name.to_string(), #schema);
290 });
291 }
292
293 if !param.optional {
294 required_names.push(quote!(#name));
295 }
296 }
297
298 quote! {
299 {
300 let mut __props = ::serde_json::Map::new();
301 #(#prop_entries)*
302 let mut __schema = ::serde_json::Map::new();
303 __schema.insert("type".to_string(), ::serde_json::Value::String("object".to_string()));
304 __schema.insert("properties".to_string(), ::serde_json::Value::Object(__props));
305 let __required: Vec<&str> = vec![#(#required_names),*];
306 if !__required.is_empty() {
307 __schema.insert(
308 "required".to_string(),
309 ::serde_json::Value::Array(
310 __required.into_iter().map(|s| ::serde_json::Value::String(s.to_string())).collect()
311 ),
312 );
313 }
314 ::serde_json::Value::Object(__schema)
315 }
316 }
317}
318
319fn generate_extraction(params: &[ParamInfo], crate_path: &TokenStream2) -> TokenStream2 {
320 let mut extractions = Vec::new();
321
322 for param in params {
323 let name_str = ¶m.name;
324 let name_ident = format_ident!("{}", ¶m.name);
325 let ty = ¶m.ty;
326
327 if param.optional {
328 let inner = param.inner_ty.as_ref().unwrap_or(¶m.ty);
329 extractions.push(quote! {
330 let #name_ident: #ty = match __daimon_input.get(#name_str) {
331 Some(v) if !v.is_null() => {
332 Some(::serde_json::from_value::<#inner>(v.clone()).map_err(|__e| {
333 #crate_path::DaimonError::Other(
334 format!("parameter '{}': {}", #name_str, __e)
335 )
336 })?)
337 }
338 _ => None,
339 };
340 });
341 } else {
342 extractions.push(quote! {
343 let #name_ident: #ty = ::serde_json::from_value(
344 __daimon_input
345 .get(#name_str)
346 .cloned()
347 .unwrap_or(::serde_json::Value::Null),
348 )
349 .map_err(|__e| {
350 #crate_path::DaimonError::Other(
351 format!("parameter '{}': {}", #name_str, __e)
352 )
353 })?;
354 });
355 }
356 }
357
358 quote! { #(#extractions)* }
359}
360
361fn to_pascal_case(s: &str) -> String {
362 s.split('_')
363 .filter(|part| !part.is_empty())
364 .map(|part| {
365 let mut chars = part.chars();
366 match chars.next() {
367 Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
368 None => String::new(),
369 }
370 })
371 .collect()
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_to_pascal_case() {
380 assert_eq!(to_pascal_case("add"), "Add");
381 assert_eq!(to_pascal_case("fetch_weather"), "FetchWeather");
382 assert_eq!(to_pascal_case("get_user_by_id"), "GetUserById");
383 assert_eq!(to_pascal_case("a"), "A");
384 }
385}