use convert_case::{Case, Casing};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{DeriveInput, LitStr, parse_macro_input};
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match generate(&input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into()
}
}
fn generate(input: &DeriveInput) -> syn::Result<TokenStream2> {
let name = &input.ident;
let variants = match &input.data {
syn::Data::Enum(data) => &data.variants,
_ => {
return Err(syn::Error::new_spanned(
input,
"ValueObject can only be derived for enums"
));
}
};
extract_pg_type(&input.attrs)?;
let variant_names: Vec<String> = variants
.iter()
.map(|v| {
let ident_str = v.ident.to_string();
ident_str.to_case(Case::Snake)
})
.collect();
let display_arms: Vec<TokenStream2> = variants
.iter()
.zip(&variant_names)
.map(|(v, name)| {
let variant_ident = &v.ident;
quote! { Self::#variant_ident => write!(f, #name) }
})
.collect();
let fromstr_arms: Vec<TokenStream2> = variants
.iter()
.zip(&variant_names)
.map(|(v, name)| {
let variant_ident = &v.ident;
let name_lower = name.to_lowercase();
quote! { #name_lower => Ok(Self::#variant_ident) }
})
.collect();
let asref_arms: Vec<TokenStream2> = variants
.iter()
.zip(&variant_names)
.map(|(v, name)| {
let variant_ident = &v.ident;
quote! { Self::#variant_ident => #name }
})
.collect();
Ok(quote! {
impl std::fmt::Display for #name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
#(#display_arms),*
}
}
}
impl std::str::FromStr for #name {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
#(#fromstr_arms),*,
other => Err(format!("unknown variant `{other}`"))
}
}
}
impl AsRef<str> for #name {
fn as_ref(&self) -> &str {
match self {
#(#asref_arms),*
}
}
}
impl std::convert::TryFrom<&str> for #name {
type Error = String;
fn try_from(s: &str) -> Result<Self, Self::Error> {
s.parse()
}
}
})
}
fn extract_pg_type(attrs: &[syn::Attribute]) -> syn::Result<String> {
let mut pg_type: Option<String> = None;
for attr in attrs {
if attr.path().is_ident("value_object")
&& let syn::Meta::List(meta_list) = &attr.meta
{
let _ = meta_list.parse_nested_meta(|meta| {
if meta.path.is_ident("pg_type") {
let val_stream = meta.value()?;
let lit: LitStr = val_stream.parse()?;
pg_type = Some(lit.value());
}
Ok(())
});
}
}
pg_type.ok_or_else(|| {
syn::Error::new(
proc_macro2::Span::call_site(),
"missing #[value_object(pg_type = \"...\")] attribute"
)
})
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_input(input: &str) -> DeriveInput {
syn::parse_str(input).unwrap()
}
fn normalize(s: &str) -> String {
s.chars().filter(|c| !c.is_whitespace()).collect()
}
#[test]
fn extract_pg_type_basic() {
let input: DeriveInput = syn::parse_quote! {
#[value_object(pg_type = "order_status")]
enum OrderStatus { Pending }
};
let pg_type = extract_pg_type(&input.attrs).unwrap();
assert_eq!(pg_type, "order_status");
}
#[test]
fn extract_pg_type_missing_fails() {
let input: DeriveInput = syn::parse_quote! {
enum OrderStatus { Pending }
};
let result = extract_pg_type(&input.attrs);
assert!(result.is_err());
}
#[test]
fn extract_pg_type_with_quotes() {
let input: DeriveInput = syn::parse_quote! {
#[value_object(pg_type = "user_role")]
enum UserRole { Admin }
};
let pg_type = extract_pg_type(&input.attrs).unwrap();
assert_eq!(pg_type, "user_role");
}
#[test]
fn generate_basic_enum() {
let input = parse_input(
r#"
#[value_object(pg_type = "order_status")]
enum OrderStatus {
Pending,
Confirmed,
Cancelled,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("DisplayforOrderStatus"));
assert!(output.contains("FromStrforOrderStatus"));
assert!(output.contains("AsRef<str>forOrderStatus"));
assert!(output.contains("TryFrom<&str>forOrderStatus"));
}
#[test]
fn display_output_lowercase() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status {
Pending,
Confirmed,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("write!(f,\"pending\")"));
assert!(output.contains("write!(f,\"confirmed\")"));
}
#[test]
fn display_output_underscore_variant() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status {
InProgress,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("write!(f,\"in_progress\")"));
}
#[test]
fn fromstr_case_insensitive() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status {
Active,
Inactive,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("\"active\"=>Ok(Self::Active)"));
assert!(output.contains("\"inactive\"=>Ok(Self::Inactive)"));
assert!(output.contains("s.to_lowercase().as_str()"));
}
#[test]
fn fromstr_error_unknown_variant() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Active }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("unknownvariant"));
}
#[test]
fn asref_matches_display() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status {
Pending,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("Self::Pending=>\"pending\""));
}
#[test]
fn tryfrom_delegates_to_parse() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Active }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("s.parse()"));
assert!(output.contains("typeError=String"));
}
#[test]
fn generate_for_non_enum_fails() {
let input = parse_input(
r#"
struct NotAnEnum {
field: String,
}
"#
);
let _result = generate(&input);
assert!(_result.is_err());
}
#[test]
fn roundtrip_display_fromstr() {
let input = parse_input(
r#"
#[value_object(pg_type = "order_status")]
enum OrderStatus {
Pending,
Confirmed,
Cancelled,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("write!(f,\"pending\")"));
assert!(output.contains("write!(f,\"confirmed\")"));
assert!(output.contains("write!(f,\"cancelled\")"));
assert!(output.contains("\"pending\"=>Ok(Self::Pending)"));
assert!(output.contains("\"confirmed\"=>Ok(Self::Confirmed)"));
assert!(output.contains("\"cancelled\"=>Ok(Self::Cancelled)"));
}
#[test]
fn variant_with_numbers() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status {
V2Active,
V3Inactive,
}
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("write!(f,\"v_2_active\")"));
assert!(output.contains("write!(f,\"v_3_inactive\")"));
}
#[test]
fn single_variant_enum() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Only }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("DisplayforStatus"));
assert!(output.contains("write!(f,\"only\")"));
assert!(output.contains("\"only\"=>Ok(Self::Only)"));
}
#[test]
fn multiple_pg_type_attributes_use_last() {
let input: DeriveInput = syn::parse_quote! {
#[value_object(pg_type = "first")]
#[value_object(pg_type = "second")]
enum Status { Active }
};
let _result = generate(&input);
let pg_type = extract_pg_type(&input.attrs).unwrap();
assert_eq!(pg_type, "second");
}
#[test]
fn fromstr_error_type_string() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Active }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("typeErr=String"));
}
#[test]
fn tryfrom_error_type_string() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Active }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("typeError=String"));
}
#[test]
fn all_traits_implemented() {
let input = parse_input(
r#"
#[value_object(pg_type = "status")]
enum Status { Active, Inactive }
"#
);
let result = generate(&input).unwrap();
let output = normalize(&result.to_string());
assert!(output.contains("DisplayforStatus"));
assert!(output.contains("FromStrforStatus"));
assert!(output.contains("AsRef<str>forStatus"));
assert!(output.contains("TryFrom<&str>forStatus"));
}
}