Skip to main content

daimon_macros/
lib.rs

1//! Proc macros for the Daimon AI agent framework.
2//!
3//! Provides `#[tool_fn]` to derive [`Tool`](https://docs.rs/daimon/latest/daimon/tool/trait.Tool.html)
4//! implementations from plain async functions.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{
10    parse::{Parse, ParseStream},
11    parse_macro_input, Attribute, Expr, FnArg, Ident, ItemFn, Lit, Meta, Pat, PatType, Token, Type,
12};
13
14struct ToolFnArgs {
15    crate_path: Option<syn::Path>,
16    name_override: Option<String>,
17    description_override: Option<String>,
18}
19
20impl Parse for ToolFnArgs {
21    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
22        let mut crate_path = None;
23        let mut name_override = None;
24        let mut description_override = None;
25
26        while !input.is_empty() {
27            let key: Ident = input.parse()?;
28            input.parse::<Token![=]>()?;
29
30            match key.to_string().as_str() {
31                "crate_path" => {
32                    let lit: Lit = input.parse()?;
33                    if let Lit::Str(s) = lit {
34                        crate_path = Some(s.parse()?);
35                    }
36                }
37                "name" => {
38                    let lit: Lit = input.parse()?;
39                    if let Lit::Str(s) = lit {
40                        name_override = Some(s.value());
41                    }
42                }
43                "description" => {
44                    let lit: Lit = input.parse()?;
45                    if let Lit::Str(s) = lit {
46                        description_override = Some(s.value());
47                    }
48                }
49                other => {
50                    return Err(syn::Error::new(key.span(), format!("unknown attribute `{other}`")));
51                }
52            }
53
54            if !input.is_empty() {
55                input.parse::<Token![,]>()?;
56            }
57        }
58
59        Ok(Self {
60            crate_path,
61            name_override,
62            description_override,
63        })
64    }
65}
66
67struct ParamInfo {
68    name: String,
69    ty: Type,
70    doc: Option<String>,
71    optional: bool,
72    inner_ty: Option<Type>,
73}
74
75/// Derives a [`Tool`] implementation from an async function.
76///
77/// The function's parameters become the tool's JSON Schema properties.
78/// Doc comments on the function become the tool description; doc comments
79/// on individual parameters become property descriptions.
80///
81/// # Supported types
82///
83/// `String`, `i8`–`i128`, `u8`–`u128`, `isize`, `usize`, `f32`, `f64`,
84/// `bool`, `Option<T>` (marks the parameter as not required).
85///
86/// # Attributes
87///
88/// - `name = "..."` — override the tool name (defaults to the function name)
89/// - `description = "..."` — override the description (defaults to doc comments)
90/// - `crate_path = "..."` — override the path to the daimon crate (defaults to `::daimon`)
91///
92/// # Example
93///
94/// ```ignore
95/// use daimon::prelude::*;
96/// use daimon::tool_fn;
97///
98/// /// Adds two numbers together.
99/// #[tool_fn]
100/// async fn add(
101///     /// The first number.
102///     a: f64,
103///     /// The second number.
104///     b: f64,
105/// ) -> daimon::Result<ToolOutput> {
106///     Ok(ToolOutput::text(format!("{}", a + b)))
107/// }
108///
109/// // Use the generated struct:
110/// let agent = Agent::builder()
111///     .model(model)
112///     .tool(Add) // PascalCase struct generated from `add`
113///     .build()?;
114/// ```
115#[proc_macro_attribute]
116pub fn tool_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
117    let args = parse_macro_input!(attr as ToolFnArgs);
118    let func = parse_macro_input!(item as ItemFn);
119
120    match expand_tool_fn(args, func) {
121        Ok(tokens) => tokens.into(),
122        Err(e) => e.to_compile_error().into(),
123    }
124}
125
126fn expand_tool_fn(args: ToolFnArgs, func: ItemFn) -> syn::Result<TokenStream2> {
127    let crate_path = args
128        .crate_path
129        .map(|p| quote!(#p))
130        .unwrap_or_else(|| quote!(::daimon));
131
132    let fn_name = &func.sig.ident;
133    let struct_name = format_ident!("{}", to_pascal_case(&fn_name.to_string()));
134    let tool_name_str = args.name_override.unwrap_or_else(|| fn_name.to_string());
135
136    let description = args
137        .description_override
138        .unwrap_or_else(|| extract_doc_comments(&func.attrs));
139
140    if func.sig.asyncness.is_none() {
141        return Err(syn::Error::new_spanned(
142            func.sig.fn_token,
143            "tool_fn requires an async function",
144        ));
145    }
146
147    let params = extract_params(&func)?;
148    let schema_tokens = generate_schema(&params, &crate_path);
149    let extraction_tokens = generate_extraction(&params, &crate_path);
150    let body = &func.block;
151
152    Ok(quote! {
153        /// Auto-generated tool struct from `#[tool_fn]` on [`#fn_name`].
154        pub struct #struct_name;
155
156        impl #crate_path::tool::Tool for #struct_name {
157            fn name(&self) -> &str {
158                #tool_name_str
159            }
160
161            fn description(&self) -> &str {
162                #description
163            }
164
165            fn parameters_schema(&self) -> ::serde_json::Value {
166                #schema_tokens
167            }
168
169            async fn execute(
170                &self,
171                __daimon_input: &::serde_json::Value,
172            ) -> #crate_path::Result<#crate_path::tool::ToolOutput> {
173                #extraction_tokens
174                #body
175            }
176        }
177    })
178}
179
180fn extract_doc_comments(attrs: &[Attribute]) -> String {
181    let mut lines = Vec::new();
182    for attr in attrs {
183        if attr.path().is_ident("doc") {
184            if let Meta::NameValue(nv) = &attr.meta {
185                if let Expr::Lit(lit) = &nv.value {
186                    if let Lit::Str(s) = &lit.lit {
187                        lines.push(s.value().trim().to_string());
188                    }
189                }
190            }
191        }
192    }
193    lines.join(" ").trim().to_string()
194}
195
196fn extract_params(func: &ItemFn) -> syn::Result<Vec<ParamInfo>> {
197    let mut params = Vec::new();
198
199    for arg in &func.sig.inputs {
200        if let FnArg::Typed(PatType { pat, ty, attrs, .. }) = arg {
201            let name = match pat.as_ref() {
202                Pat::Ident(ident) => ident.ident.to_string(),
203                _ => {
204                    return Err(syn::Error::new_spanned(pat, "expected a simple identifier"));
205                }
206            };
207
208            let doc = extract_doc_comments(attrs);
209            let doc = if doc.is_empty() { None } else { Some(doc) };
210
211            let (optional, inner_ty) = unwrap_option(ty);
212
213            params.push(ParamInfo {
214                name,
215                ty: *ty.clone(),
216                doc,
217                optional,
218                inner_ty,
219            });
220        }
221    }
222
223    Ok(params)
224}
225
226fn unwrap_option(ty: &Type) -> (bool, Option<Type>) {
227    if let Type::Path(tp) = ty {
228        if let Some(seg) = tp.path.segments.last() {
229            if seg.ident == "Option" {
230                if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
231                    if let Some(syn::GenericArgument::Type(inner)) = ab.args.first() {
232                        return (true, Some(inner.clone()));
233                    }
234                }
235            }
236        }
237    }
238    (false, None)
239}
240
241fn type_to_json_schema(ty: &Type) -> TokenStream2 {
242    if let Type::Path(tp) = ty {
243        if let Some(seg) = tp.path.segments.last() {
244            let name = seg.ident.to_string();
245            match name.as_str() {
246                "String" | "str" => return quote!(::serde_json::json!({"type": "string"})),
247                "bool" => return quote!(::serde_json::json!({"type": "boolean"})),
248                "f32" | "f64" => return quote!(::serde_json::json!({"type": "number"})),
249                "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32"
250                | "u64" | "u128" | "usize" => {
251                    return quote!(::serde_json::json!({"type": "integer"}));
252                }
253                "Vec" => {
254                    if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
255                        if let Some(syn::GenericArgument::Type(inner)) = ab.args.first() {
256                            let inner_schema = type_to_json_schema(inner);
257                            return quote!(::serde_json::json!({"type": "array", "items": #inner_schema}));
258                        }
259                    }
260                    return quote!(::serde_json::json!({"type": "array"}));
261                }
262                "Value" => return quote!(::serde_json::json!({})),
263                _ => {}
264            }
265        }
266    }
267    quote!(::serde_json::json!({}))
268}
269
270fn generate_schema(params: &[ParamInfo], _crate_path: &TokenStream2) -> TokenStream2 {
271    let mut prop_entries = Vec::new();
272    let mut required_names = Vec::new();
273
274    for param in params {
275        let name = &param.name;
276        let effective_ty = param.inner_ty.as_ref().unwrap_or(&param.ty);
277        let schema = type_to_json_schema(effective_ty);
278
279        if let Some(doc) = &param.doc {
280            prop_entries.push(quote! {
281                let mut __prop = #schema;
282                if let Some(obj) = __prop.as_object_mut() {
283                    obj.insert("description".to_string(), ::serde_json::Value::String(#doc.to_string()));
284                }
285                __props.insert(#name.to_string(), __prop);
286            });
287        } else {
288            prop_entries.push(quote! {
289                __props.insert(#name.to_string(), #schema);
290            });
291        }
292
293        if !param.optional {
294            required_names.push(quote!(#name));
295        }
296    }
297
298    quote! {
299        {
300            let mut __props = ::serde_json::Map::new();
301            #(#prop_entries)*
302            let mut __schema = ::serde_json::Map::new();
303            __schema.insert("type".to_string(), ::serde_json::Value::String("object".to_string()));
304            __schema.insert("properties".to_string(), ::serde_json::Value::Object(__props));
305            let __required: Vec<&str> = vec![#(#required_names),*];
306            if !__required.is_empty() {
307                __schema.insert(
308                    "required".to_string(),
309                    ::serde_json::Value::Array(
310                        __required.into_iter().map(|s| ::serde_json::Value::String(s.to_string())).collect()
311                    ),
312                );
313            }
314            ::serde_json::Value::Object(__schema)
315        }
316    }
317}
318
319fn generate_extraction(params: &[ParamInfo], crate_path: &TokenStream2) -> TokenStream2 {
320    let mut extractions = Vec::new();
321
322    for param in params {
323        let name_str = &param.name;
324        let name_ident = format_ident!("{}", &param.name);
325        let ty = &param.ty;
326
327        if param.optional {
328            let inner = param.inner_ty.as_ref().unwrap_or(&param.ty);
329            extractions.push(quote! {
330                let #name_ident: #ty = match __daimon_input.get(#name_str) {
331                    Some(v) if !v.is_null() => {
332                        Some(::serde_json::from_value::<#inner>(v.clone()).map_err(|__e| {
333                            #crate_path::DaimonError::Other(
334                                format!("parameter '{}': {}", #name_str, __e)
335                            )
336                        })?)
337                    }
338                    _ => None,
339                };
340            });
341        } else {
342            extractions.push(quote! {
343                let #name_ident: #ty = ::serde_json::from_value(
344                    __daimon_input
345                        .get(#name_str)
346                        .cloned()
347                        .unwrap_or(::serde_json::Value::Null),
348                )
349                .map_err(|__e| {
350                    #crate_path::DaimonError::Other(
351                        format!("parameter '{}': {}", #name_str, __e)
352                    )
353                })?;
354            });
355        }
356    }
357
358    quote! { #(#extractions)* }
359}
360
361fn to_pascal_case(s: &str) -> String {
362    s.split('_')
363        .filter(|part| !part.is_empty())
364        .map(|part| {
365            let mut chars = part.chars();
366            match chars.next() {
367                Some(c) => c.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
368                None => String::new(),
369            }
370        })
371        .collect()
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_to_pascal_case() {
380        assert_eq!(to_pascal_case("add"), "Add");
381        assert_eq!(to_pascal_case("fetch_weather"), "FetchWeather");
382        assert_eq!(to_pascal_case("get_user_by_id"), "GetUserById");
383        assert_eq!(to_pascal_case("a"), "A");
384    }
385}