1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
use object::read::pe::{PeFile32, PeFile64};
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{parse_macro_input, LitStr};
const FORWARD_ATTR_LACK_MESSAGE: &str =
r#"你需要添加 #[forward(target = "path/of/target_dll.dll")]"#;
const FORWARD_ATTR_INVALID_MESSAGE: &str = r#"#[forward()] 的参数格式错误,正确格式如 #[forward(target = "C:\Windows\System32\version.dll")]"#;
#[proc_macro_derive(ForwardModule, attributes(forward))]
pub fn derive_forward_module(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as syn::DeriveInput);
let forward_attr = input
.attrs
.iter()
.find(|i| i.path().is_ident("forward"))
.expect(FORWARD_ATTR_LACK_MESSAGE);
let mut dll_path: Option<LitStr> = None;
let mut has_ordinal = false;
forward_attr
.parse_nested_meta(|meta| {
let path = &meta.path;
if path.is_ident("target") {
let value = meta.value().expect(FORWARD_ATTR_INVALID_MESSAGE);
dll_path = Some(value.parse().expect(FORWARD_ATTR_INVALID_MESSAGE));
} else if path.is_ident("ordinal") {
has_ordinal = true;
} else {
return Err(meta.error(FORWARD_ATTR_INVALID_MESSAGE));
}
Ok(())
})
.expect(FORWARD_ATTR_INVALID_MESSAGE);
let dll_path = dll_path.expect(FORWARD_ATTR_INVALID_MESSAGE);
let exports = get_dll_export_names(dll_path.value().as_str())
.expect("指定的 DLL 可能是一个无效的 PE 文件");
if has_ordinal {
generate_linker_args(&exports);
}
let export_names: Vec<_> = exports.iter().map(|(_, fn_name)| fn_name).collect();
let export_definitions: Vec<_> = exports
.iter()
.map(|(_, fn_name)| {
let export_name = match has_ordinal {
true => format_ident!("_{}", fn_name),
false => format_ident!("{}", fn_name),
};
let fn_name = format_ident!("{}", fn_name);
quote! {
#export_name = #fn_name
}
})
.collect();
let export_count = exports.len();
let struct_name = input.ident;
let impl_code = quote! {
const _ : () = {
extern crate forward_dll as _forward_dll;
static mut _FORWARDER: _forward_dll::DllForwarder<#export_count> = _forward_dll::DllForwarder {
initialized: false,
module_handle: 0,
lib_name: #dll_path,
target_functions_address: [0; #export_count],
target_function_names: [#(#export_names),*],
};
_forward_dll::define_function!(#dll_path, _FORWARDER, 0, #(#export_definitions)*);
impl _forward_dll::ForwardModule for #struct_name {
fn init(&self) -> _forward_dll::ForwardResult<()> {
unsafe { _FORWARDER.forward_all() }
}
}
};
};
impl_code.into()
}
fn get_dll_export_names(dll_path: &str) -> Result<Vec<(u32, String)>, String> {
let dll_file = std::fs::read(dll_path).map_err(|err| format!("Failed to read file: {err}"))?;
let in_data = dll_file.as_slice();
let kind = object::FileKind::parse(in_data).map_err(|err| format!("Invalid file: {err}"))?;
let exports = match kind {
object::FileKind::Pe32 => PeFile32::parse(in_data)
.map_err(|err| format!("Invalid pe file: {err}"))?
.export_table()
.map_err(|err| format!("Invalid pe file: {err}"))?
.ok_or_else(|| "No export table".to_string())?
.exports(),
object::FileKind::Pe64 => PeFile64::parse(in_data)
.map_err(|err| format!("Invalid pe file: {err}"))?
.export_table()
.map_err(|err| format!("Invalid pe file: {err}"))?
.ok_or_else(|| "No export table".to_string())?
.exports(),
_ => return Err("Invalid file".to_string()),
}
.map_err(|err| format!("Invalid file: {err}"))?;
let mut names = Vec::new();
for export_item in exports {
names.push((
export_item.ordinal,
export_item
.name
.map(String::from_utf8_lossy)
.map(String::from)
.unwrap_or_default(),
));
}
Ok(names)
}
fn generate_linker_args(exports: &Vec<(u32, String)>) {
let out_dir: std::path::PathBuf = std::path::PathBuf::from(env!("OUT_DIR"))
.components()
.rev()
.skip_while(|path| {
let path = path.as_os_str().to_str().unwrap_or_default();
path == "out" || path.contains("forward-dll-derive") || path == "build"
})
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
if out_dir.is_dir() {
let ordinal_content = exports
.iter()
.map(|(ordinal, fn_name)| format!("/EXPORT:{}=_{},@{}", fn_name, fn_name, ordinal))
.collect::<Vec<_>>()
.join("\n");
let ordinal_file = out_dir.join("ordinal_link_args.txt");
let _ = std::fs::write(ordinal_file, ordinal_content);
}
}