cdylib_shim_macros/
lib.rs

1use std::{collections::HashSet, path::Path, str::Utf8Error};
2
3use convert_case::{Case, Casing};
4use itertools::{Either, Itertools};
5use object::Object;
6use proc_macro2::{Span, TokenStream};
7use quote::{ToTokens, quote};
8use syn::{
9    Attribute, Error, FnArg, Ident, Item, ItemFn, ItemMod, LitStr, PatType, Signature,
10    parse_macro_input, parse_quote, spanned::Spanned,
11};
12
13#[proc_macro_attribute]
14pub fn shim(
15    attr: proc_macro::TokenStream,
16    item: proc_macro::TokenStream,
17) -> proc_macro::TokenStream {
18    let name = parse_macro_input!(attr as LitStr);
19
20    let Some(library) = Library::load(name.value()) else {
21        panic!("Failed to load library");
22    };
23
24    let mut module = parse_macro_input!(item as ItemMod);
25
26    let mut ctx = Context {
27        library,
28        init_fn: None,
29        hook_fns: Vec::new(),
30    };
31
32    let Some((_, content)) = &mut module.content else {
33        return module.into_token_stream().into();
34    };
35
36    for item in content.iter_mut() {
37        let result = match item {
38            Item::Fn(item_fn) => handle_item_fn(&mut ctx, item_fn),
39            _ => Ok(()),
40        };
41
42        if let Err(errors) = result {
43            return errors
44                .into_iter()
45                .map(|error| error.to_compile_error())
46                .collect::<TokenStream>()
47                .into();
48        }
49    }
50
51    content.push({
52        let original_mod = OriginalModule { ctx: &ctx };
53        parse_quote! { #original_mod }
54    });
55
56    module.into_token_stream().into()
57}
58
59struct Context {
60    library: Library,
61    init_fn: Option<InitFn>,
62    hook_fns: Vec<HookFn>,
63}
64
65fn handle_item_fn(ctx: &mut Context, item_fn: &mut ItemFn) -> Result<(), Vec<Error>> {
66    let Some((kind, attr)) = parse_attrs(item_fn)? else {
67        return Ok(());
68    };
69
70    match kind {
71        AttributeKind::Init => handle_init_fn(ctx, item_fn, &attr),
72        AttributeKind::Hook => handle_hook_fn(ctx, item_fn),
73    }
74}
75
76fn parse_attrs(item_fn: &mut ItemFn) -> Result<Option<(AttributeKind, Attribute)>, Vec<Error>> {
77    let (parsed_attrs, attrs): (Vec<_>, Vec<_>) = std::mem::take(&mut item_fn.attrs)
78        .into_iter()
79        .partition_map(|attr| match AttributeKind::try_from(&attr) {
80            Ok(kind) => Either::Left((kind, attr)),
81            Err(_) => Either::Right(attr),
82        });
83
84    item_fn.attrs = attrs;
85    let mut parsed_attrs = parsed_attrs.into_iter();
86
87    let Some(parsed_attr) = parsed_attrs.next() else {
88        return Ok(None);
89    };
90
91    let errors: Vec<_> = parsed_attrs
92        .map(|(_, attr)| {
93            Error::new(
94                attr.span(),
95                "Only one `init` or `hook` attribute is allowed per function",
96            )
97        })
98        .collect();
99
100    if !errors.is_empty() {
101        return Err(errors);
102    }
103
104    Ok(Some(parsed_attr))
105}
106
107fn handle_init_fn(
108    ctx: &mut Context,
109    item_fn: &mut ItemFn,
110    attr: &Attribute,
111) -> Result<(), Vec<Error>> {
112    if ctx.init_fn.is_some() {
113        return Err(vec![Error::new(
114            attr.span(),
115            "There can only be one `init` function",
116        )]);
117    }
118
119    item_fn.attrs.push(parse_quote!(#[allow(dead_code)]));
120
121    ctx.init_fn = Some(InitFn {
122        sig: item_fn.sig.clone(),
123    });
124
125    Ok(())
126}
127
128fn handle_hook_fn(ctx: &mut Context, item_fn: &mut ItemFn) -> Result<(), Vec<Error>> {
129    let export = item_fn.sig.ident.to_string().as_str().into();
130
131    if !ctx.library.exports.contains(&export) {
132        return Err(vec![Error::new(
133            item_fn.sig.ident.span(),
134            format!("Function is not an exported symbol in {}", ctx.library.name),
135        )]);
136    }
137
138    item_fn.attrs.push(parse_quote!(#[unsafe(no_mangle)]));
139    item_fn.attrs.push(parse_quote!(#[allow(non_snake_case)]));
140
141    ctx.hook_fns.push(HookFn {
142        sig: item_fn.sig.clone(),
143        export,
144    });
145
146    Ok(())
147}
148
149enum AttributeKind {
150    Init,
151    Hook,
152}
153
154impl TryFrom<&Attribute> for AttributeKind {
155    type Error = ();
156
157    fn try_from(value: &Attribute) -> Result<Self, Self::Error> {
158        if value.path().is_ident("init") {
159            Ok(Self::Init)
160        } else if value.path().is_ident("hook") {
161            Ok(Self::Hook)
162        } else {
163            Err(())
164        }
165    }
166}
167
168struct Library {
169    name: String,
170    exports: HashSet<Export>,
171}
172
173impl Library {
174    fn load(name: String) -> Option<Self> {
175        let separator = if cfg!(windows) { ';' } else { ':' };
176
177        let path = std::env::var("PATH")
178            .ok()?
179            .split(separator)
180            .map(|directory| Path::new(directory).join(&name))
181            .find(|path| path.exists())?;
182
183        let data = std::fs::read(&path).ok()?;
184
185        let exports = object::File::parse(data.as_slice())
186            .ok()?
187            .exports()
188            .ok()?
189            .into_iter()
190            .filter_map(|export| Export::try_from(&export).ok())
191            .collect();
192
193        Some(Self { name, exports })
194    }
195
196    fn lit_str(&self, span: Span) -> LitStr {
197        LitStr::new(&self.name, span)
198    }
199}
200
201#[derive(PartialEq, Eq, Hash)]
202struct Export {
203    name: String,
204}
205
206impl Export {
207    fn ident(&self, span: Span) -> Ident {
208        Ident::new(&self.name, span)
209    }
210
211    fn lit_str(&self, span: Span) -> LitStr {
212        LitStr::new(&self.name, span)
213    }
214
215    fn address(&self) -> ExportAddress {
216        ExportAddress { export: self }
217    }
218}
219
220impl From<&str> for Export {
221    fn from(value: &str) -> Self {
222        Self { name: value.into() }
223    }
224}
225
226impl TryFrom<&object::Export<'_>> for Export {
227    type Error = Utf8Error;
228
229    fn try_from(value: &object::Export) -> Result<Self, Self::Error> {
230        std::str::from_utf8(value.name()).map(Into::into)
231    }
232}
233
234struct ExportAddress<'a> {
235    export: &'a Export,
236}
237
238impl ExportAddress<'_> {
239    fn ident(&self, span: Span) -> Ident {
240        Ident::new(
241            &format!("{}_ADDRESS", self.export.name.to_case(Case::UpperSnake)),
242            span,
243        )
244    }
245}
246
247impl ToTokens for ExportAddress<'_> {
248    fn to_tokens(&self, tokens: &mut TokenStream) {
249        let ident = self.ident(Span::call_site());
250
251        tokens.extend(quote! {
252            static mut #ident: usize = 0;
253        });
254    }
255}
256
257struct ShimFn<'a> {
258    export: &'a Export,
259}
260
261impl ToTokens for ShimFn<'_> {
262    fn to_tokens(&self, tokens: &mut TokenStream) {
263        let ident = self.export.ident(Span::call_site());
264        let address_ident = self.export.address().ident(Span::call_site());
265
266        tokens.extend(quote! {
267            #[unsafe(naked)]
268            #[unsafe(no_mangle)]
269            unsafe extern "system" fn #ident() {
270                std::arch::naked_asm!("jmp [rip + {}]", sym #address_ident)
271            }
272        });
273    }
274}
275
276struct LoadLibraryFn<'a> {
277    library: &'a Library,
278}
279
280impl LoadLibraryFn<'_> {
281    fn ident(&self, span: Span) -> Ident {
282        Ident::new("load_library", span)
283    }
284
285    fn to_call_tokens(&self) -> TokenStream {
286        let ident = self.ident(Span::call_site());
287        quote! { #ident() }
288    }
289}
290
291impl ToTokens for LoadLibraryFn<'_> {
292    fn to_tokens(&self, tokens: &mut TokenStream) {
293        let ident = self.ident(Span::call_site());
294        let library_name = self.library.lit_str(Span::call_site());
295
296        let load_exports = self.library.exports.iter().map(|export| {
297            let address_ident = export.address().ident(Span::call_site());
298            let export_name = export.lit_str(Span::call_site());
299
300            quote! {
301                #address_ident = *library.get::<usize>(#export_name.as_bytes()).unwrap();
302            }
303        });
304
305        tokens.extend(quote! {
306            fn #ident() {
307                unsafe {
308                    let mut path = cdylib_shim::__private::system_dir().expect("should exist");
309                    path.push(#library_name);
310                    static mut LIBRARY: Option<cdylib_shim::__private::Library> = None;
311                    let library = LIBRARY.insert(cdylib_shim::__private::Library::new(path).unwrap());
312                    #(#load_exports)*
313                }
314            }
315        });
316    }
317}
318
319struct InitFn {
320    sig: Signature,
321}
322
323impl InitFn {
324    fn to_call_tokens(&self) -> TokenStream {
325        let ident = &self.sig.ident;
326        quote! { #ident() }
327    }
328}
329
330struct HookFn {
331    sig: Signature,
332    export: Export,
333}
334
335impl HookFn {
336    fn to_original_fn(&self) -> OriginalFn {
337        OriginalFn { hook_fn: self }
338    }
339}
340
341struct OriginalFn<'a> {
342    hook_fn: &'a HookFn,
343}
344
345impl ToTokens for OriginalFn<'_> {
346    fn to_tokens(&self, tokens: &mut TokenStream) {
347        let HookFn { sig, export } = self.hook_fn;
348        let abi = &sig.abi;
349        let output = &sig.output;
350        let address_ident = export.address().ident(Span::call_site());
351
352        let (pats, tys): (Vec<_>, Vec<_>) = sig
353            .inputs
354            .iter()
355            .filter_map(|arg| match arg {
356                FnArg::Typed(PatType { pat, ty, .. }) => Some((pat, ty)),
357                FnArg::Receiver(_) => None,
358            })
359            .collect();
360
361        tokens.extend(quote! {
362            #[allow(non_snake_case)]
363            pub #sig {
364                unsafe {
365                    std::mem::transmute::<_, #abi fn(#(#tys),*) #output>(#address_ident)(#(#pats),*)
366                }
367            }
368        })
369    }
370}
371
372struct Initializer<'a> {
373    load_library_fn: &'a LoadLibraryFn<'a>,
374    init_fn: Option<&'a InitFn>,
375}
376
377impl ToTokens for Initializer<'_> {
378    fn to_tokens(&self, tokens: &mut TokenStream) {
379        let load_library_fn_call = self.load_library_fn.to_call_tokens();
380        let init_fn_call = self.init_fn.map(|init_fn| {
381            let tokens = init_fn.to_call_tokens();
382            quote! { super::#tokens; }
383        });
384
385        tokens.extend(quote! {
386            #[used]
387            #[unsafe(link_section = ".CRT$XCU")]
388            static INITIALIZER: extern "C" fn() = {
389                extern "C" fn init() {
390                    #load_library_fn_call;
391                    #init_fn_call;
392                }
393                init
394            };
395        });
396    }
397}
398
399struct OriginalModule<'a> {
400    ctx: &'a Context,
401}
402
403impl ToTokens for OriginalModule<'_> {
404    fn to_tokens(&self, tokens: &mut TokenStream) {
405        let export_addresses = self.ctx.library.exports.iter().map(Export::address);
406        let original_fns = self.ctx.hook_fns.iter().map(HookFn::to_original_fn);
407
408        let hook_exports: HashSet<_> = self
409            .ctx
410            .hook_fns
411            .iter()
412            .map(|hook_fn| &hook_fn.export)
413            .collect();
414
415        let shim_fns = self
416            .ctx
417            .library
418            .exports
419            .iter()
420            .filter(|export| !hook_exports.contains(export))
421            .map(|export| ShimFn { export });
422
423        let load_library_fn = LoadLibraryFn {
424            library: &self.ctx.library,
425        };
426
427        let initializer = Initializer {
428            load_library_fn: &load_library_fn,
429            init_fn: self.ctx.init_fn.as_ref(),
430        };
431
432        tokens.extend(quote! {
433            mod original {
434                use super::*;
435
436                #(#export_addresses)*
437                #(#original_fns)*
438                #(#shim_fns)*
439                #load_library_fn
440                #initializer
441            }
442        })
443    }
444}