1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
use object::Object;
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, LitStr};

#[proc_macro_derive(ForwardModule, attributes(forward))]
pub fn derive_forward_module(item: TokenStream) -> TokenStream {
    let input = parse_macro_input!(item as syn::DeriveInput);
    let forward_attr = input
        .attrs
        .iter()
        .find(|i| i.path().is_ident("forward"))
        .expect(r#"你需要添加 #[forward("path/of/target_dll.dll")]"#);
    let dll_path: LitStr = forward_attr.parse_args().expect(r#"#[forward()] 的参数应为一个字符串字面量,如 #[forward("C:\Windows\System32\version.dll")]"#);
    let struct_name = input.ident;

    let export_names = get_dll_export_names(dll_path.value().as_str());
    let export_idents: Vec<_> = export_names
        .iter()
        .map(|i| format_ident!("{}", i))
        .collect();
    let export_count = export_names.len();

    let impl_code = quote! {
        const _ : () = {
            extern crate forward_dll as _forward_dll;

            static mut _FORWARDER: _forward_dll::DllForwarder<#export_count> = _forward_dll::DllForwarder {
                initialized: false,
                module_handle: 0,
                lib_name: #dll_path,
                target_functions_address: [0; #export_count],
                target_function_names: [#(#export_names),*],
            };

            _forward_dll::define_function!(#dll_path, _FORWARDER, 0, #(#export_idents)*);

            impl _forward_dll::ForwardModule for #struct_name {
                fn init(&self) -> _forward_dll::ForwardResult<()> {
                    unsafe { _FORWARDER.forward_all() }
                }
            }
        };
    };
    impl_code.into()
}

fn get_dll_export_names(dll_path: &str) -> Vec<String> {
    let dll_file = std::fs::read(dll_path).unwrap();
    let pe = object::File::parse(&*dll_file).unwrap();
    let exports = pe.exports().unwrap();
    let mut names = Vec::new();
    for export_item in exports {
        names.push(String::from_utf8_lossy(export_item.name()).into());
    }
    names
}