1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
//! forward-dll 是一个辅助构造转发 DLL 的库。
//!
//! # Example 1
//!
//! 在 `build.rs` 中添加如下代码:
//!
//! ```rust
//! use forward_dll::forward_dll;
//!
//! forward_dll("C:\\Windows\\System32\\version.dll").unwrap();
//! ```
//!
//! 这将会读取目标 `DLL` 的导出表,然后使用 `cargo:rustc-*` 输出来链接到目标 DLL。这种方式可以连带 `ordinal` 一起转发。
//!
//! # Example 2
//!
//! 这种方式是在运行时动态加载目标 `DLL`,然后在导出的函数中,跳转到目标 `DLL` 的地址。
//!
//! ```rust
//! use forward_dll::ForwardModule;
//!
//! #[derive(ForwardModule)]
//! #[forward(target = "C:\\Windows\\system32\\version.dll")]
//! pub struct VersionModule;
//!
//! const VERSION_LIB: VersionModule = VersionModule;
//!
//! #[no_mangle]
//! pub extern "system" fn DllMain(_inst: isize, reason: u32, _: *const u8) -> u32 {
//!     if reason == 1 {
//!         println!("==> version.dll loaded");
//!         VERSION_LIB.init().unwrap();
//!         println!("==> version.dll initialized");
//!     }
//!     1
//! }
//! ```

pub mod utils;

use std::{ffi::NulError, path::PathBuf};

use implib::{def::ModuleDef, Flavor, ImportLibrary, MachineType};
use object::read::pe::{PeFile32, PeFile64};
use utils::ForeignLibrary;

pub use forward_dll_derive::ForwardModule;
use windows_sys::Win32::Foundation::HMODULE;

/// 由过程宏实现的 trait。
pub trait ForwardModule {
    /// 初始化转发相关的信息,如,加载目标 DLL 获取目标函数地址。
    fn init(&self) -> ForwardResult<()>;
}

#[doc(hidden)]
#[macro_export]
macro_rules! count {
    () => (0usize);
    ( $x:tt $($xs:tt)* ) => (1usize + $crate::count!($($xs)*));
}

/// 生成转发的导出函数,以及初始化方法,须在 DllMain 中调用初始化方法,以使生成的函数指向转发的目标函数。
///
/// # Examples
///
/// ```rust
/// forward_dll::forward_dll!(
///   "C:\\Windows\\system32\\version.dll",
///   DLL_VERSION_FORWARDER,
///   GetFileVersionInfoA
///   GetFileVersionInfoByHandle
///   GetFileVersionInfoExA
///   GetFileVersionInfoExW
///   GetFileVersionInfoSizeA
///   GetFileVersionInfoSizeExA
///   GetFileVersionInfoSizeExW
///   GetFileVersionInfoSizeW
///   GetFileVersionInfoW
///   VerFindFileA
///   VerFindFileW
///   VerInstallFileA
///   VerInstallFileW
///   VerLanguageNameA
///   VerLanguageNameW
///   VerQueryValueA
///   VerQueryValueW
/// );
///
/// #[no_mangle]
/// pub extern "system" fn DllMain(_inst: isize, reason: u32, _: *const u8) -> u32 {
///   if reason == 1 {
///     // 这里要自行持有底层的 version.dll 的句柄,防止被释放。
///     let _ = forward_dll::utils::load_library("C:\\Windows\\system32\\version.dll");
///     // 调用 forward_all 方法,建立导出函数与目标函数之间的映射关系。
///     let _ = unsafe { DLL_VERSION_FORWARDER.forward_all() };
///   }
///   1
/// }
/// ```
#[macro_export]
macro_rules! forward_dll {
    ($lib:expr, $name:ident, $($proc:ident)*) => {
        static mut $name: $crate::DllForwarder<{ $crate::count!($($proc)*) }> = $crate::DllForwarder {
            initialized: false,
            module_handle: 0,
            lib_name: $lib,
            target_functions_address: [
                0;
                $crate::count!($($proc)*)
            ],
            target_function_names: [
                $(stringify!($proc),)*
            ]
        };
        $crate::define_function!($lib, $name, 0, $($proc)*);
    };
}

