1#[cfg(feature = "bevy_app")]
4mod app;
5mod util;
6
7use bevy_macro_utils::BevyManifest;
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11 DeriveInput, Error, Meta, Path, Result, Token, Type, parse_macro_input, parse_str,
12 punctuated::Punctuated,
13};
14
15use crate::util::concat;
16
17#[proc_macro_derive(State, attributes(state))]
18pub fn derive_state(input: TokenStream) -> TokenStream {
19 let input = parse_macro_input!(input as DeriveInput);
21 let attrs = parse_state_attrs(&input).expect("Failed to parse state attributes");
22
23 let impl_state = derive_state_helper(&input, &attrs);
25
26 #[cfg(not(feature = "bevy_app"))]
28 let impl_register_state = quote! {};
29 #[cfg(feature = "bevy_app")]
30 let impl_register_state = app::derive_register_state_helper(&input, &attrs);
31
32 let impl_resource = derive_resource_helper(&input);
34
35 quote! {
36 #impl_state
37 #impl_register_state
38 #impl_resource
39 }
40 .into()
41}
42
43fn derive_resource_helper(input: &DeriveInput) -> proc_macro2::TokenStream {
44 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
45 let ty_name = &input.ident;
46
47 let bevy_ecs_path = BevyManifest::shared().get_path("bevy_ecs");
49 let bevy_ecs_resource_path = concat(&bevy_ecs_path, "resource");
50 let resource_trait = concat(&bevy_ecs_resource_path, "Resource");
51
52 quote! {
53 impl #impl_generics #resource_trait for #ty_name #ty_generics #where_clause {}
54 }
55 .into()
56}
57
58fn derive_state_helper(input: &DeriveInput, attrs: &StateAttrs) -> proc_macro2::TokenStream {
59 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
60 let ty_name = &input.ident;
61
62 let crate_path = parse_str::<Path>("pyri_state").unwrap();
65 let crate_state_path = concat(&crate_path, "state");
66 let state_trait = concat(&crate_state_path, "State");
67
68 let next_ty = if let Some(next) = attrs.next.as_ref() {
70 quote! {
71 #next
72 }
73 } else {
74 let crate_next_state_path = concat(&crate_path, "next_state");
75 let crate_buffer_path = concat(&crate_next_state_path, "buffer");
76 let state_buffer_ty = concat(&crate_buffer_path, "NextStateBuffer");
77
78 quote! {
79 #state_buffer_ty<Self>
80 }
81 };
82
83 quote! {
85 impl #impl_generics #state_trait for #ty_name #ty_generics #where_clause {
86 type Next = #next_ty;
87 }
88 }
89 .into()
90}
91
92#[derive(Default)]
93struct StateAttrs {
94 next: Option<Type>,
95 local: bool,
96 after: Punctuated<Type, Token![,]>,
97 before: Punctuated<Type, Token![,]>,
98 no_defaults: bool,
99 detect_change: bool,
100 flush_message: bool,
101 log_flush: bool,
102 bevy_state: bool,
103 react: bool,
104 apply_flush: bool,
105}
106
107fn parse_state_attrs(input: &DeriveInput) -> Result<StateAttrs> {
109 let mut state_attrs = StateAttrs::default();
110
111 for attr in &input.attrs {
112 if !attr.path().is_ident("state") {
113 continue;
114 }
115
116 let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
117 for meta in nested {
118 match meta {
119 Meta::List(meta) if meta.path.is_ident("after") => {
120 state_attrs.after = meta
121 .parse_args_with(Punctuated::<Type, Token![,]>::parse_terminated)
122 .expect("invalid `after` states");
123 }
124
125 Meta::List(meta) if meta.path.is_ident("before") => {
126 state_attrs.before = meta
127 .parse_args_with(Punctuated::<Type, Token![,]>::parse_terminated)
128 .expect("invalid `before` states");
129 }
130
131 Meta::List(meta) if meta.path.is_ident("next") => {
132 state_attrs.next = Some(meta.parse_args().expect("invalid `next` type"));
133 }
134
135 Meta::Path(path) => {
136 let Some(ident) = path.get_ident() else {
137 return Err(Error::new_spanned(path, "invalid state attribute"));
138 };
139
140 match ident.to_string().as_str() {
141 "no_defaults" => state_attrs.no_defaults = true,
142 "local" => state_attrs.local = true,
143 "detect_change" => state_attrs.detect_change = true,
144 "flush_message" => state_attrs.flush_message = true,
145 "log_flush" => state_attrs.log_flush = true,
146 "bevy_state" => state_attrs.bevy_state = true,
147 "react" => state_attrs.react = true,
148 "apply_flush" => state_attrs.apply_flush = true,
149 _ => return Err(Error::new_spanned(ident, "invalid state attribute")),
150 }
151 }
152
153 _ => return Err(Error::new_spanned(meta, "invalid state attribute")),
154 }
155 }
156 }
157
158 if !state_attrs.no_defaults {
160 state_attrs.detect_change = true;
161 state_attrs.flush_message = true;
162 state_attrs.apply_flush = true;
163 }
164
165 Ok(state_attrs)
166}