use crate::generated::descriptor::field_descriptor_proto::Type;
use crate::generated::descriptor::{
DescriptorProto, FieldDescriptorProto, FileDescriptorProto, OneofDescriptorProto,
};
use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use crate::context::CodeGenContext;
use crate::features::ResolvedFeatures;
use crate::impl_message::{field_string_repr, field_uses_bytes};
use crate::message::scalar_or_message_type_nested;
use crate::CodeGenError;
pub(crate) fn is_null_value_field(field: &FieldDescriptorProto) -> bool {
field.type_name.as_deref() == Some(".google.protobuf.NullValue")
}
pub(crate) fn null_is_valid_value(field: &FieldDescriptorProto) -> bool {
matches!(
field.type_name.as_deref(),
Some(".google.protobuf.NullValue") | Some(".google.protobuf.Value")
)
}
pub(crate) fn is_boxed_variant(ty: Type) -> bool {
matches!(ty, Type::TYPE_MESSAGE | Type::TYPE_GROUP)
}
pub(crate) fn variant_boxed(ctx: &CodeGenContext, ty: Type, variant_fqn: &str) -> bool {
is_boxed_variant(ty) && !ctx.oneof_unboxed(variant_fqn)
}
pub(crate) fn resolve_unboxed_variants(
files: &[FileDescriptorProto],
rules: &[String],
) -> std::collections::HashSet<String> {
let mut resolved = std::collections::HashSet::new();
if rules.is_empty() {
return resolved;
}
let index = message_index(files);
for (msg_fqn, msg) in &index {
for_each_message_variant(msg, msg_fqn, |variant_fqn, type_name| {
if rule_matches(rules, &variant_fqn)
&& !unboxing_is_recursive(&index, rules, msg_fqn, type_name)
{
resolved.insert(variant_fqn);
}
});
}
resolved
}
fn rule_matches(rules: &[String], variant_fqn: &str) -> bool {
rules
.iter()
.any(|prefix| crate::context::matches_proto_prefix(prefix, variant_fqn))
}
fn for_each_message_variant(msg: &DescriptorProto, msg_fqn: &str, mut f: impl FnMut(String, &str)) {
for field in &msg.field {
if !crate::impl_message::is_real_oneof_member(field) {
continue;
}
if !is_boxed_variant(field.r#type.unwrap_or_default()) {
continue;
}
let (Some(oneof_idx), Some(field_name), Some(type_name)) = (
field.oneof_index,
field.name.as_deref(),
field.type_name.as_deref(),
) else {
continue;
};
let Some(oneof_name) = usize::try_from(oneof_idx)
.ok()
.and_then(|i| msg.oneof_decl.get(i))
.and_then(|o| o.name.as_deref())
else {
continue;
};
f(
format!(".{msg_fqn}.{oneof_name}.{field_name}"),
type_name.trim_start_matches('.'),
);
}
}
fn message_index(
files: &[FileDescriptorProto],
) -> std::collections::HashMap<String, &DescriptorProto> {
fn walk<'a>(
map: &mut std::collections::HashMap<String, &'a DescriptorProto>,
prefix: &str,
msg: &'a DescriptorProto,
) {
let Some(name) = msg.name.as_deref() else {
return;
};
let fqn = if prefix.is_empty() {
name.to_string()
} else {
format!("{prefix}.{name}")
};
for nested in &msg.nested_type {
walk(map, &fqn, nested);
}
map.insert(fqn, msg);
}
let mut map = std::collections::HashMap::new();
for file in files {
let package = file.package.as_deref().unwrap_or("");
for msg in &file.message_type {
walk(&mut map, package, msg);
}
}
map
}
fn unboxing_is_recursive(
index: &std::collections::HashMap<String, &DescriptorProto>,
rules: &[String],
enclosing: &str,
target: &str,
) -> bool {
let mut seen = std::collections::HashSet::new();
let mut stack = vec![target.to_string()];
while let Some(current) = stack.pop() {
if current == enclosing {
return true;
}
if !seen.insert(current.clone()) {
continue;
}
let Some(msg) = index.get(current.as_str()) else {
continue;
};
for_each_message_variant(msg, ¤t, |variant_fqn, type_name| {
if rule_matches(rules, &variant_fqn) {
stack.push(type_name.to_string());
}
});
}
false
}
struct VariantInfo {
variant_ident: proc_macro2::Ident,
rust_type: TokenStream,
json_name: String,
field_type: Type,
is_null_value: bool,
is_boxed: bool,
custom_attrs: TokenStream,
use_bytes: bool,
string_repr: crate::StringRepr,
debug_redact: bool,
}
#[allow(clippy::too_many_arguments)]
fn collect_variant_info(
ctx: &CodeGenContext,
msg: &DescriptorProto,
oneof_name: &str,
current_package: &str,
proto_fqn: &str,
features: &ResolvedFeatures,
resolver: &crate::imports::ImportResolver,
nesting: usize,
) -> Result<Vec<VariantInfo>, CodeGenError> {
let oneof_index = msg
.oneof_decl
.iter()
.position(|o| o.name.as_deref() == Some(oneof_name))
.ok_or_else(|| CodeGenError::Other(format!("oneof '{oneof_name}' not found in message")))?;
let fields: Vec<&FieldDescriptorProto> = msg
.field
.iter()
.filter(|f| {
f.oneof_index == Some(oneof_index as i32) && !f.proto3_optional.unwrap_or(false)
})
.collect();
fields
.iter()
.map(|field| {
let proto_name = field
.name
.as_deref()
.ok_or(CodeGenError::MissingField("field.name"))?;
let json_name = field.json_name.as_deref().unwrap_or(proto_name).to_string();
let variant_ident = oneof_variant_ident(proto_name);
let field_type = crate::impl_message::effective_type(ctx, field, features);
let use_bytes =
field_type == Type::TYPE_BYTES && field_uses_bytes(ctx, proto_fqn, proto_name);
let string_repr = if field_type == Type::TYPE_STRING {
field_string_repr(ctx, proto_fqn, proto_name)
} else {
crate::StringRepr::String
};
let rust_type = if use_bytes {
quote! { ::buffa::bytes::Bytes }
} else if field_type == Type::TYPE_STRING && !string_repr.is_default() {
string_repr.type_path(resolver)
} else {
scalar_or_message_type_nested(
ctx,
field,
current_package,
nesting + 3,
features,
resolver,
)?
};
let variant_fqn = format!("{proto_fqn}.{oneof_name}.{proto_name}");
let custom_attrs =
CodeGenContext::matching_attributes(&ctx.config.field_attributes, &variant_fqn)?;
let dotted_fqn = format!(".{variant_fqn}");
let is_boxed = variant_boxed(ctx, field_type, &dotted_fqn);
if is_boxed
&& ctx
.config
.unboxed_oneof_fields
.iter()
.any(|r| r == &dotted_fqn)
{
return Err(CodeGenError::Other(format!(
"oneof variant `{variant_fqn}` is recursive and cannot be \
stored inline: it would make the generated enum unsized. \
Remove `\"{dotted_fqn}\"` from unbox_oneof_in, or use a \
broader prefix (or unbox_oneof()) to keep this variant \
boxed while inlining the rest."
)));
}
Ok(VariantInfo {
variant_ident,
rust_type,
json_name,
field_type,
is_boxed,
is_null_value: is_null_value_field(field),
custom_attrs,
use_bytes,
string_repr,
debug_redact: crate::message::is_debug_redacted(field),
})
})
.collect()
}
#[allow(clippy::too_many_arguments)]
pub fn generate_oneof_enum(
ctx: &CodeGenContext,
msg: &DescriptorProto,
idx: usize,
oneof: &OneofDescriptorProto,
current_package: &str,
proto_fqn: &str,
features: &ResolvedFeatures,
resolver: &crate::imports::ImportResolver,
oneof_idents: &std::collections::HashMap<usize, proc_macro2::Ident>,
nesting: usize,
) -> Result<TokenStream, CodeGenError> {
let rust_enum_ident = match oneof_idents.get(&idx) {
Some(id) => id.clone(),
None => return Ok(TokenStream::new()),
};
let oneof_name = oneof
.name
.as_deref()
.ok_or(CodeGenError::MissingField("oneof.name"))?;
let variants_info = collect_variant_info(
ctx,
msg,
oneof_name,
current_package,
proto_fqn,
features,
resolver,
nesting,
)?;
if variants_info.is_empty() {
return Ok(TokenStream::new());
}
let variants: Vec<_> = variants_info
.iter()
.map(|v| {
let ident = &v.variant_ident;
let ty = &v.rust_type;
let attrs = &v.custom_attrs;
let arbitrary_field_attr = if ctx.config.generate_arbitrary && v.use_bytes {
quote! { #[cfg_attr(feature = "arbitrary", arbitrary(with = ::buffa::__private::arbitrary_bytes))] }
} else if ctx.config.generate_arbitrary
&& v.string_repr == crate::StringRepr::EcoString
{
quote! { #[cfg_attr(feature = "arbitrary", arbitrary(with = ::buffa::__private::arbitrary_ecow))] }
} else {
quote! {}
};
if v.is_boxed {
debug_assert!(!v.use_bytes, "boxed oneof variant cannot be bytes_fields-typed");
quote! { #attrs #ident(::buffa::alloc::boxed::Box<#ty>) }
} else {
quote! { #attrs #ident(#arbitrary_field_attr #ty) }
}
})
.collect();
let mut type_counts: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for v in variants_info
.iter()
.filter(|v| is_boxed_variant(v.field_type))
{
*type_counts.entry(v.rust_type.to_string()).or_insert(0) += 1;
}
let from_impls: Vec<_> = variants_info
.iter()
.filter(|v| is_boxed_variant(v.field_type) && type_counts[&v.rust_type.to_string()] == 1)
.map(|v| {
let ident = &v.variant_ident;
let ty = &v.rust_type;
let ty_str = ty.to_string();
let ty_is_extern = ty_str.trim_start().starts_with("::");
let wrapped = if v.is_boxed {
quote! { ::buffa::alloc::boxed::Box::new(v) }
} else {
quote! { v }
};
let from_oneof = quote! {
impl From<#ty> for #rust_enum_ident {
fn from(v: #ty) -> Self {
Self::#ident(#wrapped)
}
}
};
let from_option = if ty_is_extern {
quote! {}
} else {
quote! {
impl From<#ty> for ::core::option::Option<#rust_enum_ident> {
fn from(v: #ty) -> Self {
Self::Some(#rust_enum_ident::from(v))
}
}
}
};
quote! { #from_oneof #from_option }
})
.collect();
let serde_impls = if ctx.config.generate_json {
crate::feature_gates::cfg_block(
generate_oneof_serialize(&rust_enum_ident, &variants_info),
ctx.config.feature_gates().json,
)
} else {
quote! {}
};
let arbitrary_derive = if ctx.config.generate_arbitrary {
quote! { #[cfg_attr(feature = "arbitrary", derive(::arbitrary::Arbitrary))] }
} else {
quote! {}
};
let oneof_fqn = format!("{}.{}", proto_fqn, oneof_name);
let oneof_doc =
crate::comments::doc_attrs_resolved(ctx.comment(&oneof_fqn), proto_fqn, &ctx.type_map);
let custom_type_attrs =
CodeGenContext::matching_attributes(&ctx.config.type_attributes, &oneof_fqn)?;
let large_variant_allow = if variants_info
.iter()
.any(|v| is_boxed_variant(v.field_type) && !v.is_boxed)
{
quote! { #[allow(clippy::large_enum_variant)] }
} else {
quote! {}
};
let any_redacted = variants_info.iter().any(|v| v.debug_redact);
let (debug_derive, debug_impl) = if any_redacted {
let placeholder = crate::message::DEBUG_REDACT_PLACEHOLDER;
let arms: Vec<TokenStream> = variants_info
.iter()
.map(|v| {
let ident = &v.variant_ident;
let name = ident.to_string();
if v.debug_redact {
quote! {
Self::#ident(_) => f
.debug_tuple(#name)
.field(&::core::format_args!(#placeholder))
.finish(),
}
} else {
quote! {
Self::#ident(value) => f.debug_tuple(#name).field(value).finish(),
}
}
})
.collect();
(
quote! { #[derive(Clone, PartialEq)] },
quote! {
impl ::core::fmt::Debug for #rust_enum_ident {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
match self {
#(#arms)*
}
}
}
},
)
} else {
(quote! { #[derive(Clone, PartialEq, Debug)] }, quote! {})
};
Ok(quote! {
#oneof_doc
#debug_derive
#arbitrary_derive
#large_variant_allow
#custom_type_attrs
pub enum #rust_enum_ident {
#(#variants,)*
}
#debug_impl
impl ::buffa::Oneof for #rust_enum_ident {}
#(#from_impls)*
#serde_impls
})
}
pub(crate) fn serde_helper_path(field_type: Type) -> Option<TokenStream> {
match field_type {
Type::TYPE_INT32 | Type::TYPE_SINT32 | Type::TYPE_SFIXED32 => {
Some(quote! { ::buffa::json_helpers::int32 })
}
Type::TYPE_UINT32 | Type::TYPE_FIXED32 => Some(quote! { ::buffa::json_helpers::uint32 }),
Type::TYPE_INT64 | Type::TYPE_SINT64 | Type::TYPE_SFIXED64 => {
Some(quote! { ::buffa::json_helpers::int64 })
}
Type::TYPE_UINT64 | Type::TYPE_FIXED64 => Some(quote! { ::buffa::json_helpers::uint64 }),
Type::TYPE_FLOAT => Some(quote! { ::buffa::json_helpers::float }),
Type::TYPE_DOUBLE => Some(quote! { ::buffa::json_helpers::double }),
Type::TYPE_BYTES => Some(quote! { ::buffa::json_helpers::bytes }),
_ => None,
}
}
fn generate_oneof_serialize(
enum_ident: &proc_macro2::Ident,
variants: &[VariantInfo],
) -> TokenStream {
let arms: Vec<_> = variants
.iter()
.map(|v| {
let ident = &v.variant_ident;
let json_name = &v.json_name;
if v.is_null_value {
return quote! {
Self::#ident(_) => {
map.serialize_entry(#json_name, &())?;
}
};
}
let rust_type = &v.rust_type;
if let Some(helper) = serde_helper_path(v.field_type) {
quote! {
Self::#ident(v) => {
struct _W<'a>(&'a #rust_type);
impl serde::Serialize for _W<'_> {
fn serialize<S2: serde::Serializer>(&self, s: S2) -> ::core::result::Result<S2::Ok, S2::Error> {
#helper::serialize(self.0, s)
}
}
map.serialize_entry(#json_name, &_W(v))?;
}
}
} else {
quote! {
Self::#ident(v) => {
map.serialize_entry(#json_name, v)?;
}
}
}
})
.collect();
quote! {
impl serde::Serialize for #enum_ident {
fn serialize<S: serde::Serializer>(&self, s: S) -> ::core::result::Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
let mut map = s.serialize_map(Some(1))?;
match self {
#(#arms)*
}
map.end()
}
}
}
}
pub(crate) struct OneofVariantDeserInput<'a> {
pub variant_ident: &'a Ident,
pub variant_type: &'a TokenStream,
pub json_name: &'a str,
pub proto_name: &'a str,
pub field_type: Type,
pub null_forward: bool,
pub is_boxed: bool,
pub enum_ident: &'a TokenStream,
pub result_var: &'a Ident,
pub oneof_name: &'a str,
}
pub(crate) fn oneof_variant_deser_arm(input: &OneofVariantDeserInput<'_>) -> TokenStream {
let OneofVariantDeserInput {
variant_ident,
variant_type,
json_name,
proto_name,
field_type,
null_forward,
is_boxed,
enum_ident,
result_var,
oneof_name,
} = input;
let dup_err_msg = format!("multiple oneof fields set for '{oneof_name}'");
let wrapped_v = if *is_boxed {
quote! { ::buffa::alloc::boxed::Box::new(v) }
} else {
quote! { v }
};
let (deser, set_result) = if *null_forward {
let deser = quote! {
let v: #variant_type = map.next_value_seed(
::buffa::json_helpers::DefaultDeserializeSeed::<#variant_type>::new()
)?;
};
let set = quote! {
if #result_var.is_some() {
return Err(serde::de::Error::custom(#dup_err_msg));
}
#result_var = Some(#enum_ident::#variant_ident(#wrapped_v));
};
(deser, set)
} else {
let deser = if let Some(helper) = serde_helper_path(*field_type) {
quote! {
struct _DeserSeed;
impl<'de> serde::de::DeserializeSeed<'de> for _DeserSeed {
type Value = #variant_type;
fn deserialize<D: serde::Deserializer<'de>>(self, d: D) -> ::core::result::Result<#variant_type, D::Error> {
#helper::deserialize(d)
}
}
let v: ::core::option::Option<#variant_type> = map.next_value_seed(
::buffa::json_helpers::NullableDeserializeSeed(_DeserSeed)
)?;
}
} else {
quote! {
let v: ::core::option::Option<#variant_type> = map.next_value_seed(
::buffa::json_helpers::NullableDeserializeSeed(
::buffa::json_helpers::DefaultDeserializeSeed::<#variant_type>::new()
)
)?;
}
};
let set = quote! {
if let Some(v) = v {
if #result_var.is_some() {
return Err(serde::de::Error::custom(#dup_err_msg));
}
#result_var = Some(#enum_ident::#variant_ident(#wrapped_v));
}
};
(deser, set)
};
if json_name == proto_name {
quote! {
#json_name => {
#deser
#set_result
}
}
} else {
quote! {
#json_name | #proto_name => {
#deser
#set_result
}
}
}
}
fn oneof_enum_ident(oneof_name: &str) -> proc_macro2::Ident {
format_ident!("{}", to_pascal_case(oneof_name))
}
pub(crate) fn resolve_oneof_idents(
msg: &DescriptorProto,
) -> std::collections::HashMap<usize, Ident> {
let mut result = std::collections::HashMap::new();
for (idx, oneof) in msg.oneof_decl.iter().enumerate() {
let has_real_fields = msg.field.iter().any(|f| {
crate::impl_message::is_real_oneof_member(f) && f.oneof_index == Some(idx as i32)
});
if !has_real_fields {
continue;
}
if let Some(oneof_name) = &oneof.name {
result.insert(idx, oneof_enum_ident(oneof_name));
}
}
result
}
pub(crate) fn oneof_variant_ident(proto_name: &str) -> proc_macro2::Ident {
crate::idents::make_field_ident(&to_pascal_case(proto_name))
}
pub(crate) fn to_pascal_case(s: &str) -> String {
s.split('_')
.map(|part| {
let mut chars = part.chars();
match chars.next() {
None => String::new(),
Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
}
})
.collect()
}
pub(crate) fn to_snake_case(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
let chars: Vec<char> = s.chars().collect();
for (i, &c) in chars.iter().enumerate() {
if c.is_uppercase() && i > 0 {
let prev = chars[i - 1];
let next_is_lower = chars.get(i + 1).is_some_and(|n| n.is_lowercase());
if prev.is_lowercase() || (prev.is_uppercase() && next_is_lower) {
result.push('_');
}
}
result.extend(c.to_lowercase());
}
result
}
#[cfg(test)]
mod tests {
use super::{to_pascal_case, to_snake_case};
#[test]
fn test_to_pascal_case_basic() {
assert_eq!(to_pascal_case("foo_bar"), "FooBar");
assert_eq!(to_pascal_case("hello_world_baz"), "HelloWorldBaz");
assert_eq!(to_pascal_case("single"), "Single");
}
#[test]
fn test_to_pascal_case_leading_underscore() {
assert_eq!(to_pascal_case("_foo"), "Foo");
assert_eq!(to_pascal_case("_foo_bar"), "FooBar");
}
#[test]
fn test_to_pascal_case_consecutive_underscores() {
assert_eq!(to_pascal_case("foo__bar"), "FooBar");
assert_eq!(to_pascal_case("a___b"), "AB");
}
#[test]
fn test_to_pascal_case_empty() {
assert_eq!(to_pascal_case(""), "");
}
#[test]
fn test_to_snake_case_basic() {
assert_eq!(to_snake_case("FooBar"), "foo_bar");
assert_eq!(to_snake_case("HelloWorldBaz"), "hello_world_baz");
assert_eq!(to_snake_case("Single"), "single");
}
#[test]
fn test_to_snake_case_acronym_run() {
assert_eq!(to_snake_case("XMLHttpRequest"), "xml_http_request");
assert_eq!(to_snake_case("HTTPResponse"), "http_response");
assert_eq!(to_snake_case("IOError"), "io_error");
}
#[test]
fn test_to_snake_case_already_lower() {
assert_eq!(to_snake_case("foo"), "foo");
}
#[test]
fn test_to_snake_case_all_caps() {
assert_eq!(to_snake_case("XML"), "xml");
assert_eq!(to_snake_case("IO"), "io");
}
#[test]
fn test_to_snake_case_proto_names() {
assert_eq!(to_snake_case("TestAllTypesProto3"), "test_all_types_proto3");
assert_eq!(to_snake_case("NestedMessage"), "nested_message");
assert_eq!(to_snake_case("ForeignMessage"), "foreign_message");
assert_eq!(to_snake_case("ExtensionWithOneof"), "extension_with_oneof");
}
#[test]
fn test_to_snake_case_empty() {
assert_eq!(to_snake_case(""), "");
}
}