Skip to main content

fraiseql_cli/codegen/
proto_gen.rs

1//! GraphQL → Protobuf type mapping and `.proto` file generation.
2
3use std::collections::BTreeSet;
4
5use fraiseql_core::{
6    db::dialect::RowViewColumnType,
7    schema::{CompiledSchema, FieldDefinition, FieldType},
8};
9
10/// Map a GraphQL type name to a Protobuf type name.
11///
12/// Returns the protobuf scalar or well-known type for the given GraphQL type.
13/// Unknown types fall back to `"string"`.
14///
15/// # Examples
16///
17/// ```
18/// use fraiseql_cli::codegen::proto_gen::graphql_to_proto_type;
19///
20/// assert_eq!(graphql_to_proto_type("String"), "string");
21/// assert_eq!(graphql_to_proto_type("Int"), "int32");
22/// assert_eq!(graphql_to_proto_type("DateTime"), "google.protobuf.Timestamp");
23/// ```
24#[must_use]
25pub fn graphql_to_proto_type(graphql_type: &str) -> &'static str {
26    match graphql_type {
27        "String" => "string",
28        "Int" => "int32",
29        "Float" => "double",
30        "Boolean" => "bool",
31        "ID" => "string",
32        "DateTime" => "google.protobuf.Timestamp",
33        "Date" => "string",
34        "BigInt" => "int64",
35        "JSON" => "google.protobuf.Struct",
36        _ => "string", // Custom scalars fall back to string
37    }
38}
39
40/// Map a GraphQL type name to a [`RowViewColumnType`] for SQL view generation.
41///
42/// Used by the row-shaped view DDL generator to determine the target SQL type
43/// for each field extracted from the JSON column.
44///
45/// # Examples
46///
47/// ```
48/// use fraiseql_cli::codegen::proto_gen::graphql_to_row_view_type;
49/// use fraiseql_core::db::dialect::RowViewColumnType;
50///
51/// assert_eq!(graphql_to_row_view_type("String"), RowViewColumnType::Text);
52/// assert_eq!(graphql_to_row_view_type("Int"), RowViewColumnType::Int32);
53/// ```
54#[must_use]
55pub fn graphql_to_row_view_type(graphql_type: &str) -> RowViewColumnType {
56    match graphql_type {
57        "String" => RowViewColumnType::Text,
58        "Date" => RowViewColumnType::Date,
59        "Int" => RowViewColumnType::Int32,
60        "BigInt" => RowViewColumnType::Int64,
61        "Float" => RowViewColumnType::Float64,
62        "Boolean" => RowViewColumnType::Boolean,
63        "ID" => RowViewColumnType::Uuid,
64        "DateTime" => RowViewColumnType::Timestamptz,
65        "JSON" => RowViewColumnType::Json,
66        _ => RowViewColumnType::Text, // Custom scalars → text
67    }
68}
69
70/// Returns `true` if the given protobuf type requires an import of a
71/// well-known type `.proto` file.
72#[must_use]
73pub fn needs_well_known_import(proto_type: &str) -> bool {
74    matches!(proto_type, "google.protobuf.Timestamp" | "google.protobuf.Struct")
75}
76
77/// Map a [`FieldType`] to a protobuf type string.
78///
79/// Handles scalars, lists (`repeated`), enums (referenced by name),
80/// and object references (referenced by message name).
81fn field_type_to_proto(ft: &FieldType) -> ProtoFieldType {
82    match ft {
83        FieldType::String => ProtoFieldType::scalar("string"),
84        FieldType::Int => ProtoFieldType::scalar("int32"),
85        FieldType::Float => ProtoFieldType::scalar("double"),
86        FieldType::Boolean => ProtoFieldType::scalar("bool"),
87        FieldType::Id | FieldType::Uuid => ProtoFieldType::scalar("string"),
88        FieldType::DateTime => ProtoFieldType::scalar("google.protobuf.Timestamp"),
89        FieldType::Date | FieldType::Time | FieldType::Decimal => ProtoFieldType::scalar("string"),
90        FieldType::Json => ProtoFieldType::scalar("google.protobuf.Struct"),
91        FieldType::Vector => ProtoFieldType::repeated("double"),
92        FieldType::Scalar(_) => ProtoFieldType::scalar("string"),
93        FieldType::Enum(name) => ProtoFieldType::scalar(name),
94        FieldType::Object(name) | FieldType::Interface(name) | FieldType::Union(name) => {
95            ProtoFieldType::scalar(name)
96        },
97        FieldType::Input(name) => ProtoFieldType::scalar(name),
98        FieldType::List(inner) => {
99            let inner_proto = field_type_to_proto(inner);
100            ProtoFieldType::repeated(&inner_proto.type_name)
101        },
102        _ => ProtoFieldType::scalar("string"),
103    }
104}
105
106/// Intermediate representation of a protobuf field type.
107struct ProtoFieldType {
108    type_name: String,
109    repeated:  bool,
110}
111
112impl ProtoFieldType {
113    fn scalar(name: &str) -> Self {
114        Self {
115            type_name: name.to_string(),
116            repeated:  false,
117        }
118    }
119
120    fn repeated(name: &str) -> Self {
121        Self {
122            type_name: name.to_string(),
123            repeated:  true,
124        }
125    }
126}
127
128/// Generate a complete `.proto` file from a compiled schema.
129///
130/// Produces a proto3 service definition with:
131/// - One message per GraphQL type (fields sorted alphabetically for stable numbering)
132/// - One RPC per query (Get for single, List for list queries)
133/// - One RPC per mutation (returns `MutationResponse`)
134/// - Enum definitions from the schema
135/// - Request/response wrapper messages
136///
137/// # Errors
138///
139/// Returns an error if the schema contains no types to expose.
140pub fn generate_proto_file(
141    schema: &CompiledSchema,
142    package: &str,
143    include_types: &[String],
144    exclude_types: &[String],
145) -> String {
146    let mut out = String::new();
147    let mut imports = BTreeSet::new();
148
149    // Collect which types to expose
150    let types: Vec<_> = schema
151        .types
152        .iter()
153        .filter(|t| should_include_type(t.name.as_ref(), include_types, exclude_types))
154        .collect();
155
156    // Pre-scan for needed imports
157    for td in &types {
158        for field in &td.fields {
159            let proto = field_type_to_proto(&field.field_type);
160            if needs_well_known_import(&proto.type_name) {
161                add_import_for_type(&proto.type_name, &mut imports);
162            }
163        }
164    }
165    // Scan query/mutation arguments too
166    for q in &schema.queries {
167        for arg in &q.arguments {
168            let proto = field_type_to_proto(&arg.arg_type);
169            if needs_well_known_import(&proto.type_name) {
170                add_import_for_type(&proto.type_name, &mut imports);
171            }
172        }
173    }
174
175    // Header
176    out.push_str("syntax = \"proto3\";\n\n");
177    out.push_str(&format!("package {package};\n\n"));
178
179    // Imports
180    for imp in &imports {
181        out.push_str(&format!("import \"{imp}\";\n"));
182    }
183    if !imports.is_empty() {
184        out.push('\n');
185    }
186
187    // Enum definitions
188    for enum_def in &schema.enums {
189        generate_enum(&mut out, &enum_def.name, &enum_def.values);
190    }
191
192    // Type messages
193    for td in &types {
194        generate_message(&mut out, td.name.as_ref(), &td.fields);
195    }
196
197    // MutationResponse message (if any mutations exist)
198    if !schema.mutations.is_empty() {
199        out.push_str("message MutationResponse {\n");
200        out.push_str("  bool success = 1;\n");
201        out.push_str("  optional string id = 2;\n");
202        out.push_str("  optional string error = 3;\n");
203        out.push_str("}\n\n");
204    }
205
206    // Request/response messages for queries
207    for q in &schema.queries {
208        if !types.iter().any(|t| t.name == q.return_type) {
209            continue;
210        }
211        generate_query_messages(&mut out, q);
212    }
213
214    // Request messages for mutations
215    for m in &schema.mutations {
216        generate_mutation_request_message(&mut out, m);
217    }
218
219    // Service definition
220    let service_name = package_to_service(package);
221    out.push_str(&format!("service {service_name} {{\n"));
222
223    for q in &schema.queries {
224        if !types.iter().any(|t| t.name == q.return_type) {
225            continue;
226        }
227        let rpc_name = to_pascal_case(&q.name);
228        let req = format!("{rpc_name}Request");
229        if q.returns_list {
230            // Server-streaming RPC: each response frame is a single entity message.
231            out.push_str(&format!("  rpc {rpc_name}({req}) returns (stream {});\n", q.return_type));
232        } else {
233            out.push_str(&format!("  rpc {rpc_name}({req}) returns ({});\n", q.return_type));
234        }
235    }
236
237    for m in &schema.mutations {
238        let rpc_name = to_pascal_case(&m.name);
239        let req = format!("{rpc_name}Request");
240        out.push_str(&format!("  rpc {rpc_name}({req}) returns (MutationResponse);\n"));
241    }
242
243    out.push_str("}\n");
244
245    out
246}
247
248/// Generate a protobuf message from a type's fields.
249///
250/// Fields are sorted alphabetically for deterministic field numbering.
251fn generate_message(out: &mut String, name: &str, fields: &[FieldDefinition]) {
252    out.push_str(&format!("message {name} {{\n"));
253
254    let mut sorted_fields: Vec<&FieldDefinition> = fields.iter().collect();
255    sorted_fields.sort_by(|a, b| a.name.as_ref().cmp(b.name.as_ref()));
256
257    for (i, field) in sorted_fields.iter().enumerate() {
258        let proto = field_type_to_proto(&field.field_type);
259        let field_num = i + 1;
260        let optional = if field.nullable && !proto.repeated {
261            "optional "
262        } else {
263            ""
264        };
265        let repeated = if proto.repeated { "repeated " } else { "" };
266        out.push_str(&format!(
267            "  {optional}{repeated}{} {} = {field_num};\n",
268            proto.type_name, field.name
269        ));
270    }
271
272    out.push_str("}\n\n");
273}
274
275/// Generate a protobuf enum definition.
276fn generate_enum(
277    out: &mut String,
278    name: &str,
279    values: &[fraiseql_core::schema::EnumValueDefinition],
280) {
281    out.push_str(&format!("enum {name} {{\n"));
282    out.push_str(&format!("  {}_UNSPECIFIED = 0;\n", to_screaming_snake(name)));
283
284    for (i, val) in values.iter().enumerate() {
285        out.push_str(&format!("  {} = {};\n", val.name, i + 1));
286    }
287
288    out.push_str("}\n\n");
289}
290
291/// Generate request/response messages for a query.
292fn generate_query_messages(out: &mut String, q: &fraiseql_core::schema::QueryDefinition) {
293    let rpc_name = to_pascal_case(&q.name);
294
295    // Request message
296    out.push_str(&format!("message {rpc_name}Request {{\n"));
297
298    let mut sorted_args: Vec<_> = q.arguments.iter().collect();
299    sorted_args.sort_by(|a, b| a.name.cmp(&b.name));
300
301    for (i, arg) in sorted_args.iter().enumerate() {
302        let proto = field_type_to_proto(&arg.arg_type);
303        let optional = if arg.nullable && !proto.repeated {
304            "optional "
305        } else {
306            ""
307        };
308        let repeated = if proto.repeated { "repeated " } else { "" };
309        out.push_str(&format!(
310            "  {optional}{repeated}{} {} = {};\n",
311            proto.type_name,
312            arg.name,
313            i + 1,
314        ));
315    }
316
317    // Add standard pagination fields for list queries
318    if q.returns_list {
319        let next_num = sorted_args.len() + 1;
320        out.push_str(&format!("  optional int32 limit = {next_num};\n"));
321        out.push_str(&format!("  optional int32 offset = {};\n", next_num + 1));
322    }
323
324    out.push_str("}\n\n");
325
326    // Note: list queries use server-streaming RPCs and do not need a
327    // response wrapper message — each streamed frame is the entity type directly.
328}
329
330/// Generate a request message for a mutation.
331fn generate_mutation_request_message(
332    out: &mut String,
333    m: &fraiseql_core::schema::MutationDefinition,
334) {
335    let rpc_name = to_pascal_case(&m.name);
336
337    out.push_str(&format!("message {rpc_name}Request {{\n"));
338
339    let mut sorted_args: Vec<_> = m.arguments.iter().collect();
340    sorted_args.sort_by(|a, b| a.name.cmp(&b.name));
341
342    for (i, arg) in sorted_args.iter().enumerate() {
343        let proto = field_type_to_proto(&arg.arg_type);
344        let optional = if arg.nullable && !proto.repeated {
345            "optional "
346        } else {
347            ""
348        };
349        let repeated = if proto.repeated { "repeated " } else { "" };
350        out.push_str(&format!(
351            "  {optional}{repeated}{} {} = {};\n",
352            proto.type_name,
353            arg.name,
354            i + 1,
355        ));
356    }
357
358    out.push_str("}\n\n");
359}
360
361/// Check if a type should be included based on include/exclude lists.
362pub(crate) fn should_include_type(
363    name: &str,
364    include_types: &[String],
365    exclude_types: &[String],
366) -> bool {
367    if !include_types.is_empty() && !include_types.iter().any(|t| t == name) {
368        return false;
369    }
370    !exclude_types.iter().any(|t| t == name)
371}
372
373/// Convert a snake_case or camelCase name to PascalCase.
374pub(crate) fn to_pascal_case(name: &str) -> String {
375    name.split('_')
376        .map(|part| {
377            let mut chars = part.chars();
378            match chars.next() {
379                Some(c) => {
380                    let mut s = c.to_uppercase().to_string();
381                    s.push_str(&chars.collect::<String>());
382                    s
383                },
384                None => String::new(),
385            }
386        })
387        .collect()
388}
389
390/// Convert a PascalCase name to SCREAMING_SNAKE_CASE.
391pub(crate) fn to_screaming_snake(name: &str) -> String {
392    let mut result = String::new();
393    for (i, c) in name.chars().enumerate() {
394        if c.is_uppercase() && i > 0 {
395            result.push('_');
396        }
397        result.push(c.to_ascii_uppercase());
398    }
399    result
400}
401
402/// Extract service name from package (e.g., "fraiseql.v1" → "FraiseQLService").
403fn package_to_service(package: &str) -> String {
404    let parts: Vec<&str> = package.split('.').collect();
405    let base = parts.first().copied().unwrap_or("FraiseQL");
406    let mut service = to_pascal_case(base);
407    service.push_str("Service");
408    service
409}
410
411/// Add the import path for a well-known protobuf type.
412fn add_import_for_type(proto_type: &str, imports: &mut BTreeSet<String>) {
413    match proto_type {
414        "google.protobuf.Timestamp" => {
415            imports.insert("google/protobuf/timestamp.proto".to_string());
416        },
417        "google.protobuf.Struct" => {
418            imports.insert("google/protobuf/struct.proto".to_string());
419        },
420        _ => {},
421    }
422}