use std::collections::HashMap;
use crate::generated::descriptor::EnumDescriptorProto;
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use crate::context::CodeGenContext;
use crate::features::ResolvedFeatures;
use crate::CodeGenError;
fn generate_enum_serde(name_ident: &Ident) -> TokenStream {
quote! {
impl ::serde::Serialize for #name_ident {
fn serialize<S: ::serde::Serializer>(&self, s: S) -> ::core::result::Result<S::Ok, S::Error> {
s.serialize_str(::buffa::Enumeration::proto_name(self))
}
}
impl<'de> ::serde::Deserialize<'de> for #name_ident {
fn deserialize<D: ::serde::Deserializer<'de>>(d: D) -> ::core::result::Result<Self, D::Error> {
struct _V;
impl ::serde::de::Visitor<'_> for _V {
type Value = #name_ident;
fn expecting(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
f.write_str(concat!("a string, integer, or null for ", stringify!(#name_ident)))
}
fn visit_str<E: ::serde::de::Error>(self, v: &str) -> ::core::result::Result<#name_ident, E> {
<#name_ident as ::buffa::Enumeration>::from_proto_name(v).ok_or_else(|| {
::serde::de::Error::unknown_variant(v, &[])
})
}
fn visit_i64<E: ::serde::de::Error>(self, v: i64) -> ::core::result::Result<#name_ident, E> {
let v32 = i32::try_from(v).map_err(|_| {
::serde::de::Error::custom(
::buffa::alloc::format!("enum value {v} out of i32 range")
)
})?;
<#name_ident as ::buffa::Enumeration>::from_i32(v32).ok_or_else(|| {
::serde::de::Error::custom(
::buffa::alloc::format!("unknown enum value {v32}")
)
})
}
fn visit_u64<E: ::serde::de::Error>(self, v: u64) -> ::core::result::Result<#name_ident, E> {
let v32 = i32::try_from(v).map_err(|_| {
::serde::de::Error::custom(
::buffa::alloc::format!("enum value {v} out of i32 range")
)
})?;
<#name_ident as ::buffa::Enumeration>::from_i32(v32).ok_or_else(|| {
::serde::de::Error::custom(
::buffa::alloc::format!("unknown enum value {v32}")
)
})
}
fn visit_unit<E: ::serde::de::Error>(self) -> ::core::result::Result<#name_ident, E> {
::core::result::Result::Ok(::core::default::Default::default())
}
}
d.deserialize_any(_V)
}
}
impl ::buffa::json_helpers::ProtoElemJson for #name_ident {
fn serialize_proto_json<S: ::serde::Serializer>(
v: &Self,
s: S,
) -> ::core::result::Result<S::Ok, S::Error> {
::serde::Serialize::serialize(v, s)
}
fn deserialize_proto_json<'de, D: ::serde::Deserializer<'de>>(
d: D,
) -> ::core::result::Result<Self, D::Error> {
<Self as ::serde::Deserialize>::deserialize(d)
}
}
}
}
pub fn generate_enum(
ctx: &CodeGenContext,
enum_desc: &EnumDescriptorProto,
rust_name: &str,
proto_fqn: &str,
features: &ResolvedFeatures,
_resolver: &crate::imports::ImportResolver,
) -> Result<TokenStream, CodeGenError> {
let name_ident = format_ident!("{}", rust_name);
let mut seen: HashMap<i32, &str> = HashMap::new();
let mut variants = Vec::new();
let mut alias_consts = Vec::new();
let mut from_i32_arms = Vec::new();
let mut from_proto_name_arms: Vec<TokenStream> = Vec::new();
let mut proto_name_arms = Vec::new();
let mut value_idents: Vec<Ident> = Vec::new();
let mut zero_variant: Option<Ident> = None;
let mut first_variant: Option<Ident> = None;
let mut value_records: Vec<(String, Ident, String)> = Vec::new();
for v in &enum_desc.value {
let value_name = v
.name
.as_deref()
.ok_or(CodeGenError::MissingField("enum_value.name"))?;
let number = v
.number
.ok_or(CodeGenError::MissingField("enum_value.number"))?;
let variant_ident = crate::message::make_field_ident(value_name);
let value_fqn = format!("{}.{}", proto_fqn, value_name);
let variant_doc =
crate::comments::doc_attrs_resolved(ctx.comment(&value_fqn), proto_fqn, &ctx.type_map);
if let Some(&primary_name) = seen.get(&number) {
let primary_ident = crate::message::make_field_ident(primary_name);
alias_consts.push(quote! {
#variant_doc
#[allow(non_upper_case_globals)]
pub const #variant_ident: Self = Self::#primary_ident;
});
from_proto_name_arms.push(quote! {
#value_name => ::core::option::Option::Some(Self::#primary_ident)
});
if ctx.config.idiomatic_enum_aliases {
value_records.push((
value_name.to_string(),
primary_ident,
variant_ident.to_string(),
));
}
} else {
seen.insert(number, value_name);
if first_variant.is_none() {
first_variant = Some(variant_ident.clone());
}
if number == 0 && zero_variant.is_none() {
zero_variant = Some(variant_ident.clone());
}
variants.push(quote! { #variant_doc #variant_ident = #number });
from_i32_arms.push(quote! {
#number => ::core::option::Option::Some(Self::#variant_ident)
});
from_proto_name_arms.push(quote! {
#value_name => ::core::option::Option::Some(Self::#variant_ident)
});
proto_name_arms.push(quote! {
Self::#variant_ident => #value_name
});
if ctx.config.idiomatic_enum_aliases {
value_records.push((
value_name.to_string(),
variant_ident.clone(),
variant_ident.to_string(),
));
}
value_idents.push(variant_ident);
}
}
let enum_simple_name = enum_desc.name.as_deref().unwrap_or(rust_name);
let (idiomatic_consts, idiomatic_doc_note) =
idiomatic_aliases(ctx, rust_name, enum_simple_name, value_records);
let alias_block = if alias_consts.is_empty() && idiomatic_consts.is_empty() {
quote! {}
} else {
quote! {
impl #name_ident {
#(#alias_consts)*
#(#idiomatic_consts)*
}
}
};
let default_variant = if features.enum_type == crate::features::EnumType::Closed {
first_variant
} else {
zero_variant.or(first_variant)
};
let default_block = match default_variant {
Some(v) => quote! {
impl ::core::default::Default for #name_ident {
fn default() -> Self {
Self::#v
}
}
},
None => quote! {},
};
let serde_impls = if ctx.config.generate_json {
crate::feature_gates::cfg_const_block(
generate_enum_serde(&name_ident),
ctx.config.feature_gates().json,
)
} else {
quote! {}
};
let arbitrary_derive = if ctx.config.generate_arbitrary {
quote! { #[cfg_attr(feature = "arbitrary", derive(::arbitrary::Arbitrary))] }
} else {
quote! {}
};
let reflect_element_impl = if ctx.config.generate_reflection
&& ctx.config.generate_reflection_vtable
&& crate::message::is_closed_enum(features)
{
crate::feature_gates::cfg_block(
quote! {
impl ::buffa_descriptor::reflect::ReflectElement for #name_ident {
fn as_value_ref(&self) -> ::buffa_descriptor::reflect::ValueRef<'_> {
::buffa_descriptor::reflect::ValueRef::EnumNumber(
::buffa::Enumeration::to_i32(self),
)
}
}
},
ctx.config.feature_gates().reflect,
)
} else {
quote! {}
};
let enum_doc = {
let base =
crate::comments::doc_attrs_resolved(ctx.comment(proto_fqn), proto_fqn, &ctx.type_map);
quote! { #base #idiomatic_doc_note }
};
let custom_type_attrs = crate::context::CodeGenContext::matching_attributes(
&ctx.config.type_attributes,
proto_fqn,
)?;
let custom_enum_attrs = crate::context::CodeGenContext::matching_attributes(
&ctx.config.enum_attributes,
proto_fqn,
)?;
Ok(quote! {
#enum_doc
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
#arbitrary_derive
#custom_type_attrs
#custom_enum_attrs
#[repr(i32)]
pub enum #name_ident {
#(#variants,)*
}
#alias_block
#default_block
#serde_impls
impl ::buffa::Enumeration for #name_ident {
fn from_i32(value: i32) -> ::core::option::Option<Self> {
match value {
#(#from_i32_arms,)*
_ => ::core::option::Option::None,
}
}
fn to_i32(&self) -> i32 {
*self as i32
}
fn proto_name(&self) -> &'static str {
match self {
#(#proto_name_arms,)*
}
}
fn from_proto_name(name: &str) -> ::core::option::Option<Self> {
match name {
#(#from_proto_name_arms,)*
_ => ::core::option::Option::None,
}
}
fn values() -> &'static [Self] {
&[#(Self::#value_idents),*]
}
}
#reflect_element_impl
})
}
fn idiomatic_aliases(
ctx: &CodeGenContext,
rust_name: &str,
enum_simple_name: &str,
records: Vec<(String, Ident, String)>,
) -> (Vec<TokenStream>, TokenStream) {
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::fmt::Write;
if records.is_empty() {
return (Vec::new(), quote! {});
}
let is_valid = |c: &str| !c.is_empty() && !c.starts_with(|ch: char| ch.is_ascii_digit());
let prefix = format!("{}_", crate::idents::to_shouty_snake_case(enum_simple_name));
let strip = records.iter().all(|(name, ..)| {
name.strip_prefix(&prefix)
.is_some_and(|base| is_valid(&crate::idents::to_upper_camel_case(base)))
});
let camel = |name: &str| {
let base = if strip {
name.strip_prefix(&prefix).unwrap_or(name)
} else {
name
};
crate::idents::to_upper_camel_case(base)
};
let existing: HashSet<String> = records.iter().map(|(_, _, own)| own.clone()).collect();
let mut buckets: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
let mut invalid: BTreeSet<String> = BTreeSet::new();
{
let owner: HashMap<&str, &str> = records
.iter()
.map(|(name, _, own)| (own.as_str(), name.as_str()))
.collect();
for (name, _, _) in &records {
let candidate = camel(name);
if !is_valid(&candidate) {
invalid.insert(name.clone());
continue;
}
let escaped = crate::idents::make_field_ident(&candidate).to_string();
if let Some(&existing_owner) = owner.get(escaped.as_str()) {
buckets
.entry(escaped.clone())
.or_default()
.insert(existing_owner.to_string());
}
buckets.entry(escaped).or_default().insert(name.clone());
}
}
let conflicts: Vec<(&String, &BTreeSet<String>)> = buckets
.iter()
.filter(|(_, claimants)| claimants.len() > 1)
.collect();
if conflicts.is_empty() && invalid.is_empty() {
let consts = records
.into_iter()
.filter_map(|(name, target, _own)| {
let escaped = crate::idents::make_field_ident(&camel(&name));
if existing.contains(&escaped.to_string()) {
return None;
}
let target_name = target.to_string();
let alias_doc = if let Some(stripped) = target_name.strip_prefix("r#") {
format!("Idiomatic alias for `{stripped}`; `Debug` prints the variant name.")
} else {
format!(
"Idiomatic alias for [`Self::{target_name}`]; `Debug` prints the variant name."
)
};
Some(quote! {
#[doc = #alias_doc]
#[allow(non_upper_case_globals)]
pub const #escaped: Self = Self::#target;
})
})
.collect();
return (consts, quote! {});
}
let conflict_data: Vec<crate::AliasConflict> = conflicts
.iter()
.map(|(camel_ident, claimants)| crate::AliasConflict {
camel_target: (*camel_ident).clone(),
proto_values: claimants.iter().cloned().collect(),
})
.collect();
let invalid_data: Vec<String> = invalid.into_iter().collect();
let mut note = String::from(
"Idiomatic CamelCase aliases are not generated for this enum: two or more proto values \
collide after conversion (or would be invalid identifiers). Use the `SHOUTY_SNAKE_CASE` \
variants directly. Collisions:\n",
);
for conflict in &conflict_data {
let joined = conflict
.proto_values
.iter()
.map(|n| format!("`{n}`"))
.collect::<Vec<_>>()
.join(", ");
let _ = writeln!(note, "- {joined} → `{}`", conflict.camel_target);
}
for name in &invalid_data {
let _ = writeln!(note, "- `{name}` produces an invalid identifier");
}
ctx.warn(crate::CodeGenWarning::IdiomaticAliasesSuppressed {
enum_name: rust_name.to_string(),
conflicts: conflict_data,
invalid: invalid_data,
});
(Vec::new(), quote! { #[doc = #note] })
}