enum2contract_derive/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use alloc::{
6    format,
7    string::{String, ToString},
8    vec::Vec,
9};
10use core::iter;
11use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span};
13use quote::quote;
14use syn::{
15    parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed, LitStr, Variant,
16};
17
18#[proc_macro_derive(EnumContract, attributes(topic))]
19pub fn derive_enum2contract(input: TokenStream) -> TokenStream {
20    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
21
22    let name = &input.ident;
23
24    let data = match input.data {
25        Data::Enum(data) => data,
26        _ => {
27            return syn::Error::new(input.span(), "enum2contract only supports enums")
28                .to_compile_error()
29                .into()
30        }
31    };
32
33    let mut message_functions = proc_macro2::TokenStream::new();
34    let mut payloads = proc_macro2::TokenStream::new();
35
36    for variant in data.variants.iter() {
37        match variant.fields {
38            Fields::Unit => {
39                let topic = match parse_topic_attribute(variant) {
40                    Ok(value) => value,
41                    Err(error) => return error.to_compile_error().into(),
42                };
43
44                let payload_name =
45                    Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
46                let payload_struct = quote!(
47                    #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
48                    pub struct #payload_name;
49                );
50                payloads.extend(payload_struct);
51
52                #[cfg(feature = "json")]
53                {
54                    let json_conversions = quote!(
55                        impl #payload_name {
56                            pub fn to_json(&self) -> Result<String, serde_json::Error> {
57                                serde_json::to_string(self)
58                            }
59
60                            pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
61                                serde_json::from_str(json)
62                            }
63                        }
64                    );
65                    payloads.extend(json_conversions);
66                }
67
68                #[cfg(feature = "binary")]
69                {
70                    let binary_conversions = quote!(
71                        impl #payload_name {
72                            pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
73                                postcard::to_allocvec(self)
74                            }
75
76                            pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
77                                postcard::from_bytes(bytes)
78                            }
79                        }
80                    );
81
82                    payloads.extend(binary_conversions);
83                }
84
85                let payload_type = quote! { #payload_name };
86                let payload_default = quote! { #payload_name::default() };
87                let ident_name = &to_snake_case(&variant.ident.to_string());
88                let create_message = Ident::new(ident_name, variant.ident.span());
89                let create_topic =
90                    Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
91                let topic_string = &topic.value();
92                let args = extract_substrings(topic_string);
93                let topic_string = remove_substrings(&topic.value(), &args);
94                let args: Vec<_> = args
95                    .iter()
96                    .map(|arg| Ident::new(arg, Span::call_site()))
97                    .collect();
98
99                let message_function = quote! {
100                    pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
101                        (Self::#create_topic(#(#args),*), #payload_default)
102                    }
103
104                    pub fn #create_topic(#(#args: &str),*) -> String {
105                        format!(#topic_string, #(#args),*)
106                    }
107                };
108                message_functions.extend(message_function);
109            }
110
111            Fields::Named(FieldsNamed { ref named, .. }) => {
112                let topic = match parse_topic_attribute(variant) {
113                    Ok(value) => value,
114                    Err(error) => return error.to_compile_error().into(),
115                };
116
117                let mut fields = proc_macro2::TokenStream::new();
118
119                for field in named.iter() {
120                    fields.extend(quote! {
121                        pub #field,
122                    })
123                }
124
125                let payload_name =
126                    Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
127
128                let payload_struct = quote! {
129                    use serde::{Serialize, Deserialize};
130
131                    #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
132                    pub struct #payload_name {
133                        #fields
134                    }
135                };
136                payloads.extend(payload_struct);
137
138                #[cfg(feature = "json")]
139                {
140                    let json_conversions = quote!(
141                        impl #payload_name {
142                            pub fn to_json(&self) -> Result<String, serde_json::Error> {
143                                serde_json::to_string(self)
144                            }
145
146                            pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
147                                serde_json::from_str(json)
148                            }
149                        }
150                    );
151                    payloads.extend(json_conversions);
152                }
153
154                #[cfg(feature = "binary")]
155                {
156                    let binary_conversions = quote!(
157                        impl #payload_name {
158                            pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
159                                postcard::to_allocvec(self)
160                            }
161
162                            pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
163                                postcard::from_bytes(bytes)
164                            }
165                        }
166                    );
167
168                    payloads.extend(binary_conversions);
169                }
170
171                let payload_type = quote! { #payload_name };
172                let payload_default = quote! { #payload_name::default() };
173                let ident_name = &to_snake_case(&variant.ident.to_string());
174                let create_message = Ident::new(ident_name, variant.ident.span());
175                let create_topic =
176                    Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
177                let topic_string = &topic.value();
178                let args = extract_substrings(topic_string);
179                let topic_string = remove_substrings(&topic.value(), &args);
180                let args: Vec<_> = args
181                    .iter()
182                    .map(|arg| Ident::new(arg, Span::call_site()))
183                    .collect();
184
185                let message_function = quote! {
186                    pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
187                        (Self::#create_topic(#(#args),*), #payload_default)
188                    }
189
190                    pub fn #create_topic(#(#args: &str),*) -> String {
191                        format!(#topic_string, #(#args),*)
192                    }
193                };
194                message_functions.extend(message_function);
195            }
196
197            _ => {
198                return syn::Error::new(
199                    variant.span(),
200                    "enum2contract is only implemented for unit and named-field enums",
201                )
202                .to_compile_error()
203                .into()
204            }
205        };
206    }
207
208    let expanded = quote! {
209        #payloads
210
211        impl #name {
212            #message_functions
213        }
214    };
215
216    TokenStream::from(expanded)
217}
218
219fn parse_topic_attribute(variant: &Variant) -> Result<LitStr, syn::Error> {
220    let mut topic = None;
221    for attr in &variant.attrs {
222        if attr.path.is_ident("topic") {
223            match attr.parse_args::<LitStr>() {
224                Ok(literal) => topic = Some(literal),
225                Err(_) => {
226                    return Err(syn::Error::new(
227                        attr.path.span(),
228                        r#"The 'topic' attribute is missing a String argument. Example: #[topic("system/{id}/start")] "#,
229                    ));
230                }
231            }
232        }
233    }
234    topic.ok_or_else(|| {
235        syn::Error::new(
236            variant.span(),
237            r#"The 'topic' attribute is required. Example: #[topic("system/{id}/start")]"#,
238        )
239    })
240}
241
242fn extract_substrings(s: &str) -> Vec<&str> {
243    s.split('{')
244        .skip(1)
245        .filter_map(|substr| substr.split_once('}'))
246        .map(|(outer, _)| outer)
247        .collect()
248}
249
250fn remove_substrings(s: &str, substrings: &[&str]) -> String {
251    let mut result = String::from(s);
252    for substring in substrings {
253        result = result.replace(&format!("{{{}}}", substring), "{}");
254    }
255    result
256}
257
258fn to_snake_case(input: &str) -> String {
259    input
260        .chars()
261        .enumerate()
262        .flat_map(|(i, c)| {
263            if c.is_uppercase() {
264                let mut s = String::new();
265                if i != 0 && !input.is_empty() && input.chars().next().unwrap().is_uppercase() {
266                    s.push('_');
267                }
268                s.push_str(&c.to_lowercase().to_string());
269                iter::once(s)
270            } else {
271                iter::once(c.to_string())
272            }
273        })
274        .collect::<Vec<String>>()
275        .join("")
276}