funcall_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{FnArg, ItemFn, Pat, PatIdent, Type, parse_macro_input};
4
5#[proc_macro_attribute]
6pub fn funcall(_attr: TokenStream, item: TokenStream) -> TokenStream {
7    let func = parse_macro_input!(item as ItemFn);
8    let name = &func.sig.ident;
9    let wrapper_name = format_ident!("{}_tool", name);
10
11    let mut arg_names = Vec::new();
12    let mut extract_stmts = Vec::new();
13
14    for (i, input) in func.sig.inputs.iter().enumerate() {
15        if let FnArg::Typed(pat_type) = input {
16            let ident = match &*pat_type.pat {
17                Pat::Ident(PatIdent { ident, .. }) => ident,
18                _ => panic!("Unsupported pattern"),
19            };
20
21            let (stmt, _) = extract_type_code(ident, &pat_type.ty, i);
22            arg_names.push(ident);
23            extract_stmts.push(stmt);
24        }
25    }
26
27    let arg_count = arg_names.len();
28
29    let expanded = quote! {
30        #func
31
32        pub fn #wrapper_name(args: &::serde_json::Value) -> ::serde_json::Value {
33            let args = args.as_array().expect("expected JSON array");
34            assert_eq!(args.len(), #arg_count, "wrong number of args");
35
36            #(#extract_stmts)*
37
38            let result = #name(#(#arg_names),*);
39            ::serde_json::json!(result)
40        }
41    };
42
43    TokenStream::from(expanded)
44}
45
46fn extract_type_code(
47    ident: &syn::Ident,
48    ty: &Box<Type>,
49    index: usize,
50) -> (proc_macro2::TokenStream, bool) {
51    let ty_str = quote!(#ty).to_string().replace(' ', "");
52
53    let stmt = if ty_str == "i32" {
54        quote! {
55            let #ident = args[#index].as_i64().expect("expected i64") as i32;
56        }
57    } else if ty_str == "f64" {
58        quote! {
59            let #ident = args[#index].as_f64().expect("expected f64");
60        }
61    } else if ty_str == "bool" {
62        quote! {
63            let #ident = args[#index].as_bool().expect("expected bool");
64        }
65    } else if ty_str == "String" {
66        quote! {
67            let #ident = args[#index].as_str().expect("expected string").to_string();
68        }
69    } else if ty_str.starts_with("Option<") {
70        quote! {
71            let #ident = match args.get(#index) {
72                Some(v) if !v.is_null() => Some(::serde::Deserialize::deserialize(v).expect("failed to parse Option")),
73                _ => None
74            };
75        }
76    } else if ty_str.starts_with("Vec<") {
77        quote! {
78            let #ident: #ty = ::serde::Deserialize::deserialize(&args[#index]).expect("failed to parse Vec");
79        }
80    } else {
81        // Assume any other type implements Deserialize
82        quote! {
83            let #ident: #ty = ::serde::Deserialize::deserialize(&args[#index]).expect("failed to deserialize struct");
84        }
85    };
86
87    (stmt, true)
88}