use crate::registry::SchemaRegistry;
use crate::types::{
EnumDefinition, PrimitiveType, SchemaDefinition, SchemaType, StructDefinition,
VariantData,
};
use std::fmt::Write;
pub fn generate(definition: &SchemaDefinition) -> String {
let mut output = String::new();
let type_def = generate_type_definition(&definition.name, &definition.schema_type);
output.push_str(&type_def);
output
}
pub fn generate_bundle(registry: &SchemaRegistry) -> String {
let mut output = String::new();
writeln!(output, "syntax = \"proto3\";").unwrap();
writeln!(output).unwrap();
if let Some(ref namespace) = registry.config().namespace {
writeln!(output, "package {};", namespace.replace('.', "_")).unwrap();
writeln!(output).unwrap();
}
writeln!(output, "// Auto-generated Protocol Buffers definitions").unwrap();
if let Some(ref title) = registry.config().title {
writeln!(output, "// {}", title).unwrap();
}
if let Some(ref description) = registry.config().description {
writeln!(output, "// {}", description).unwrap();
}
writeln!(output).unwrap();
writeln!(output, "import \"google/protobuf/any.proto\";").unwrap();
writeln!(output, "import \"google/protobuf/timestamp.proto\";").unwrap();
writeln!(output, "import \"google/protobuf/wrappers.proto\";").unwrap();
writeln!(output).unwrap();
let definitions: Vec<_> = match registry.topological_sort() {
Ok(sorted) => sorted,
Err(_) => registry.definitions().map(|(_, def)| def).collect(),
};
for definition in definitions {
if let Some(ref desc) = definition.description {
for line in desc.lines() {
writeln!(output, "// {}", line).unwrap();
}
}
let type_def = generate_type_definition(&definition.name, &definition.schema_type);
output.push_str(&type_def);
writeln!(output).unwrap();
}
output
}
fn generate_type_definition(name: &str, schema_type: &SchemaType) -> String {
let name = to_pascal_case(name);
match schema_type {
SchemaType::Struct(def) => generate_message(&name, def),
SchemaType::Enum(def) => generate_enum(&name, def),
SchemaType::Newtype(def) => {
let mut output = String::new();
writeln!(output, "message {} {{", name).unwrap();
writeln!(
output,
" {} value = 1;",
get_proto_type(&def.inner_type, false)
)
.unwrap();
writeln!(output, "}}").unwrap();
output
}
SchemaType::Primitive(_) => {
String::new()
}
_ => {
let mut output = String::new();
writeln!(output, "message {} {{", name).unwrap();
writeln!(output, " {} value = 1;", get_proto_type(schema_type, false)).unwrap();
writeln!(output, "}}").unwrap();
output
}
}
}
fn generate_message(name: &str, def: &StructDefinition) -> String {
let mut output = String::new();
writeln!(output, "message {} {{", name).unwrap();
let mut field_number = 1;
for (field_name, field) in &def.fields {
if let Some(ref desc) = field.description {
writeln!(output, " // {}", desc).unwrap();
}
if field.deprecated {
writeln!(output, " // @deprecated").unwrap();
}
let proto_name = to_snake_case(field_name);
let proto_type = get_proto_type(&field.schema_type, !field.required);
writeln!(output, " {} {} = {};", proto_type, proto_name, field_number).unwrap();
field_number += 1;
}
writeln!(output, "}}").unwrap();
output
}
fn generate_enum(name: &str, def: &EnumDefinition) -> String {
let mut output = String::new();
if def.is_simple_enum() {
writeln!(output, "enum {} {{", name).unwrap();
let mut value = 0;
for variant in &def.variants {
if let Some(ref desc) = variant.description {
writeln!(output, " // {}", desc).unwrap();
}
let variant_name = format!(
"{}_{}",
to_screaming_snake_case(name),
to_screaming_snake_case(&variant.name)
);
writeln!(output, " {} = {};", variant_name, value).unwrap();
value += 1;
}
writeln!(output, "}}").unwrap();
return output;
}
writeln!(output, "message {} {{", name).unwrap();
writeln!(output, " oneof variant {{").unwrap();
let mut field_number = 1;
for variant in &def.variants {
let variant_name = to_snake_case(&variant.name);
match &variant.data {
VariantData::Unit => {
writeln!(output, " bool {} = {};", variant_name, field_number).unwrap();
}
VariantData::Newtype(inner) => {
let proto_type = get_proto_type(inner, false);
writeln!(
output,
" {} {} = {};",
proto_type, variant_name, field_number
)
.unwrap();
}
VariantData::Tuple(_types) => {
let tuple_msg_name = format!("{}_{}", name, to_pascal_case(&variant.name));
writeln!(
output,
" {} {} = {};",
tuple_msg_name, variant_name, field_number
)
.unwrap();
}
VariantData::Struct(_fields) => {
let struct_msg_name = format!("{}_{}", name, to_pascal_case(&variant.name));
writeln!(
output,
" {} {} = {};",
struct_msg_name, variant_name, field_number
)
.unwrap();
}
}
field_number += 1;
}
writeln!(output, " }}").unwrap();
writeln!(output, "}}").unwrap();
for variant in &def.variants {
match &variant.data {
VariantData::Tuple(types) => {
let tuple_msg_name = format!("{}_{}", name, to_pascal_case(&variant.name));
writeln!(output).unwrap();
writeln!(output, "message {} {{", tuple_msg_name).unwrap();
for (i, t) in types.iter().enumerate() {
writeln!(
output,
" {} field_{} = {};",
get_proto_type(t, false),
i,
i + 1
)
.unwrap();
}
writeln!(output, "}}").unwrap();
}
VariantData::Struct(fields) => {
let struct_msg_name = format!("{}_{}", name, to_pascal_case(&variant.name));
writeln!(output).unwrap();
writeln!(output, "message {} {{", struct_msg_name).unwrap();
let mut field_num = 1;
for (field_name, field) in fields {
let proto_name = to_snake_case(field_name);
let proto_type = get_proto_type(&field.schema_type, !field.required);
writeln!(output, " {} {} = {};", proto_type, proto_name, field_num).unwrap();
field_num += 1;
}
writeln!(output, "}}").unwrap();
}
_ => {}
}
}
output
}
fn get_proto_type(schema_type: &SchemaType, optional: bool) -> String {
let base_type = match schema_type {
SchemaType::Primitive(prim) => get_primitive_type(prim),
SchemaType::Option(inner) => {
return format!("optional {}", get_proto_type(inner, false));
}
SchemaType::Array(inner) => {
return format!("repeated {}", get_proto_type(inner, false));
}
SchemaType::Set(inner) => {
return format!("repeated {}", get_proto_type(inner, false));
}
SchemaType::Map(value_type) => {
return format!("map<string, {}>", get_proto_type(value_type, false));
}
SchemaType::Tuple(types) => {
if types.is_empty() {
"google.protobuf.Empty".to_string()
} else if types.len() == 1 {
return get_proto_type(&types[0], optional);
} else {
"bytes".to_string()
}
}
SchemaType::Struct(_) => {
"bytes".to_string()
}
SchemaType::Enum(_) => {
"int32".to_string()
}
SchemaType::Newtype(def) => {
return get_proto_type(&def.inner_type, optional);
}
SchemaType::Reference(name) => to_pascal_case(name),
SchemaType::Unit => "google.protobuf.Empty".to_string(),
SchemaType::Any => "google.protobuf.Any".to_string(),
};
if optional {
format!("optional {}", base_type)
} else {
base_type
}
}
fn get_primitive_type(prim: &PrimitiveType) -> String {
match prim {
PrimitiveType::Bool => "bool".to_string(),
PrimitiveType::I8 | PrimitiveType::I16 | PrimitiveType::I32 => "int32".to_string(),
PrimitiveType::I64 | PrimitiveType::Isize => "int64".to_string(),
PrimitiveType::I128 => "bytes".to_string(),
PrimitiveType::U8 | PrimitiveType::U16 | PrimitiveType::U32 => "uint32".to_string(),
PrimitiveType::U64 | PrimitiveType::Usize => "uint64".to_string(),
PrimitiveType::U128 => "bytes".to_string(),
PrimitiveType::F32 => "float".to_string(),
PrimitiveType::F64 => "double".to_string(),
PrimitiveType::Char => "string".to_string(),
PrimitiveType::String => "string".to_string(),
PrimitiveType::Bytes => "bytes".to_string(),
}
}
fn to_pascal_case(s: &str) -> String {
let mut result = String::with_capacity(s.len());
let mut capitalize_next = true;
for c in s.chars() {
if c == '_' || c == '-' || c == ' ' {
capitalize_next = true;
} else if capitalize_next {
result.extend(c.to_uppercase());
capitalize_next = false;
} else {
result.push(c);
}
}
result
}
fn to_snake_case(s: &str) -> String {
let mut result = String::with_capacity(s.len() + 4);
let mut prev_was_uppercase = false;
let mut prev_was_separator = true;
for c in s.chars() {
if c == '-' || c == ' ' {
result.push('_');
prev_was_separator = true;
prev_was_uppercase = false;
} else if c.is_uppercase() {
if !prev_was_separator && !prev_was_uppercase {
result.push('_');
}
result.extend(c.to_lowercase());
prev_was_uppercase = true;
prev_was_separator = false;
} else {
result.push(c);
prev_was_uppercase = false;
prev_was_separator = c == '_';
}
}
result
}
fn to_screaming_snake_case(s: &str) -> String {
to_snake_case(s).to_uppercase()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{EnumRepresentation, EnumVariant, NewtypeDefinition, StructField};
use indexmap::IndexMap;
#[test]
fn test_primitive_types() {
assert_eq!(get_primitive_type(&PrimitiveType::Bool), "bool");
assert_eq!(get_primitive_type(&PrimitiveType::I32), "int32");
assert_eq!(get_primitive_type(&PrimitiveType::I64), "int64");
assert_eq!(get_primitive_type(&PrimitiveType::U32), "uint32");
assert_eq!(get_primitive_type(&PrimitiveType::F32), "float");
assert_eq!(get_primitive_type(&PrimitiveType::F64), "double");
assert_eq!(get_primitive_type(&PrimitiveType::String), "string");
assert_eq!(get_primitive_type(&PrimitiveType::Bytes), "bytes");
}
#[test]
fn test_optional_type() {
let opt = SchemaType::Option(Box::new(SchemaType::Primitive(PrimitiveType::String)));
assert_eq!(get_proto_type(&opt, false), "optional string");
}
#[test]
fn test_array_type() {
let arr = SchemaType::Array(Box::new(SchemaType::Primitive(PrimitiveType::I32)));
assert_eq!(get_proto_type(&arr, false), "repeated int32");
}
#[test]
fn test_map_type() {
let map = SchemaType::Map(Box::new(SchemaType::Primitive(PrimitiveType::String)));
assert_eq!(get_proto_type(&map, false), "map<string, string>");
}
#[test]
fn test_reference_type() {
let ref_type = SchemaType::Reference("user".to_string());
assert_eq!(get_proto_type(&ref_type, false), "User");
}
#[test]
fn test_simple_message() {
let def = StructDefinition::new()
.with_field(
"name",
StructField::new(SchemaType::Primitive(PrimitiveType::String), "name"),
)
.with_field(
"age",
StructField::new(
SchemaType::Option(Box::new(SchemaType::Primitive(PrimitiveType::I32))),
"age",
),
);
let output = generate_message("User", &def);
assert!(output.contains("message User {"));
assert!(output.contains("string name = 1;"));
assert!(output.contains("optional int32 age = 2;"));
}
#[test]
fn test_simple_enum() {
let def = EnumDefinition::new(EnumRepresentation::External)
.with_variant(EnumVariant::unit("Active"))
.with_variant(EnumVariant::unit("Inactive"));
let output = generate_enum("Status", &def);
assert!(output.contains("enum Status {"));
assert!(output.contains("STATUS_ACTIVE = 0;"));
assert!(output.contains("STATUS_INACTIVE = 1;"));
}
#[test]
fn test_complex_enum() {
let mut fields = IndexMap::new();
fields.insert(
"reason".to_string(),
StructField::new(SchemaType::Primitive(PrimitiveType::String), "reason"),
);
let def = EnumDefinition::new(EnumRepresentation::External)
.with_variant(EnumVariant::unit("Active"))
.with_variant(EnumVariant::struct_variant("Suspended", fields));
let output = generate_enum("Status", &def);
assert!(output.contains("message Status {"));
assert!(output.contains("oneof variant {"));
assert!(output.contains("bool active = 1;"));
assert!(output.contains("Status_Suspended suspended = 2;"));
assert!(output.contains("message Status_Suspended {"));
}
#[test]
fn test_case_conversions() {
assert_eq!(to_pascal_case("user_name"), "UserName");
assert_eq!(to_snake_case("userName"), "user_name");
assert_eq!(to_snake_case("UserName"), "user_name");
assert_eq!(to_screaming_snake_case("userName"), "USER_NAME");
}
#[test]
fn test_newtype() {
let newtype = NewtypeDefinition::new(SchemaType::Primitive(PrimitiveType::String));
let def = SchemaDefinition::new("Email", SchemaType::Newtype(newtype));
let output = generate(&def);
assert!(output.contains("message Email {"));
assert!(output.contains("string value = 1;"));
}
#[test]
fn test_full_definition() {
let def = SchemaDefinition::new(
"User",
SchemaType::Struct(
StructDefinition::new().with_field(
"id",
StructField::new(SchemaType::Primitive(PrimitiveType::U64), "id"),
),
),
);
let output = generate(&def);
assert!(output.contains("message User {"));
assert!(output.contains("uint64 id = 1;"));
}
}