use std::collections::HashMap;
use anyhow::{Context, Result};
use prost::Message;
use prost_types::{DescriptorProto, FileDescriptorSet};
#[derive(Debug, Clone, PartialEq)]
pub enum FieldKind {
Primitive,
Bytes,
Message(String),
}
const WRAPPER_TYPES: &[&str] = &[
"google.protobuf.BoolValue",
"google.protobuf.BytesValue",
"google.protobuf.DoubleValue",
"google.protobuf.FloatValue",
"google.protobuf.Int32Value",
"google.protobuf.Int64Value",
"google.protobuf.StringValue",
"google.protobuf.UInt32Value",
"google.protobuf.UInt64Value",
];
#[derive(Debug, Clone)]
pub struct ProtoSchema {
messages: HashMap<String, MessageSchema>,
}
#[derive(Debug, Clone)]
struct MessageSchema {
fields: HashMap<String, FieldSchema>,
is_map_entry: bool,
}
#[derive(Debug, Clone)]
struct FieldSchema {
name: String,
type_name: Option<String>,
number: u32,
kind: FieldKind,
repeated: bool,
}
impl ProtoSchema {
pub fn from_descriptor_set(bytes: &[u8]) -> Result<Self> {
let fds = FileDescriptorSet::decode(bytes)
.context("Failed to decode FileDescriptorSet. Make sure the file was generated by protoc with --descriptor_set_out")?;
let mut messages = HashMap::new();
for file in fds.file {
let package = file.package.as_deref().unwrap_or("");
Self::process_messages(&mut messages, package, "", &file.message_type);
}
Ok(Self { messages })
}
fn process_messages(
messages: &mut HashMap<String, MessageSchema>,
package: &str,
parent_path: &str,
message_types: &[DescriptorProto],
) {
for message in message_types {
let message_name = message.name.as_deref().unwrap_or("");
let full_name = if parent_path.is_empty() {
if package.is_empty() {
message_name.to_string()
} else {
format!("{}.{}", package, message_name)
}
} else {
format!("{}.{}", parent_path, message_name)
};
let is_map_entry = message
.options
.as_ref()
.and_then(|o| o.map_entry)
.unwrap_or(false);
let mut fields = HashMap::new();
for field in &message.field {
let field_name = field.name.as_deref().unwrap_or("").to_string();
let type_name = field.type_name.as_ref().map(|tn| {
tn.strip_prefix('.').unwrap_or(tn).to_string()
});
let number = field.number.unwrap_or(0) as u32;
use prost_types::field_descriptor_proto::Label as FLabel;
let repeated = field.label() == FLabel::Repeated;
use prost_types::field_descriptor_proto::Type as FType;
let kind = match field.r#type() {
FType::Message | FType::Group => {
let msg_type = type_name.clone().unwrap_or_default();
FieldKind::Message(msg_type)
}
FType::Bytes => FieldKind::Bytes,
_ => FieldKind::Primitive,
};
fields.insert(
field_name.clone(),
FieldSchema {
name: field_name,
type_name,
number,
kind,
repeated,
},
);
}
messages.insert(
full_name.clone(),
MessageSchema {
fields,
is_map_entry,
},
);
Self::process_messages(messages, package, &full_name, &message.nested_type);
}
}
pub fn is_wrapper_type(type_name: &str) -> bool {
WRAPPER_TYPES.contains(&type_name)
}
pub fn get_wrapper_fields(&self, message_type: &str) -> Vec<String> {
self.messages
.get(message_type)
.map(|msg| {
msg.fields
.values()
.filter(|field| {
field
.type_name
.as_ref()
.map(|tn| Self::is_wrapper_type(tn))
.unwrap_or(false)
})
.map(|field| field.name.clone())
.collect()
})
.unwrap_or_default()
}
pub fn has_message_type(&self, message_type: &str) -> bool {
self.messages.contains_key(message_type)
}
pub fn get_field_default_kinds(&self, message_type: &str) -> HashMap<String, String> {
self.messages
.get(message_type)
.map(|msg| {
msg.fields
.values()
.filter_map(|field| {
if !field.repeated {
return None; }
let kind_str = match &field.kind {
FieldKind::Message(type_name)
if self
.messages
.get(type_name.as_str())
.map(|m| m.is_map_entry)
.unwrap_or(false) =>
{
"map"
}
_ => "list",
};
Some((field.name.clone(), kind_str.to_string()))
})
.collect()
})
.unwrap_or_default()
}
pub fn get_any_field_schema(&self, message_type: &str) -> HashMap<u32, String> {
self.messages
.get(message_type)
.map(|msg| {
msg.fields
.values()
.map(|f| {
let kind_str = match &f.kind {
FieldKind::Primitive => "primitive".to_string(),
FieldKind::Bytes => "bytes".to_string(),
FieldKind::Message(fqn) => format!("message:{}", fqn),
};
(f.number, kind_str)
})
.collect()
})
.unwrap_or_default()
}
pub fn message_types(&self) -> Vec<&str> {
self.messages.keys().map(|s| s.as_str()).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_wrapper_type() {
assert!(ProtoSchema::is_wrapper_type("google.protobuf.BoolValue"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.BytesValue"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.DoubleValue"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.FloatValue"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.Int32Value"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.Int64Value"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.StringValue"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.UInt32Value"));
assert!(ProtoSchema::is_wrapper_type("google.protobuf.UInt64Value"));
assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Timestamp"));
assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Duration"));
assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Any"));
assert!(!ProtoSchema::is_wrapper_type("google.protobuf.Empty"));
assert!(!ProtoSchema::is_wrapper_type("my.custom.Message"));
}
#[test]
fn test_empty_descriptor_set() {
let fds = FileDescriptorSet { file: vec![] };
let bytes = fds.encode_to_vec();
let schema = ProtoSchema::from_descriptor_set(&bytes).unwrap();
assert_eq!(schema.message_types().len(), 0);
}
}