d2_stampede_macros/
lib.rs

1mod protobuf_map;
2
3use crate::protobuf_map::get_enum_from_struct;
4use proc_macro::TokenStream;
5use quote::{quote, ToTokens};
6use syn::{parse_macro_input, ItemImpl, Type, FnArg};
7
8#[proc_macro_attribute]
9pub fn observer(_attr: TokenStream, item: TokenStream) -> TokenStream {
10    let input = parse_macro_input!(item as ItemImpl);
11    let struct_name = &input.self_ty;
12
13    let mut on_dota_user_message_body = quote!();
14    let mut on_base_user_message_body = quote!();
15    let mut on_svc_message_body = quote!();
16    let mut on_net_message_body = quote!();
17    let mut on_base_game_event_body = quote!();
18    let mut on_demo_command_body = quote!();
19    let mut on_tick_start_body = quote!();
20    let mut on_tick_end_body = quote!();
21    let mut on_entity_body = quote!();
22    let mut on_combat_log_body = quote!();
23
24    for item in &input.items {
25        if let syn::ImplItem::Fn(method) = item {
26            let method_name = &method.sig.ident;
27            for attr in &method.attrs {
28                if attr.path().is_ident("on_message") {
29                    check_second_arg_is_context(method);
30                    if let Some((arg_type, is_reference)) = get_message_type(method) {
31                        let enum_type = get_enum_from_struct(&arg_type.to_token_stream().to_string());
32                        let call_message = if is_reference {
33                            quote! { self.#method_name(ctx, &message)?; }
34                        } else {
35                            quote! { self.#method_name(ctx, message)?; }
36                        };
37                        match enum_type.to_token_stream().to_string().split("::").collect::<Vec<_>>()[0].trim() {
38                            "EDotaUserMessages" => {
39                                on_dota_user_message_body = quote! {
40                                    #on_dota_user_message_body
41                                    if msg_type == #enum_type {
42                                        if let Ok(message) = #arg_type::decode(msg) {
43                                            #call_message
44                                        }
45                                    }
46                                };
47                            }
48                            "EBaseUserMessages" => {
49                                on_base_user_message_body = quote! {
50                                    #on_base_user_message_body
51                                    if msg_type == #enum_type {
52                                        if let Ok(message) = #arg_type::decode(msg) {
53                                            #call_message
54                                        }
55                                    }
56                                };
57                            }
58                            "SvcMessages" => {
59                                on_svc_message_body = quote! {
60                                    #on_svc_message_body
61                                    if msg_type == #enum_type {
62                                        if let Ok(message) = #arg_type::decode(msg) {
63                                            #call_message
64                                        }
65                                    }
66                                };
67                            }
68                            "EBaseGameEvents" => {
69                                on_base_game_event_body = quote! {
70                                    #on_base_game_event_body
71                                    if msg_type == #enum_type {
72                                        if let Ok(message) = #arg_type::decode(msg) {
73                                            #call_message
74                                        }
75                                    }
76                                };
77                            }
78                            "NetMessages" => {
79                                on_net_message_body = quote! {
80                                    #on_net_message_body
81                                    if msg_type == #enum_type {
82                                        if let Ok(message) = #arg_type::decode(msg) {
83                                            #call_message
84                                        }
85                                    }
86                                };
87                            }
88                            "EDemoCommands" => {
89                                on_demo_command_body = quote! {
90                                    #on_demo_command_body
91                                    if msg_type == #enum_type {
92                                        if let Ok(message) = #arg_type::decode(msg) {
93                                            #call_message
94                                        }
95                                    }
96                                };
97                            }
98                            x => unreachable!("{}", x),
99                        }
100                    } else {
101                        panic!("Message handler must have a message argument")
102                    }
103                }
104
105                if attr.path().is_ident("on_tick_start") {
106                    check_second_arg_is_context(method);
107                    on_tick_start_body = quote! {
108                        #on_tick_start_body
109                        self.#method_name(ctx)?;
110                    };
111                }
112
113                if attr.path().is_ident("on_tick_end") {
114                    check_second_arg_is_context(method);
115                    on_tick_end_body = quote! {
116                        #on_tick_end_body
117                        self.#method_name(ctx)?;
118                    };
119                }
120
121                if attr.path().is_ident("on_entity") {
122                    check_second_arg_is_context(method);
123                    on_entity_body = quote! {
124                        #on_entity_body
125                        self.#method_name(ctx, event, entity)?;
126                    };
127                }
128
129                if attr.path().is_ident("on_combat_log") {
130                    check_second_arg_is_context(method);
131                    on_combat_log_body = quote! {
132                        #on_combat_log_body
133                        self.#method_name(ctx, cle)?;
134                    };
135                }
136            }
137        }
138    }
139
140    let expanded = quote! {
141        impl Observer for #struct_name {
142            fn on_dota_user_message(
143                &mut self,
144                ctx: &Context,
145                msg_type: EDotaUserMessages,
146                msg: &[u8],
147            ) -> ObserverResult {
148                #on_dota_user_message_body
149                Ok(())
150            }
151
152            fn on_base_user_message(
153                &mut self,
154                ctx: &Context,
155                msg_type: EBaseUserMessages,
156                msg: &[u8],
157            ) -> ObserverResult {
158                #on_base_user_message_body
159                Ok(())
160            }
161
162            fn on_svc_message(
163                &mut self,
164                ctx: &Context,
165                msg_type: SvcMessages,
166                msg: &[u8],
167            ) -> ObserverResult {
168                #on_svc_message_body
169                Ok(())
170            }
171
172            fn on_net_message(
173                &mut self,
174                ctx: &Context,
175                msg_type: NetMessages,
176                msg: &[u8],
177            ) -> ObserverResult {
178                #on_net_message_body
179                Ok(())
180            }
181
182            fn on_base_game_event(
183                &mut self,
184                ctx: &Context,
185                msg_type: EBaseGameEvents,
186                msg: &[u8],
187            ) -> ObserverResult {
188                #on_base_game_event_body
189                Ok(())
190            }
191
192            fn on_demo_command(
193                &mut self,
194                ctx: &Context,
195                msg_type: EDemoCommands,
196                msg: &[u8],
197            ) -> ObserverResult {
198                #on_demo_command_body
199                Ok(())
200            }
201
202            fn on_tick_start(
203                &mut self,
204                ctx: &Context,
205            ) -> ObserverResult {
206                #on_tick_start_body
207                Ok(())
208            }
209
210            fn on_tick_end(
211                &mut self,
212                ctx: &Context,
213            ) -> ObserverResult {
214                #on_tick_end_body
215                Ok(())
216            }
217
218            fn on_entity(
219                &mut self,
220                ctx: &Context,
221                event: EntityEvents,
222                entity: &Entity,
223            ) -> ObserverResult {
224                #on_entity_body
225                Ok(())
226            }
227
228            fn on_combat_log(
229                &mut self,
230                ctx: &Context,
231                cle: &CombatLogEntry,
232            ) -> ObserverResult {
233                #on_combat_log_body
234                Ok(())
235            }
236        }
237
238        #input
239    };
240
241    TokenStream::from(expanded)
242}
243
244#[proc_macro_attribute]
245pub fn on_message(_attr: TokenStream, item: TokenStream) -> TokenStream {
246    item
247}
248
249#[proc_macro_attribute]
250pub fn on_tick_start(_attr: TokenStream, item: TokenStream) -> TokenStream {
251    item
252}
253
254#[proc_macro_attribute]
255pub fn on_tick_end(_attr: TokenStream, item: TokenStream) -> TokenStream {
256    item
257}
258
259#[proc_macro_attribute]
260pub fn on_entity(_attr: TokenStream, item: TokenStream) -> TokenStream {
261    item
262}
263
264#[proc_macro_attribute]
265pub fn on_combat_log(_attr: TokenStream, item: TokenStream) -> TokenStream {
266    item
267}
268
269fn get_message_type(method: &syn::ImplItemFn) -> Option<(Type, bool)> {
270    method.sig.inputs.iter().nth(2).and_then(|arg| {
271        if let syn::FnArg::Typed(pat_type) = arg {
272            if let Type::Reference(x) = pat_type.ty.as_ref() {
273                if x.mutability.is_some() {
274                    panic!("Mutable reference not supported")
275                }
276                Some((*x.elem.clone(), true))
277            } else {
278                Some((*pat_type.ty.clone(), false))
279            }
280        } else {
281            None
282        }
283    })
284}
285
286
287fn check_second_arg_is_context(method: &syn::ImplItemFn) {
288    if let Some(FnArg::Typed(pat_type)) = method.sig.inputs.iter().nth(1) {
289        if let Type::Reference(type_reference) = pat_type.ty.as_ref() {
290            let elem_type = type_reference.elem.as_ref();
291            if let Type::Path(type_path) = elem_type {
292                if let Some(segment) = type_path.path.segments.first() {
293                    if segment.ident == "Context" && type_reference.mutability.is_none() {
294                        return;
295                    }
296                }
297            }
298        }
299    }
300    panic!("The second argument must be of type &Context");
301}