pub mod utils;
use std::{ffi::NulError, path::PathBuf};
use implib::{def::ModuleDef, Flavor, ImportLibrary, MachineType};
use object::read::pe::{PeFile32, PeFile64};
use utils::ForeignLibrary;
pub use forward_dll_derive::ForwardModule;
use windows_sys::Win32::Foundation::HMODULE;
pub trait ForwardModule {
fn init(&self) -> ForwardResult<()>;
}
#[doc(hidden)]
#[macro_export]
macro_rules! count {
() => (0usize);
( $x:tt $($xs:tt)* ) => (1usize + $crate::count!($($xs)*));
}
#[macro_export]
macro_rules! forward_dll {
($lib:expr, $name:ident, $($proc:ident)*) => {
static mut $name: $crate::DllForwarder<{ $crate::count!($($proc)*) }> = $crate::DllForwarder {
initialized: false,
module_handle: 0,
lib_name: $lib,
target_functions_address: [
0;
$crate::count!($($proc)*)
],
target_function_names: [
$(stringify!($proc),)*
]
};
$crate::define_function!($lib, $name, 0, $($proc)*);
};
}
#[doc(hidden)]
#[macro_export]
macro_rules! define_function {
($lib:expr, $name:ident, $index:expr, ) => {};
($lib:expr, $name:ident, $index:expr, $export_name:ident = $proc:ident $($procs:tt)*) => {
const _: () = {
fn default_jumper(original_fn_addr: *const ()) -> usize {
if original_fn_addr as usize != 0 {
return original_fn_addr as usize;
}
match $crate::utils::ForeignLibrary::new($lib) {
Ok(lib) => match lib.get_proc_address(std::stringify!($proc)) {
Ok(addr) => return addr as usize,
Err(err) => eprintln!("Error: {}", err)
}
Err(err) => eprintln!("Error: {}", err)
}
exit_fn as usize
}
fn exit_fn() {
std::process::exit(1);
}
#[no_mangle]
pub extern "system" fn $export_name() -> u32 {
unsafe {
std::arch::asm!(
"push rcx",
"push rdx",
"push r8",
"push r9",
"push r10",
"push r11",
options(nostack)
);
std::arch::asm!(
"sub rsp, 28h",
"call rax",
"add rsp, 28h",
in("rax") default_jumper,
in("rcx") $name.target_functions_address[$index],
options(nostack)
);
std::arch::asm!(
"pop r11",
"pop r10",
"pop r9",
"pop r8",
"pop rdx",
"pop rcx",
"jmp rax",
options(nostack)
);
}
1
}
};
$crate::define_function!($lib, $name, ($index + 1), $($procs)*);
};
($lib:expr, $name:ident, $index:expr, $proc:ident $($procs:tt)*) => {
$crate::define_function!($lib, $name, $index, $proc=$proc $($procs)*);
};
}
#[derive(Debug)]
pub enum ForwardError {
Win32Error(&'static str, u32),
StringError(NulError),
AlreadyInitialized,
}
impl std::fmt::Display for ForwardError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
ForwardError::Win32Error(func_name, err_code) => {
write!(f, "Win32Error: {} {}", func_name, err_code)
}
ForwardError::StringError(ref err) => write!(f, "StringError: {}", err),
ForwardError::AlreadyInitialized => write!(f, "AlreadyInitialized"),
}
}
}
impl std::error::Error for ForwardError {}
pub type ForwardResult<T> = std::result::Result<T, ForwardError>;
pub struct DllForwarder<const N: usize> {
pub initialized: bool,
pub module_handle: HMODULE,
pub target_functions_address: [usize; N],
pub target_function_names: [&'static str; N],
pub lib_name: &'static str,
}
impl<const N: usize> DllForwarder<N> {
pub fn forward_all(&mut self) -> ForwardResult<()> {
if self.initialized {
return Err(ForwardError::AlreadyInitialized);
}
let lib = ForeignLibrary::new(self.lib_name)?;
for index in 0..self.target_functions_address.len() {
let addr_in_remote_module = lib.get_proc_address(self.target_function_names[index])?;
self.target_functions_address[index] = addr_in_remote_module as *const usize as usize;
}
self.module_handle = lib.into_raw();
self.initialized = true;
Ok(())
}
}
pub fn forward_dll(dll_path: &str) -> Result<(), String> {
forward_dll_with_dev_path(dll_path, dll_path)
}
pub fn forward_dll_with_dev_path(dll_path: &str, dev_dll_path: &str) -> Result<(), String> {
let exports = get_dll_export_names(dev_dll_path)?;
forward_dll_with_exports(
dll_path,
exports
.iter()
.map(|(ord, name)| (*ord, name.as_str()))
.collect::<Vec<_>>()
.as_slice(),
)
}
pub fn forward_dll_with_exports(dll_path: &str, exports: &[(u32, &str)]) -> Result<(), String> {
const SUFFIX: &str = ".dll";
let dll_path_without_ext = if dll_path.to_ascii_lowercase().ends_with(SUFFIX) {
&dll_path[..dll_path.len() - SUFFIX.len()]
} else {
dll_path
};
let out_dir = get_tmp_dir();
for (ordinal, name) in exports {
println!("cargo:rustc-link-arg=/EXPORT:{name}={dll_path_without_ext}.{name},@{ordinal}")
}
let exports_def = String::from("LIBRARY version\nEXPORTS\n")
+ exports
.iter()
.map(|(ordinal, name)| format!(" {name} @{ordinal}\n"))
.collect::<String>()
.as_str();
#[cfg(target_arch = "x86_64")]
let machine = MachineType::AMD64;
#[cfg(target_arch = "x86")]
let machine = MachineType::I386;
let mut def = ModuleDef::parse(&exports_def, machine)
.map_err(|err| format!("ImportLibrary::new error: {err}"))?;
for item in def.exports.iter_mut() {
item.symbol_name = item.name.trim_start_matches('_').to_string();
}
let lib = ImportLibrary::from_def(def, machine, Flavor::Msvc);
let version_lib_path = out_dir.join("version_proxy.lib");
let mut lib_file = std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(version_lib_path)
.map_err(|err| format!("OpenOptions::open error: {err}"))?;
lib.write_to(&mut lib_file)
.map_err(|err| format!("ImportLibrary::write_to error: {err}"))?;
println!("cargo:rustc-link-search={}", out_dir.display());
println!("cargo:rustc-link-lib=version_proxy");
Ok(())
}
fn get_tmp_dir() -> PathBuf {
std::env::var("OUT_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| {
let dir = std::env::temp_dir().join("forward-dll-libs");
if !dir.exists() {
std::fs::create_dir_all(&dir).expect("Failed to create temp dir");
}
dir
})
}
fn get_dll_export_names(dll_path: &str) -> Result<Vec<(u32, String)>, String> {
let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
let in_data = dll_file.as_slice();
let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
let exports = match kind {
object::FileKind::Pe32 => PeFile32::parse(in_data)
.map_err(|err| format!("Invalid pe file: {err}"))?
.export_table()
.map_err(|err| format!("Invalid pe file: {err}"))?
.ok_or_else(|| "No export table".to_string())?
.exports(),
object::FileKind::Pe64 => PeFile64::parse(in_data)
.map_err(|err| format!("Invalid pe file: {err}"))?
.export_table()
.map_err(|err| format!("Invalid pe file: {err}"))?
.ok_or_else(|| "No export table".to_string())?
.exports(),
_ => return Err("Invalid file".to_string()),
}
.map_err(|err| format!("Invalid file: {err}"))?;
let mut names = Vec::new();
for export_item in exports {
let export_name = export_item
.name
.map(String::from_utf8_lossy)
.map(String::from)
.unwrap_or_default();
names.push((export_item.ordinal, export_name));
}
Ok(names)
}