funcall_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn funcall(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    let func = parse_macro_input!(item as ItemFn);
8    let name = &func.sig.ident;
9    let wrapper_name = format_ident!("{}_tool", name);
10
11    let mut arg_idents = Vec::new();
12    let mut positional_stmts = Vec::new();
13    let mut named_stmts = Vec::new();
14
15    for (i, input) in func.sig.inputs.iter().enumerate() {
16        if let FnArg::Typed(pat_type) = input {
17            let ident = match &*pat_type.pat {
18                Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
19                _ => panic!("Unsupported argument pattern"),
20            };
21            let ty = &pat_type.ty;
22            let index = syn::Index::from(i);
23            let key = ident.to_string();
24
25            let (pos_stmt, named_stmt) = extract_dual(&ident, ty, &index, &key);
26            arg_idents.push(ident);
27            positional_stmts.push(pos_stmt);
28            named_stmts.push(named_stmt);
29        }
30    }
31
32    let expanded = quote! {
33        #func
34
35        pub fn #wrapper_name(args: &::serde_json::Value) -> ::serde_json::Value {
36            #(let #arg_idents;)*
37            if let Some(arr) = args.as_array() {
38                #(#positional_stmts)*
39            } else if let Some(obj) = args.as_object() {
40                #(#named_stmts)*
41            } else {
42                panic!("expected JSON array or object");
43            }
44
45            let result = #name(#(#arg_idents),*);
46            ::serde_json::json!(result)
47        }
48    };
49
50    TokenStream::from(expanded)
51}
52
53fn extract_dual(
54    ident: &syn::Ident,
55    ty: &Box<Type>,
56    index: &syn::Index,
57    key: &str,
58) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
59    let ty_str = quote!(#ty).to_string().replace(' ', "");
60
61    let positional = if ty_str == "i32" {
62        quote! { #ident = arr[#index].as_i64().expect("expected i64") as i32; }
63    } else if ty_str == "f64" {
64        quote! { #ident = arr[#index].as_f64().expect("expected f64"); }
65    } else if ty_str == "bool" {
66        quote! { #ident = arr[#index].as_bool().expect("expected bool"); }
67    } else if ty_str == "String" {
68        quote! { #ident = arr[#index].as_str().expect("expected string").to_string(); }
69    } else {
70        // fallback to full deserialization for Option<T>, Vec<T>, struct
71        quote! {
72            #ident = ::serde::Deserialize::deserialize(&arr[#index]).expect("failed to deserialize positional");
73        }
74    };
75
76    let named = if ty_str == "i32" {
77        quote! { #ident = obj[#key].as_i64().expect("expected i64") as i32; }
78    } else if ty_str == "f64" {
79        quote! { #ident = obj[#key].as_f64().expect("expected f64"); }
80    } else if ty_str == "bool" {
81        quote! { #ident = obj[#key].as_bool().expect("expected bool"); }
82    } else if ty_str == "String" {
83        quote! { #ident = obj[#key].as_str().expect("expected string").to_string(); }
84    } else if ty_str.starts_with("Option<") {
85        quote! {
86            #ident = match obj.get(#key) {
87                Some(v) if !v.is_null() => Some(::serde::Deserialize::deserialize(v).expect("failed to parse Option")),
88                _ => None
89            };
90        }
91    } else {
92        quote! {
93            #ident = ::serde::Deserialize::deserialize(&obj[#key]).expect("failed to deserialize named param");
94        }
95    };
96
97    (positional, named)
98}