enum2contract-derive 0.1.7

A pub/sub style message contract library.
Documentation
#![no_std]

extern crate alloc;

use alloc::{
    format,
    string::{String, ToString},
    vec::Vec,
};
use core::iter;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::quote;
use syn::{
    parse_macro_input, spanned::Spanned, Data, DeriveInput, Fields, FieldsNamed, LitStr, Variant,
};

#[proc_macro_derive(EnumContract, attributes(topic))]
pub fn derive_enum2contract(input: TokenStream) -> TokenStream {
    let input: DeriveInput = parse_macro_input!(input as DeriveInput);

    let name = &input.ident;

    let data = match input.data {
        Data::Enum(data) => data,
        _ => {
            return syn::Error::new(input.span(), "enum2contract only supports enums")
                .to_compile_error()
                .into()
        }
    };

    let mut message_functions = proc_macro2::TokenStream::new();
    let mut payloads = proc_macro2::TokenStream::new();

    for variant in data.variants.iter() {
        match variant.fields {
            Fields::Unit => {
                let topic = match parse_topic_attribute(variant) {
                    Ok(value) => value,
                    Err(error) => return error.to_compile_error().into(),
                };

                let payload_name =
                    Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());
                let payload_struct = quote!(
                    #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
                    pub struct #payload_name;
                );
                payloads.extend(payload_struct);

                #[cfg(feature = "json")]
                {
                    let json_conversions = quote!(
                        impl #payload_name {
                            pub fn to_json(&self) -> Result<String, serde_json::Error> {
                                serde_json::to_string(self)
                            }

                            pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
                                serde_json::from_str(json)
                            }
                        }
                    );
                    payloads.extend(json_conversions);
                }

                #[cfg(feature = "binary")]
                {
                    let binary_conversions = quote!(
                        impl #payload_name {
                            pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
                                postcard::to_allocvec(self)
                            }

                            pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
                                postcard::from_bytes(bytes)
                            }
                        }
                    );

                    payloads.extend(binary_conversions);
                }

                let payload_type = quote! { #payload_name };
                let payload_default = quote! { #payload_name::default() };
                let ident_name = &to_snake_case(&variant.ident.to_string());
                let create_message = Ident::new(ident_name, variant.ident.span());
                let create_topic =
                    Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
                let topic_string = &topic.value();
                let args = extract_substrings(topic_string);
                let topic_string = remove_substrings(&topic.value(), &args);
                let args: Vec<_> = args
                    .iter()
                    .map(|arg| Ident::new(arg, Span::call_site()))
                    .collect();

                let message_function = quote! {
                    pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
                        (Self::#create_topic(#(#args),*), #payload_default)
                    }

                    pub fn #create_topic(#(#args: &str),*) -> String {
                        format!(#topic_string, #(#args),*)
                    }
                };
                message_functions.extend(message_function);
            }

            Fields::Named(FieldsNamed { ref named, .. }) => {
                let topic = match parse_topic_attribute(variant) {
                    Ok(value) => value,
                    Err(error) => return error.to_compile_error().into(),
                };

                let mut fields = proc_macro2::TokenStream::new();

                for field in named.iter() {
                    fields.extend(quote! {
                        pub #field,
                    })
                }

                let payload_name =
                    Ident::new(&format!("{}Payload", variant.ident), variant.ident.span());

                let payload_struct = quote! {
                    use serde::{Serialize, Deserialize};

                    #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
                    pub struct #payload_name {
                        #fields
                    }
                };
                payloads.extend(payload_struct);

                #[cfg(feature = "json")]
                {
                    let json_conversions = quote!(
                        impl #payload_name {
                            pub fn to_json(&self) -> Result<String, serde_json::Error> {
                                serde_json::to_string(self)
                            }

                            pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
                                serde_json::from_str(json)
                            }
                        }
                    );
                    payloads.extend(json_conversions);
                }

                #[cfg(feature = "binary")]
                {
                    let binary_conversions = quote!(
                        impl #payload_name {
                            pub fn to_bytes(&self) -> Result<Vec<u8>, postcard::Error> {
                                postcard::to_allocvec(self)
                            }

                            pub fn from_binary(bytes: &[u8]) -> Result<Self, postcard::Error> {
                                postcard::from_bytes(bytes)
                            }
                        }
                    );

                    payloads.extend(binary_conversions);
                }

                let payload_type = quote! { #payload_name };
                let payload_default = quote! { #payload_name::default() };
                let ident_name = &to_snake_case(&variant.ident.to_string());
                let create_message = Ident::new(ident_name, variant.ident.span());
                let create_topic =
                    Ident::new(&format!("{}_topic", ident_name), variant.ident.span());
                let topic_string = &topic.value();
                let args = extract_substrings(topic_string);
                let topic_string = remove_substrings(&topic.value(), &args);
                let args: Vec<_> = args
                    .iter()
                    .map(|arg| Ident::new(arg, Span::call_site()))
                    .collect();

                let message_function = quote! {
                    pub fn #create_message(#(#args: &str),*) -> (String, #payload_type) {
                        (Self::#create_topic(#(#args),*), #payload_default)
                    }

                    pub fn #create_topic(#(#args: &str),*) -> String {
                        format!(#topic_string, #(#args),*)
                    }
                };
                message_functions.extend(message_function);
            }

            _ => {
                return syn::Error::new(
                    variant.span(),
                    "enum2contract is only implemented for unit and named-field enums",
                )
                .to_compile_error()
                .into()
            }
        };
    }

    let expanded = quote! {
        #payloads

        impl #name {
            #message_functions
        }
    };

    TokenStream::from(expanded)
}

fn parse_topic_attribute(variant: &Variant) -> Result<LitStr, syn::Error> {
    let mut topic = None;
    for attr in &variant.attrs {
        if attr.path.is_ident("topic") {
            match attr.parse_args::<LitStr>() {
                Ok(literal) => topic = Some(literal),
                Err(_) => {
                    return Err(syn::Error::new(
                        attr.path.span(),
                        r#"The 'topic' attribute is missing a String argument. Example: #[topic("system/{id}/start")] "#,
                    ));
                }
            }
        }
    }
    topic.ok_or_else(|| {
        syn::Error::new(
            variant.span(),
            r#"The 'topic' attribute is required. Example: #[topic("system/{id}/start")]"#,
        )
    })
}

fn extract_substrings(s: &str) -> Vec<&str> {
    s.split('{')
        .skip(1)
        .filter_map(|substr| substr.split_once('}'))
        .map(|(outer, _)| outer)
        .collect()
}

fn remove_substrings(s: &str, substrings: &[&str]) -> String {
    let mut result = String::from(s);
    for substring in substrings {
        result = result.replace(&format!("{{{}}}", substring), "{}");
    }
    result
}

fn to_snake_case(input: &str) -> String {
    input
        .chars()
        .enumerate()
        .flat_map(|(i, c)| {
            if c.is_uppercase() {
                let mut s = String::new();
                if i != 0 && !input.is_empty() && input.chars().next().unwrap().is_uppercase() {
                    s.push('_');
                }
                s.push_str(&c.to_lowercase().to_string());
                iter::once(s)
            } else {
                iter::once(c.to_string())
            }
        })
        .collect::<Vec<String>>()
        .join("")
}