use std::collections::HashMap;
use convert_case::{Case, Casing};
use prost::Message as _;
use protobuf::Message;
use protobuf::descriptor::field_descriptor_proto::{Label, Type};
use protobuf::descriptor::{DescriptorProto, FieldDescriptorProto, MessageOptions, SourceCodeInfo};
use super::{CodeGenMetadata, MessageField, MessageInfo, OneofVariant, extract_documentation};
use crate::google::api::{FieldBehavior, ResourceDescriptor, ResourceReference};
use crate::parsing::types::{BaseType, UnifiedType};
use crate::{Error, Result};
pub(super) fn process_message(
message: &DescriptorProto,
_file_name: &str,
codegen_metadata: &mut CodeGenMetadata,
type_prefix: &str,
source_code_info: Option<&SourceCodeInfo>,
path_prefix: &[i32],
) -> Result<()> {
let message_name = message.name();
let full_type_name = if type_prefix.is_empty() {
format!(".{}", message_name)
} else {
format!("{}.{}", type_prefix, message_name)
};
let map_entries = collect_map_entries(message, &full_type_name);
let mut fields = Vec::new();
let mut oneof_fields: HashMap<String, Vec<OneofVariant>> = HashMap::new();
for (field_index, field) in message.field.iter().enumerate() {
let field_path = [path_prefix, &[2, field_index as i32]].concat();
let documentation = extract_documentation(source_code_info, &field_path);
if field.has_oneof_index() && !field.proto3_optional() {
if let Some(oneof_desc) = message.oneof_decl.get(field.oneof_index() as usize) {
let field_name = field.name().to_string();
let field_type = if field.has_type_name() {
let clean_type = field.type_name().trim_start_matches('.');
let base = if field.type_() == Type::TYPE_ENUM {
BaseType::Enum(clean_type.to_string())
} else {
BaseType::Message(clean_type.to_string())
};
UnifiedType {
base_type: base,
is_optional: false,
is_repeated: false,
}
} else {
let base = match field.type_() {
Type::TYPE_STRING => BaseType::String,
Type::TYPE_INT32 => BaseType::Int32,
Type::TYPE_INT64 => BaseType::Int64,
Type::TYPE_BOOL => BaseType::Bool,
Type::TYPE_DOUBLE => BaseType::Float64,
Type::TYPE_FLOAT => BaseType::Float32,
Type::TYPE_BYTES => BaseType::Bytes,
_ => BaseType::String,
};
UnifiedType {
base_type: base,
is_optional: false,
is_repeated: false,
}
};
let variant = OneofVariant {
variant_name: field_name.to_case(Case::Pascal),
field_name,
field_type,
documentation,
};
let oneof_name = format!("{}.{}", full_type_name, oneof_desc.name());
oneof_fields.entry(oneof_name).or_default().push(variant);
continue;
}
}
let unified_type = parse_field_to_unified_type(field, &map_entries);
let field_behavior = extract_field_behavior_option(field)?;
let is_sensitive = extract_debug_redact(field);
let resource_reference = extract_resource_reference(field)?;
let field_info = MessageField {
name: field.name().to_string(),
unified_type,
documentation,
field_behavior,
oneof_variants: None,
is_sensitive,
resource_reference,
};
fields.push(field_info);
}
for (oneof_name, variants) in oneof_fields {
let oneof_field_name = oneof_name.split('.').next_back().unwrap().to_string();
let enum_type_name = format!(
"{}::{}",
message_name.to_case(Case::Snake),
oneof_field_name.to_case(Case::Pascal)
);
let oneof_field = MessageField {
name: oneof_field_name.clone(),
unified_type: UnifiedType {
base_type: BaseType::OneOf(enum_type_name.clone()),
is_optional: true, is_repeated: false,
},
oneof_variants: Some(variants),
documentation: None,
field_behavior: vec![],
is_sensitive: false,
resource_reference: None,
};
fields.push(oneof_field);
}
let resource_descriptor = extract_message_resource_option(message)?;
let documentation = extract_documentation(source_code_info, path_prefix);
let message_info = MessageInfo {
name: full_type_name.clone(),
fields,
resource_descriptor,
documentation,
};
codegen_metadata
.messages
.insert(full_type_name.clone(), message_info);
for (nested_index, nested_message) in message.nested_type.iter().enumerate() {
let nested_path = [path_prefix, &[3, nested_index as i32]].concat();
process_message(
nested_message,
_file_name,
codegen_metadata,
&full_type_name,
source_code_info,
&nested_path,
)?;
}
Ok(())
}
fn extract_field_behavior_option(field: &FieldDescriptorProto) -> Result<Vec<FieldBehavior>> {
if field.options.is_none() {
return Ok(vec![]);
}
let options = field.options.as_ref().unwrap();
let unknown_fields = options.unknown_fields();
let mut behaviors = Vec::new();
for (field_number, field_value) in unknown_fields.iter() {
if field_number == super::GOOGLE_API_FIELD_BEHAVIOR_EXTENSION {
match field_value {
protobuf::UnknownValueRef::Varint(value) => {
if let Ok(behavior) = FieldBehavior::try_from(value as i32) {
behaviors.push(behavior);
}
}
protobuf::UnknownValueRef::LengthDelimited(bytes) => {
let mut cursor = std::io::Cursor::new(bytes);
while cursor.position() < bytes.len() as u64 {
match decode_varint(&mut cursor) {
Ok(value) => {
if let Ok(behavior) = FieldBehavior::try_from(value as i32) {
behaviors.push(behavior);
}
}
Err(_) => break,
}
}
}
_ => {
}
}
}
}
if !behaviors.is_empty() {
return Ok(behaviors);
}
Ok(vec![])
}
const DEBUG_REDACT_FIELD_NUMBER: u32 = 16;
fn extract_debug_redact(field: &FieldDescriptorProto) -> bool {
let Some(options) = field.options.as_ref() else {
return false;
};
for (field_number, field_value) in options.unknown_fields().iter() {
if field_number == DEBUG_REDACT_FIELD_NUMBER {
if let protobuf::UnknownValueRef::Varint(v) = field_value {
return v != 0;
}
}
}
false
}
fn extract_resource_reference(field: &FieldDescriptorProto) -> Result<Option<ResourceReference>> {
let Some(options) = field.options.as_ref() else {
return Ok(None);
};
for (field_number, field_value) in options.unknown_fields().iter() {
if field_number == super::GOOGLE_API_RESOURCE_REFERENCE_EXTENSION {
let data = match field_value {
protobuf::UnknownValueRef::LengthDelimited(bytes) => bytes,
_ => continue,
};
match ResourceReference::decode(data) {
Ok(rr) => {
if rr.r#type.is_empty() && rr.child_type.is_empty() {
return Ok(None);
}
return Ok(Some(rr));
}
Err(e) => {
return Err(Error::InvalidAnnotation {
object: field.name().to_string(),
message: format!("Failed to parse google.api.resource_reference: {}", e),
});
}
}
}
}
Ok(None)
}
fn decode_varint(cursor: &mut std::io::Cursor<&[u8]>) -> Result<u64, std::io::Error> {
let mut result = 0u64;
let mut shift = 0;
loop {
if cursor.position() >= cursor.get_ref().len() as u64 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"Unexpected end of data while reading varint",
));
}
let byte = cursor.get_ref()[cursor.position() as usize];
cursor.set_position(cursor.position() + 1);
result |= ((byte & 0x7F) as u64) << shift;
if (byte & 0x80) == 0 {
break;
}
shift += 7;
if shift >= 64 {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Varint too long",
));
}
}
Ok(result)
}
fn extract_message_resource_option(
message: &DescriptorProto,
) -> Result<Option<ResourceDescriptor>> {
if message.options.is_none() {
return Ok(None);
}
let options = message.options.as_ref().unwrap();
let unknown_fields = options.unknown_fields();
for (field_number, field_value) in unknown_fields.iter() {
if field_number == super::GOOGLE_API_RESOURCE_EXTENSION {
let data = match field_value {
protobuf::UnknownValueRef::LengthDelimited(bytes) => bytes,
_ => {
tracing::warn!("Skipping non-length-delimited google.api.resource field");
continue;
}
};
match ResourceDescriptor::decode(data) {
Ok(resource_descriptor) => {
return Ok(Some(resource_descriptor));
}
Err(e) => {
return Err(Error::InvalidAnnotation {
object: message.name().to_string(),
message: format!("Failed to parse google.api.resource: {}", e),
});
}
}
}
}
Ok(None)
}
fn collect_map_entries(
message: &DescriptorProto,
parent_full_name: &str,
) -> HashMap<String, (BaseType, BaseType)> {
let mut entries = HashMap::new();
for nested in &message.nested_type {
let is_map_entry = nested
.options
.as_ref()
.is_some_and(|opts: &MessageOptions| opts.map_entry());
if !is_map_entry {
continue;
}
let entry_name = format!("{}.{}", parent_full_name, nested.name());
let key_type = nested
.field
.iter()
.find(|f| f.number() == 1)
.map(proto_field_to_base_type)
.unwrap_or(BaseType::String);
let value_type = nested
.field
.iter()
.find(|f| f.number() == 2)
.map(proto_field_to_base_type)
.unwrap_or(BaseType::String);
entries.insert(entry_name, (key_type, value_type));
}
entries
}
fn proto_field_to_base_type(field: &FieldDescriptorProto) -> BaseType {
match field.type_() {
Type::TYPE_STRING => BaseType::String,
Type::TYPE_INT32 => BaseType::Int32,
Type::TYPE_INT64 => BaseType::Int64,
Type::TYPE_BOOL => BaseType::Bool,
Type::TYPE_DOUBLE => BaseType::Float64,
Type::TYPE_FLOAT => BaseType::Float32,
Type::TYPE_BYTES => BaseType::Bytes,
Type::TYPE_MESSAGE => {
let type_name = field.type_name().trim_start_matches('.');
BaseType::Message(type_name.to_string())
}
Type::TYPE_ENUM => {
let type_name = field.type_name().trim_start_matches('.');
BaseType::Enum(type_name.to_string())
}
_ => BaseType::String,
}
}
fn parse_field_to_unified_type(
field: &FieldDescriptorProto,
map_entries: &HashMap<String, (BaseType, BaseType)>,
) -> UnifiedType {
let base_type = match field.type_() {
Type::TYPE_STRING => BaseType::String,
Type::TYPE_INT32 => BaseType::Int32,
Type::TYPE_INT64 => BaseType::Int64,
Type::TYPE_BOOL => BaseType::Bool,
Type::TYPE_DOUBLE => BaseType::Float64,
Type::TYPE_FLOAT => BaseType::Float32,
Type::TYPE_BYTES => BaseType::Bytes,
Type::TYPE_MESSAGE => {
let type_name = field.type_name().trim_start_matches('.');
if let Some((key_bt, val_bt)) = map_entries.get(field.type_name()) {
UnifiedType::map(
UnifiedType {
base_type: key_bt.clone(),
is_optional: false,
is_repeated: false,
},
UnifiedType {
base_type: val_bt.clone(),
is_optional: false,
is_repeated: false,
},
)
.base_type
} else {
BaseType::Message(type_name.to_string())
}
}
Type::TYPE_ENUM => {
let type_name = field.type_name().trim_start_matches('.');
BaseType::Enum(type_name.to_string())
}
_ => BaseType::String,
};
let is_repeated =
field.label() == Label::LABEL_REPEATED && !matches!(base_type, BaseType::Map(_, _));
let is_optional = field.label() == Label::LABEL_OPTIONAL && field.proto3_optional();
UnifiedType {
base_type,
is_optional,
is_repeated,
}
}
#[cfg(test)]
mod tests {
use protobuf::descriptor::{DescriptorProto, FieldDescriptorProto};
use super::*;
#[test]
fn test_extract_message_resource_option_no_options() {
let mut message = DescriptorProto::new();
message.set_name("TestMessage".to_string());
let result = extract_message_resource_option(&message).unwrap();
assert!(result.is_none());
}
#[test]
fn test_extract_field_behavior_option_no_options() {
let field = FieldDescriptorProto::new();
let result = extract_field_behavior_option(&field).unwrap();
assert!(result.is_empty());
}
#[test]
fn test_extract_field_behavior_option_function_exists() {
let field = FieldDescriptorProto::new();
let result = extract_field_behavior_option(&field);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_message_info_stores_resource_descriptor() {
let resource_descriptor = crate::google::api::ResourceDescriptor {
r#type: "example.io/Schema".to_string(),
pattern: vec!["catalogs/{catalog}/schemas/{schema}".to_string()],
name_field: "name".to_string(),
..Default::default()
};
let message_info = MessageInfo {
name: ".example.catalog.v1.Schema".to_string(),
fields: vec![],
resource_descriptor: Some(resource_descriptor.clone()),
documentation: None,
};
assert!(message_info.resource_descriptor.is_some());
let stored = message_info.resource_descriptor.unwrap();
assert_eq!(stored.r#type, "example.io/Schema");
assert_eq!(stored.pattern, vec!["catalogs/{catalog}/schemas/{schema}"]);
}
#[test]
fn test_extract_message_documentation() {
let mut sci = SourceCodeInfo::new();
let mut location = protobuf::descriptor::source_code_info::Location::new();
location.path = vec![4, 0]; location.set_leading_comments("This is a test message for documentation.".to_string());
sci.location.push(location);
let message_path = vec![4, 0];
let result = extract_documentation(Some(&sci), &message_path);
assert!(result.is_some());
assert_eq!(result.unwrap(), "This is a test message for documentation.");
}
}