protograph-codegen 0.1.0

Code generation for protograph (proto + Rust)
Documentation
use crate::utils::{pluralize, to_pascal_case, to_snake_case};
use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
use std::fmt::Write;

pub fn generate_proto(schema: &ProtographSchema, package_name: &str) -> String {
    let mut output = String::new();

    writeln!(output, "syntax = \"proto3\";").unwrap();
    writeln!(output).unwrap();
    writeln!(output, "package {};", package_name).unwrap();
    writeln!(output).unwrap();

    for (name, entity) in &schema.types {
        if entity.is_entity && !entity.is_private {
            generate_entity_messages(&mut output, entity, schema);
        }
    }

    for (name, input_type) in &schema.input_types {
        generate_input_message(&mut output, input_type);
    }

    for (name, enum_type) in &schema.enums {
        generate_enum(&mut output, enum_type);
    }

    for (name, entity) in &schema.types {
        if entity.is_entity && !entity.is_private {
            generate_entity_service(&mut output, entity, schema);
        }
    }

    output
}

fn generate_entity_messages(
    output: &mut String,
    entity: &EntityType,
    schema: &ProtographSchema,
) {
    let name = &entity.name;
    let snake_name = to_snake_case(name);
    let plural_name = pluralize(name);

    writeln!(output, "message {} {{", name).unwrap();
    let mut field_num = 1;
    for field in &entity.fields {
        if field.is_private {
            continue;
        }
        if field.relationship.is_some() {
            continue;
        }

        let proto_type = graphql_to_proto_type(&field.field_type);
        writeln!(
            output,
            "  {} {} = {};",
            proto_type,
            to_snake_case(&field.name),
            field_num
        )
        .unwrap();
        field_num += 1;
    }
    for field in &entity.fields {
        if let Some(Relationship::BelongsTo { foreign_key }) | Some(Relationship::HasMany { foreign_key }) = &field.relationship {
            if entity.fields.iter().any(|f| &f.name == foreign_key && !f.is_private) {
                continue;
            }
            let proto_type = "string";
            writeln!(
                output,
                "  {} {} = {};",
                proto_type,
                to_snake_case(foreign_key),
                field_num
            )
            .unwrap();
            field_num += 1;
        }
    }
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(output, "message Get{}Request {{", name).unwrap();
    writeln!(output, "  string id = 1;").unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(output, "message BatchGet{}Request {{", plural_name).unwrap();
    writeln!(output, "  repeated string ids = 1;").unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(output, "message BatchGet{}Response {{", plural_name).unwrap();
    writeln!(output, "  repeated {} {} = 1;", name, to_snake_case(&plural_name)).unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    for field in &entity.fields {
        if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
            generate_has_many_messages(output, entity, field, foreign_key, schema);
        }
    }
}

fn generate_has_many_messages(
    output: &mut String,
    parent: &EntityType,
    field: &protograph_core::Field,
    foreign_key: &str,
    schema: &ProtographSchema,
) {
    let related_type = field.field_type.base_type();
    let fk_pascal = to_pascal_case(foreign_key);
    let fk_snake = to_snake_case(foreign_key);
    let plural_related = pluralize(related_type);

    writeln!(
        output,
        "message Get{}By{}Request {{",
        plural_related, fk_pascal
    )
    .unwrap();
    writeln!(output, "  string {} = 1;", fk_snake).unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(
        output,
        "message BatchGet{}By{}sRequest {{",
        plural_related, fk_pascal
    )
    .unwrap();
    writeln!(output, "  repeated string {}s = 1;", fk_snake).unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(output, "message {}List {{", related_type).unwrap();
    writeln!(
        output,
        "  repeated {} {} = 1;",
        related_type,
        to_snake_case(&plural_related)
    )
    .unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();

    writeln!(
        output,
        "message BatchGet{}By{}sResponse {{",
        plural_related, fk_pascal
    )
    .unwrap();
    writeln!(
        output,
        "  map<string, {}List> {} = 1;",
        related_type,
        to_snake_case(&format!("{}_by_{}", plural_related, fk_snake))
    )
    .unwrap();
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();
}

fn generate_input_message(output: &mut String, input: &protograph_core::InputType) {
    writeln!(output, "message {} {{", input.name).unwrap();
    for (i, field) in input.fields.iter().enumerate() {
        let proto_type = graphql_to_proto_type(&field.field_type);
        writeln!(
            output,
            "  {} {} = {};",
            proto_type,
            to_snake_case(&field.name),
            i + 1
        )
        .unwrap();
    }
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();
}

fn generate_enum(output: &mut String, enum_type: &protograph_core::EnumType) {
    writeln!(output, "enum {} {{", enum_type.name).unwrap();
    for (i, value) in enum_type.values.iter().enumerate() {
        writeln!(output, "  {} = {};", value, i).unwrap();
    }
    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();
}

fn generate_entity_service(
    output: &mut String,
    entity: &EntityType,
    schema: &ProtographSchema,
) {
    let name = &entity.name;
    let plural_name = pluralize(name);

    writeln!(output, "service {}Service {{", name).unwrap();
    writeln!(
        output,
        "  rpc Get{}(Get{}Request) returns ({});",
        name, name, name
    )
    .unwrap();
    writeln!(
        output,
        "  rpc BatchGet{}(BatchGet{}Request) returns (BatchGet{}Response);",
        plural_name, plural_name, plural_name
    )
    .unwrap();

    for field in &entity.fields {
        if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
            let related_type = field.field_type.base_type();
            let fk_pascal = to_pascal_case(foreign_key);
            let plural_related = pluralize(related_type);

            writeln!(
                output,
                "  rpc Get{}By{}(Get{}By{}Request) returns ({}List);",
                plural_related, fk_pascal, plural_related, fk_pascal, related_type
            )
            .unwrap();
            writeln!(
                output,
                "  rpc BatchGet{}By{}s(BatchGet{}By{}sRequest) returns (BatchGet{}By{}sResponse);",
                plural_related, fk_pascal, plural_related, fk_pascal, plural_related, fk_pascal
            )
            .unwrap();
        }
    }

    writeln!(output, "}}").unwrap();
    writeln!(output).unwrap();
}

fn graphql_to_proto_type(gql_type: &FieldType) -> String {
    match gql_type {
        FieldType::Named(name) => match name.as_str() {
            "ID" | "String" => "string".to_string(),
            "Int" => "int32".to_string(),
            "Float" => "double".to_string(),
            "Boolean" => "bool".to_string(),
            other => other.to_string(),
        },
        FieldType::NonNull(inner) => graphql_to_proto_type(inner),
        FieldType::List(inner) => format!("repeated {}", graphql_to_proto_type(inner)),
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use protograph_core::parse_schema_file;

    #[test]
    fn test_generate_proto() {
        let schema = r#"
            type User @entity {
                id: ID!
                name: String!
                email: String! @private
                posts: [Post!]! @hasMany(field: "authorId")
            }

            type Post @entity {
                id: ID!
                title: String!
                author: User! @belongsTo(field: "authorId")
                authorId: ID! @private
            }
        "#;

        let parsed = parse_schema_file(schema).unwrap();
        let proto = generate_proto(&parsed, "protograph");

        assert!(proto.contains("message User {"));
        assert!(proto.contains("message Post {"));
        assert!(proto.contains("service UserService {"));
        assert!(proto.contains("rpc BatchGetUsers("));
        assert!(proto.contains("rpc BatchGetPostsByAuthorIds("));
    }
}