#[doc(hidden)]
#[macro_export]
macro_rules! define_function {
    ($lib:expr, $name:ident, $index:expr, ) => {};
    ($lib:expr, $name:ident, $index:expr, $export_name:ident = $proc:ident $($procs:tt)*) => {
        const _: () = {
            fn default_jumper(original_fn_addr: *const ()) -> usize {
                if original_fn_addr as usize != 0 {
                    return original_fn_addr as usize;
                }
                match $crate::utils::ForeignLibrary::new($lib) {
                    Ok(lib) => match lib.get_proc_address(std::stringify!($proc)) {
                        Ok(addr) => return addr as usize,
                        Err(err) => eprintln!("Error: {}", err)
                    }
                    Err(err) => eprintln!("Error: {}", err)
                }
                exit_fn as usize
            }

            fn exit_fn() {
                std::process::exit(1);
            }

            #[no_mangle]
            pub extern "system" fn $export_name() -> u32 {
                unsafe {
                    std::arch::asm!(
                        "push rcx",
                        "push rdx",
                        "push r8",
                        "push r9",
                        "push r10",
                        "push r11",
                        options(nostack)
                    );
                    std::arch::asm!(
                        "sub rsp, 28h",
                        "call rax",
                        "add rsp, 28h",
                        in("rax") default_jumper,
                        in("rcx") $name.target_functions_address[$index],
                        options(nostack)
                    );
                    std::arch::asm!(
                        "pop r11",
                        "pop r10",
                        "pop r9",
                        "pop r8",
                        "pop rdx",
                        "pop rcx",
                        "jmp rax",
                        options(nostack)
                    );
                }
                1
            }
        };
        $crate::define_function!($lib, $name, ($index + 1), $($procs)*);
    };
    ($lib:expr, $name:ident, $index:expr, $proc:ident $($procs:tt)*) => {
        $crate::define_function!($lib, $name, $index, $proc=$proc $($procs)*);
    };
}

#[derive(Debug)]
pub enum ForwardError {
    /// Win32 API 返回的错误。第一个值为调用的 Win32 API 函数名称,第二个为错误代码。
    Win32Error(&'static str, u32),
    /// 字符串编码错误。
    StringError(NulError),
    /// 已经初始化过了,不需要再次初始化。
    AlreadyInitialized,
}

impl std::fmt::Display for ForwardError {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match *self {
            ForwardError::Win32Error(func_name, err_code) => {
                write!(f, "Win32Error: {} {}", func_name, err_code)
            }
            ForwardError::StringError(ref err) => write!(f, "StringError: {}", err),
            ForwardError::AlreadyInitialized => write!(f, "AlreadyInitialized"),
        }
    }
}

impl std::error::Error for ForwardError {}

pub type ForwardResult<T> = std::result::Result<T, ForwardError>;

