1pub mod utils;
40
41use std::{collections::HashMap, ffi::NulError, path::PathBuf};
42
43use implib::{def::ModuleDef, Flavor, ImportLibrary, MachineType};
44use object::read::pe::{PeFile32, PeFile64};
45use utils::ForeignLibrary;
46
47pub use forward_dll_derive::ForwardModule;
48use windows_sys::Win32::Foundation::HMODULE;
49
50pub trait ForwardModule {
52 fn init(&self) -> ForwardResult<()>;
54}
55
56#[doc(hidden)]
57#[macro_export]
58macro_rules! count {
59 () => (0usize);
60 ( $x:tt $($xs:tt)* ) => (1usize + $crate::count!($($xs)*));
61}
62
63#[macro_export]
102macro_rules! forward_dll {
103 ($lib:expr, $name:ident, $($proc:ident)*) => {
104 static mut $name: $crate::DllForwarder<{ $crate::count!($($proc)*) }> = $crate::DllForwarder {
105 initialized: false,
106 module_handle: 0,
107 lib_name: $lib,
108 target_functions_address: [
109 0;
110 $crate::count!($($proc)*)
111 ],
112 target_function_names: [
113 $(stringify!($proc),)*
114 ]
115 };
116 $crate::define_function!($lib, $name, 0, $($proc)*);
117 };
118}
119
120#[doc(hidden)]
121#[macro_export]
122macro_rules! define_function {
123 ($lib:expr, $name:ident, $index:expr, ) => {};
124 ($lib:expr, $name:ident, $index:expr, $export_name:ident = $proc:ident $($procs:tt)*) => {
125 const _: () = {
126 fn default_jumper(original_fn_addr: *const ()) -> usize {
127 if original_fn_addr as usize != 0 {
128 return original_fn_addr as usize;
129 }
130 match $crate::utils::ForeignLibrary::new($lib) {
131 Ok(lib) => match lib.get_proc_address(std::stringify!($proc)) {
132 Ok(addr) => return addr as usize,
133 Err(err) => eprintln!("Error: {}", err)
134 }
135 Err(err) => eprintln!("Error: {}", err)
136 }
137 exit_fn as usize
138 }
139
140 fn exit_fn() {
141 std::process::exit(1);
142 }
143
144 #[no_mangle]
145 pub extern "system" fn $export_name() -> u32 {
146 #[cfg(target_arch = "x86")]
147 unsafe {
148 std::arch::asm!(
149 "push ecx",
150 "call eax",
151 "add esp, 4h",
152 "jmp eax",
153 in("eax") default_jumper,
154 in("ecx") $name.target_functions_address[$index],
155 options(nostack)
156 );
157 }
158 #[cfg(target_arch = "x86_64")]
159 unsafe {
160 std::arch::asm!(
161 "push rcx",
162 "push rdx",
163 "push r8",
164 "push r9",
165 "push r10",
166 "push r11",
167 options(nostack)
168 );
169 std::arch::asm!(
170 "sub rsp, 28h",
171 "call rax",
172 "add rsp, 28h",
173 in("rax") default_jumper,
174 in("rcx") $name.target_functions_address[$index],
175 options(nostack)
176 );
177 std::arch::asm!(
178 "pop r11",
179 "pop r10",
180 "pop r9",
181 "pop r8",
182 "pop rdx",
183 "pop rcx",
184 "jmp rax",
185 options(nostack)
186 );
187 }
188 1
189 }
190 };
191 $crate::define_function!($lib, $name, ($index + 1), $($procs)*);
192 };
193 ($lib:expr, $name:ident, $index:expr, $proc:ident $($procs:tt)*) => {
194 $crate::define_function!($lib, $name, $index, $proc=$proc $($procs)*);
195 };
196}
197
198#[derive(Debug)]
199pub enum ForwardError {
200 Win32Error(&'static str, u32),
202 StringError(NulError),
204 AlreadyInitialized,
206}
207
208impl std::fmt::Display for ForwardError {
209 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210 match *self {
211 ForwardError::Win32Error(func_name, err_code) => {
212 write!(f, "Win32Error: {} {}", func_name, err_code)
213 }
214 ForwardError::StringError(ref err) => write!(f, "StringError: {}", err),
215 ForwardError::AlreadyInitialized => write!(f, "AlreadyInitialized"),
216 }
217 }
218}
219
220impl std::error::Error for ForwardError {}
221
222pub type ForwardResult<T> = std::result::Result<T, ForwardError>;
223
224pub struct DllForwarder<const N: usize> {
226 pub initialized: bool,
227 pub module_handle: HMODULE,
228 pub target_functions_address: [usize; N],
229 pub target_function_names: [&'static str; N],
230 pub lib_name: &'static str,
231}
232
233impl<const N: usize> DllForwarder<N> {
234 pub fn forward_all(&mut self) -> ForwardResult<()> {
236 if self.initialized {
237 return Err(ForwardError::AlreadyInitialized);
238 }
239
240 let lib = ForeignLibrary::new(self.lib_name)?;
241 for index in 0..self.target_functions_address.len() {
242 let addr_in_remote_module = lib.get_proc_address(self.target_function_names[index])?;
243 self.target_functions_address[index] = addr_in_remote_module as *const usize as usize;
244 }
245
246 self.module_handle = lib.into_raw();
247 self.initialized = true;
248
249 Ok(())
250 }
251}
252
253struct ExportItem {
254 ordinal: u32,
255 name: Option<String>,
256}
257
258pub fn forward_dll(dll_path: &str) -> Result<(), String> {
260 forward_dll_with_dev_path(dll_path, dll_path)
261}
262
263pub fn forward_dll_with_dev_path(dll_path: &str, dev_dll_path: &str) -> Result<(), String> {
265 let exports = get_dll_export_names(dev_dll_path)?;
266 forward_dll_impl(dll_path, exports.as_slice())
267}
268
269pub fn forward_dll_with_exports(dll_path: &str, exports: &[(u32, &str)]) -> Result<(), String> {
271 forward_dll_impl(
272 dll_path,
273 exports
274 .iter()
275 .map(|(ord, name)| ExportItem {
276 ordinal: *ord,
277 name: Some(name.to_string()),
278 })
279 .collect::<Vec<_>>()
280 .as_slice(),
281 )
282}
283
284fn forward_dll_impl(dll_path: &str, exports: &[ExportItem]) -> Result<(), String> {
285 const SUFFIX: &str = ".dll";
286 let dll_path_without_ext = if dll_path.to_ascii_lowercase().ends_with(SUFFIX) {
287 &dll_path[..dll_path.len() - SUFFIX.len()]
288 } else {
289 dll_path
290 };
291
292 let out_dir = get_tmp_dir();
293
294 let mut anonymous_map = HashMap::new();
296 let mut anonymous_name_id = 0;
297
298 for ExportItem { name, ordinal } in exports {
300 match name {
301 Some(name) => println!(
302 "cargo:rustc-link-arg=/EXPORT:{name}={dll_path_without_ext}.{name},@{ordinal}"
303 ),
304 None => {
305 anonymous_name_id += 1;
306 let fn_name = format!("forward_dll_anonymous_{anonymous_name_id}");
307 println!(
308 "cargo:rustc-link-arg=/EXPORT:{fn_name}={dll_path_without_ext}.#{ordinal},@{ordinal},NONAME"
309 );
310 anonymous_map.insert(ordinal, fn_name);
311 }
312 };
313 }
314
315 let exports_def = String::from("LIBRARY version\nEXPORTS\n")
317 + exports
318 .iter()
319 .map(|ExportItem { name, ordinal }| match name {
320 Some(name) => format!(" {name} @{ordinal}\n"),
321 None => {
322 let fn_name = anonymous_map.get(ordinal).unwrap();
323 format!(" {fn_name} @{ordinal} NONAME\n")
324 }
325 })
326 .collect::<String>()
327 .as_str();
328 #[cfg(target_arch = "x86_64")]
329 let machine = MachineType::AMD64;
330 #[cfg(target_arch = "x86")]
331 let machine = MachineType::I386;
332 let mut def = ModuleDef::parse(&exports_def, machine)
333 .map_err(|err| format!("ImportLibrary::new error: {err}"))?;
334 for item in def.exports.iter_mut() {
335 item.symbol_name = item.name.trim_start_matches('_').to_string();
336 }
337 let lib = ImportLibrary::from_def(def, machine, Flavor::Msvc);
338 let version_lib_path = out_dir.join("version_proxy.lib");
339 let mut lib_file = std::fs::OpenOptions::new()
340 .create(true)
341 .write(true)
342 .truncate(true)
343 .open(version_lib_path)
344 .map_err(|err| format!("OpenOptions::open error: {err}"))?;
345 lib.write_to(&mut lib_file)
346 .map_err(|err| format!("ImportLibrary::write_to error: {err}"))?;
347
348 println!("cargo:rustc-link-search={}", out_dir.display());
349 println!("cargo:rustc-link-lib=version_proxy");
350
351 Ok(())
352}
353
354fn get_tmp_dir() -> PathBuf {
356 std::env::var("OUT_DIR")
357 .map(PathBuf::from)
358 .unwrap_or_else(|_| {
359 let dir = std::env::temp_dir().join("forward-dll-libs");
360 if !dir.exists() {
361 std::fs::create_dir_all(&dir).expect("Failed to create temp dir");
362 }
363 dir
364 })
365}
366
367fn get_dll_export_names(dll_path: &str) -> Result<Vec<ExportItem>, String> {
368 let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
369 let in_data = dll_file.as_slice();
370
371 let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
372 let exports = match kind {
373 object::FileKind::Pe32 => PeFile32::parse(in_data)
374 .map_err(|err| format!("Invalid pe file: {err}"))?
375 .export_table()
376 .map_err(|err| format!("Invalid pe file: {err}"))?
377 .ok_or_else(|| "No export table".to_string())?
378 .exports(),
379 object::FileKind::Pe64 => PeFile64::parse(in_data)
380 .map_err(|err| format!("Invalid pe file: {err}"))?
381 .export_table()
382 .map_err(|err| format!("Invalid pe file: {err}"))?
383 .ok_or_else(|| "No export table".to_string())?
384 .exports(),
385 _ => return Err("Invalid file".to_string()),
386 }
387 .map_err(|err| format!("Invalid file: {err}"))?;
388
389 let mut export_list = Vec::new();
390 for export_item in exports {
391 let ordinal = export_item.ordinal;
392 let name = export_item
393 .name
394 .map(String::from_utf8_lossy)
395 .map(String::from);
396 let item = ExportItem { name, ordinal };
397 export_list.push(item);
398 }
399 Ok(export_list)
400}