pyri_state_derive/
lib.rs

1//! Derive macros for `pyri_state`.
2
3#[cfg(feature = "bevy_app")]
4mod app;
5mod util;
6
7use bevy_macro_utils::BevyManifest;
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11    DeriveInput, Error, Meta, Path, Result, Token, Type, parse_macro_input, parse_str,
12    punctuated::Punctuated,
13};
14
15use crate::util::concat;
16
17#[proc_macro_derive(State, attributes(state))]
18pub fn derive_state(input: TokenStream) -> TokenStream {
19    // Parse the type and `#[state(...)]` attributes.
20    let input = parse_macro_input!(input as DeriveInput);
21    let attrs = parse_state_attrs(&input).expect("Failed to parse state attributes");
22
23    // Construct `State` impl.
24    let impl_state = derive_state_helper(&input, &attrs);
25
26    // Construct `RegisterState` impl.
27    #[cfg(not(feature = "bevy_app"))]
28    let impl_register_state = quote! {};
29    #[cfg(feature = "bevy_app")]
30    let impl_register_state = app::derive_register_state_helper(&input, &attrs);
31
32    // Construct `Resource` impl.
33    let impl_resource = derive_resource_helper(&input);
34
35    quote! {
36        #impl_state
37        #impl_register_state
38        #impl_resource
39    }
40    .into()
41}
42
43fn derive_resource_helper(input: &DeriveInput) -> proc_macro2::TokenStream {
44    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
45    let ty_name = &input.ident;
46
47    // Construct paths.
48    let bevy_ecs_path = BevyManifest::shared().get_path("bevy_ecs");
49    let bevy_ecs_resource_path = concat(&bevy_ecs_path, "resource");
50    let resource_trait = concat(&bevy_ecs_resource_path, "Resource");
51
52    quote! {
53        impl #impl_generics #resource_trait for #ty_name #ty_generics #where_clause {}
54    }
55    .into()
56}
57
58fn derive_state_helper(input: &DeriveInput, attrs: &StateAttrs) -> proc_macro2::TokenStream {
59    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
60    let ty_name = &input.ident;
61
62    // Construct paths.
63    // TODO: This is not 100% portable I guess, but probably good enough.
64    let crate_path = parse_str::<Path>("pyri_state").unwrap();
65    let crate_state_path = concat(&crate_path, "state");
66    let state_trait = concat(&crate_state_path, "State");
67
68    // Construct `NextState` type.
69    let next_ty = if let Some(next) = attrs.next.as_ref() {
70        quote! {
71            #next
72        }
73    } else {
74        let crate_next_state_path = concat(&crate_path, "next_state");
75        let crate_buffer_path = concat(&crate_next_state_path, "buffer");
76        let state_buffer_ty = concat(&crate_buffer_path, "NextStateBuffer");
77
78        quote! {
79            #state_buffer_ty<Self>
80        }
81    };
82
83    // Construct `State` impl.
84    quote! {
85        impl #impl_generics #state_trait for #ty_name #ty_generics #where_clause {
86            type Next = #next_ty;
87        }
88    }
89    .into()
90}
91
92#[derive(Default)]
93struct StateAttrs {
94    next: Option<Type>,
95    local: bool,
96    after: Punctuated<Type, Token![,]>,
97    before: Punctuated<Type, Token![,]>,
98    no_defaults: bool,
99    detect_change: bool,
100    flush_message: bool,
101    log_flush: bool,
102    bevy_state: bool,
103    react: bool,
104    apply_flush: bool,
105}
106
107// Parse `#[state(...)]` attributes.
108fn parse_state_attrs(input: &DeriveInput) -> Result<StateAttrs> {
109    let mut state_attrs = StateAttrs::default();
110
111    for attr in &input.attrs {
112        if !attr.path().is_ident("state") {
113            continue;
114        }
115
116        let nested = attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
117        for meta in nested {
118            match meta {
119                Meta::List(meta) if meta.path.is_ident("after") => {
120                    state_attrs.after = meta
121                        .parse_args_with(Punctuated::<Type, Token![,]>::parse_terminated)
122                        .expect("invalid `after` states");
123                }
124
125                Meta::List(meta) if meta.path.is_ident("before") => {
126                    state_attrs.before = meta
127                        .parse_args_with(Punctuated::<Type, Token![,]>::parse_terminated)
128                        .expect("invalid `before` states");
129                }
130
131                Meta::List(meta) if meta.path.is_ident("next") => {
132                    state_attrs.next = Some(meta.parse_args().expect("invalid `next` type"));
133                }
134
135                Meta::Path(path) => {
136                    let Some(ident) = path.get_ident() else {
137                        return Err(Error::new_spanned(path, "invalid state attribute"));
138                    };
139
140                    match ident.to_string().as_str() {
141                        "no_defaults" => state_attrs.no_defaults = true,
142                        "local" => state_attrs.local = true,
143                        "detect_change" => state_attrs.detect_change = true,
144                        "flush_message" => state_attrs.flush_message = true,
145                        "log_flush" => state_attrs.log_flush = true,
146                        "bevy_state" => state_attrs.bevy_state = true,
147                        "react" => state_attrs.react = true,
148                        "apply_flush" => state_attrs.apply_flush = true,
149                        _ => return Err(Error::new_spanned(ident, "invalid state attribute")),
150                    }
151                }
152
153                _ => return Err(Error::new_spanned(meta, "invalid state attribute")),
154            }
155        }
156    }
157
158    // Enable default options.
159    if !state_attrs.no_defaults {
160        state_attrs.detect_change = true;
161        state_attrs.flush_message = true;
162        state_attrs.apply_flush = true;
163    }
164
165    Ok(state_attrs)
166}