1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn funcall(_attr: TokenStream, item: TokenStream) -> TokenStream {
7 let func = parse_macro_input!(item as ItemFn);
8 let name = &func.sig.ident;
9 let wrapper_name = format_ident!("{}_tool", name);
10
11 let mut arg_names = Vec::new();
12 let mut extract_stmts = Vec::new();
13
14 for (i, input) in func.sig.inputs.iter().enumerate() {
15 if let FnArg::Typed(pat_type) = input {
16 let ident = match &*pat_type.pat {
17 Pat::Ident(PatIdent { ident, .. }) => ident,
18 _ => panic!("Unsupported pattern"),
19 };
20
21 let (stmt, _) = extract_type_code(ident, &pat_type.ty, i);
22 arg_names.push(ident);
23 extract_stmts.push(stmt);
24 }
25 }
26
27 let arg_count = arg_names.len();
28
29 let expanded = quote! {
30 #func
31
32 pub fn #wrapper_name(args: &::serde_json::Value) -> ::serde_json::Value {
33 let args = args.as_array().expect("expected JSON array");
34 assert_eq!(args.len(), #arg_count, "wrong number of args");
35
36 #(#extract_stmts)*
37
38 let result = #name(#(#arg_names),*);
39 ::serde_json::json!(result)
40 }
41 };
42
43 TokenStream::from(expanded)
44}
45
46fn extract_type_code(
47 ident: &syn::Ident,
48 ty: &Box<Type>,
49 index: usize,
50) -> (proc_macro2::TokenStream, bool) {
51 let ty_str = quote!(#ty).to_string().replace(' ', "");
52
53 let stmt = if ty_str == "i32" {
54 quote! {
55 let #ident = args[#index].as_i64().expect("expected i64") as i32;
56 }
57 } else if ty_str == "f64" {
58 quote! {
59 let #ident = args[#index].as_f64().expect("expected f64");
60 }
61 } else if ty_str == "bool" {
62 quote! {
63 let #ident = args[#index].as_bool().expect("expected bool");
64 }
65 } else if ty_str == "String" {
66 quote! {
67 let #ident = args[#index].as_str().expect("expected string").to_string();
68 }
69 } else if ty_str.starts_with("Option<") {
70 quote! {
71 let #ident = match args.get(#index) {
72 Some(v) if !v.is_null() => Some(::serde::Deserialize::deserialize(v).expect("failed to parse Option")),
73 _ => None
74 };
75 }
76 } else if ty_str.starts_with("Vec<") {
77 quote! {
78 let #ident: #ty = ::serde::Deserialize::deserialize(&args[#index]).expect("failed to parse Vec");
79 }
80 } else {
81 quote! {
83 let #ident: #ty = ::serde::Deserialize::deserialize(&args[#index]).expect("failed to deserialize struct");
84 }
85 };
86
87 (stmt, true)
88}