/// DLL 转发类型的具体实现。该类型不要自己实例化,应调用 forward_dll 宏生成具体的实例。
pub struct DllForwarder<const N: usize> {
    pub initialized: bool,
    pub module_handle: HMODULE,
    pub target_functions_address: [usize; N],
    pub target_function_names: [&'static str; N],
    pub lib_name: &'static str,
}

impl<const N: usize> DllForwarder<N> {
    /// 将所有函数的跳转地址设置为对应的 DLL 的同名函数地址。
    pub fn forward_all(&mut self) -> ForwardResult<()> {
        if self.initialized {
            return Err(ForwardError::AlreadyInitialized);
        }

        let lib = ForeignLibrary::new(self.lib_name)?;
        for index in 0..self.target_functions_address.len() {
            let addr_in_remote_module = lib.get_proc_address(self.target_function_names[index])?;
            self.target_functions_address[index] = addr_in_remote_module as *const usize as usize;
        }

        self.module_handle = lib.into_raw();
        self.initialized = true;

        Ok(())
    }
}

/// 转发目标 `DLL` 的所有函数,同时会确保 `ordinal` 与目标函数一致。这个函数会读取目标 `DLL` 以获得导出函数信息,因此,要确保目标 `DLL` 在编译期存在。
pub fn forward_dll(dll_path: &str) -> Result<(), String> {
    forward_dll_with_dev_path(dll_path, dll_path)
}

/// 转发目标 `DLL` 的所有函数。与 `forward_dll` 类似,区别在于这个函数可以指定在编译时的目标 `DLL` 路径。
pub fn forward_dll_with_dev_path(dll_path: &str, dev_dll_path: &str) -> Result<(), String> {
    let exports = get_dll_export_names(dev_dll_path)?;
    forward_dll_with_exports(
        dll_path,
        exports
            .iter()
            .map(|(ord, name)| (*ord, name.as_str()))
            .collect::<Vec<_>>()
            .as_slice(),
    )
}

/// 转发目标 `DLL` 的所有函数。与 `forward_dll` 类似,区别在于这个函数不要求在编译期存在 dll。
pub fn forward_dll_with_exports(dll_path: &str, exports: &[(u32, &str)]) -> Result<(), String> {
    const SUFFIX: &str = ".dll";
    let dll_path_without_ext = if dll_path.to_ascii_lowercase().ends_with(SUFFIX) {
        &dll_path[..dll_path.len() - SUFFIX.len()]
    } else {
        dll_path
    };

    let out_dir = get_tmp_dir();

    // 输出链接参数,转发入口点到目标库。
    for (ordinal, name) in exports {
        println!("cargo:rustc-link-arg=/EXPORT:{name}={dll_path_without_ext}.{name},@{ordinal}")
    }

    // 构造 Import Library。
    let exports_def = String::from("LIBRARY version\nEXPORTS\n")
        + exports
            .iter()
            .map(|(ordinal, name)| format!("  {name} @{ordinal}\n"))
            .collect::<String>()
            .as_str();
    #[cfg(target_arch = "x86_64")]
    let machine = MachineType::AMD64;
    #[cfg(target_arch = "x86")]
    let machine = MachineType::I386;
    let mut def = ModuleDef::parse(&exports_def, machine)
        .map_err(|err| format!("ImportLibrary::new error: {err}"))?;
    for item in def.exports.iter_mut() {
        item.symbol_name = item.name.trim_start_matches('_').to_string();
    }
    let lib = ImportLibrary::from_def(def, machine, Flavor::Msvc);
    let version_lib_path = out_dir.join("version_proxy.lib");
    let mut lib_file = std::fs::OpenOptions::new()
        .create(true)
        .write(true)
        .truncate(true)
        .open(version_lib_path)
        .map_err(|err| format!("OpenOptions::open error: {err}"))?;
    lib.write_to(&mut lib_file)
        .map_err(|err| format!("ImportLibrary::write_to error: {err}"))?;

    println!("cargo:rustc-link-search={}", out_dir.display());
    println!("cargo:rustc-link-lib=version_proxy");

    Ok(())
}

/// 查询 OUT_DIR 变量,作为创建的 Import Library 存储路径。如果是在 doctest 的上下文中,是取不到 OUT_DIR 的。
fn get_tmp_dir() -> PathBuf {
    std::env::var("OUT_DIR")
        .map(PathBuf::from)
        .unwrap_or_else(|_| {
            let dir = std::env::temp_dir().join("forward-dll-libs");
            if !dir.exists() {
                std::fs::create_dir_all(&dir).expect("Failed to create temp dir");
            }
            dir
        })
}

fn get_dll_export_names(dll_path: &str) -> Result<Vec<(u32, String)>, String> {
    let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
    let in_data = dll_file.as_slice();

    let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
    let exports = match kind {
        object::FileKind::Pe32 => PeFile32::parse(in_data)
            .map_err(|err| format!("Invalid pe file: {err}"))?
            .export_table()
            .map_err(|err| format!("Invalid pe file: {err}"))?
            .ok_or_else(|| "No export table".to_string())?
            .exports(),
        object::FileKind::Pe64 => PeFile64::parse(in_data)
            .map_err(|err| format!("Invalid pe file: {err}"))?
            .export_table()
            .map_err(|err| format!("Invalid pe file: {err}"))?
            .ok_or_else(|| "No export table".to_string())?
            .exports(),
        _ => return Err("Invalid file".to_string()),
    }
    .map_err(|err| format!("Invalid file: {err}"))?;

    let mut names = Vec::new();
    for export_item in exports {
        let export_name = export_item
            .name
            .map(String::from_utf8_lossy)
            .map(String::from)
            .unwrap_or_default();
        names.push((export_item.ordinal, export_name));
    }
    Ok(names)
}