Skip to main content

enum2contract_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Ident, Span};
3use quote::quote;
4use syn::{Data, DeriveInput, Fields, LitStr, Variant, parse_macro_input, spanned::Spanned};
5
6#[proc_macro_derive(EnumContract, attributes(topic))]
7pub fn derive_enum2contract(input: TokenStream) -> TokenStream {
8    let input: DeriveInput = parse_macro_input!(input as DeriveInput);
9
10    let name = &input.ident;
11
12    let data = match input.data {
13        Data::Enum(data) => data,
14        _ => {
15            return syn::Error::new(input.span(), "enum2contract only supports enums")
16                .to_compile_error()
17                .into();
18        }
19    };
20
21    let mut message_functions = proc_macro2::TokenStream::new();
22    let mut payloads = proc_macro2::TokenStream::new();
23
24    for variant in data.variants.iter() {
25        match expand_variant(variant) {
26            Ok((payload, message_function)) => {
27                payloads.extend(payload);
28                message_functions.extend(message_function);
29            }
30            Err(error) => return error.to_compile_error().into(),
31        }
32    }
33
34    let expanded = quote! {
35        #payloads
36
37        impl #name {
38            #message_functions
39        }
40    };
41
42    TokenStream::from(expanded)
43}
44
45type VariantTokens = (proc_macro2::TokenStream, proc_macro2::TokenStream);
46
47fn expand_variant(variant: &Variant) -> Result<VariantTokens, syn::Error> {
48    let topic = parse_topic_attribute(variant)?;
49    let payload_name = Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
50
51    let payload_struct = match &variant.fields {
52        Fields::Unit => quote! {
53            #[derive(
54                Default,
55                Debug,
56                Clone,
57                PartialEq,
58                enum2contract::serde::Serialize,
59                enum2contract::serde::Deserialize,
60            )]
61            #[serde(crate = "enum2contract::serde")]
62            pub struct #payload_name;
63        },
64        Fields::Named(named_fields) => {
65            let mut fields = proc_macro2::TokenStream::new();
66            for field in named_fields.named.iter() {
67                fields.extend(quote! { pub #field, });
68            }
69            quote! {
70                #[derive(
71                    Default,
72                    Debug,
73                    Clone,
74                    PartialEq,
75                    enum2contract::serde::Serialize,
76                    enum2contract::serde::Deserialize,
77                )]
78                #[serde(crate = "enum2contract::serde")]
79                pub struct #payload_name {
80                    #fields
81                }
82            }
83        }
84        Fields::Unnamed(_) => {
85            return Err(syn::Error::new(
86                variant.span(),
87                "enum2contract is only implemented for unit and named-field enum variants",
88            ));
89        }
90    };
91
92    let payload = quote! {
93        #payload_struct
94
95        impl #payload_name {
96            pub fn to_json(&self) -> Result<String, enum2contract::serde_json::Error> {
97                enum2contract::serde_json::to_string(self)
98            }
99
100            pub fn from_json(json: &str) -> Result<Self, enum2contract::serde_json::Error> {
101                enum2contract::serde_json::from_str(json)
102            }
103
104            pub fn to_bytes(&self) -> Result<Vec<u8>, enum2contract::postcard::Error> {
105                enum2contract::postcard::to_allocvec(self)
106            }
107
108            pub fn from_bytes(bytes: &[u8]) -> Result<Self, enum2contract::postcard::Error> {
109                enum2contract::postcard::from_bytes(bytes)
110            }
111        }
112    };
113
114    let ident_name = to_snake_case(&variant.ident.to_string());
115    let create_message = Ident::new(&ident_name, variant.ident.span());
116    let create_topic = Ident::new(&format!("{ident_name}_topic"), variant.ident.span());
117    let placeholders = extract_placeholders(&topic)?;
118    let format_string = remove_placeholders(&topic.value(), &placeholders);
119    let parameters: Vec<Ident> = placeholders
120        .iter()
121        .map(|placeholder| Ident::new(placeholder, Span::call_site()))
122        .collect();
123
124    let message_function = quote! {
125        pub fn #create_message(#(#parameters: &str),*) -> (String, #payload_name) {
126            (Self::#create_topic(#(#parameters),*), #payload_name::default())
127        }
128
129        pub fn #create_topic(#(#parameters: &str),*) -> String {
130            format!(#format_string, #(#parameters),*)
131        }
132    };
133
134    Ok((payload, message_function))
135}
136
137fn parse_topic_attribute(variant: &Variant) -> Result<LitStr, syn::Error> {
138    let mut topic = None;
139    for attr in &variant.attrs {
140        if attr.path().is_ident("topic") {
141            match attr.parse_args::<LitStr>() {
142                Ok(literal) => topic = Some(literal),
143                Err(_) => {
144                    return Err(syn::Error::new(
145                        attr.path().span(),
146                        r#"The 'topic' attribute is missing a String argument. Example: #[topic("system/{id}/start")] "#,
147                    ));
148                }
149            }
150        }
151    }
152    topic.ok_or_else(|| {
153        syn::Error::new(
154            variant.span(),
155            r#"The 'topic' attribute is required. Example: #[topic("system/{id}/start")]"#,
156        )
157    })
158}
159
160fn extract_placeholders(topic: &LitStr) -> Result<Vec<String>, syn::Error> {
161    let value = topic.value();
162    let mut placeholders = Vec::new();
163    for segment in value.split('{').skip(1) {
164        let Some((placeholder, _)) = segment.split_once('}') else {
165            continue;
166        };
167        if placeholder.is_empty() {
168            return Err(syn::Error::new(
169                topic.span(),
170                "topic placeholders must be named, like {id}",
171            ));
172        }
173        if syn::parse_str::<Ident>(placeholder).is_err() {
174            return Err(syn::Error::new(
175                topic.span(),
176                format!("topic placeholder '{{{placeholder}}}' is not a valid identifier"),
177            ));
178        }
179        if placeholders.iter().any(|existing| existing == placeholder) {
180            return Err(syn::Error::new(
181                topic.span(),
182                format!("topic placeholder '{{{placeholder}}}' appears more than once"),
183            ));
184        }
185        placeholders.push(placeholder.to_string());
186    }
187    Ok(placeholders)
188}
189
190fn remove_placeholders(topic: &str, placeholders: &[String]) -> String {
191    let mut result = String::from(topic);
192    for placeholder in placeholders {
193        result = result.replace(&format!("{{{placeholder}}}"), "{}");
194    }
195    result
196}
197
198fn to_snake_case(input: &str) -> String {
199    let characters: Vec<char> = input.chars().collect();
200    let mut result = String::new();
201    for (index, character) in characters.iter().enumerate() {
202        if character.is_uppercase() {
203            let previous_is_lowercase = index > 0 && characters[index - 1].is_lowercase();
204            let previous_is_digit = index > 0 && characters[index - 1].is_ascii_digit();
205            let previous_is_uppercase = index > 0 && characters[index - 1].is_uppercase();
206            let next_is_lowercase = characters
207                .get(index + 1)
208                .is_some_and(|next| next.is_lowercase());
209            if previous_is_lowercase
210                || previous_is_digit
211                || (previous_is_uppercase && next_is_lowercase)
212            {
213                result.push('_');
214            }
215            result.extend(character.to_lowercase());
216        } else {
217            result.push(*character);
218        }
219    }
220    result
221}