#![no_std]
extern crate alloc;
extern crate proc_macro;
use alloc::fmt::format;
use parse::{
Input,
LabeledStringInput,
};
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{
quote,
ToTokens,
};
use syn::{
parse_macro_input,
Ident,
};
mod parse;
#[proc_macro_derive(SerializeStringEnum)]
pub fn derive_serialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Input);
let ident = input.ident;
TokenStream::from(quote! {
impl serde::Serialize for #ident {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.collect_str(self)
}
}
})
}
#[proc_macro_derive(DeserializeStringEnum)]
pub fn derive_deserialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as Input);
let ident = input.ident;
let visitor_ident = Ident::new(&format(format_args!("{ident}Visitor")), Span::call_site());
TokenStream::from(quote! {
struct #visitor_ident;
impl<'de> serde::de::Visitor<'de> for #visitor_ident {
type Value = #ident;
fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_fmt(format_args!("a valid {} string value", stringify!(#ident)))
}
fn visit_str<E>(self, v: &str) -> core::result::Result<Self::Value, E> where E: serde::de::Error {
match Self::Value::from_str(&v) {
Ok(v) => Ok(v),
Err(_) => Err(E::invalid_value(serde::de::Unexpected::Str(&v), &self)),
}
}
}
impl<'de> serde::Deserialize<'de> for #ident {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de> {
deserializer.deserialize_str(#visitor_ident)
}
}
})
}
#[proc_macro_derive(SerializeLabeledStringEnum, attributes(string))]
pub fn derive_labeled_serialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as LabeledStringInput);
let ident = input.ident;
let match_variants = input.variants.iter().map(|variant| {
let string = variant.attrs.string.as_ref().unwrap();
let variant = &variant.ident;
quote! {
Self::#variant => write!(f, #string),
}
});
TokenStream::from(quote! {
impl core::fmt::Display for #ident {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
#(#match_variants)*
}
}
}
impl serde::Serialize for #ident {
fn serialize<S>(&self, serializer: S) -> core::result::Result<S::Ok, S::Error> where S: serde::Serializer {
serializer.collect_str(self)
}
}
})
}
fn wrap_unicase<T>(t: &T) -> proc_macro2::TokenStream
where
T: ToTokens,
{
if cfg!(feature = "unicase") {
quote! {
unicase::UniCase::new(#t)
}
} else {
quote! {
#t
}
}
}
#[proc_macro_derive(DeserializeLabeledStringEnum, attributes(string, alias))]
pub fn derive_labeled_deserialize(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as LabeledStringInput);
let call_site = Span::call_site();
let ident = input.ident;
let visitor_ident = Ident::new(&format(format_args!("{ident}Visitor")), call_site);
let input_ident = Ident::new("s", call_site);
let match_variants = input.variants.iter().map(|variant| {
let variant_ident = &variant.ident;
let alias_match = variant.attrs.aliases.iter().map(|alias| {
let alias = wrap_unicase(alias);
quote! {
if s == #alias {
return Ok(Self::#variant_ident)
}
}
});
let string = variant.attrs.string.as_ref().unwrap();
let string = wrap_unicase(string);
quote! {
if #input_ident == #string {
return Ok(Self::#variant_ident)
}
#(#alias_match)*
}
});
let error_type = if cfg!(feature = "std") {
quote! {
std::string::String
}
} else if cfg!(feature = "alloc") {
quote! {
alloc::string::String
}
} else {
quote! {
&'static str
}
};
let error = if cfg!(feature = "std") {
quote! {
std::format!("invalid {}: {}", stringify!(#ident), #input_ident)
}
} else if cfg!(feature = "alloc") {
quote! {
alloc::fmt::format(format_args!("invalid {}: {}", stringify!(#ident), #input_ident))
}
} else {
quote! {
"invalid value"
}
};
let unicase_input = wrap_unicase(&input_ident);
TokenStream::from(quote! {
impl core::str::FromStr for #ident {
type Err = #error_type;
fn from_str(#input_ident: &str) -> core::result::Result<Self, Self::Err> {
let #input_ident = #unicase_input;
#(#match_variants)*
Err(#error)
}
}
struct #visitor_ident;
impl<'de> serde::de::Visitor<'de> for #visitor_ident {
type Value = #ident;
fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
formatter.write_fmt(format_args!("a valid {} string value", stringify!(#ident)))
}
fn visit_str<E>(self, v: &str) -> core::result::Result<Self::Value, E> where E: serde::de::Error {
use core::str::FromStr;
match Self::Value::from_str(&v) {
Ok(v) => Ok(v),
Err(_) => Err(E::invalid_value(serde::de::Unexpected::Str(&v), &self)),
}
}
}
impl<'de> serde::Deserialize<'de> for #ident {
fn deserialize<D>(deserializer: D) -> core::result::Result<Self, D::Error> where D: serde::Deserializer<'de> {
deserializer.deserialize_str(#visitor_ident)
}
}
})
}