use crate::utils::{pluralize, to_pascal_case, to_snake_case};
use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
use std::fmt::Write;
pub fn generate_proto(schema: &ProtographSchema, package_name: &str) -> String {
let mut output = String::new();
writeln!(output, "syntax = \"proto3\";").unwrap();
writeln!(output).unwrap();
writeln!(output, "package {};", package_name).unwrap();
writeln!(output).unwrap();
for (name, entity) in &schema.types {
if entity.is_entity && !entity.is_private {
generate_entity_messages(&mut output, entity, schema);
}
}
for (name, input_type) in &schema.input_types {
generate_input_message(&mut output, input_type);
}
for (name, enum_type) in &schema.enums {
generate_enum(&mut output, enum_type);
}
for (name, entity) in &schema.types {
if entity.is_entity && !entity.is_private {
generate_entity_service(&mut output, entity, schema);
}
}
output
}
fn generate_entity_messages(
output: &mut String,
entity: &EntityType,
schema: &ProtographSchema,
) {
let name = &entity.name;
let snake_name = to_snake_case(name);
let plural_name = pluralize(name);
writeln!(output, "message {} {{", name).unwrap();
let mut field_num = 1;
for field in &entity.fields {
if field.is_private {
continue;
}
if field.relationship.is_some() {
continue;
}
let proto_type = graphql_to_proto_type(&field.field_type);
writeln!(
output,
" {} {} = {};",
proto_type,
to_snake_case(&field.name),
field_num
)
.unwrap();
field_num += 1;
}
for field in &entity.fields {
if let Some(Relationship::BelongsTo { foreign_key }) | Some(Relationship::HasMany { foreign_key }) = &field.relationship {
if entity.fields.iter().any(|f| &f.name == foreign_key && !f.is_private) {
continue;
}
let proto_type = "string";
writeln!(
output,
" {} {} = {};",
proto_type,
to_snake_case(foreign_key),
field_num
)
.unwrap();
field_num += 1;
}
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(output, "message Get{}Request {{", name).unwrap();
writeln!(output, " string id = 1;").unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(output, "message BatchGet{}Request {{", plural_name).unwrap();
writeln!(output, " repeated string ids = 1;").unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(output, "message BatchGet{}Response {{", plural_name).unwrap();
writeln!(output, " repeated {} {} = 1;", name, to_snake_case(&plural_name)).unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
for field in &entity.fields {
if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
generate_has_many_messages(output, entity, field, foreign_key, schema);
}
}
}
fn generate_has_many_messages(
output: &mut String,
parent: &EntityType,
field: &protograph_core::Field,
foreign_key: &str,
schema: &ProtographSchema,
) {
let related_type = field.field_type.base_type();
let fk_pascal = to_pascal_case(foreign_key);
let fk_snake = to_snake_case(foreign_key);
let plural_related = pluralize(related_type);
writeln!(
output,
"message Get{}By{}Request {{",
plural_related, fk_pascal
)
.unwrap();
writeln!(output, " string {} = 1;", fk_snake).unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(
output,
"message BatchGet{}By{}sRequest {{",
plural_related, fk_pascal
)
.unwrap();
writeln!(output, " repeated string {}s = 1;", fk_snake).unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(output, "message {}List {{", related_type).unwrap();
writeln!(
output,
" repeated {} {} = 1;",
related_type,
to_snake_case(&plural_related)
)
.unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
writeln!(
output,
"message BatchGet{}By{}sResponse {{",
plural_related, fk_pascal
)
.unwrap();
writeln!(
output,
" map<string, {}List> {} = 1;",
related_type,
to_snake_case(&format!("{}_by_{}", plural_related, fk_snake))
)
.unwrap();
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
fn generate_input_message(output: &mut String, input: &protograph_core::InputType) {
writeln!(output, "message {} {{", input.name).unwrap();
for (i, field) in input.fields.iter().enumerate() {
let proto_type = graphql_to_proto_type(&field.field_type);
writeln!(
output,
" {} {} = {};",
proto_type,
to_snake_case(&field.name),
i + 1
)
.unwrap();
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
fn generate_enum(output: &mut String, enum_type: &protograph_core::EnumType) {
writeln!(output, "enum {} {{", enum_type.name).unwrap();
for (i, value) in enum_type.values.iter().enumerate() {
writeln!(output, " {} = {};", value, i).unwrap();
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
fn generate_entity_service(
output: &mut String,
entity: &EntityType,
schema: &ProtographSchema,
) {
let name = &entity.name;
let plural_name = pluralize(name);
writeln!(output, "service {}Service {{", name).unwrap();
writeln!(
output,
" rpc Get{}(Get{}Request) returns ({});",
name, name, name
)
.unwrap();
writeln!(
output,
" rpc BatchGet{}(BatchGet{}Request) returns (BatchGet{}Response);",
plural_name, plural_name, plural_name
)
.unwrap();
for field in &entity.fields {
if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
let related_type = field.field_type.base_type();
let fk_pascal = to_pascal_case(foreign_key);
let plural_related = pluralize(related_type);
writeln!(
output,
" rpc Get{}By{}(Get{}By{}Request) returns ({}List);",
plural_related, fk_pascal, plural_related, fk_pascal, related_type
)
.unwrap();
writeln!(
output,
" rpc BatchGet{}By{}s(BatchGet{}By{}sRequest) returns (BatchGet{}By{}sResponse);",
plural_related, fk_pascal, plural_related, fk_pascal, plural_related, fk_pascal
)
.unwrap();
}
}
writeln!(output, "}}").unwrap();
writeln!(output).unwrap();
}
fn graphql_to_proto_type(gql_type: &FieldType) -> String {
match gql_type {
FieldType::Named(name) => match name.as_str() {
"ID" | "String" => "string".to_string(),
"Int" => "int32".to_string(),
"Float" => "double".to_string(),
"Boolean" => "bool".to_string(),
other => other.to_string(),
},
FieldType::NonNull(inner) => graphql_to_proto_type(inner),
FieldType::List(inner) => format!("repeated {}", graphql_to_proto_type(inner)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use protograph_core::parse_schema_file;
#[test]
fn test_generate_proto() {
let schema = r#"
type User @entity {
id: ID!
name: String!
email: String! @private
posts: [Post!]! @hasMany(field: "authorId")
}
type Post @entity {
id: ID!
title: String!
author: User! @belongsTo(field: "authorId")
authorId: ID! @private
}
"#;
let parsed = parse_schema_file(schema).unwrap();
let proto = generate_proto(&parsed, "protograph");
assert!(proto.contains("message User {"));
assert!(proto.contains("message Post {"));
assert!(proto.contains("service UserService {"));
assert!(proto.contains("rpc BatchGetUsers("));
assert!(proto.contains("rpc BatchGetPostsByAuthorIds("));
}
}