bevy_simple_subsecond_system_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::spanned::Spanned;
4use syn::{
5    FnArg, Ident, ItemFn, LitBool, Pat, PatIdent, ReturnType, Token, Type, TypePath, TypeReference,
6    parse::{Parse, ParseStream},
7};
8
9struct HotArgs {
10    rerun_on_hot_patch: Option<bool>,
11    hot_patch_signature: Option<bool>,
12}
13
14impl Parse for HotArgs {
15    fn parse(input: ParseStream) -> std::result::Result<HotArgs, syn::Error> {
16        let mut rerun_on_hot_patch = None;
17        let mut hot_patch_signature = None;
18
19        while !input.is_empty() {
20            let ident: Ident = input.parse()?;
21            input.parse::<Token![=]>()?;
22
23            if ident == "rerun_on_hot_patch" {
24                let value: LitBool = input.parse()?;
25                rerun_on_hot_patch = Some(value.value);
26            } else if ident == "hot_patch_signature" {
27                let value: LitBool = input.parse()?;
28                hot_patch_signature = Some(value.value);
29            } else {
30                return Err(syn::Error::new_spanned(ident, "Unknown attribute key"));
31            }
32
33            if input.peek(Token![,]) {
34                input.parse::<Token![,]>()?;
35            }
36        }
37
38        Ok(HotArgs {
39            rerun_on_hot_patch,
40            hot_patch_signature,
41        })
42    }
43}
44
45#[proc_macro_attribute]
46pub fn hot(attr: TokenStream, item: TokenStream) -> TokenStream {
47    // Parse the attribute as a Meta
48    let args = syn::parse::<HotArgs>(attr.clone());
49    let args = match args {
50        Ok(parsed) => parsed,
51        Err(_) => return item, // If parsing the attributes fails, just return the original function.
52    };
53    let rerun_on_hot_patch = args.rerun_on_hot_patch.unwrap_or(false);
54    let hot_patch_signature = args.hot_patch_signature.unwrap_or(false);
55
56    let input_fn = syn::parse::<ItemFn>(item.clone());
57    let input_fn = match input_fn {
58        Ok(parsed) => parsed,
59        Err(_) => return item, // If parsing the function fails, return it unchanged.
60    };
61
62    let vis = &input_fn.vis;
63    let sig = &input_fn.sig;
64    let original_output = &sig.output;
65    let original_fn_name = &sig.ident;
66    let block = &input_fn.block;
67    let inputs = &sig.inputs;
68    let generics = &sig.generics;
69
70    // Generate new identifiers
71    let hotpatched_fn = format_ident!("__{}_hotpatched", original_fn_name);
72    let original_wrapper_fn = format_ident!("__{}_original", original_fn_name);
73
74    let newlines = if let Some(source_text) = block.span().unwrap().source_text() {
75        source_text.chars().filter(|ch| *ch == '\n').count() as u32
76    } else {
77        0
78    };
79
80    // Capture parameter types, names, and mutability
81    let mut param_types = Vec::new();
82    let mut param_idents = Vec::new();
83    let mut param_mutability = Vec::new();
84
85    for input in inputs {
86        match input {
87            FnArg::Typed(pat_type) => {
88                param_types.push(&pat_type.ty);
89                if let Pat::Ident(PatIdent {
90                    ident, mutability, ..
91                }) = &*pat_type.pat
92                {
93                    param_idents.push(ident);
94                    param_mutability.push(mutability.is_some());
95                } else {
96                    panic!("`#[hot]` only supports simple identifiers in parameter patterns.");
97                }
98            }
99            FnArg::Receiver(_) => {
100                panic!("`#[hot]` does not support `self` methods.");
101            }
102        }
103    }
104
105    // Generate correct destructuring pattern for parameters
106    let destructure = param_idents
107        .iter()
108        .zip(param_mutability.iter())
109        .map(|(ident, is_mut)| {
110            if *is_mut {
111                quote! { mut #ident }
112            } else {
113                quote! { #ident }
114            }
115        });
116    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
117    let maybe_generics = if generics.params.is_empty() {
118        quote! {}
119    } else {
120        quote! { ::#ty_generics }
121    };
122
123    let hot_fn = quote! {
124        ::bevy_simple_subsecond_system::dioxus_devtools::subsecond::HotFn::current(#hotpatched_fn #maybe_generics)
125    };
126
127    if !hot_patch_signature && !rerun_on_hot_patch {
128        let result = quote! {
129            #[cfg(any(target_family = "wasm", not(debug_assertions)))]
130            #vis fn #original_fn_name #impl_generics(#inputs) #where_clause #original_output {
131                #block
132            }
133
134
135            #[cfg(all(not(target_family = "wasm"), debug_assertions))]
136            #[allow(unused_mut)]
137            #vis fn #original_fn_name #impl_generics(#inputs) #where_clause #original_output {
138                #hot_fn.call((#(#param_idents,)*))
139            }
140
141
142            #[cfg(all(not(target_family = "wasm"), debug_assertions))]
143            #vis fn #hotpatched_fn #impl_generics(#inputs) #where_clause #original_output {
144                #block
145            }
146        };
147        return result.into();
148    }
149
150    let maybe_run_call = if rerun_on_hot_patch {
151        quote! {
152            let name = ::bevy_simple_subsecond_system::__macros_internal::IntoSystem::into_system(#original_fn_name #maybe_generics).name();
153            ::bevy_simple_subsecond_system::__macros_internal::debug!("Hot-patched and rerunning system {name}");
154            #hot_fn.call((world,))
155        }
156    } else {
157        quote! {
158            let name = ::bevy_simple_subsecond_system::__macros_internal::IntoSystem::into_system(#original_fn_name #maybe_generics).name();
159            bevy::prelude::debug!("Hot-patched system {name}");
160        }
161    };
162
163    let early_return = if is_result_unit(original_output) {
164        quote! {
165            return Ok(());
166        }
167    } else {
168        quote! {
169            return;
170        }
171    };
172
173    let hotpatched_fn_definition = match has_single_world_param(sig) {
174        WorldParam::Mut | WorldParam::Ref => quote! {
175            #vis fn #hotpatched_fn #impl_generics(world: &mut ::bevy_simple_subsecond_system::__macros_internal::World) #where_clause #original_output {
176                if let Some(mut reload_positions) = world.get_resource_mut::<::bevy_simple_subsecond_system::__macros_internal::__ReloadPositions>() {
177                    reload_positions.insert((file!(), line!(), line!() + #newlines));
178                }
179                #original_wrapper_fn #maybe_generics(world)
180            }
181        },
182        WorldParam::None => quote! {
183            #vis fn #hotpatched_fn #impl_generics(world: &mut ::bevy_simple_subsecond_system::__macros_internal::World) #where_clause #original_output {
184                if let Some(mut reload_positions) = world.get_resource_mut::<::bevy_simple_subsecond_system::__macros_internal::__ReloadPositions>() {
185                    reload_positions.insert((file!(), line!(), line!() + #newlines));
186                }
187                use ::bevy_simple_subsecond_system::__macros_internal::SystemState;
188                let mut __system_state: SystemState<(#(#param_types),*)> = SystemState::new(world);
189                let __unsafe_world = world.as_unsafe_world_cell_readonly();
190
191                let __validation = unsafe { SystemState::validate_param(&__system_state, __unsafe_world) };
192
193                match __validation {
194                    Ok(()) => (),
195                    Err(e) => {
196                        if e.skipped {
197                            #early_return
198                        }
199                    }
200                }
201
202                let (#(#destructure),*) = __system_state.get_mut(world);
203                let __result = #original_wrapper_fn(#(#param_idents),*);
204                __system_state.apply(world);
205                #[allow(clippy::unused_unit)]
206                __result
207            }
208        },
209    };
210
211    let result = quote! {
212        #[cfg(any(target_family = "wasm", not(debug_assertions)))]
213        #vis fn #original_fn_name #impl_generics(#inputs) #where_clause #original_output {
214            #block
215        }
216        // Outer entry point: stable ABI, hot-reload safe
217        #[cfg(all(not(target_family = "wasm"), debug_assertions))]
218        #vis fn #original_fn_name #impl_generics(world: &mut ::bevy_simple_subsecond_system::__macros_internal::World) #where_clause #original_output {
219            use std::any::Any as _;
220            let type_id = #hotpatched_fn #maybe_generics.type_id();
221            let contains_system = world.get_resource::<::bevy_simple_subsecond_system::__macros_internal::__HotPatchedSystems>().unwrap().0.contains_key(&type_id);
222            if !contains_system {
223                let hot_fn_ptr = #hot_fn.ptr_address();
224                world.resource_mut::<::bevy_simple_subsecond_system::__macros_internal::Schedules>().add_systems(::bevy_simple_subsecond_system::__macros_internal::PreUpdate, move |world: &mut ::bevy_simple_subsecond_system::__macros_internal::World| {
225                    let needs_update = {
226                        let mut hot_patched_systems = world.get_resource_mut::<::bevy_simple_subsecond_system::__macros_internal::__HotPatchedSystems>().unwrap();
227                        let mut hot_patched_system = hot_patched_systems.0.get_mut(&type_id).unwrap();
228                        hot_patched_system.current_ptr = #hot_fn.ptr_address();
229                        let needs_update = hot_patched_system.current_ptr != hot_patched_system.last_ptr;
230                        hot_patched_system.last_ptr = hot_patched_system.current_ptr;
231                        needs_update
232                    };
233                    if !needs_update {
234                        return;
235                    }
236                    // TODO: we simply ignore the `Result` here, but we should be propagating it
237                    let _ = {#maybe_run_call};
238                });
239                let system = ::bevy_simple_subsecond_system::__macros_internal::__HotPatchedSystem {
240                    current_ptr: hot_fn_ptr,
241                    last_ptr: hot_fn_ptr,
242                };
243                world.get_resource_mut::<::bevy_simple_subsecond_system::__macros_internal::__HotPatchedSystems>().unwrap().0.insert(type_id, system);
244            }
245
246            #hot_fn.call((world,))
247        }
248
249        // Hotpatched version with stable signature
250        #[cfg(all(not(target_family = "wasm"), debug_assertions))]
251        #hotpatched_fn_definition
252
253        // Original function body moved into a standalone fn
254        #[cfg(all(not(target_family = "wasm"), debug_assertions))]
255        #vis fn #original_wrapper_fn #impl_generics(#inputs) #where_clause #original_output {
256            #block
257        }
258    };
259
260    result.into()
261}
262
263enum WorldParam {
264    Ref,
265    Mut,
266    None,
267}
268
269fn has_single_world_param(sig: &syn::Signature) -> WorldParam {
270    if sig.inputs.len() != 1 {
271        return WorldParam::None;
272    }
273
274    let param = sig.inputs.first().unwrap();
275
276    let pat_type = match param {
277        FnArg::Typed(pt) => pt,
278        _ => return WorldParam::None,
279    };
280
281    match &*pat_type.ty {
282        Type::Reference(TypeReference {
283            mutability, elem, ..
284        }) => {
285            match &**elem {
286                Type::Path(type_path) => {
287                    let segments = &type_path.path.segments;
288
289                    let Some(last_segment) = segments.last().cloned() else {
290                        return WorldParam::None;
291                    };
292
293                    // TODO: Make this more robust :D
294                    if last_segment.ident != "World" {
295                        return WorldParam::None;
296                    }
297
298                    if mutability.is_some() {
299                        WorldParam::Mut
300                    } else {
301                        WorldParam::Ref
302                    }
303                }
304                _ => WorldParam::None,
305            }
306        }
307        _ => WorldParam::None,
308    }
309}
310
311fn is_result_unit(output: &ReturnType) -> bool {
312    match output {
313        ReturnType::Default => false, // no return type, i.e., returns ()
314        ReturnType::Type(_, ty) => match &**ty {
315            Type::Path(TypePath { path, .. }) => {
316                // Match on the outer type
317                let Some(seg) = path.segments.last() else {
318                    return false;
319                };
320                if seg.ident != "Result" {
321                    return false;
322                }
323
324                // Match on the generic args: Result<(), BevyError>
325                match seg.arguments {
326                    syn::PathArguments::AngleBracketed(ref generics) => {
327                        let args = &generics.args;
328
329                        let Some(first) = args.first() else {
330                            // Not sure this case can even happen
331                            return true;
332                        };
333
334                        // Check first generic arg is ()
335                        matches!(
336                            first,
337                            syn::GenericArgument::Type(Type::Tuple(t)) if t.elems.is_empty()
338                        )
339                    }
340                    syn::PathArguments::Parenthesized(_) => false,
341                    // TODO: This could also be a result that has a non-unit Ok variant
342                    syn::PathArguments::None => true,
343                }
344            }
345            _ => false,
346        },
347    }
348}