1use std::{collections::HashSet, path::Path, str::Utf8Error};
2
3use convert_case::{Case, Casing};
4use itertools::{Either, Itertools};
5use object::Object;
6use proc_macro2::{Span, TokenStream};
7use quote::{ToTokens, quote};
8use syn::{
9 Attribute, Error, FnArg, Ident, Item, ItemFn, ItemMod, LitStr, PatType, Signature,
10 parse_macro_input, parse_quote, spanned::Spanned,
11};
12
13#[proc_macro_attribute]
14pub fn shim(
15 attr: proc_macro::TokenStream,
16 item: proc_macro::TokenStream,
17) -> proc_macro::TokenStream {
18 let name = parse_macro_input!(attr as LitStr);
19
20 let Some(library) = Library::load(name.value()) else {
21 panic!("Failed to load library");
22 };
23
24 let mut module = parse_macro_input!(item as ItemMod);
25
26 let mut ctx = Context {
27 library,
28 init_fn: None,
29 hook_fns: Vec::new(),
30 };
31
32 let Some((_, content)) = &mut module.content else {
33 return module.into_token_stream().into();
34 };
35
36 for item in content.iter_mut() {
37 let result = match item {
38 Item::Fn(item_fn) => handle_item_fn(&mut ctx, item_fn),
39 _ => Ok(()),
40 };
41
42 if let Err(errors) = result {
43 return errors
44 .into_iter()
45 .map(|error| error.to_compile_error())
46 .collect::<TokenStream>()
47 .into();
48 }
49 }
50
51 content.push({
52 let original_mod = OriginalModule { ctx: &ctx };
53 parse_quote! { #original_mod }
54 });
55
56 module.into_token_stream().into()
57}
58
59struct Context {
60 library: Library,
61 init_fn: Option<InitFn>,
62 hook_fns: Vec<HookFn>,
63}
64
65fn handle_item_fn(ctx: &mut Context, item_fn: &mut ItemFn) -> Result<(), Vec<Error>> {
66 let Some((kind, attr)) = parse_attrs(item_fn)? else {
67 return Ok(());
68 };
69
70 match kind {
71 AttributeKind::Init => handle_init_fn(ctx, item_fn, &attr),
72 AttributeKind::Hook => handle_hook_fn(ctx, item_fn),
73 }
74}
75
76fn parse_attrs(item_fn: &mut ItemFn) -> Result<Option<(AttributeKind, Attribute)>, Vec<Error>> {
77 let (parsed_attrs, attrs): (Vec<_>, Vec<_>) = std::mem::take(&mut item_fn.attrs)
78 .into_iter()
79 .partition_map(|attr| match AttributeKind::try_from(&attr) {
80 Ok(kind) => Either::Left((kind, attr)),
81 Err(_) => Either::Right(attr),
82 });
83
84 item_fn.attrs = attrs;
85 let mut parsed_attrs = parsed_attrs.into_iter();
86
87 let Some(parsed_attr) = parsed_attrs.next() else {
88 return Ok(None);
89 };
90
91 let errors: Vec<_> = parsed_attrs
92 .map(|(_, attr)| {
93 Error::new(
94 attr.span(),
95 "Only one `init` or `hook` attribute is allowed per function",
96 )
97 })
98 .collect();
99
100 if !errors.is_empty() {
101 return Err(errors);
102 }
103
104 Ok(Some(parsed_attr))
105}
106
107fn handle_init_fn(
108 ctx: &mut Context,
109 item_fn: &mut ItemFn,
110 attr: &Attribute,
111) -> Result<(), Vec<Error>> {
112 if ctx.init_fn.is_some() {
113 return Err(vec![Error::new(
114 attr.span(),
115 "There can only be one `init` function",
116 )]);
117 }
118
119 item_fn.attrs.push(parse_quote!(#[allow(dead_code)]));
120
121 ctx.init_fn = Some(InitFn {
122 sig: item_fn.sig.clone(),
123 });
124
125 Ok(())
126}
127
128fn handle_hook_fn(ctx: &mut Context, item_fn: &mut ItemFn) -> Result<(), Vec<Error>> {
129 let export = item_fn.sig.ident.to_string().as_str().into();
130
131 if !ctx.library.exports.contains(&export) {
132 return Err(vec![Error::new(
133 item_fn.sig.ident.span(),
134 format!("Function is not an exported symbol in {}", ctx.library.name),
135 )]);
136 }
137
138 item_fn.attrs.push(parse_quote!(#[unsafe(no_mangle)]));
139 item_fn.attrs.push(parse_quote!(#[allow(non_snake_case)]));
140
141 ctx.hook_fns.push(HookFn {
142 sig: item_fn.sig.clone(),
143 export,
144 });
145
146 Ok(())
147}
148
149enum AttributeKind {
150 Init,
151 Hook,
152}
153
154impl TryFrom<&Attribute> for AttributeKind {
155 type Error = ();
156
157 fn try_from(value: &Attribute) -> Result<Self, Self::Error> {
158 if value.path().is_ident("init") {
159 Ok(Self::Init)
160 } else if value.path().is_ident("hook") {
161 Ok(Self::Hook)
162 } else {
163 Err(())
164 }
165 }
166}
167
168struct Library {
169 name: String,
170 exports: HashSet<Export>,
171}
172
173impl Library {
174 fn load(name: String) -> Option<Self> {
175 let separator = if cfg!(windows) { ';' } else { ':' };
176
177 let path = std::env::var("PATH")
178 .ok()?
179 .split(separator)
180 .map(|directory| Path::new(directory).join(&name))
181 .find(|path| path.exists())?;
182
183 let data = std::fs::read(&path).ok()?;
184
185 let exports = object::File::parse(data.as_slice())
186 .ok()?
187 .exports()
188 .ok()?
189 .into_iter()
190 .filter_map(|export| Export::try_from(&export).ok())
191 .collect();
192
193 Some(Self { name, exports })
194 }
195
196 fn lit_str(&self, span: Span) -> LitStr {
197 LitStr::new(&self.name, span)
198 }
199}
200
201#[derive(PartialEq, Eq, Hash)]
202struct Export {
203 name: String,
204}
205
206impl Export {
207 fn ident(&self, span: Span) -> Ident {
208 Ident::new(&self.name, span)
209 }
210
211 fn lit_str(&self, span: Span) -> LitStr {
212 LitStr::new(&self.name, span)
213 }
214
215 fn address(&self) -> ExportAddress {
216 ExportAddress { export: self }
217 }
218}
219
220impl From<&str> for Export {
221 fn from(value: &str) -> Self {
222 Self { name: value.into() }
223 }
224}
225
226impl TryFrom<&object::Export<'_>> for Export {
227 type Error = Utf8Error;
228
229 fn try_from(value: &object::Export) -> Result<Self, Self::Error> {
230 std::str::from_utf8(value.name()).map(Into::into)
231 }
232}
233
234struct ExportAddress<'a> {
235 export: &'a Export,
236}
237
238impl ExportAddress<'_> {
239 fn ident(&self, span: Span) -> Ident {
240 Ident::new(
241 &format!("{}_ADDRESS", self.export.name.to_case(Case::UpperSnake)),
242 span,
243 )
244 }
245}
246
247impl ToTokens for ExportAddress<'_> {
248 fn to_tokens(&self, tokens: &mut TokenStream) {
249 let ident = self.ident(Span::call_site());
250
251 tokens.extend(quote! {
252 static mut #ident: usize = 0;
253 });
254 }
255}
256
257struct ShimFn<'a> {
258 export: &'a Export,
259}
260
261impl ToTokens for ShimFn<'_> {
262 fn to_tokens(&self, tokens: &mut TokenStream) {
263 let ident = self.export.ident(Span::call_site());
264 let address_ident = self.export.address().ident(Span::call_site());
265
266 tokens.extend(quote! {
267 #[unsafe(naked)]
268 #[unsafe(no_mangle)]
269 unsafe extern "system" fn #ident() {
270 std::arch::naked_asm!("jmp [rip + {}]", sym #address_ident)
271 }
272 });
273 }
274}
275
276struct LoadLibraryFn<'a> {
277 library: &'a Library,
278}
279
280impl LoadLibraryFn<'_> {
281 fn ident(&self, span: Span) -> Ident {
282 Ident::new("load_library", span)
283 }
284
285 fn to_call_tokens(&self) -> TokenStream {
286 let ident = self.ident(Span::call_site());
287 quote! { #ident() }
288 }
289}
290
291impl ToTokens for LoadLibraryFn<'_> {
292 fn to_tokens(&self, tokens: &mut TokenStream) {
293 let ident = self.ident(Span::call_site());
294 let library_name = self.library.lit_str(Span::call_site());
295
296 let load_exports = self.library.exports.iter().map(|export| {
297 let address_ident = export.address().ident(Span::call_site());
298 let export_name = export.lit_str(Span::call_site());
299
300 quote! {
301 #address_ident = *library.get::<usize>(#export_name.as_bytes()).unwrap();
302 }
303 });
304
305 tokens.extend(quote! {
306 fn #ident() {
307 unsafe {
308 let mut path = cdylib_shim::__private::system_dir().expect("should exist");
309 path.push(#library_name);
310 static mut LIBRARY: Option<cdylib_shim::__private::Library> = None;
311 let library = LIBRARY.insert(cdylib_shim::__private::Library::new(path).unwrap());
312 #(#load_exports)*
313 }
314 }
315 });
316 }
317}
318
319struct InitFn {
320 sig: Signature,
321}
322
323impl InitFn {
324 fn to_call_tokens(&self) -> TokenStream {
325 let ident = &self.sig.ident;
326 quote! { #ident() }
327 }
328}
329
330struct HookFn {
331 sig: Signature,
332 export: Export,
333}
334
335impl HookFn {
336 fn to_original_fn(&self) -> OriginalFn {
337 OriginalFn { hook_fn: self }
338 }
339}
340
341struct OriginalFn<'a> {
342 hook_fn: &'a HookFn,
343}
344
345impl ToTokens for OriginalFn<'_> {
346 fn to_tokens(&self, tokens: &mut TokenStream) {
347 let HookFn { sig, export } = self.hook_fn;
348 let abi = &sig.abi;
349 let output = &sig.output;
350 let address_ident = export.address().ident(Span::call_site());
351
352 let (pats, tys): (Vec<_>, Vec<_>) = sig
353 .inputs
354 .iter()
355 .filter_map(|arg| match arg {
356 FnArg::Typed(PatType { pat, ty, .. }) => Some((pat, ty)),
357 FnArg::Receiver(_) => None,
358 })
359 .collect();
360
361 tokens.extend(quote! {
362 #[allow(non_snake_case)]
363 pub #sig {
364 unsafe {
365 std::mem::transmute::<_, #abi fn(#(#tys),*) #output>(#address_ident)(#(#pats),*)
366 }
367 }
368 })
369 }
370}
371
372struct Initializer<'a> {
373 load_library_fn: &'a LoadLibraryFn<'a>,
374 init_fn: Option<&'a InitFn>,
375}
376
377impl ToTokens for Initializer<'_> {
378 fn to_tokens(&self, tokens: &mut TokenStream) {
379 let load_library_fn_call = self.load_library_fn.to_call_tokens();
380 let init_fn_call = self.init_fn.map(|init_fn| {
381 let tokens = init_fn.to_call_tokens();
382 quote! { super::#tokens; }
383 });
384
385 tokens.extend(quote! {
386 #[used]
387 #[unsafe(link_section = ".CRT$XCU")]
388 static INITIALIZER: extern "C" fn() = {
389 extern "C" fn init() {
390 #load_library_fn_call;
391 #init_fn_call;
392 }
393 init
394 };
395 });
396 }
397}
398
399struct OriginalModule<'a> {
400 ctx: &'a Context,
401}
402
403impl ToTokens for OriginalModule<'_> {
404 fn to_tokens(&self, tokens: &mut TokenStream) {
405 let export_addresses = self.ctx.library.exports.iter().map(Export::address);
406 let original_fns = self.ctx.hook_fns.iter().map(HookFn::to_original_fn);
407
408 let hook_exports: HashSet<_> = self
409 .ctx
410 .hook_fns
411 .iter()
412 .map(|hook_fn| &hook_fn.export)
413 .collect();
414
415 let shim_fns = self
416 .ctx
417 .library
418 .exports
419 .iter()
420 .filter(|export| !hook_exports.contains(export))
421 .map(|export| ShimFn { export });
422
423 let load_library_fn = LoadLibraryFn {
424 library: &self.ctx.library,
425 };
426
427 let initializer = Initializer {
428 load_library_fn: &load_library_fn,
429 init_fn: self.ctx.init_fn.as_ref(),
430 };
431
432 tokens.extend(quote! {
433 mod original {
434 use super::*;
435
436 #(#export_addresses)*
437 #(#original_fns)*
438 #(#shim_fns)*
439 #load_library_fn
440 #initializer
441 }
442 })
443 }
444}