1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::{Error, FnArg, Ident, Item, LitStr, Pat, ReturnType, Token, Type};
7
8#[derive(Clone, Copy, Eq, PartialEq)]
9enum PanicPolicy {
10 Abort,
11 ReturnFalse,
12}
13
14#[derive(Clone, Copy, Eq, PartialEq)]
15enum DllEvent {
16 ProcessDetach,
17 ProcessAttach,
18 ThreadAttach,
19 ThreadDetach,
20}
21
22impl DllEvent {
23 fn from_ident(ident: &Ident) -> Result<Self, Error> {
24 match ident.to_string().as_str() {
25 "process_attach" => Ok(Self::ProcessAttach),
26 "process_detach" => Ok(Self::ProcessDetach),
27 "thread_attach" => Ok(Self::ThreadAttach),
28 "thread_detach" => Ok(Self::ThreadDetach),
29 _ => Err(Error::new_spanned(
30 ident,
31 "unknown event; expected one of: process_attach, process_detach, thread_attach, thread_detach",
32 )),
33 }
34 }
35
36 fn match_arm_tokens(self, reason_binding: Option<&Pat>, block: &syn::Block) -> TokenStream2 {
37 let reason = match self {
38 Self::ProcessDetach => quote! { DLL_PROCESS_DETACH },
39 Self::ProcessAttach => quote! { DLL_PROCESS_ATTACH },
40 Self::ThreadAttach => quote! { DLL_THREAD_ATTACH },
41 Self::ThreadDetach => quote! { DLL_THREAD_DETACH },
42 };
43
44 let bind_reason = match reason_binding {
45 Some(pattern) => quote! { let #pattern: u32 = call_reason; },
46 None => quote! {},
47 };
48
49 quote! {
50 #reason => {
51 #bind_reason
52 #block
53 },
54 }
55 }
56}
57
58struct EntryArgs {
59 events: Vec<DllEvent>,
60 panic_policy: PanicPolicy,
61}
62
63impl Default for EntryArgs {
64 fn default() -> Self {
65 Self {
66 events: vec![DllEvent::ProcessAttach],
67 panic_policy: PanicPolicy::Abort,
68 }
69 }
70}
71
72impl Parse for EntryArgs {
73 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
74 let mut args = EntryArgs::default();
75 let mut seen_events = false;
76 let mut seen_panic = false;
77
78 while !input.is_empty() {
79 let option: Ident = input.parse()?;
80
81 if option == "events" {
82 if seen_events {
83 return Err(Error::new_spanned(option, "duplicate option `events`"));
84 }
85 seen_events = true;
86
87 let content;
88 syn::parenthesized!(content in input);
89 let parsed_events: Punctuated<Ident, Token![,]> =
90 content.parse_terminated(Ident::parse)?;
91
92 if parsed_events.is_empty() {
93 return Err(Error::new_spanned(
94 option,
95 "`events(...)` must include at least one event",
96 ));
97 }
98
99 let mut events = Vec::with_capacity(parsed_events.len());
100 for event_ident in parsed_events {
101 let event = DllEvent::from_ident(&event_ident)?;
102 if events.contains(&event) {
103 return Err(Error::new_spanned(
104 event_ident,
105 "duplicate event in `events(...)`",
106 ));
107 }
108 events.push(event);
109 }
110 args.events = events;
111 } else if option == "panic" {
112 if seen_panic {
113 return Err(Error::new_spanned(option, "duplicate option `panic`"));
114 }
115 seen_panic = true;
116
117 input.parse::<Token![=]>()?;
118 let value: LitStr = input.parse()?;
119
120 args.panic_policy = match value.value().as_str() {
121 "abort" => PanicPolicy::Abort,
122 "return_false" => PanicPolicy::ReturnFalse,
123 _ => {
124 return Err(Error::new_spanned(
125 value,
126 "invalid panic policy; expected \"abort\" or \"return_false\"",
127 ));
128 }
129 };
130 } else {
131 return Err(Error::new_spanned(
132 option,
133 "unknown option; expected `events(...)` or `panic = \"...\"`",
134 ));
135 }
136
137 if input.is_empty() {
138 break;
139 }
140 input.parse::<Token![,]>()?;
141 }
142
143 Ok(args)
144 }
145}
146
147fn is_u32_type(ty: &Type) -> bool {
148 match ty {
149 Type::Path(path) => path.qself.is_none() && path.path.is_ident("u32"),
150 _ => false,
151 }
152}
153
154fn reason_pattern(sig: &syn::Signature) -> syn::Result<Option<&Pat>> {
155 if sig.constness.is_some() {
156 return Err(Error::new_spanned(
157 sig.constness,
158 "const functions are not supported by #[dllmain_rs::entry]",
159 ));
160 }
161
162 if sig.asyncness.is_some() {
163 return Err(Error::new_spanned(
164 sig.asyncness,
165 "async functions are not supported by #[dllmain_rs::entry]",
166 ));
167 }
168
169 if sig.unsafety.is_some() {
170 return Err(Error::new_spanned(
171 sig.unsafety,
172 "unsafe functions are not supported by #[dllmain_rs::entry]",
173 ));
174 }
175
176 if let Some(abi) = &sig.abi {
177 return Err(Error::new_spanned(
178 abi,
179 "explicit ABI is not supported; #[dllmain_rs::entry] generates DllMain ABI",
180 ));
181 }
182
183 if let Some(variadic) = &sig.variadic {
184 return Err(Error::new_spanned(
185 variadic,
186 "variadic functions are not supported by #[dllmain_rs::entry]",
187 ));
188 }
189
190 if !sig.generics.params.is_empty() || sig.generics.where_clause.is_some() {
191 return Err(Error::new_spanned(
192 &sig.generics,
193 "generic functions are not supported by #[dllmain_rs::entry]",
194 ));
195 }
196
197 if !matches!(sig.output, ReturnType::Default) {
198 return Err(Error::new_spanned(
199 &sig.output,
200 "function must return () for #[dllmain_rs::entry]",
201 ));
202 }
203
204 match sig.inputs.len() {
205 0 => Ok(None),
206 1 => match sig.inputs.first() {
207 Some(FnArg::Typed(arg)) => {
208 if !is_u32_type(&arg.ty) {
209 return Err(Error::new_spanned(
210 &arg.ty,
211 "single argument must be `u32` (the DLL reason code)",
212 ));
213 }
214 Ok(Some(&arg.pat))
215 }
216 Some(FnArg::Receiver(receiver)) => Err(Error::new_spanned(
217 receiver,
218 "#[dllmain_rs::entry] expects a free function",
219 )),
220 None => Ok(None),
221 },
222 _ => Err(Error::new_spanned(
223 &sig.inputs,
224 "function must have signature `fn name()` or `fn name(reason: u32)`",
225 )),
226 }
227}
228
229#[proc_macro_attribute]
230pub fn entry(attr: TokenStream, item: TokenStream) -> TokenStream {
231 let args = match syn::parse::<EntryArgs>(attr) {
232 Ok(args) => args,
233 Err(err) => return TokenStream::from(err.to_compile_error()),
234 };
235
236 let parsed_item = match syn::parse::<Item>(item) {
237 Ok(item) => item,
238 Err(err) => return TokenStream::from(err.to_compile_error()),
239 };
240
241 let func = match parsed_item {
242 Item::Fn(func) => func,
243 other => {
244 return TokenStream::from(
245 Error::new_spanned(other, "#[dllmain_rs::entry] expects a free function")
246 .to_compile_error(),
247 );
248 }
249 };
250
251 let reason_binding = match reason_pattern(&func.sig) {
252 Ok(binding) => binding,
253 Err(err) => return TokenStream::from(err.to_compile_error()),
254 };
255
256 let block = &func.block;
257 let match_arms: Vec<_> = args
258 .events
259 .iter()
260 .copied()
261 .map(|event| event.match_arm_tokens(reason_binding, block))
262 .collect();
263
264 let wrapped_body = quote! {
265 match call_reason {
266 #(#match_arms)*
267 _ => {},
268 }
269 DLLMAIN_TRUE
270 };
271
272 let panic_policy = match args.panic_policy {
273 PanicPolicy::Abort => quote! {
274 match ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
275 #wrapped_body
276 })) {
277 Ok(value) => value,
278 Err(_) => ::std::process::abort(),
279 }
280 },
281 PanicPolicy::ReturnFalse => quote! {
282 match ::std::panic::catch_unwind(::std::panic::AssertUnwindSafe(|| {
283 #wrapped_body
284 })) {
285 Ok(value) => value,
286 Err(_) => DLLMAIN_FALSE,
287 }
288 },
289 };
290
291 let output = quote! {
292 #[unsafe(no_mangle)]
293 #[allow(non_snake_case, unused_variables)]
294 extern "system" fn DllMain(
295 _dll_module: *mut ::core::ffi::c_void,
296 call_reason: u32,
297 _reserved: *mut ::core::ffi::c_void,
298 ) -> i32 {
299 const DLL_PROCESS_DETACH: u32 = 0;
300 const DLL_PROCESS_ATTACH: u32 = 1;
301 const DLL_THREAD_ATTACH: u32 = 2;
302 const DLL_THREAD_DETACH: u32 = 3;
303 const DLLMAIN_TRUE: i32 = 1;
304 const DLLMAIN_FALSE: i32 = 0;
305
306 #panic_policy
307 }
308 };
309
310 TokenStream::from(output)
311}