anchor_yard_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    FnArg, GenericArgument, Ident, ItemFn, LitInt, PatType, PathArguments, Token, Type, TypePath,
5    parse::Parse, parse_macro_input,
6};
7
8struct SnapshotArgs {
9    threshold_ms: u64,
10}
11
12impl Parse for SnapshotArgs {
13    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
14        let mut threshold_ms = 1u64;
15
16        if !input.is_empty() {
17            let ident: Ident = input.parse()?;
18            if ident == "threshold_ms" {
19                input.parse::<Token![=]>()?;
20                let lit: LitInt = input.parse()?;
21                threshold_ms = lit.base10_parse()?;
22            }
23        }
24
25        Ok(SnapshotArgs { threshold_ms })
26    }
27}
28
29fn extract_component_types(
30    inputs: &syn::punctuated::Punctuated<FnArg, Token![,]>,
31) -> Vec<TypePath> {
32    let mut component_types = Vec::new();
33
34    for input in inputs {
35        if let FnArg::Typed(PatType { ty, .. }) = input {
36            if let Type::Path(type_path) = &**ty {
37                if let Some(last_segment) = type_path.path.segments.last() {
38                    if last_segment.ident == "View" || last_segment.ident == "ViewMut" {
39                        if let PathArguments::AngleBracketed(args) = &last_segment.arguments {
40                            if let Some(GenericArgument::Type(Type::Path(component_type))) =
41                                args.args.first()
42                            {
43                                component_types.push(component_type.clone());
44                            }
45                        }
46                    }
47                }
48            }
49        }
50    }
51
52    component_types
53}
54
55#[proc_macro_attribute]
56pub fn snapshot_system(args: TokenStream, input: TokenStream) -> TokenStream {
57    let args = parse_macro_input!(args as SnapshotArgs);
58    let input_fn = parse_macro_input!(input as ItemFn);
59
60    let threshold_ms = args.threshold_ms;
61    let fn_name = &input_fn.sig.ident;
62    let fn_inputs = &input_fn.sig.inputs;
63    let fn_output = &input_fn.sig.output;
64    let fn_block = &input_fn.block;
65    let vis = &input_fn.vis;
66    let component_types = extract_component_types(fn_inputs);
67
68    let register_components = component_types.iter().map(|ty| {
69        quote! {
70            anchor_yard::REGISTRY.lock().unwrap().register::<#ty>();
71        }
72    });
73
74    let expanded = quote! {
75        #vis fn #fn_name(#fn_inputs) #fn_output {
76            {
77                #(#register_components)*
78            }
79
80            use std::time::Instant;
81
82            let snapshot = anchor_yard_core::with_current_world(|world| {
83                anchor_yard_core::SystemSnapshot::capture_world(
84                world,
85                stringify!(#fn_name),
86                #threshold_ms,
87                )
88            }).flatten();
89
90            let start = Instant::now();
91            let result = (|| #fn_block)();
92            let elapsed = start.elapsed();
93
94            if elapsed.as_millis() as u64 > #threshold_ms && let Some(snapshot) = snapshot {
95                #[cfg(feature = "tracing")]
96                tracing::info!("System '{}' took {}ms (threshold: {}ms). Saving snapshot...", stringify!(#fn_name), elapsed.as_millis(), #threshold_ms);
97                let _ = snapshot.save_to_file();
98                #[cfg(feature = "tracing")]
99                tracing::info!("Snapshot saved!");
100            }
101
102            result
103        }
104    };
105
106    TokenStream::from(expanded)
107}