forward_dll_derive/
lib.rs

1use object::read::pe::{PeFile32, PeFile64};
2use proc_macro::TokenStream;
3use quote::{format_ident, quote};
4use syn::{parse_macro_input, LitStr};
5
6const FORWARD_ATTR_LACK_MESSAGE: &str =
7    r#"你需要添加 #[forward(target = "path/of/target_dll.dll")]"#;
8const FORWARD_ATTR_INVALID_MESSAGE: &str = r#"#[forward()] 的参数格式错误,正确格式如 #[forward(target = "C:\Windows\System32\version.dll")]"#;
9
10/// ForwardModule 派生宏。用于读取 DLL 的导出表,生成用于转发的导出函数。
11///
12/// # 使用方式:
13///
14/// ```rust,ignore
15/// use forward_dll::ForwardModule;
16///
17/// #[derive(ForwardModule)]
18/// #[forward(target = "C:\\Windows\\System32\\version.dll")]
19/// struct VersionModule;
20/// ```
21#[proc_macro_derive(ForwardModule, attributes(forward))]
22pub fn derive_forward_module(item: TokenStream) -> TokenStream {
23    let input = parse_macro_input!(item as syn::DeriveInput);
24    let forward_attr = input
25        .attrs
26        .iter()
27        .find(|i| i.path().is_ident("forward"))
28        .expect(FORWARD_ATTR_LACK_MESSAGE);
29
30    // 解析 #[forward(target = "")] 的参数。
31    let mut dll_path: Option<LitStr> = None;
32    forward_attr
33        .parse_nested_meta(|meta| {
34            let path = &meta.path;
35            if path.is_ident("target") {
36                let value = meta.value().expect(FORWARD_ATTR_INVALID_MESSAGE);
37                dll_path = Some(value.parse().expect(FORWARD_ATTR_INVALID_MESSAGE));
38            } else {
39                return Err(meta.error(FORWARD_ATTR_INVALID_MESSAGE));
40            }
41            Ok(())
42        })
43        .expect(FORWARD_ATTR_INVALID_MESSAGE);
44
45    let dll_path = dll_path.expect(FORWARD_ATTR_INVALID_MESSAGE);
46    let exports = get_dll_export_names(dll_path.value().as_str())
47        .expect("指定的 DLL 可能是一个无效的 PE 文件");
48
49    let export_names: Vec<_> = exports.iter().map(|(_, fn_name)| fn_name).collect();
50    let export_idents: Vec<_> = exports
51        .iter()
52        .map(|(_, fn_name)| format_ident!("{fn_name}"))
53        .collect();
54    let export_count = exports.len();
55    let struct_name = input.ident;
56    let impl_code = quote! {
57        const _ : () = {
58            extern crate forward_dll as _forward_dll;
59
60            static mut _FORWARDER: _forward_dll::DllForwarder<#export_count> = _forward_dll::DllForwarder {
61                initialized: false,
62                module_handle: 0,
63                lib_name: #dll_path,
64                target_functions_address: [0; #export_count],
65                target_function_names: [#(#export_names),*],
66            };
67
68            _forward_dll::define_function!(#dll_path, _FORWARDER, 0, #(#export_idents)*);
69
70            impl _forward_dll::ForwardModule for #struct_name {
71                fn init(&self) -> _forward_dll::ForwardResult<()> {
72                    unsafe { _FORWARDER.forward_all() }
73                }
74            }
75        };
76    };
77    impl_code.into()
78}
79
80fn get_dll_export_names(dll_path: &str) -> Result<Vec<(u32, String)>, String> {
81    let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
82    let in_data = dll_file.as_slice();
83
84    let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
85    let exports = match kind {
86        object::FileKind::Pe32 => PeFile32::parse(in_data)
87            .map_err(|err| format!("Invalid pe file: {err}"))?
88            .export_table()
89            .map_err(|err| format!("Invalid pe file: {err}"))?
90            .ok_or_else(|| "No export table".to_string())?
91            .exports(),
92        object::FileKind::Pe64 => PeFile64::parse(in_data)
93            .map_err(|err| format!("Invalid pe file: {err}"))?
94            .export_table()
95            .map_err(|err| format!("Invalid pe file: {err}"))?
96            .ok_or_else(|| "No export table".to_string())?
97            .exports(),
98        _ => return Err("Invalid file".to_string()),
99    }
100    .map_err(|err| format!("Invalid file: {err}"))?;
101
102    let mut names = Vec::new();
103    for export_item in exports {
104        names.push((
105            export_item.ordinal,
106            export_item
107                .name
108                .map(String::from_utf8_lossy)
109                .map(String::from)
110                .unwrap_or_default(),
111        ));
112    }
113    Ok(names)
114}