poggers_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use proc_macro_crate::crate_name;
4// use proc_macro2::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{parse::Parse, parse_macro_input, punctuated::Punctuated, Ident, ItemFn, Token};
7
8struct CreateEntryArguments {
9    no_console: bool,
10    no_thread: bool,
11    no_free: bool,
12}
13
14impl Parse for CreateEntryArguments {
15    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
16        let argss = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
17        let mut no_console = false;
18        let mut no_thread = false;
19        let mut no_free = false;
20        for arg in argss {
21            match arg.to_string().as_str() {
22                "no_console" => {
23                    no_console = true;
24                }
25                "no_thread" => {
26                    no_thread = true;
27                }
28                "no_free" => {
29                    no_free = true;
30                }
31                _ => {}
32            }
33        }
34        Ok(CreateEntryArguments {
35            no_console,
36            no_thread,
37            no_free,
38        })
39    }
40}
41
42/// This macro allows you to define a function which will be called upon dll injection
43/// you can get the HMODULE of this dll by simply having a parameter of type `HMODULE`
44/// ## Notes
45/// On windows, this will automatically allocate a console, if you don't want to do that, use the `no_console` attribute
46/// On windows, this will automatically free the console upon dll unload, if you don't want to do that, use the `no_free` attribute
47#[proc_macro_attribute]
48pub fn create_entry(attr: TokenStream, item: TokenStream) -> TokenStream {
49    let input = parse_macro_input!(item as ItemFn);
50    let inputb = input.clone();
51    let arg = parse_macro_input!(attr as CreateEntryArguments);
52    let input_name = input.sig.ident;
53    let has_hmd = !input.sig.inputs.is_empty();
54
55    let curr_crate = match crate_name("poggers").expect("poggers-derive to be found") {
56        proc_macro_crate::FoundCrate::Itself => quote!(crate),
57        proc_macro_crate::FoundCrate::Name(x) => {
58            let i = Ident::new(&x, Span::call_site());
59            quote!(#i)
60        }
61    };
62
63    let ret = input.sig.output;
64
65    let handle_ret = match ret {
66        syn::ReturnType::Default => quote!(),
67        syn::ReturnType::Type(_, ty) => {
68            if ty.to_token_stream().to_string().contains("Result") {
69                quote! {
70                    match r {
71                        Ok(_) => (),
72                        Err(e) => {
73                            println!(concat!(stringify!{#input_name}," has errored: {:?}"), e);
74                        }
75                    }
76                }
77            } else {
78                quote!()
79            }
80        }
81    };
82
83    let alloc_console = if arg.no_console {
84        quote! {}
85    } else {
86        quote! {
87            unsafe {
88                #curr_crate::exports::AllocConsole();
89            };
90        }
91    };
92    let free_console = if arg.no_console || arg.no_free {
93        quote! {}
94    } else {
95        quote! {
96            unsafe {
97                #curr_crate::exports::FreeConsole();
98            };
99        }
100    };
101    let call_main = if has_hmd {
102        quote! {
103            #input_name(h_module)
104        }
105    } else {
106        quote! {
107            #input_name()
108        }
109    };
110    let cross_platform = quote! {
111        use ::std::panic;
112
113        match panic::catch_unwind(move || #call_main) {
114            Err(e) => {
115                println!("`{}` has panicked: {:#?}",stringify!{#input_name}, e);
116            }
117            Ok(r) => {#handle_ret},
118        };
119    };
120
121    let thread_spawn = if arg.no_thread {
122        quote! {#alloc_console;#cross_platform;#free_console}
123    } else {
124        quote! {
125            std::thread::spawn(move || {
126                #alloc_console
127                #cross_platform
128                #free_console
129            });
130        }
131    };
132
133    #[cfg(target_os = "windows")]
134    let generated = quote! {
135        #[no_mangle]
136        extern "system" fn DllMain(
137            h_module : #curr_crate::exports::HMODULE,
138            reason : u32,
139            _: *const ::std::ffi::c_void
140        ) -> #curr_crate::exports::BOOL {
141            match reason {
142                #curr_crate::exports::DLL_PROCESS_ATTACH => {
143                    #thread_spawn
144                    (true).into()
145                }
146                _ => (false).into()
147            }
148        }
149    };
150    #[cfg(not(target_os = "windows"))]
151    let generated = quote! {
152        #[#curr_crate::exports::ctor]
153        fn lib_init() {
154            std::thread::spawn(|| {
155
156                #cross_platform
157
158            });
159        }
160    };
161
162    TokenStream::from(quote! {
163        #inputb
164
165        #generated
166    })
167}