elevate_code_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse_macro_input;
4
5#[proc_macro_attribute]
6pub fn elevate_code(_attr: TokenStream, input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as syn::ItemFn);
8    let mut inner = input.clone();
9    let fn_ident = &input.sig.ident;
10    let fn_name = input.sig.ident.to_string();
11    let sig = &input.sig;
12    let args = input
13        .sig
14        .inputs
15        .iter()
16        .map(|arg| match arg {
17            syn::FnArg::Receiver(_) => format_ident!("self"),
18            syn::FnArg::Typed(typed) => match typed.pat.as_ref() {
19                syn::Pat::Ident(ident) => ident.ident.clone(),
20                _ => todo!(),
21            },
22        })
23        .collect::<Vec<_>>();
24    let type_list = input
25        .sig
26        .inputs
27        .iter()
28        .map(|arg| match arg {
29            syn::FnArg::Receiver(_) => todo!(),
30            syn::FnArg::Typed(typed) => {
31                let ty = &typed.ty;
32                quote!(#ty)
33            }
34        })
35        .collect::<Vec<_>>();
36
37    inner.sig.ident = format_ident!("_{}", fn_name);
38    let inner_name = &inner.sig.ident;
39
40    let (call, serialization) = if type_list.is_empty() {
41        (
42            quote! {
43                #fn_ident();
44                std::process::exit(0);
45            },
46            quote! {
47                Some(String::new())
48            },
49        )
50    } else {
51        (
52            quote! {
53                let (#(#args),*,) : (#(#type_list),*,) = _elevate_code::serde_json::from_str(&payload).map_err(|err| format!("{err}")).unwrap();
54                #fn_ident(
55                    #(#args),*
56                );
57                std::process::exit(0);
58            },
59            quote! {
60                _elevate_code::serde_json::to_string(
61                    &(#(&#args),*,)
62                ).map_err(|err| format!("{err}")).ok()
63            },
64        )
65    };
66
67    let decoration = quote! {
68        const _: () = {
69            extern crate elevate_code as _elevate_code;
70
71            struct T;
72
73            #[_elevate_code::ctor::ctor]
74            fn init() {
75                let id: &str = #fn_name;
76
77                let cmd_line = _elevate_code::ElevateToken::from_command_line();
78
79                match _elevate_code::ElevateToken::from_command_line() {
80                    Some(_elevate_code::ElevateToken::Execute { task_id, payload }) if id == task_id => {
81                        #call
82                    }
83                    _ => {},
84                }
85            }
86        };
87
88        #sig {
89            extern crate elevate_code as _elevate_code;
90
91            #inner
92
93            let id: &str = #fn_name;
94
95            if _elevate_code::is_elevated() {
96                return #inner_name(#(#args),*);
97            }
98
99            if let Some(json) = #serialization {
100                let token = _elevate_code::ElevateToken::Execute {
101                    task_id: id.to_string(),
102                    payload: json,
103                };
104                _elevate_code::create_process(&[&token.to_string()], |pid| {
105                    match _elevate_code::GLOBAL_CLIENT.request(_elevate_code::ElevationRequest::new(pid)) {
106                        Ok(_) => _elevate_code::ProcessControlFlow::ResumeMainThread,
107                        Err(err) => _elevate_code::ProcessControlFlow::Terminate,
108                    }
109                });
110            } else {
111                panic!("Error on serializing arguments")
112            }
113        }
114    };
115
116    decoration.into()
117}