Skip to main content

dllmain_rs/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::{Error, FnArg, Ident, Item, LitStr, Pat, ReturnType, Token, Type};
7
8#[derive(Clone, Copy, Eq, PartialEq)]
9enum PanicPolicy {
10    Abort,
11    ReturnFalse,
12}
13
14#[derive(Clone, Copy, Eq, PartialEq)]
15enum DllEvent {
16    ProcessDetach,
17    ProcessAttach,
18    ThreadAttach,
19    ThreadDetach,
20}
21
22impl DllEvent {
23    fn from_ident(ident: &Ident) -> Result<Self, Error> {
24        match ident.to_string().as_str() {
25            "process_attach" => Ok(Self::ProcessAttach),
26            "process_detach" => Ok(Self::ProcessDetach),
27            "thread_attach" => Ok(Self::ThreadAttach),
28            "thread_detach" => Ok(Self::ThreadDetach),
29            _ => Err(Error::new_spanned(
30                ident,
31                "unknown event; expected one of: process_attach, process_detach, thread_attach, thread_detach",
32            )),
33        }
34    }
35
36    fn match_arm_tokens(self, reason_binding: Option<&Pat>, block: &syn::Block) -> TokenStream2 {
37        let reason = match self {
38            Self::ProcessDetach => quote! { DLL_PROCESS_DETACH },
39            Self::ProcessAttach => quote! { DLL_PROCESS_ATTACH },
40            Self::ThreadAttach => quote! { DLL_THREAD_ATTACH },
41            Self::ThreadDetach => quote! { DLL_THREAD_DETACH },
42        };
43
44        let bind_reason = match reason_binding {
45            Some(pattern) => quote! { let #pattern: u32 = call_reason; },
46            None => quote! {},
47        };
48
49        quote! {
50            #reason => {
51                #bind_reason
52                #block
53            },
54        }
55    }
56}
57
58struct EntryArgs {
59    events: Vec<DllEvent>,
60    panic_policy: PanicPolicy,
61}
62
63impl Default for EntryArgs {
64    fn default() -> Self {
65        Self {
66            events: vec![DllEvent::ProcessAttach],
67            panic_policy: PanicPolicy::Abort,
68        }
69    }
70}
71
72impl Parse for EntryArgs {
73    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
74        let mut args = EntryArgs::default();
75        let mut seen_events = false;
76        let mut seen_panic = false;
77
78        while !input.is_empty() {
79            let option: Ident = input.parse()?;
80
81            if option == "events" {
82                if seen_events {
83                    return Err(Error::new_spanned(option, "duplicate option `events`"));
84                }
85                seen_events = true;
86
87                let content;
88                syn::parenthesized!(content in input);
89                let parsed_events: Punctuated<Ident, Token![,]> =
90                    content.parse_terminated(Ident::parse)?;
91
92                if parsed_events.is_empty() {
93                    return Err(Error::new_spanned(
94                        option,
95                        "`events(...)` must include at least one event",
96                    ));
97                }
98
99                let mut events = Vec::with_capacity(parsed_events.len());
100                for event_ident in parsed_events {
101                    let event = DllEvent::from_ident(&event_ident)?;
102                    if events.contains(&event) {
103                        return Err(Error::new_spanned(
104                            event_ident,
105                            "duplicate event in `events(...)`",
106                        ));
107                    }
108                    events.push(event);
109                }
110                args.events = events;
111            } else if option == "panic" {
112                if seen_panic {
113                    return Err(Error::new_spanned(option, "duplicate option `panic`"));
114                }
115                seen_panic = true;
116
117                input.parse::<Token![=]>()?;
118                let value: LitStr = input.parse()?;
119
120                args.panic_policy = match value.value().as_str() {
121                    "abort" => PanicPolicy::Abort,
122                    "return_false" => PanicPolicy::ReturnFalse,
123                    _ => {
124                        return Err(Error::new_spanned(
125                            value,
126                            "invalid panic policy; expected \"abort\" or \"return_false\"",
127                        ));
128                    }
129                };
130            } else {
131                return Err(Error::new_spanned(
132                    option,
133                    "unknown option; expected `events(...)` or `panic = \"...\"`",
134                ));
135            }
136
137            if input.is_empty() {
138                break;
139            }
140            input.parse::<Token![,]>()?;
141        }
142
143        Ok(args)
144    }
145}
146
147fn is_u32_type(ty: &Type) -> bool {
148    match ty {
149        Type::Path(path) => path.qself.is_none() && path.path.is_ident("u32"),
150        _ => false,
151    }
152}
153
154fn reason_pattern(sig: &syn::Signature) -> syn::Result<Option<&Pat>> {
155    if sig.constness.is_some() {
156        return Err(Error::new_spanned(
157            sig.constness,
158            "const functions are not supported by #[dllmain_rs::entry]",
159        ));
160    }
161
162    if sig.asyncness.is_some() {
163        return Err(Error::new_spanned(
164            sig.asyncness,
165            "async functions are not supported by #[dllmain_rs::entry]",
166        ));
167    }
168
169    if sig.unsafety.is_some() {
170        return Err(Error::new_spanned(
171            sig.unsafety,
172            "unsafe functions are not supported by #[dllmain_rs::entry]",
173        ));
174    }
175
176    if let Some(abi) = &sig.abi {
177        return Err(Error::new_spanned(
178            abi,
179            "explicit ABI is not supported; #[dllmain_rs::entry] generates DllMain ABI",
180        ));
181    }
182
183    if let Some(variadic) = &sig.variadic {
184        return Err(Error::new_spanned(
185            variadic,
186            "variadic functions are not supported by #[dllmain_rs::entry]",
187        ));
188    }
189
190    if !sig.generics.params.is_empty() || sig.generics.where_clause.is_some() {
191        return Err(Error::new_spanned(
192            &sig.generics,
193            "generic functions are not supported by #[dllmain_rs::entry]",
194        ));
195    }
196
197    if !matches!(sig.output, ReturnType::Default) {
198        return Err(Error::new_spanned(
199            &sig.output,
200            "function must return () for #[dllmain_rs::entry]",
201        ));
202    }
203
204    match sig.inputs.len() {
205        0 => Ok(None),
206        1 => match sig.inputs.first() {
207            Some(FnArg::Typed(arg)) => {
208                if !is_u32_type(&arg.ty) {
209                    return Err(Error::new_spanned(
210                        &arg.ty,
211                        "single argument must be `u32` (the DLL reason code)",
212                    ));
213                }
214                Ok(Some(&arg.pat))
215            }
216            Some(FnArg::Receiver(receiver)) => Err(Error::new_spanned(
217                receiver,
218                "#[dllmain_rs::entry] expects a free function",
219            )),
220            None => Ok(None),
221        },
222        _ => Err(Error::new_spanned(
223            &sig.inputs,
224            "function must have signature `fn name()` or `fn name(reason: u32)`",
225        )),
226    }
227}
228
229#[proc_macro_attribute]
230pub fn entry(attr: TokenStream, item: TokenStream) -> TokenStream {
231    let args = match syn::parse::<EntryArgs>(attr) {
232        Ok(args) => args,
233        Err(err) => return TokenStream::from(err.to_compile_error()),
234    };
235
236    let parsed_item = match syn::parse::<Item>(item) {
237        Ok(item) => item,
238        Err(err) => return TokenStream::from(err.to_compile_error()),
239    };
240
241    let func = match parsed_item {
242        Item::Fn(func) => func,
243        other => {
244            return TokenStream::from(
245                Error::new_spanned(other, "#[dllmain_rs::entry] expects a free function")
246                    .to_compile_error(),
247            );
248        }
249    };
250
251    let reason_binding = match reason_pattern(&func.sig) {
252        Ok(binding) => binding,
253        Err(err) => return TokenStream::from(err.to_compile_error()),
254    };
255
256    let block = &func.block;
257    let match_arms: Vec<_> = args
258        .events
259        .iter()
260        .copied()
261        .map(|event| event.match_arm_tokens(reason_binding, block))
262        .collect();
263
264    let wrapped_body = quote! {
265        match call_reason {
266            #(#match_arms)*
267            _ => {},
268        }
269        DLLMAIN_TRUE
270    };
271
272    let panic_policy = match args.panic_policy {
273        PanicPolicy::Abort => quote! {
274            match ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
275                #wrapped_body
276            })) {
277                Ok(value) => value,
278                Err(_) => ::std::process::abort(),
279            }
280        },
281        PanicPolicy::ReturnFalse => quote! {
282            match ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
283                #wrapped_body
284            })) {
285                Ok(value) => value,
286                Err(_) => DLLMAIN_FALSE,
287            }
288        },
289    };
290
291    let output = quote! {
292        #[unsafe(no_mangle)]
293        #[allow(non_snake_case, unused_variables)]
294        extern "system" fn DllMain(
295            _dll_module: *mut ::core::ffi::c_void,
296            call_reason: u32,
297            _reserved: *mut ::core::ffi::c_void,
298        ) -> i32 {
299            const DLL_PROCESS_DETACH: u32 = 0;
300            const DLL_PROCESS_ATTACH: u32 = 1;
301            const DLL_THREAD_ATTACH: u32 = 2;
302            const DLL_THREAD_DETACH: u32 = 3;
303            const DLLMAIN_TRUE: i32 = 1;
304            const DLLMAIN_FALSE: i32 = 0;
305
306            #panic_policy
307        }
308    };
309
310    TokenStream::from(output)
311}