horfimbor_eventsource_derive/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use proc_macro::{self, TokenStream};
5
6use convert_case::{Case, Casing};
7use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
8use quote::{quote, quote_spanned};
9use syn::spanned::Spanned;
10use syn::{Data, DeriveInput, Error, Fields, parse_macro_input};
11
12macro_rules! derive_error {
13    ($string: tt) => {
14        Error::new(Span::call_site(), $string)
15            .to_compile_error()
16            .into()
17    };
18}
19
20/// `derive_command` generate the boilerplate to get the `CommandName` from the command enum
21/// the attribute `state` give the prefix for the name
22#[proc_macro_derive(Command, attributes(state))]
23pub fn derive_command(input: TokenStream) -> TokenStream {
24    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
25
26    let state_name = match get_state_name(&input) {
27        Ok(value) => value,
28        Err(value) => return value,
29    };
30
31    // get enum name
32    let name = &input.ident;
33    let data = &input.data;
34    let mut fn_core;
35
36    match data {
37        Data::Enum(data_enum) => {
38            fn_core = TokenStream2::new();
39
40            // Iterate over enum variants
41            // `variants` if of type `Punctuated` which implements IntoIterator
42            for variant in &data_enum.variants {
43                // Variant's name
44                let variant_name = &variant.ident;
45
46                // Variant can have unnamed fields like `Variant(i32, i64)`
47                // Variant can have named fields like `Variant {x: i32, y: i32}`
48                // Variant can be named Unit like `Variant`
49                let fields_in_variant = match &variant.fields {
50                    Fields::Unnamed(_) => quote_spanned! {variant.span()=> (..) },
51                    Fields::Unit => quote_spanned! { variant.span()=> },
52                    Fields::Named(_) => quote_spanned! {variant.span()=> {..} },
53                };
54
55                // Here we construct the function for the current variant
56                let result = format!(".CMD.{variant_name}");
57                fn_core.extend(quote! {
58                    #name::#variant_name #fields_in_variant => {
59
60                        const SUFFIX: &str = #result;
61
62                        const LEN: usize = #state_name.len() + SUFFIX.len();
63                        const BYTES: [u8; LEN] = {
64                            let mut bytes = [0; LEN];
65
66                            let mut i = 0;
67                            while i < #state_name.len() {
68                                bytes[i] = #state_name.as_bytes()[i];
69                                i += 1;
70                            }
71
72                            let mut j = 0;
73                            while j < SUFFIX.len() {
74                                bytes[#state_name.len() + j] = SUFFIX.as_bytes()[j];
75                                j += 1;
76                            }
77
78                            bytes
79                        };
80
81                        match std::str::from_utf8(&BYTES) {
82                            Ok(s) => s,
83                            Err(_) => unreachable!(),
84                        }
85                    },
86                });
87            }
88        }
89        _ => return derive_error!("Command is only implemented for enums"),
90    }
91
92    let output = quote! {
93        impl Command for #name {
94            fn command_name(&self) -> CommandName {
95
96                match self {
97                    #fn_core
98                }
99            }
100        }
101    };
102    output.into()
103}
104
105/// `derive_event` generate the boilerplate to get the `EventName`
106///
107/// it generates it from the event enum :
108/// the attribute `state` give the prefix for the name
109/// unless the attribute `composite_state` in which case the current enum level is skip
110#[proc_macro_derive(Event, attributes(state, composite_state))]
111pub fn derive_event(input: TokenStream) -> TokenStream {
112    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
113
114    let mut state_name = None;
115
116    let is_composite_state = &input
117        .attrs
118        .iter()
119        .any(|attr| attr.path().is_ident("composite_state"));
120
121    if !is_composite_state {
122        state_name = match get_state_name(&input) {
123            Ok(value) => Some(value),
124            Err(value) => return value,
125        };
126    }
127
128    // get enum name
129    let name = &input.ident;
130    let data = &input.data;
131    let mut fn_core;
132
133    match data {
134        Data::Enum(data_enum) => {
135            fn_core = TokenStream2::new();
136
137            // Iterate over enum variants
138            // `variants` if of type `Punctuated` which implements IntoIterator
139            for variant in &data_enum.variants {
140                // Variant's name
141                let variant_name = &variant.ident;
142
143                if *is_composite_state {
144                    // Variant can have unnamed fields like `Variant(i32, i64)`
145                    // Variant can have named fields like `Variant {x: i32, y: i32}`
146                    // Variant can be named Unit like `Variant`
147                    let fields_in_variant = match &variant.fields {
148                        Fields::Unnamed(_) => quote_spanned! {variant.span()=> (event) },
149                        _ => {
150                            return derive_error!(
151                                "composite variants can only have one unnamed fields "
152                            );
153                        }
154                    };
155
156                    fn_core.extend(quote! {
157                        #name::#variant_name #fields_in_variant => {
158
159                            event.event_name()
160                        },
161                    });
162                } else {
163                    // Variant can have unnamed fields like `Variant(i32, i64)`
164                    // Variant can have named fields like `Variant {x: i32, y: i32}`
165                    // Variant can be named Unit like `Variant`
166                    let fields_in_variant = match &variant.fields {
167                        Fields::Unnamed(_) => quote_spanned! {variant.span()=> (..) },
168                        Fields::Unit => quote_spanned! { variant.span()=> },
169                        Fields::Named(_) => quote_spanned! {variant.span()=> {..} },
170                    };
171
172                    // Here we construct the function for the current variant
173                    let result = format!(".evt.{}", variant_name.to_string().to_case(Case::Snake));
174                    fn_core.extend(quote! {
175                        #name::#variant_name #fields_in_variant => {
176
177                            const SUFFIX: &str = #result;
178
179                            const LEN: usize = #state_name.len() + SUFFIX.len();
180                            const BYTES: [u8; LEN] = {
181                                let mut bytes = [0; LEN];
182
183                                let mut i = 0;
184                                while i < #state_name.len() {
185                                    bytes[i] = #state_name.as_bytes()[i];
186                                    i += 1;
187                                }
188
189                                let mut j = 0;
190                                while j < SUFFIX.len() {
191                                    bytes[#state_name.len() + j] = SUFFIX.as_bytes()[j];
192                                    j += 1;
193                                }
194
195                                bytes
196                            };
197
198                            match std::str::from_utf8(&BYTES) {
199                                Ok(s) => s,
200                                Err(_) => unreachable!(),
201                            }
202                        },
203                    });
204                }
205            }
206        }
207        _ => return derive_error!("Event is only implemented for enums"),
208    }
209
210    let output = quote! {
211        impl Event for #name {
212            fn event_name(&self) -> EventName {
213                match self {
214                    #fn_core
215                }
216            }
217        }
218    };
219    output.into()
220}
221
222/// # Panics
223///
224/// Will panic if attribute "state" is not parsable
225#[proc_macro_derive(StateNamed, attributes(state))]
226pub fn derive_state(input: TokenStream) -> TokenStream {
227    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
228
229    let attrs = &input.attrs;
230    let name = &input.ident;
231
232    let state = attrs.iter().find(|attr| attr.path().is_ident("state"));
233
234    let output = match state {
235        Some(s) => {
236            let state_name: syn::Ident = match s.parse_args() {
237                Ok(s) => s,
238                Err(_) => {
239                    return derive_error!("attribute 'state' cannot be parsed");
240                }
241            };
242            quote! {
243                impl StateNamed for #name {
244                    fn state_name() -> StateName {
245                        #state_name
246                    }
247                }
248            }
249        }
250        None => {
251            return derive_error!("attribute 'state' is mandatory");
252        }
253    };
254
255    output.into()
256}
257
258fn get_state_name(input: &DeriveInput) -> Result<Ident, TokenStream> {
259    let attrs = &input.attrs;
260
261    let state = attrs.iter().find(|attr| attr.path().is_ident("state"));
262
263    let Some(state) = state else {
264        return Err(derive_error!("attribute 'state' is mandatory"));
265    };
266
267    let state_name: syn::Ident = match state.parse_args() {
268        Ok(s) => s,
269        Err(_) => {
270            return Err(derive_error!("attribute 'state' cannot be parsed"));
271        }
272    };
273    Ok(state_name)
274}