1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn funcall(_attr: TokenStream, item: TokenStream) -> TokenStream {
7 let func = parse_macro_input!(item as ItemFn);
8 let name = &func.sig.ident;
9 let wrapper_name = format_ident!("{}_tool", name);
10
11 let mut arg_idents = Vec::new();
12 let mut positional_stmts = Vec::new();
13 let mut named_stmts = Vec::new();
14
15 for (i, input) in func.sig.inputs.iter().enumerate() {
16 if let FnArg::Typed(pat_type) = input {
17 let ident = match &*pat_type.pat {
18 Pat::Ident(PatIdent { ident, .. }) => ident.clone(),
19 _ => panic!("Unsupported argument pattern"),
20 };
21 let ty = &pat_type.ty;
22 let index = syn::Index::from(i);
23 let key = ident.to_string();
24
25 let (pos_stmt, named_stmt) = extract_dual(&ident, ty, &index, &key);
26 arg_idents.push(ident);
27 positional_stmts.push(pos_stmt);
28 named_stmts.push(named_stmt);
29 }
30 }
31
32 let expanded = quote! {
33 #func
34
35 pub fn #wrapper_name(args: &::serde_json::Value) -> ::serde_json::Value {
36 #(let #arg_idents;)*
37 if let Some(arr) = args.as_array() {
38 #(#positional_stmts)*
39 } else if let Some(obj) = args.as_object() {
40 #(#named_stmts)*
41 } else {
42 panic!("expected JSON array or object");
43 }
44
45 let result = #name(#(#arg_idents),*);
46 ::serde_json::json!(result)
47 }
48 };
49
50 TokenStream::from(expanded)
51}
52
53fn extract_dual(
54 ident: &syn::Ident,
55 ty: &Box<Type>,
56 index: &syn::Index,
57 key: &str,
58) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
59 let ty_str = quote!(#ty).to_string().replace(' ', "");
60
61 let positional = if ty_str == "i32" {
62 quote! { #ident = arr[#index].as_i64().expect("expected i64") as i32; }
63 } else if ty_str == "f64" {
64 quote! { #ident = arr[#index].as_f64().expect("expected f64"); }
65 } else if ty_str == "bool" {
66 quote! { #ident = arr[#index].as_bool().expect("expected bool"); }
67 } else if ty_str == "String" {
68 quote! { #ident = arr[#index].as_str().expect("expected string").to_string(); }
69 } else {
70 quote! {
72 #ident = ::serde::Deserialize::deserialize(&arr[#index]).expect("failed to deserialize positional");
73 }
74 };
75
76 let named = if ty_str == "i32" {
77 quote! { #ident = obj[#key].as_i64().expect("expected i64") as i32; }
78 } else if ty_str == "f64" {
79 quote! { #ident = obj[#key].as_f64().expect("expected f64"); }
80 } else if ty_str == "bool" {
81 quote! { #ident = obj[#key].as_bool().expect("expected bool"); }
82 } else if ty_str == "String" {
83 quote! { #ident = obj[#key].as_str().expect("expected string").to_string(); }
84 } else if ty_str.starts_with("Option<") {
85 quote! {
86 #ident = match obj.get(#key) {
87 Some(v) if !v.is_null() => Some(::serde::Deserialize::deserialize(v).expect("failed to parse Option")),
88 _ => None
89 };
90 }
91 } else {
92 quote! {
93 #ident = ::serde::Deserialize::deserialize(&obj[#key]).expect("failed to deserialize named param");
94 }
95 };
96
97 (positional, named)
98}