1use std::sync::LazyLock;
2
3use derive_syn_parse::Parse;
4use quote::{format_ident, quote};
5use syn::parse_macro_input;
6
7static CRATE_NAME: LazyLock<&str> = LazyLock::new(|| {
8 if "jslt" == std::env::var("CARGO_PKG_NAME").unwrap_or_default() {
9 "crate"
10 } else {
11 "jslt"
12 }
13});
14
15#[proc_macro_attribute]
16pub fn static_function(
17 _attr: proc_macro::TokenStream,
18 item: proc_macro::TokenStream,
19) -> proc_macro::TokenStream {
20 let mut item = parse_macro_input!(item as syn::ItemFn);
21
22 let mut vis = syn::Visibility::Inherited;
23 std::mem::swap(&mut item.vis, &mut vis);
24
25 let mut ident = format_ident!("_wrapped_implementation");
27 std::mem::swap(&mut item.sig.ident, &mut ident);
28
29 let arguments_ident = format_ident!("arguments");
30
31 let arguments = item.sig.inputs.iter().enumerate().map(|(index, item)| {
32 let required = !matches!(item, syn::FnArg::Typed(syn::PatType { ty, .. }) if matches!(**ty, syn::Type::Path(_)));
33 let required = required.then(|| quote! { .unwrap_or(&serde_json::Value::Null) });
34
35 quote! {
36 #arguments_ident .get(#index) #required
37 }
38 });
39
40 let jslt = format_ident!("{}", *CRATE_NAME);
41
42 quote! {
43 #vis fn #ident(#arguments_ident: &[serde_json::Value]) -> Result<serde_json::Value, #jslt::error::JsltError> {
44 #item
45
46 _wrapped_implementation( #(#arguments,)* )
47 }
48 }.into()
49}
50
51#[derive(Parse, Debug)]
52struct ExpectInnerArgs {
53 pairs: syn::Ident,
54 _period_token: syn::Token![,],
55 rule: syn::Type,
56}
57
58#[proc_macro]
59pub fn expect_inner(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
60 let ExpectInnerArgs { pairs, rule, .. } = parse_macro_input!(item as ExpectInnerArgs);
61
62 let jslt = format_ident!("{}", *CRATE_NAME);
63
64 quote! {
65 {
66 let Some(pair) = #pairs.next() else {
67 return Err(#jslt::error::JsltError::UnexpectedEnd);
68 };
69
70 let rule = pair.as_rule();
71
72 if !matches!(rule, #rule) {
73 return Err(#jslt::error::JsltError::UnexpectedInput(
74 #rule,
75 rule,
76 pair.as_str().to_owned(),
77 ));
78 }
79
80 Ok::<Pairs<_>, #jslt::error::JsltError>(pair.into_inner())
81 }
82 }
83 .into()
84}