1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse_macro_input;
4
5#[proc_macro_attribute]
6pub fn main(_attr: TokenStream, input: TokenStream) -> TokenStream {
7 let mut inner_fn = parse_macro_input!(input as syn::ItemFn);
8 let wrapper_fn = inner_fn.clone();
9 let wrapper_sig = &wrapper_fn.sig;
10 inner_fn.sig.ident = format_ident!("_{}", wrapper_sig.ident);
11 let inner_ident = &inner_fn.sig.ident;
12 let args = wrapper_sig.inputs.iter().map(|arg| match arg {
13 syn::FnArg::Receiver(_) => unimplemented!(),
14 syn::FnArg::Typed(typed) => match typed.pat.as_ref() {
15 syn::Pat::Ident(ident) => ident.ident.clone(),
16 _ => unimplemented!(),
17 },
18 });
19
20 let decoration = quote! {
21 #wrapper_sig {
22 extern crate elevated as _elevated;
23
24 #inner_fn
25
26 _elevated::execute_elevation_and_tasks();
27
28 #inner_ident(#(#args),*)
29 }
30 };
31
32 decoration.into()
33}
34
35#[proc_macro_attribute]
36pub fn elevated(_attr: TokenStream, input: TokenStream) -> TokenStream {
37 let input = parse_macro_input!(input as syn::ItemFn);
38 let fn_name = input.sig.ident.to_string();
39 let sig = &input.sig;
40 let args = input
41 .sig
42 .inputs
43 .iter()
44 .map(|arg| match arg {
45 syn::FnArg::Receiver(_) => format_ident!("self"),
46 syn::FnArg::Typed(typed) => match typed.pat.as_ref() {
47 syn::Pat::Ident(ident) => ident.ident.clone(),
48 _ => todo!(),
49 },
50 })
51 .collect::<Vec<_>>();
52
53 let mut inner = input.clone();
54 inner.sig.ident = format_ident!("_{}", fn_name);
55 let inner_name = &inner.sig.ident;
56
57 let arg_types = input.sig.inputs.iter().map(|arg| match arg {
58 syn::FnArg::Receiver(_) => unimplemented!(),
59 syn::FnArg::Typed(typed) => &typed.ty,
60 });
61
62 let decoration = quote! {
63 #sig {
64 extern crate elevated as _elevated;
65
66 #inner
67
68 fn caller(arg: String) -> String {
69 let (#(#args,)*): (#(#arg_types,)*) = _elevated::serde_json::from_str(&arg).unwrap();
70 let ret = #inner_name(#(#args),*);
71 _elevated::serde_json::to_string(&ret).unwrap()
72 }
73
74 let id: &str = #fn_name;
75
76 if _elevated::is_elevated() {
77 return #inner_name(#(#args),*);
78 }
79
80 let args = _elevated::serde_json::to_string(&(#(#args,)*)).unwrap();
81 _elevated::spawn_task(caller, args).unwrap()
82 }
83 };
84
85 decoration.into()
86}