proxygen_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{FnArg, ItemFn};
4
5const GET_ARG_TYPES: fn(&FnArg) -> syn::Ident = |arg: &FnArg| match arg {
6    FnArg::Receiver(_) => panic!("Cannot use receivers (self) with proxy functions"),
7    FnArg::Typed(arg) => {
8        if let syn::Type::Path(syn::TypePath {
9            path: syn::Path { segments, .. },
10            ..
11        }) = arg.ty.as_ref()
12        {
13            return segments.first().unwrap().ident.clone();
14        }
15        panic!("Unsupported function signature");
16    }
17};
18
19const GET_ARG_NAMES: fn(&FnArg) -> syn::Ident = |arg: &FnArg| match arg {
20    FnArg::Receiver(_) => panic!("Cannot use receivers (self) with proxy functions"),
21    FnArg::Typed(arg) => {
22        let syn::PatType { pat, .. } = &arg;
23        let pat = pat.clone();
24        match *pat {
25            syn::Pat::Ident(ident) => ident.ident,
26            _ => panic!("Unexpected arg name: {:?}", pat),
27        }
28    }
29};
30
31#[derive(Debug, Clone, Copy)]
32enum ProxySignatureType {
33    Known,
34    Unknown,
35}
36
37impl From<syn::Meta> for ProxySignatureType {
38    fn from(meta: syn::Meta) -> Self {
39        match meta {
40            syn::Meta::Path(_) => panic!("Unsupported attribute inputs"),
41            syn::Meta::List(_) => panic!("Unsupported attribute inputs"),
42            syn::Meta::NameValue(sig) => {
43                if let Some(ident) = sig.path.get_ident() {
44                    if ident.to_string() != "sig" {
45                        panic!("Expected sig=\"unknown\" or sig=\"known\"")
46                    }
47                    if let syn::Expr::Lit(syn::ExprLit {
48                        lit: syn::Lit::Str(token),
49                        ..
50                    }) = sig.value
51                    {
52                        match token.value().as_str() {
53                            "known" => ProxySignatureType::Known,
54                            "unknown" => ProxySignatureType::Unknown,
55                            _ => panic!("Expected sig=\"unknown\" or sig=\"known\""),
56                        }
57                    } else {
58                        panic!("Expected sig=\"unknown\" or sig=\"known\"")
59                    }
60                } else {
61                    panic!("Expected sig=\"unknown\" or sig=\"known\"")
62                }
63            }
64        }
65    }
66}
67
68// Proc macro to forward a function call to the orginal function
69//
70// Note: You may not have any instructions in the function body when forwarding function calls
71#[proc_macro_attribute]
72pub fn forward(_attr_input: TokenStream, item: TokenStream) -> TokenStream {
73    let input: ItemFn = syn::parse(item).expect("You may only proxy a function");
74    let func_name = input.sig.clone().ident;
75    let func_body = input.block.stmts.clone();
76    let ret_type = input.sig.output.clone();
77    let orig_index_ident =
78        syn::parse_str::<syn::Path>(&format!("crate::export_indices::Index_{}", &func_name))
79            .unwrap();
80    let arg_types = input.sig.inputs.iter().map(GET_ARG_TYPES);
81    let attrs = input
82        .attrs
83        .into_iter()
84        .filter(|attr| !attr.path().is_ident("proxy"));
85
86    if arg_types.len() > 0 {
87        panic!("You may not specifiy arguments in a forwarding proxy");
88    }
89    match ret_type.clone() {
90        syn::ReturnType::Default => {}
91        syn::ReturnType::Type(_, ty) => match *ty {
92            syn::Type::Path(ref p) => {
93                if !p.path.is_ident("()") {
94                    panic!("You may not specify a return type when forwarding a function call");
95                }
96            }
97            syn::Type::Tuple(ref t) => {
98                if !t.elems.is_empty() {
99                    panic!("You may not specify a return type when forwarding a function call");
100                }
101            }
102            _ => panic!("You may not specify a return type when forwarding a function call"),
103        },
104    };
105    if func_body.len() > 0 {
106        panic!("Your function body will not get run in a forwarding proxy. Perhaps you meant to use a `pre_hook`?");
107    }
108
109    TokenStream::from(quote!(
110        #[naked]
111        #(#attrs)*
112        pub unsafe extern "C" fn #func_name() {
113            #[cfg(target_arch = "x86_64")]
114            {
115                std::arch::asm!(
116                    "call {wait_dll_proxy_init}",
117                    "mov rax, qword ptr [rip + {ORIG_FUNCS_PTR}]",
118                    "add rax, {orig_index} * 8",
119                    "mov rax, qword ptr [rax]",
120                    "push rax",
121                    "ret",
122                    wait_dll_proxy_init = sym crate::wait_dll_proxy_init,
123                    ORIG_FUNCS_PTR = sym crate::ORIG_FUNCS_PTR,
124                    orig_index = const #orig_index_ident,
125                    options(noreturn)
126                )
127            }
128
129            #[cfg(target_arch = "x86")]
130            {
131                std::arch::asm!(
132                    "call {wait_dll_proxy_init}",
133                    "mov eax, dword ptr [{ORIG_FUNCS_PTR}]",
134                    "add eax, {orig_index} * 4",
135                    "mov eax, dword ptr [eax]",
136                    "push eax",
137                    "ret",
138                    wait_dll_proxy_init = sym crate::wait_dll_proxy_init,
139                    ORIG_FUNCS_PTR = sym crate::ORIG_FUNCS_PTR,
140                    orig_index = const #orig_index_ident,
141                    options(noreturn)
142                )
143            }
144        }
145    ))
146}
147
148// Proc macro to bring the original function into the scope of an interceptor function as `orig_func`
149#[proc_macro_attribute]
150pub fn proxy(attr_input: TokenStream, item: TokenStream) -> TokenStream {
151    let input: ItemFn = syn::parse(item).expect("You may only proxy a function");
152    let attr_input = syn::parse::<syn::Meta>(attr_input);
153    let func_name = input.sig.clone().ident;
154    let func_sig = input.sig.clone();
155    let func_body = input.block.stmts.clone();
156    let ret_type = input.sig.output.clone();
157    let orig_index_ident =
158        syn::parse_str::<syn::Path>(&format!("crate::export_indices::Index_{}", &func_name))
159            .unwrap();
160    let arg_types = input.sig.inputs.iter().map(GET_ARG_TYPES);
161    let attrs = input
162        .attrs
163        .into_iter()
164        .filter(|attr| !attr.path().is_ident("proxy"));
165    let sig_type: ProxySignatureType = match attr_input {
166        Ok(attr_input) => attr_input.into(),
167        Err(_) => panic!("Please explictly set sig=\"known\" or sig=\"unknown\". Eg. #[post_hook(sig = \"known\")]"),
168    };
169
170    match sig_type {
171        ProxySignatureType::Known => {
172            TokenStream::from(quote!(
173                #(#attrs)*
174                #func_sig {
175                    crate::wait_dll_proxy_init();
176                    let orig_func: fn (#(#arg_types,)*) #ret_type = unsafe { std::mem::transmute(crate::ORIGINAL_FUNCS[#orig_index_ident]) };
177                    #(#func_body)*
178                }
179            ))
180        },
181        ProxySignatureType::Unknown => panic!("You may not manual-proxy a function with an unknown signature (only pre-hooking is supported)"),
182    }
183}
184
185/// Proc macro that indicates that any code in this function will be run just before the original function is called.
186//
187// You should explicitly set `sig="known"` if you know the function signature
188//
189/// The function being proxied can be accessed as `orig_func`
190///
191/// Note: Returning in this function will skip running the original.
192#[proc_macro_attribute]
193pub fn pre_hook(attr_input: TokenStream, item: TokenStream) -> TokenStream {
194    let input: ItemFn = syn::parse(item).expect("You may only proxy a function");
195    let attr_input = syn::parse::<syn::Meta>(attr_input);
196    let func_name = input.sig.ident.clone();
197    let func_sig = input.sig.clone();
198    let func_body = input.block.stmts.clone();
199    let ret_type = input.sig.output.clone();
200    let orig_index_ident =
201        syn::parse_str::<syn::Path>(&format!("crate::export_indices::Index_{}", &func_name))
202            .unwrap();
203    let arg_names = input.sig.inputs.iter().map(GET_ARG_NAMES);
204    let arg_types = input.sig.inputs.iter().map(GET_ARG_TYPES);
205    let attrs = input
206        .attrs
207        .into_iter()
208        .filter(|attr| !attr.path().is_ident("pre_hook"));
209    let sig_type: ProxySignatureType = match attr_input {
210            Ok(attr_input) => attr_input.into(),
211            Err(_) => panic!("Please explictly set sig=\"known\" or sig=\"unknown\". Eg. #[post_hook(sig = \"known\")]"),
212        };
213
214    match sig_type {
215        ProxySignatureType::Known => TokenStream::from(quote!(
216            #(#attrs)*
217            #func_sig {
218                let orig_func: fn (#(#arg_types,)*) #ret_type = unsafe { std::mem::transmute(crate::ORIGINAL_FUNCS[#orig_index_ident]) };
219                #(#func_body)*
220                orig_func(#(#arg_names,)*)
221            }
222        )),
223        ProxySignatureType::Unknown => {
224            if arg_names.clone().len() != 0 {
225                panic!("You may not specifiy any arguments when proxying a function with an unknown signature");
226            }
227            match ret_type.clone() {
228                syn::ReturnType::Default => {},
229                syn::ReturnType::Type(_, ty) => {
230                    match *ty {
231                        syn::Type::Path(ref p) => if !p.path.is_ident("()") {
232                            panic!("You may not specify a return type when proxying a function with an unknown signature");
233                        },
234                        syn::Type::Tuple(ref t) => if !t.elems.is_empty() {
235                            panic!("You may not specify a return type when proxying a function with an unknown signature");
236                        },
237                        _ => panic!("You may not specify a return type when proxying a function with an unknown signature")
238                    }
239                }
240            };
241            let hook_func_name =
242                syn::parse_str::<syn::Ident>(&format!("Proxygen_PreHook_{}", &func_name)).unwrap();
243            TokenStream::from(quote!(
244                #[cfg(not(target_arch = "x86_64"))]
245                compile_error!("Pre-hooks aren't yet implemented for non x86-64");
246
247                #[no_mangle]
248                // TODO: Use the same safety/unsafety modifier as the original here
249                pub unsafe extern "C" fn #hook_func_name() {
250                    let orig_func: fn () = std::mem::transmute(crate::ORIGINAL_FUNCS[#orig_index_ident]);
251                    #(#func_body)*
252                }
253
254                #[naked]
255                #(#attrs)*
256                pub unsafe extern "C" fn #func_name() {
257                    std::arch::asm!(
258                        // Wait for dll proxy to initialize
259                        "call {wait_dll_proxy_init}",
260                        "mov rax, qword ptr [rip + {ORIG_FUNCS_PTR}]",
261                        "add rax, {orig_index} * 8",
262                        "mov rax, qword ptr [rax]",
263
264                        // Push the original function onto the stack
265                        "push rax",
266
267                        // Save the general purpose registers
268                        "push rdi; push rsi; push rcx; push rdx; push r8; push r9",
269
270                        // Save the 128-bit floating point registers
271                        "sub rsp, 64",
272                        "movaps [rsp], xmm0",
273                        "movaps [rsp + 16], xmm1",
274                        "movaps [rsp + 32], xmm2",
275                        "movaps [rsp + 48], xmm3",
276
277                        // Call our hook code here
278                        "call {proxygen_pre_hook_func}",
279
280                        // Restore the 128-bit floating point registers
281                        "movaps xmm3, [rsp + 48]",
282                        "movaps xmm2, [rsp + 32]",
283                        "movaps xmm1, [rsp + 16]",
284                        "movaps xmm0, [rsp]",
285                        "add rsp, 64",
286
287                        // Restore the general purpose registers
288                        "pop r9; pop r8; pop rdx; pop rcx; pop rsi; pop rdi",
289
290                        // Return to the original function
291                        "ret",
292                        wait_dll_proxy_init = sym crate::wait_dll_proxy_init,
293                        ORIG_FUNCS_PTR = sym crate::ORIG_FUNCS_PTR,
294                        orig_index = const #orig_index_ident,
295                        proxygen_pre_hook_func = sym #hook_func_name,
296                        options(noreturn)
297                    );
298                }
299            ))
300        }
301    }
302}
303
304/// Proc macro that indicates that any code in this function will be run after the original function is called.
305///
306/// The result of calling the original function will be accessible in `orig_result`.
307///
308/// Note: `orig_result` will be returned unless you choose to return your own result from this function.
309#[proc_macro_attribute]
310pub fn post_hook(attr_input: TokenStream, item: TokenStream) -> TokenStream {
311    let input: ItemFn = syn::parse(item).expect("You may only proxy a function");
312    let attr_input = syn::parse::<syn::Meta>(attr_input);
313    let func_name = input.sig.clone().ident;
314    let func_sig = input.sig.clone();
315    let func_body = input.block.stmts.clone();
316    let ret_type = input.sig.output.clone();
317    let orig_index_ident =
318        syn::parse_str::<syn::Path>(&format!("crate::export_indices::Index_{}", &func_name))
319            .unwrap();
320    let arg_names = input.sig.inputs.iter().map(GET_ARG_NAMES);
321    let arg_types = input.sig.inputs.iter().map(GET_ARG_TYPES);
322    let attrs = input
323        .attrs
324        .into_iter()
325        .filter(|attr| !attr.path().is_ident("post_hook"));
326    let sig_type: ProxySignatureType = match attr_input {
327        Ok(attr_input) => attr_input.into(),
328        Err(_) => panic!("Please explictly set sig=\"known\" or sig=\"unknown\". Eg. #[post_hook(sig = \"known\")]"),
329    };
330
331    match sig_type {
332        ProxySignatureType::Known => TokenStream::from(quote!(
333            #(#attrs)*
334            #func_sig {
335                crate::wait_dll_proxy_init();
336                let orig_func: fn (#(#arg_types,)*) #ret_type = unsafe { std::mem::transmute(crate::ORIGINAL_FUNCS[#orig_index_ident]) };
337                let orig_result = orig_func(#(#arg_names,)*);
338                #(#func_body)*
339                orig_result
340            }
341        )),
342        ProxySignatureType::Unknown => {
343            panic!("You may not post-hook a function with an unknown signature (only pre-hooking is supported)");
344        }
345    }
346}