protograph_codegen/
proto_gen.rs

1use crate::utils::{pluralize, to_pascal_case, to_snake_case};
2use protograph_core::{EntityType, FieldType, ProtographSchema, Relationship};
3use std::fmt::Write;
4
5pub fn generate_proto(schema: &ProtographSchema, package_name: &str) -> String {
6    let mut output = String::new();
7
8    writeln!(output, "syntax = \"proto3\";").unwrap();
9    writeln!(output).unwrap();
10    writeln!(output, "package {};", package_name).unwrap();
11    writeln!(output).unwrap();
12
13    for (name, entity) in &schema.types {
14        if entity.is_entity && !entity.is_private {
15            generate_entity_messages(&mut output, entity, schema);
16        }
17    }
18
19    for (name, input_type) in &schema.input_types {
20        generate_input_message(&mut output, input_type);
21    }
22
23    for (name, enum_type) in &schema.enums {
24        generate_enum(&mut output, enum_type);
25    }
26
27    for (name, entity) in &schema.types {
28        if entity.is_entity && !entity.is_private {
29            generate_entity_service(&mut output, entity, schema);
30        }
31    }
32
33    output
34}
35
36fn generate_entity_messages(
37    output: &mut String,
38    entity: &EntityType,
39    schema: &ProtographSchema,
40) {
41    let name = &entity.name;
42    let snake_name = to_snake_case(name);
43    let plural_name = pluralize(name);
44
45    writeln!(output, "message {} {{", name).unwrap();
46    let mut field_num = 1;
47    for field in &entity.fields {
48        if field.is_private {
49            continue;
50        }
51        if field.relationship.is_some() {
52            continue;
53        }
54
55        let proto_type = graphql_to_proto_type(&field.field_type);
56        writeln!(
57            output,
58            "  {} {} = {};",
59            proto_type,
60            to_snake_case(&field.name),
61            field_num
62        )
63        .unwrap();
64        field_num += 1;
65    }
66    for field in &entity.fields {
67        if let Some(Relationship::BelongsTo { foreign_key }) | Some(Relationship::HasMany { foreign_key }) = &field.relationship {
68            if entity.fields.iter().any(|f| &f.name == foreign_key && !f.is_private) {
69                continue;
70            }
71            let proto_type = "string";
72            writeln!(
73                output,
74                "  {} {} = {};",
75                proto_type,
76                to_snake_case(foreign_key),
77                field_num
78            )
79            .unwrap();
80            field_num += 1;
81        }
82    }
83    writeln!(output, "}}").unwrap();
84    writeln!(output).unwrap();
85
86    writeln!(output, "message Get{}Request {{", name).unwrap();
87    writeln!(output, "  string id = 1;").unwrap();
88    writeln!(output, "}}").unwrap();
89    writeln!(output).unwrap();
90
91    writeln!(output, "message BatchGet{}Request {{", plural_name).unwrap();
92    writeln!(output, "  repeated string ids = 1;").unwrap();
93    writeln!(output, "}}").unwrap();
94    writeln!(output).unwrap();
95
96    writeln!(output, "message BatchGet{}Response {{", plural_name).unwrap();
97    writeln!(output, "  repeated {} {} = 1;", name, to_snake_case(&plural_name)).unwrap();
98    writeln!(output, "}}").unwrap();
99    writeln!(output).unwrap();
100
101    for field in &entity.fields {
102        if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
103            generate_has_many_messages(output, entity, field, foreign_key, schema);
104        }
105    }
106}
107
108fn generate_has_many_messages(
109    output: &mut String,
110    parent: &EntityType,
111    field: &protograph_core::Field,
112    foreign_key: &str,
113    schema: &ProtographSchema,
114) {
115    let related_type = field.field_type.base_type();
116    let fk_pascal = to_pascal_case(foreign_key);
117    let fk_snake = to_snake_case(foreign_key);
118    let plural_related = pluralize(related_type);
119
120    writeln!(
121        output,
122        "message Get{}By{}Request {{",
123        plural_related, fk_pascal
124    )
125    .unwrap();
126    writeln!(output, "  string {} = 1;", fk_snake).unwrap();
127    writeln!(output, "}}").unwrap();
128    writeln!(output).unwrap();
129
130    writeln!(
131        output,
132        "message BatchGet{}By{}sRequest {{",
133        plural_related, fk_pascal
134    )
135    .unwrap();
136    writeln!(output, "  repeated string {}s = 1;", fk_snake).unwrap();
137    writeln!(output, "}}").unwrap();
138    writeln!(output).unwrap();
139
140    writeln!(output, "message {}List {{", related_type).unwrap();
141    writeln!(
142        output,
143        "  repeated {} {} = 1;",
144        related_type,
145        to_snake_case(&plural_related)
146    )
147    .unwrap();
148    writeln!(output, "}}").unwrap();
149    writeln!(output).unwrap();
150
151    writeln!(
152        output,
153        "message BatchGet{}By{}sResponse {{",
154        plural_related, fk_pascal
155    )
156    .unwrap();
157    writeln!(
158        output,
159        "  map<string, {}List> {} = 1;",
160        related_type,
161        to_snake_case(&format!("{}_by_{}", plural_related, fk_snake))
162    )
163    .unwrap();
164    writeln!(output, "}}").unwrap();
165    writeln!(output).unwrap();
166}
167
168fn generate_input_message(output: &mut String, input: &protograph_core::InputType) {
169    writeln!(output, "message {} {{", input.name).unwrap();
170    for (i, field) in input.fields.iter().enumerate() {
171        let proto_type = graphql_to_proto_type(&field.field_type);
172        writeln!(
173            output,
174            "  {} {} = {};",
175            proto_type,
176            to_snake_case(&field.name),
177            i + 1
178        )
179        .unwrap();
180    }
181    writeln!(output, "}}").unwrap();
182    writeln!(output).unwrap();
183}
184
185fn generate_enum(output: &mut String, enum_type: &protograph_core::EnumType) {
186    writeln!(output, "enum {} {{", enum_type.name).unwrap();
187    for (i, value) in enum_type.values.iter().enumerate() {
188        writeln!(output, "  {} = {};", value, i).unwrap();
189    }
190    writeln!(output, "}}").unwrap();
191    writeln!(output).unwrap();
192}
193
194fn generate_entity_service(
195    output: &mut String,
196    entity: &EntityType,
197    schema: &ProtographSchema,
198) {
199    let name = &entity.name;
200    let plural_name = pluralize(name);
201
202    writeln!(output, "service {}Service {{", name).unwrap();
203    writeln!(
204        output,
205        "  rpc Get{}(Get{}Request) returns ({});",
206        name, name, name
207    )
208    .unwrap();
209    writeln!(
210        output,
211        "  rpc BatchGet{}(BatchGet{}Request) returns (BatchGet{}Response);",
212        plural_name, plural_name, plural_name
213    )
214    .unwrap();
215
216    for field in &entity.fields {
217        if let Some(Relationship::HasMany { foreign_key }) = &field.relationship {
218            let related_type = field.field_type.base_type();
219            let fk_pascal = to_pascal_case(foreign_key);
220            let plural_related = pluralize(related_type);
221
222            writeln!(
223                output,
224                "  rpc Get{}By{}(Get{}By{}Request) returns ({}List);",
225                plural_related, fk_pascal, plural_related, fk_pascal, related_type
226            )
227            .unwrap();
228            writeln!(
229                output,
230                "  rpc BatchGet{}By{}s(BatchGet{}By{}sRequest) returns (BatchGet{}By{}sResponse);",
231                plural_related, fk_pascal, plural_related, fk_pascal, plural_related, fk_pascal
232            )
233            .unwrap();
234        }
235    }
236
237    writeln!(output, "}}").unwrap();
238    writeln!(output).unwrap();
239}
240
241fn graphql_to_proto_type(gql_type: &FieldType) -> String {
242    match gql_type {
243        FieldType::Named(name) => match name.as_str() {
244            "ID" | "String" => "string".to_string(),
245            "Int" => "int32".to_string(),
246            "Float" => "double".to_string(),
247            "Boolean" => "bool".to_string(),
248            other => other.to_string(),
249        },
250        FieldType::NonNull(inner) => graphql_to_proto_type(inner),
251        FieldType::List(inner) => format!("repeated {}", graphql_to_proto_type(inner)),
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use protograph_core::parse_schema_file;
259
260    #[test]
261    fn test_generate_proto() {
262        let schema = r#"
263            type User @entity {
264                id: ID!
265                name: String!
266                email: String! @private
267                posts: [Post!]! @hasMany(field: "authorId")
268            }
269
270            type Post @entity {
271                id: ID!
272                title: String!
273                author: User! @belongsTo(field: "authorId")
274                authorId: ID! @private
275            }
276        "#;
277
278        let parsed = parse_schema_file(schema).unwrap();
279        let proto = generate_proto(&parsed, "protograph");
280
281        assert!(proto.contains("message User {"));
282        assert!(proto.contains("message Post {"));
283        assert!(proto.contains("service UserService {"));
284        assert!(proto.contains("rpc BatchGetUsers("));
285        assert!(proto.contains("rpc BatchGetPostsByAuthorIds("));
286    }
287}