enum2contract_derive/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, LitStr, Variant, parse_macro_input, spanned::Spanned};
5
6#[proc_macro_derive(EnumContract, attributes(topic))]
7pub fn derive_enum2contract(input: TokenStream) -> TokenStream {
8 let input: DeriveInput = parse_macro_input!(input as DeriveInput);
9
10 let name = &input.ident;
11
12 let data = match input.data {
13 Data::Enum(data) => data,
14 _ => {
15 return syn::Error::new(input.span(), "enum2contract only supports enums")
16 .to_compile_error()
17 .into();
18 }
19 };
20
21 let mut message_functions = proc_macro2::TokenStream::new();
22 let mut payloads = proc_macro2::TokenStream::new();
23
24 for variant in data.variants.iter() {
25 match expand_variant(variant) {
26 Ok((payload, message_function)) => {
27 payloads.extend(payload);
28 message_functions.extend(message_function);
29 }
30 Err(error) => return error.to_compile_error().into(),
31 }
32 }
33
34 let expanded = quote! {
35 #payloads
36
37 impl #name {
38 #message_functions
39 }
40 };
41
42 TokenStream::from(expanded)
43}
44
45type VariantTokens = (proc_macro2::TokenStream, proc_macro2::TokenStream);
46
47fn expand_variant(variant: &Variant) -> Result<VariantTokens, syn::Error> {
48 let topic = parse_topic_attribute(variant)?;
49 let payload_name = Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
50
51 let payload_struct = match &variant.fields {
52 Fields::Unit => quote! {
53 #[derive(
54 Default,
55 Debug,
56 Clone,
57 PartialEq,
58 enum2contract::serde::Serialize,
59 enum2contract::serde::Deserialize,
60 )]
61 #[serde(crate = "enum2contract::serde")]
62 pub struct #payload_name;
63 },
64 Fields::Named(named_fields) => {
65 let mut fields = proc_macro2::TokenStream::new();
66 for field in named_fields.named.iter() {
67 fields.extend(quote! { pub #field, });
68 }
69 quote! {
70 #[derive(
71 Default,
72 Debug,
73 Clone,
74 PartialEq,
75 enum2contract::serde::Serialize,
76 enum2contract::serde::Deserialize,
77 )]
78 #[serde(crate = "enum2contract::serde")]
79 pub struct #payload_name {
80 #fields
81 }
82 }
83 }
84 Fields::Unnamed(_) => {
85 return Err(syn::Error::new(
86 variant.span(),
87 "enum2contract is only implemented for unit and named-field enum variants",
88 ));
89 }
90 };
91
92 let payload = quote! {
93 #payload_struct
94
95 impl #payload_name {
96 pub fn to_json(&self) -> Result<String, enum2contract::serde_json::Error> {
97 enum2contract::serde_json::to_string(self)
98 }
99
100 pub fn from_json(json: &str) -> Result<Self, enum2contract::serde_json::Error> {
101 enum2contract::serde_json::from_str(json)
102 }
103
104 pub fn to_bytes(&self) -> Result<Vec<u8>, enum2contract::postcard::Error> {
105 enum2contract::postcard::to_allocvec(self)
106 }
107
108 pub fn from_bytes(bytes: &[u8]) -> Result<Self, enum2contract::postcard::Error> {
109 enum2contract::postcard::from_bytes(bytes)
110 }
111 }
112 };
113
114 let ident_name = to_snake_case(&variant.ident.to_string());
115 let create_message = Ident::new(&ident_name, variant.ident.span());
116 let create_topic = Ident::new(&format!("{ident_name}_topic"), variant.ident.span());
117 let placeholders = extract_placeholders(&topic)?;
118 let format_string = remove_placeholders(&topic.value(), &placeholders);
119 let parameters: Vec<Ident> = placeholders
120 .iter()
121 .map(|placeholder| Ident::new(placeholder, Span::call_site()))
122 .collect();
123
124 let message_function = quote! {
125 pub fn #create_message(#(#parameters: &str),*) -> (String, #payload_name) {
126 (Self::#create_topic(#(#parameters),*), #payload_name::default())
127 }
128
129 pub fn #create_topic(#(#parameters: &str),*) -> String {
130 format!(#format_string, #(#parameters),*)
131 }
132 };
133
134 Ok((payload, message_function))
135}
136
137fn parse_topic_attribute(variant: &Variant) -> Result<LitStr, syn::Error> {
138 let mut topic = None;
139 for attr in &variant.attrs {
140 if attr.path().is_ident("topic") {
141 match attr.parse_args::<LitStr>() {
142 Ok(literal) => topic = Some(literal),
143 Err(_) => {
144 return Err(syn::Error::new(
145 attr.path().span(),
146 r#"The 'topic' attribute is missing a String argument. Example: #[topic("system/{id}/start")] "#,
147 ));
148 }
149 }
150 }
151 }
152 topic.ok_or_else(|| {
153 syn::Error::new(
154 variant.span(),
155 r#"The 'topic' attribute is required. Example: #[topic("system/{id}/start")]"#,
156 )
157 })
158}
159
160fn extract_placeholders(topic: &LitStr) -> Result<Vec<String>, syn::Error> {
161 let value = topic.value();
162 let mut placeholders = Vec::new();
163 for segment in value.split('{').skip(1) {
164 let Some((placeholder, _)) = segment.split_once('}') else {
165 continue;
166 };
167 if placeholder.is_empty() {
168 return Err(syn::Error::new(
169 topic.span(),
170 "topic placeholders must be named, like {id}",
171 ));
172 }
173 if syn::parse_str::<Ident>(placeholder).is_err() {
174 return Err(syn::Error::new(
175 topic.span(),
176 format!("topic placeholder '{{{placeholder}}}' is not a valid identifier"),
177 ));
178 }
179 if placeholders.iter().any(|existing| existing == placeholder) {
180 return Err(syn::Error::new(
181 topic.span(),
182 format!("topic placeholder '{{{placeholder}}}' appears more than once"),
183 ));
184 }
185 placeholders.push(placeholder.to_string());
186 }
187 Ok(placeholders)
188}
189
190fn remove_placeholders(topic: &str, placeholders: &[String]) -> String {
191 let mut result = String::from(topic);
192 for placeholder in placeholders {
193 result = result.replace(&format!("{{{placeholder}}}"), "{}");
194 }
195 result
196}
197
198fn to_snake_case(input: &str) -> String {
199 let characters: Vec<char> = input.chars().collect();
200 let mut result = String::new();
201 for (index, character) in characters.iter().enumerate() {
202 if character.is_uppercase() {
203 let previous_is_lowercase = index > 0 && characters[index - 1].is_lowercase();
204 let previous_is_digit = index > 0 && characters[index - 1].is_ascii_digit();
205 let previous_is_uppercase = index > 0 && characters[index - 1].is_uppercase();
206 let next_is_lowercase = characters
207 .get(index + 1)
208 .is_some_and(|next| next.is_lowercase());
209 if previous_is_lowercase
210 || previous_is_digit
211 || (previous_is_uppercase && next_is_lowercase)
212 {
213 result.push('_');
214 }
215 result.extend(character.to_lowercase());
216 } else {
217 result.push(*character);
218 }
219 }
220 result
221}