forward_dll/
lib.rs

1//! forward-dll 是一个辅助构造转发 DLL 的库。
2//!
3//! # Example 1
4//!
5//! 在 `build.rs` 中添加如下代码:
6//!
7//! ```rust
8//! use forward_dll::forward_dll;
9//!
10//! forward_dll("C:\\Windows\\System32\\version.dll").unwrap();
11//! ```
12//!
13//! 这将会读取目标 `DLL` 的导出表,然后使用 `cargo:rustc-*` 输出来链接到目标 DLL。这种方式可以连带 `ordinal` 一起转发。
14//!
15//! # Example 2
16//!
17//! 这种方式是在运行时动态加载目标 `DLL`,然后在导出的函数中,跳转到目标 `DLL` 的地址。
18//!
19//! ```rust
20//! use forward_dll::ForwardModule;
21//!
22//! #[derive(ForwardModule)]
23//! #[forward(target = "C:\\Windows\\system32\\version.dll")]
24//! pub struct VersionModule;
25//!
26//! const VERSION_LIB: VersionModule = VersionModule;
27//!
28//! #[no_mangle]
29//! pub extern "system" fn DllMain(_inst: isize, reason: u32, _: *const u8) -> u32 {
30//!     if reason == 1 {
31//!         println!("==> version.dll loaded");
32//!         VERSION_LIB.init().unwrap();
33//!         println!("==> version.dll initialized");
34//!     }
35//!     1
36//! }
37//! ```
38
39pub mod utils;
40
41use std::{collections::HashMap, ffi::NulError, path::PathBuf};
42
43use implib::{def::ModuleDef, Flavor, ImportLibrary, MachineType};
44use object::read::pe::{PeFile32, PeFile64};
45use utils::ForeignLibrary;
46
47pub use forward_dll_derive::ForwardModule;
48use windows_sys::Win32::Foundation::HMODULE;
49
50/// 由过程宏实现的 trait。
51pub trait ForwardModule {
52    /// 初始化转发相关的信息,如,加载目标 DLL 获取目标函数地址。
53    fn init(&self) -> ForwardResult<()>;
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! count {
59    () => (0usize);
60    ( $x:tt $($xs:tt)* ) => (1usize + $crate::count!($($xs)*));
61}
62
63/// 生成转发的导出函数,以及初始化方法,须在 DllMain 中调用初始化方法,以使生成的函数指向转发的目标函数。
64///
65/// # Examples
66///
67/// ```rust
68/// forward_dll::forward_dll!(
69///   "C:\\Windows\\system32\\version.dll",
70///   DLL_VERSION_FORWARDER,
71///   GetFileVersionInfoA
72///   GetFileVersionInfoByHandle
73///   GetFileVersionInfoExA
74///   GetFileVersionInfoExW
75///   GetFileVersionInfoSizeA
76///   GetFileVersionInfoSizeExA
77///   GetFileVersionInfoSizeExW
78///   GetFileVersionInfoSizeW
79///   GetFileVersionInfoW
80///   VerFindFileA
81///   VerFindFileW
82///   VerInstallFileA
83///   VerInstallFileW
84///   VerLanguageNameA
85///   VerLanguageNameW
86///   VerQueryValueA
87///   VerQueryValueW
88/// );
89///
90/// #[no_mangle]
91/// pub extern "system" fn DllMain(_inst: isize, reason: u32, _: *const u8) -> u32 {
92///   if reason == 1 {
93///     // 这里要自行持有底层的 version.dll 的句柄,防止被释放。
94///     let _ = forward_dll::utils::load_library("C:\\Windows\\system32\\version.dll");
95///     // 调用 forward_all 方法,建立导出函数与目标函数之间的映射关系。
96///     let _ = unsafe { DLL_VERSION_FORWARDER.forward_all() };
97///   }
98///   1
99/// }
100/// ```
101#[macro_export]
102macro_rules! forward_dll {
103    ($lib:expr, $name:ident, $($proc:ident)*) => {
104        static mut $name: $crate::DllForwarder<{ $crate::count!($($proc)*) }> = $crate::DllForwarder {
105            initialized: false,
106            module_handle: 0,
107            lib_name: $lib,
108            target_functions_address: [
109                0;
110                $crate::count!($($proc)*)
111            ],
112            target_function_names: [
113                $(stringify!($proc),)*
114            ]
115        };
116        $crate::define_function!($lib, $name, 0, $($proc)*);
117    };
118}
119
120#[doc(hidden)]
121#[macro_export]
122macro_rules! define_function {
123    ($lib:expr, $name:ident, $index:expr, ) => {};
124    ($lib:expr, $name:ident, $index:expr, $export_name:ident = $proc:ident $($procs:tt)*) => {
125        const _: () = {
126            fn default_jumper(original_fn_addr: *const ()) -> usize {
127                if original_fn_addr as usize != 0 {
128                    return original_fn_addr as usize;
129                }
130                match $crate::utils::ForeignLibrary::new($lib) {
131                    Ok(lib) => match lib.get_proc_address(std::stringify!($proc)) {
132                        Ok(addr) => return addr as usize,
133                        Err(err) => eprintln!("Error: {}", err)
134                    }
135                    Err(err) => eprintln!("Error: {}", err)
136                }
137                exit_fn as usize
138            }
139
140            fn exit_fn() {
141                std::process::exit(1);
142            }
143
144            #[no_mangle]
145            pub extern "system" fn $export_name() -> u32 {
146                #[cfg(target_arch = "x86")]
147                unsafe {
148                    std::arch::asm!(
149                        "push ecx",
150                        "call eax",
151                        "add esp, 4h",
152                        "jmp eax",
153                        in("eax") default_jumper,
154                        in("ecx") $name.target_functions_address[$index],
155                        options(nostack)
156                    );
157                }
158                #[cfg(target_arch = "x86_64")]
159                unsafe {
160                    std::arch::asm!(
161                        "push rcx",
162                        "push rdx",
163                        "push r8",
164                        "push r9",
165                        "push r10",
166                        "push r11",
167                        options(nostack)
168                    );
169                    std::arch::asm!(
170                        "sub rsp, 28h",
171                        "call rax",
172                        "add rsp, 28h",
173                        in("rax") default_jumper,
174                        in("rcx") $name.target_functions_address[$index],
175                        options(nostack)
176                    );
177                    std::arch::asm!(
178                        "pop r11",
179                        "pop r10",
180                        "pop r9",
181                        "pop r8",
182                        "pop rdx",
183                        "pop rcx",
184                        "jmp rax",
185                        options(nostack)
186                    );
187                }
188                1
189            }
190        };
191        $crate::define_function!($lib, $name, ($index + 1), $($procs)*);
192    };
193    ($lib:expr, $name:ident, $index:expr, $proc:ident $($procs:tt)*) => {
194        $crate::define_function!($lib, $name, $index, $proc=$proc $($procs)*);
195    };
196}
197
198#[derive(Debug)]
199pub enum ForwardError {
200    /// Win32 API 返回的错误。第一个值为调用的 Win32 API 函数名称,第二个为错误代码。
201    Win32Error(&'static str, u32),
202    /// 字符串编码错误。
203    StringError(NulError),
204    /// 已经初始化过了,不需要再次初始化。
205    AlreadyInitialized,
206}
207
208impl std::fmt::Display for ForwardError {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match *self {
211            ForwardError::Win32Error(func_name, err_code) => {
212                write!(f, "Win32Error: {} {}", func_name, err_code)
213            }
214            ForwardError::StringError(ref err) => write!(f, "StringError: {}", err),
215            ForwardError::AlreadyInitialized => write!(f, "AlreadyInitialized"),
216        }
217    }
218}
219
220impl std::error::Error for ForwardError {}
221
222pub type ForwardResult<T> = std::result::Result<T, ForwardError>;
223
224/// DLL 转发类型的具体实现。该类型不要自己实例化,应调用 forward_dll 宏生成具体的实例。
225pub struct DllForwarder<const N: usize> {
226    pub initialized: bool,
227    pub module_handle: HMODULE,
228    pub target_functions_address: [usize; N],
229    pub target_function_names: [&'static str; N],
230    pub lib_name: &'static str,
231}
232
233impl<const N: usize> DllForwarder<N> {
234    /// 将所有函数的跳转地址设置为对应的 DLL 的同名函数地址。
235    pub fn forward_all(&mut self) -> ForwardResult<()> {
236        if self.initialized {
237            return Err(ForwardError::AlreadyInitialized);
238        }
239
240        let lib = ForeignLibrary::new(self.lib_name)?;
241        for index in 0..self.target_functions_address.len() {
242            let addr_in_remote_module = lib.get_proc_address(self.target_function_names[index])?;
243            self.target_functions_address[index] = addr_in_remote_module as *const usize as usize;
244        }
245
246        self.module_handle = lib.into_raw();
247        self.initialized = true;
248
249        Ok(())
250    }
251}
252
253struct ExportItem {
254    ordinal: u32,
255    name: Option<String>,
256}
257
258/// 转发目标 `DLL` 的所有函数,同时会确保 `ordinal` 与目标函数一致。这个函数会读取目标 `DLL` 以获得导出函数信息,因此,要确保目标 `DLL` 在编译期存在。
259pub fn forward_dll(dll_path: &str) -> Result<(), String> {
260    forward_dll_with_dev_path(dll_path, dll_path)
261}
262
263/// 转发目标 `DLL` 的所有函数。与 `forward_dll` 类似,区别在于这个函数可以指定在编译时的目标 `DLL` 路径。
264pub fn forward_dll_with_dev_path(dll_path: &str, dev_dll_path: &str) -> Result<(), String> {
265    let exports = get_dll_export_names(dev_dll_path)?;
266    forward_dll_impl(dll_path, exports.as_slice())
267}
268
269/// 转发目标 `DLL` 的所有函数。与 `forward_dll` 类似,区别在于这个函数不要求在编译期存在 dll。
270pub fn forward_dll_with_exports(dll_path: &str, exports: &[(u32, &str)]) -> Result<(), String> {
271    forward_dll_impl(
272        dll_path,
273        exports
274            .iter()
275            .map(|(ord, name)| ExportItem {
276                ordinal: *ord,
277                name: Some(name.to_string()),
278            })
279            .collect::<Vec<_>>()
280            .as_slice(),
281    )
282}
283
284fn forward_dll_impl(dll_path: &str, exports: &[ExportItem]) -> Result<(), String> {
285    const SUFFIX: &str = ".dll";
286    let dll_path_without_ext = if dll_path.to_ascii_lowercase().ends_with(SUFFIX) {
287        &dll_path[..dll_path.len() - SUFFIX.len()]
288    } else {
289        dll_path
290    };
291
292    let out_dir = get_tmp_dir();
293
294    // 有些导出符号没有名称,在编译的过程中,临时取一个符号名。
295    let mut anonymous_map = HashMap::new();
296    let mut anonymous_name_id = 0;
297
298    // 输出链接参数,转发入口点到目标库。
299    for ExportItem { name, ordinal } in exports {
300        match name {
301            Some(name) => println!(
302                "cargo:rustc-link-arg=/EXPORT:{name}={dll_path_without_ext}.{name},@{ordinal}"
303            ),
304            None => {
305                anonymous_name_id += 1;
306                let fn_name = format!("forward_dll_anonymous_{anonymous_name_id}");
307                println!(
308                    "cargo:rustc-link-arg=/EXPORT:{fn_name}={dll_path_without_ext}.#{ordinal},@{ordinal},NONAME"
309                );
310                anonymous_map.insert(ordinal, fn_name);
311            }
312        };
313    }
314
315    // 构造 Import Library。
316    let exports_def = String::from("LIBRARY version\nEXPORTS\n")
317        + exports
318            .iter()
319            .map(|ExportItem { name, ordinal }| match name {
320                Some(name) => format!("  {name} @{ordinal}\n"),
321                None => {
322                    let fn_name = anonymous_map.get(ordinal).unwrap();
323                    format!("  {fn_name} @{ordinal} NONAME\n")
324                }
325            })
326            .collect::<String>()
327            .as_str();
328    #[cfg(target_arch = "x86_64")]
329    let machine = MachineType::AMD64;
330    #[cfg(target_arch = "x86")]
331    let machine = MachineType::I386;
332    let mut def = ModuleDef::parse(&exports_def, machine)
333        .map_err(|err| format!("ImportLibrary::new error: {err}"))?;
334    for item in def.exports.iter_mut() {
335        item.symbol_name = item.name.trim_start_matches('_').to_string();
336    }
337    let lib = ImportLibrary::from_def(def, machine, Flavor::Msvc);
338    let version_lib_path = out_dir.join("version_proxy.lib");
339    let mut lib_file = std::fs::OpenOptions::new()
340        .create(true)
341        .write(true)
342        .truncate(true)
343        .open(version_lib_path)
344        .map_err(|err| format!("OpenOptions::open error: {err}"))?;
345    lib.write_to(&mut lib_file)
346        .map_err(|err| format!("ImportLibrary::write_to error: {err}"))?;
347
348    println!("cargo:rustc-link-search={}", out_dir.display());
349    println!("cargo:rustc-link-lib=version_proxy");
350
351    Ok(())
352}
353
354/// 查询 OUT_DIR 变量,作为创建的 Import Library 存储路径。如果是在 doctest 的上下文中,是取不到 OUT_DIR 的。
355fn get_tmp_dir() -> PathBuf {
356    std::env::var("OUT_DIR")
357        .map(PathBuf::from)
358        .unwrap_or_else(|_| {
359            let dir = std::env::temp_dir().join("forward-dll-libs");
360            if !dir.exists() {
361                std::fs::create_dir_all(&dir).expect("Failed to create temp dir");
362            }
363            dir
364        })
365}
366
367fn get_dll_export_names(dll_path: &str) -> Result<Vec<ExportItem>, String> {
368    let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
369    let in_data = dll_file.as_slice();
370
371    let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
372    let exports = match kind {
373        object::FileKind::Pe32 => PeFile32::parse(in_data)
374            .map_err(|err| format!("Invalid pe file: {err}"))?
375            .export_table()
376            .map_err(|err| format!("Invalid pe file: {err}"))?
377            .ok_or_else(|| "No export table".to_string())?
378            .exports(),
379        object::FileKind::Pe64 => PeFile64::parse(in_data)
380            .map_err(|err| format!("Invalid pe file: {err}"))?
381            .export_table()
382            .map_err(|err| format!("Invalid pe file: {err}"))?
383            .ok_or_else(|| "No export table".to_string())?
384            .exports(),
385        _ => return Err("Invalid file".to_string()),
386    }
387    .map_err(|err| format!("Invalid file: {err}"))?;
388
389    let mut export_list = Vec::new();
390    for export_item in exports {
391        let ordinal = export_item.ordinal;
392        let name = export_item
393            .name
394            .map(String::from_utf8_lossy)
395            .map(String::from);
396        let item = ExportItem { name, ordinal };
397        export_list.push(item);
398    }
399    Ok(export_list)
400}