enum2contract_derive/
lib.rs1#![no_std]
2
3extern crate alloc;
4
5use alloc::{
6 format,
7 string::{String, ToString},
8 vec::Vec,
9};
10use core::iter;
11use proc_macro::TokenStream;
12use proc_macro2::{Ident, Span};
13use quote::quote;
14use syn::{
15 parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed, LitStr, Variant,
16};
17
18#[proc_macro_derive(EnumContract, attributes(topic))]
19pub fn derive_enum2contract(input: TokenStream) -> TokenStream {
20 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
21
22 let name = &input.ident;
23
24 let data = match input.data {
25 Data::Enum(data) => data,
26 _ => {
27 return syn::Error::new(input.span(), "enum2contract only supports enums")
28 .to_compile_error()
29 .into()
30 }
31 };
32
33 let mut message_functions = proc_macro2::TokenStream::new();
34 let mut payloads = proc_macro2::TokenStream::new();
35
36 for variant in data.variants.iter() {
37 match variant.fields {
38 Fields::Unit => {
39 let topic = match parse_topic_attribute(variant) {
40 Ok(value) => value,
41 Err(error) => return error.to_compile_error().into(),
42 };
43
44 let payload_name =
45 Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
46 let payload_struct = quote!(
47 #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
48 pub struct #payload_name;
49 );
50 payloads.extend(payload_struct);
51
52 #[cfg(feature = "json")]
53 {
54 let json_conversions = quote!(
55 impl #payload_name {
56 pub fn to_json(&self) -> Result<String, serde_json::Error> {
57 serde_json::to_string(self)
58 }
59
60 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
61 serde_json::from_str(json)
62 }
63 }
64 );
65 payloads.extend(json_conversions);
66 }
67
68 #[cfg(feature = "binary")]
69 {
70 let binary_conversions = quote!(
71 impl #payload_name {
72 pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
73 postcard::to_allocvec(self)
74 }
75
76 pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
77 postcard::from_bytes(bytes)
78 }
79 }
80 );
81
82 payloads.extend(binary_conversions);
83 }
84
85 let payload_type = quote! { #payload_name };
86 let payload_default = quote! { #payload_name::default() };
87 let ident_name = &to_snake_case(&variant.ident.to_string());
88 let create_message = Ident::new(ident_name, variant.ident.span());
89 let create_topic =
90 Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
91 let topic_string = &topic.value();
92 let args = extract_substrings(topic_string);
93 let topic_string = remove_substrings(&topic.value(), &args);
94 let args: Vec<_> = args
95 .iter()
96 .map(|arg| Ident::new(arg, Span::call_site()))
97 .collect();
98
99 let message_function = quote! {
100 pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
101 (Self::#create_topic(#(#args),*), #payload_default)
102 }
103
104 pub fn #create_topic(#(#args: &str),*) -> String {
105 format!(#topic_string, #(#args),*)
106 }
107 };
108 message_functions.extend(message_function);
109 }
110
111 Fields::Named(FieldsNamed { ref named, .. }) => {
112 let topic = match parse_topic_attribute(variant) {
113 Ok(value) => value,
114 Err(error) => return error.to_compile_error().into(),
115 };
116
117 let mut fields = proc_macro2::TokenStream::new();
118
119 for field in named.iter() {
120 fields.extend(quote! {
121 pub #field,
122 })
123 }
124
125 let payload_name =
126 Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
127
128 let payload_struct = quote! {
129 use serde::{Serialize, Deserialize};
130
131 #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
132 pub struct #payload_name {
133 #fields
134 }
135 };
136 payloads.extend(payload_struct);
137
138 #[cfg(feature = "json")]
139 {
140 let json_conversions = quote!(
141 impl #payload_name {
142 pub fn to_json(&self) -> Result<String, serde_json::Error> {
143 serde_json::to_string(self)
144 }
145
146 pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
147 serde_json::from_str(json)
148 }
149 }
150 );
151 payloads.extend(json_conversions);
152 }
153
154 #[cfg(feature = "binary")]
155 {
156 let binary_conversions = quote!(
157 impl #payload_name {
158 pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
159 postcard::to_allocvec(self)
160 }
161
162 pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
163 postcard::from_bytes(bytes)
164 }
165 }
166 );
167
168 payloads.extend(binary_conversions);
169 }
170
171 let payload_type = quote! { #payload_name };
172 let payload_default = quote! { #payload_name::default() };
173 let ident_name = &to_snake_case(&variant.ident.to_string());
174 let create_message = Ident::new(ident_name, variant.ident.span());
175 let create_topic =
176 Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
177 let topic_string = &topic.value();
178 let args = extract_substrings(topic_string);
179 let topic_string = remove_substrings(&topic.value(), &args);
180 let args: Vec<_> = args
181 .iter()
182 .map(|arg| Ident::new(arg, Span::call_site()))
183 .collect();
184
185 let message_function = quote! {
186 pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
187 (Self::#create_topic(#(#args),*), #payload_default)
188 }
189
190 pub fn #create_topic(#(#args: &str),*) -> String {
191 format!(#topic_string, #(#args),*)
192 }
193 };
194 message_functions.extend(message_function);
195 }
196
197 _ => {
198 return syn::Error::new(
199 variant.span(),
200 "enum2contract is only implemented for unit and named-field enums",
201 )
202 .to_compile_error()
203 .into()
204 }
205 };
206 }
207
208 let expanded = quote! {
209 #payloads
210
211 impl #name {
212 #message_functions
213 }
214 };
215
216 TokenStream::from(expanded)
217}
218
219fn parse_topic_attribute(variant: &Variant) -> Result<LitStr, syn::Error> {
220 let mut topic = None;
221 for attr in &variant.attrs {
222 if attr.path.is_ident("topic") {
223 match attr.parse_args::<LitStr>() {
224 Ok(literal) => topic = Some(literal),
225 Err(_) => {
226 return Err(syn::Error::new(
227 attr.path.span(),
228 r#"The 'topic' attribute is missing a String argument. Example: #[topic("system/{id}/start")] "#,
229 ));
230 }
231 }
232 }
233 }
234 topic.ok_or_else(|| {
235 syn::Error::new(
236 variant.span(),
237 r#"The 'topic' attribute is required. Example: #[topic("system/{id}/start")]"#,
238 )
239 })
240}
241
242fn extract_substrings(s: &str) -> Vec<&str> {
243 s.split('{')
244 .skip(1)
245 .filter_map(|substr| substr.split_once('}'))
246 .map(|(outer, _)| outer)
247 .collect()
248}
249
250fn remove_substrings(s: &str, substrings: &[&str]) -> String {
251 let mut result = String::from(s);
252 for substring in substrings {
253 result = result.replace(&format!("{{{}}}", substring), "{}");
254 }
255 result
256}
257
258fn to_snake_case(input: &str) -> String {
259 input
260 .chars()
261 .enumerate()
262 .flat_map(|(i, c)| {
263 if c.is_uppercase() {
264 let mut s = String::new();
265 if i != 0 && !input.is_empty() && input.chars().next().unwrap().is_uppercase() {
266 s.push('_');
267 }
268 s.push_str(&c.to_lowercase().to_string());
269 iter::once(s)
270 } else {
271 iter::once(c.to_string())
272 }
273 })
274 .collect::<Vec<String>>()
275 .join("")
276}