Skip to main content

reductool_proc_macro/
lib.rs

1use inflector::Inflector;
2use proc_macro::TokenStream;
3use quote::quote;
4use syn::{Attribute, FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
5
6fn path_last_ident_is(path: &syn::Path, name: &str) -> bool {
7    path.segments
8        .last()
9        .map(|seg| seg.ident == name)
10        .unwrap_or(false)
11}
12
13fn path_ends_with(path: &syn::Path, segments: &[&str]) -> bool {
14    if segments.is_empty() {
15        return false;
16    }
17    let pathlen = path.segments.len();
18    if pathlen < segments.len() {
19        return false;
20    }
21    path.segments
22        .iter()
23        .skip(pathlen - segments.len())
24        .zip(segments.iter())
25        .all(|(a, b)| a.ident == *b)
26}
27
28fn first_generic_arg<'a>(tp: &'a syn::TypePath) -> Option<&'a Type> {
29    tp.path.segments.last().and_then(|seg| {
30        if let syn::PathArguments::AngleBracketed(ab) = &seg.arguments {
31            ab.args.iter().find_map(|ga| {
32                if let syn::GenericArgument::Type(t) = ga {
33                    Some(t)
34                } else {
35                    None
36                }
37            })
38        } else {
39            None
40        }
41    })
42}
43
44fn primitive_json_type_name(path: &syn::Path) -> Option<&'static str> {
45    let ident = path.segments.last()?.ident.to_string();
46    Some(match ident.as_str() {
47        "i8" | "i16" | "i32" | "i64" | "i128" | "u8" | "u16" | "u32" | "u64" | "u128" | "isize"
48        | "usize" => "integer",
49        "f32" | "f64" => "number",
50        "bool" => "boolean",
51        "String" | "str" => "string",
52        _ => return None,
53    })
54}
55
56fn ty_to_schema(ty: &Type) -> serde_json::Value {
57    match ty {
58        // &T => T
59        Type::Reference(r) => ty_to_schema(&r.elem),
60
61        // [T; N] => array of T
62        Type::Array(arr) => serde_json::json!({
63            "type": "array",
64            "items": ty_to_schema(&arr.elem),
65        }),
66
67        // (T1, T2, ...) -> fixed-length array with item schemas
68        Type::Tuple(t) => {
69            let items: Vec<serde_json::Value> = t.elems.iter().map(ty_to_schema).collect();
70            serde_json::json!({
71                "type": "array",
72                "items": items,
73                "minItems": items.len(),
74                "maxItems": items.len(),
75            })
76        }
77
78        // T paths: primitives, String, Vec<T>, Option<T>, serde_json::Value, etc.
79        Type::Path(tp) => {
80            if path_last_ident_is(&tp.path, "Vec") {
81                if let Some(inner) = first_generic_arg(tp) {
82                    return serde_json::json!({
83                        "type": "array",
84                        "items": ty_to_schema(inner),
85                    });
86                }
87                return serde_json::json!({
88                    "type": "array",
89                    "items": { "type": "string" },
90                });
91            }
92
93            if path_last_ident_is(&tp.path, "Option") {
94                if let Some(inner) = first_generic_arg(tp) {
95                    return ty_to_schema(inner);
96                }
97                return serde_json::Value::Object(serde_json::Map::new());
98            }
99
100            if let Some(json_ty) = primitive_json_type_name(&tp.path) {
101                return serde_json::json!({"type": json_ty});
102            }
103
104            if path_ends_with(&tp.path, &["serde_json", "Value"])
105                || path_last_ident_is(&tp.path, "Value")
106            {
107                return serde_json::Value::Object(serde_json::Map::new());
108            }
109
110            serde_json::json!({"type": "string"})
111        }
112
113        _ => serde_json::json!({"type": "string"}),
114    }
115}
116
117fn collect_doc(attrs: &[Attribute]) -> String {
118    attrs
119        .iter()
120        .filter(|attr| attr.path().is_ident("doc"))
121        .filter_map(|attr| {
122            let Ok(nv) = attr.meta.require_name_value() else {
123                return None;
124            };
125            let syn::Expr::Lit(syn::ExprLit {
126                lit: syn::Lit::Str(s),
127                ..
128            }) = &nv.value
129            else {
130                return None;
131            };
132            Some(s.value())
133        })
134        .collect::<Vec<_>>()
135        .join("\n")
136}
137
138#[proc_macro_attribute]
139pub fn aitool(_attr: TokenStream, code: TokenStream) -> TokenStream {
140    let func: ItemFn = parse_macro_input!(code);
141    let funcsig = &func.sig;
142    let func_name = funcsig.ident.to_string();
143    let doc = collect_doc(&func.attrs);
144
145    let is_async = funcsig.asyncness.is_some();
146
147    let mut fields = Vec::new();
148    let mut field_names = Vec::new();
149    let mut field_idents: Vec<syn::Ident> = Vec::new();
150    let mut args = serde_json::Map::new();
151    let mut required_args = Vec::new();
152    let mut errors: Vec<syn::Error> = Vec::new();
153
154    for input in &funcsig.inputs {
155        match input {
156            FnArg::Typed(pat_ty) => {
157                //NOTE(@hadydotai): We'll only work with simple identifier patterns for now
158                let param_ident = match &*pat_ty.pat {
159                    Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
160                    _ => {
161                        errors.push(syn::Error::new_spanned(
162                            &pat_ty.pat,
163                            "unsupported parameter pattern. expected a simple identifier like `arg: T`.\n\
164                            Examples of unsupported patterns: `(_: T)`, `(a, b): (T, U)`, `S { x, y }: S`."
165                        ));
166                        continue;
167                    }
168                };
169
170                let pat_ty_ty = &pat_ty.ty;
171                let param_name = param_ident.to_string();
172                fields.push(quote!(pub #param_ident: #pat_ty_ty));
173                field_names.push(param_name.clone());
174                field_idents.push(param_ident.clone());
175                let schema = ty_to_schema(pat_ty_ty);
176                args.insert(param_name.clone(), schema);
177                let mut is_optional = false;
178                if let Type::Path(tp) = &*pat_ty.ty {
179                    if path_last_ident_is(&tp.path, "Option") {
180                        is_optional = true;
181                    }
182                }
183                if !is_optional {
184                    required_args.push(param_name);
185                }
186            }
187            FnArg::Receiver(recv) => {
188                errors.push(syn::Error::new_spanned(
189                    recv,
190                    "#[aitool] must be placed on a free-standing function (no `self`).\
191                    Move the function out of the `impl` block or remove the receiver.",
192                ));
193            }
194        }
195    }
196
197    if !errors.is_empty() {
198        let compile_errors = errors.into_iter().map(|err| err.to_compile_error());
199        return quote! { #(#compile_errors)* }.into();
200    }
201
202    let args_struct_ident = syn::Ident::new(
203        &format!("{}Args", func_name.to_table_case().to_pascal_case()),
204        funcsig.ident.span(),
205    );
206
207    let fields_tokens = quote!(#(#fields),*);
208    let required_array = serde_json::Value::Array(
209        required_args
210            .iter()
211            .map(|arg| serde_json::Value::String(arg.clone()))
212            .collect(),
213    );
214
215    let mut schema = serde_json::Map::new();
216    schema.insert(
217        "name".to_string(),
218        serde_json::Value::String(func_name.clone()),
219    );
220    schema.insert(
221        "description".to_string(),
222        serde_json::Value::String(doc.clone()),
223    );
224
225    let mut parameters = serde_json::Map::new();
226    parameters.insert(
227        "type".to_string(),
228        serde_json::Value::String("object".to_string()),
229    );
230    parameters.insert("properties".to_string(), serde_json::Value::Object(args));
231    parameters.insert("required".to_string(), required_array);
232
233    schema.insert(
234        "parameters".to_string(),
235        serde_json::Value::Object(parameters),
236    );
237    let json_schema = serde_json::to_string(&schema).unwrap();
238    let name_lit = syn::LitStr::new(&func_name, funcsig.ident.span());
239    let desc_lit = syn::LitStr::new(&doc, funcsig.ident.span());
240    let json_schema_lit = syn::LitStr::new(&json_schema, funcsig.ident.span());
241
242    let func_wrapper_name = syn::Ident::new(
243        &format!("__invoke_{}", func_name.clone()),
244        funcsig.ident.span(),
245    );
246    let reg_name = syn::Ident::new(
247        &format!("__REG_{}", func_name.clone().to_screaming_snake_case()),
248        funcsig.ident.span(),
249    );
250
251    let ident = &funcsig.ident;
252    let invoke_fn = if is_async {
253        quote! {
254            fn #func_wrapper_name(args: ::serde_json::Value) -> ::reductool::InvokeFuture {
255                Box::pin(async move {
256                    let parsed: #args_struct_ident = ::serde_json::from_value(args)?;
257                    let out = #ident(#(parsed.#field_idents),*).await;
258                    ::serde_json::to_value(out).map_err(Into::into)
259                })
260            }
261        }
262    } else {
263        quote! {
264            fn #func_wrapper_name(args: ::serde_json::Value) -> ::reductool::InvokeFuture {
265                Box::pin(async move {
266                    let parsed: #args_struct_ident = ::serde_json::from_value(args)?;
267                    let out = #ident(#(parsed.#field_idents),*);
268                    ::serde_json::to_value(out).map_err(Into::into)
269                })
270            }
271        }
272    };
273
274    let expanded = quote! {
275        #func
276
277        #[derive(::serde::Deserialize)]
278        struct #args_struct_ident {
279            #fields_tokens
280        }
281
282        #invoke_fn
283
284        #[::reductool::__linkme::distributed_slice(::reductool::ALL_TOOLS)]
285        static #reg_name: ::reductool::ToolDefinition = ::reductool::ToolDefinition {
286            name: #name_lit,
287            description: #desc_lit,
288            json_schema: #json_schema_lit,
289            invoke: #func_wrapper_name,
290        };
291    };
292    expanded.into()
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use serde_json::{Value, json};
299    use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_str};
300
301    // Build a (properties, required) pair using the same logic as the macro, driven by ty_to_schema.
302    fn build_props_and_required(func: &ItemFn) -> (serde_json::Map<String, Value>, Vec<String>) {
303        let mut props = serde_json::Map::new();
304        let mut required = Vec::new();
305
306        for input in &func.sig.inputs {
307            if let FnArg::Typed(pat_ty) = input {
308                let param_ident = match &*pat_ty.pat {
309                    Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
310                    _ => continue,
311                };
312                let name = param_ident.to_string();
313
314                // Schema via production helper
315                let schema = ty_to_schema(&pat_ty.ty);
316                props.insert(name.clone(), schema);
317
318                // Required flag via same top-level Option<T> detection as macro
319                let mut is_optional = false;
320                if let Type::Path(tp) = &*pat_ty.ty {
321                    if path_last_ident_is(&tp.path, "Option") {
322                        is_optional = true;
323                    }
324                }
325                if !is_optional {
326                    required.push(name);
327                }
328            }
329        }
330        (props, required)
331    }
332
333    fn parse_fn(src: &str) -> ItemFn {
334        parse_str::<ItemFn>(src).expect("failed to parse function")
335    }
336
337    #[test]
338    fn test_primitives_and_refs() {
339        let func = parse_fn("fn f(a: i32, b: f64, c: bool, d: String, e: &str) {}");
340        let (props, required) = build_props_and_required(&func);
341
342        assert_eq!(props.get("a").unwrap(), &json!({ "type": "integer" }));
343        assert_eq!(props.get("b").unwrap(), &json!({ "type": "number" }));
344        assert_eq!(props.get("c").unwrap(), &json!({ "type": "boolean" }));
345        assert_eq!(props.get("d").unwrap(), &json!({ "type": "string" }));
346        assert_eq!(props.get("e").unwrap(), &json!({ "type": "string" }));
347
348        assert_eq!(required, vec!["a", "b", "c", "d", "e"]);
349    }
350
351    #[test]
352    fn test_array_and_tuple() {
353        let func = parse_fn("fn g(x: [i32; 3], y: (i32, String, bool)) {}");
354        let (props, required) = build_props_and_required(&func);
355
356        assert_eq!(
357            props.get("x").unwrap(),
358            &json!({ "type": "array", "items": { "type": "integer" } })
359        );
360        assert_eq!(
361            props.get("y").unwrap(),
362            &json!({
363                "type": "array",
364                "items": [
365                    { "type": "integer" },
366                    { "type": "string" },
367                    { "type": "boolean" }
368                ],
369                "minItems": 3,
370                "maxItems": 3
371            })
372        );
373
374        assert_eq!(required, vec!["x", "y"]);
375    }
376
377    #[test]
378    fn test_vec_and_option() {
379        let func = parse_fn("fn h(a: Vec<i32>, b: Option<String>, c: Option<Vec<bool>>) {}");
380        let (props, required) = build_props_and_required(&func);
381
382        assert_eq!(
383            props.get("a").unwrap(),
384            &json!({ "type": "array", "items": { "type": "integer" } })
385        );
386        assert_eq!(props.get("b").unwrap(), &json!({ "type": "string" }));
387        assert_eq!(
388            props.get("c").unwrap(),
389            &json!({ "type": "array", "items": { "type": "boolean" } })
390        );
391
392        // Only `a` is required; `b` and `c` are Option<_>
393        assert_eq!(required, vec!["a"]);
394    }
395
396    #[test]
397    fn test_json_value_and_bare_value() {
398        let func = parse_fn("fn j(a: serde_json::Value, b: Value) {}");
399        let (props, required) = build_props_and_required(&func);
400
401        assert_eq!(props.get("a").unwrap(), &json!({}));
402        assert_eq!(props.get("b").unwrap(), &json!({}));
403
404        assert_eq!(required, vec!["a", "b"]);
405    }
406
407    #[test]
408    fn test_custom_type_and_ref_of_option() {
409        // Custom type and &Option<String>; note &Option<String> is treated as required by current logic.
410        let func = parse_fn("fn k(a: MyType, b: &Option<String>) {}");
411        let (props, required) = build_props_and_required(&func);
412
413        assert_eq!(props.get("a").unwrap(), &json!({ "type": "string" }));
414        assert_eq!(props.get("b").unwrap(), &json!({ "type": "string" }));
415        assert_eq!(required, vec!["a", "b"]);
416    }
417}