Skip to main content

fraiseql_cli/codegen/
proto_gen.rs

1//! GraphQL → Protobuf type mapping and `.proto` file generation.
2
3use std::collections::BTreeSet;
4
5use fraiseql_core::{
6    db::dialect::RowViewColumnType,
7    schema::{CompiledSchema, FieldDefinition, FieldType},
8};
9
10/// Map a GraphQL type name to a Protobuf type name.
11///
12/// Returns the protobuf scalar or well-known type for the given GraphQL type.
13/// Unknown types fall back to `"string"`.
14///
15/// # Examples
16///
17/// ```
18/// use fraiseql_cli::codegen::proto_gen::graphql_to_proto_type;
19///
20/// assert_eq!(graphql_to_proto_type("String"), "string");
21/// assert_eq!(graphql_to_proto_type("Int"), "int32");
22/// assert_eq!(graphql_to_proto_type("DateTime"), "google.protobuf.Timestamp");
23/// ```
24#[must_use]
25pub fn graphql_to_proto_type(graphql_type: &str) -> &'static str {
26    match graphql_type {
27        "String" => "string",
28        "Int" => "int32",
29        "Float" => "double",
30        "Boolean" => "bool",
31        "ID" => "string",
32        "DateTime" => "google.protobuf.Timestamp",
33        "Date" => "string",
34        "BigInt" => "int64",
35        "JSON" => "google.protobuf.Struct",
36        _ => "string", // Custom scalars fall back to string
37    }
38}
39
40/// Map a GraphQL type name to a [`RowViewColumnType`] for SQL view generation.
41///
42/// Used by the row-shaped view DDL generator to determine the target SQL type
43/// for each field extracted from the JSON column.
44///
45/// # Examples
46///
47/// ```
48/// use fraiseql_cli::codegen::proto_gen::graphql_to_row_view_type;
49/// use fraiseql_core::db::dialect::RowViewColumnType;
50///
51/// assert_eq!(graphql_to_row_view_type("String"), RowViewColumnType::Text);
52/// assert_eq!(graphql_to_row_view_type("Int"), RowViewColumnType::Int32);
53/// ```
54#[must_use]
55pub fn graphql_to_row_view_type(graphql_type: &str) -> RowViewColumnType {
56    match graphql_type {
57        "String" => RowViewColumnType::Text,
58        "Date" => RowViewColumnType::Date,
59        "Int" => RowViewColumnType::Int32,
60        "BigInt" => RowViewColumnType::Int64,
61        "Float" => RowViewColumnType::Float64,
62        "Boolean" => RowViewColumnType::Boolean,
63        "ID" => RowViewColumnType::Uuid,
64        "DateTime" => RowViewColumnType::Timestamptz,
65        "JSON" => RowViewColumnType::Json,
66        _ => RowViewColumnType::Text, // Custom scalars → text
67    }
68}
69
70/// Returns `true` if the given protobuf type requires an import of a
71/// well-known type `.proto` file.
72#[must_use]
73pub fn needs_well_known_import(proto_type: &str) -> bool {
74    matches!(proto_type, "google.protobuf.Timestamp" | "google.protobuf.Struct")
75}
76
77/// Map a [`FieldType`] to a protobuf type string.
78///
79/// Handles scalars, lists (`repeated`), enums (referenced by name),
80/// and object references (referenced by message name).
81fn field_type_to_proto(ft: &FieldType) -> ProtoFieldType {
82    match ft {
83        FieldType::String => ProtoFieldType::scalar("string"),
84        FieldType::Int => ProtoFieldType::scalar("int32"),
85        FieldType::Float => ProtoFieldType::scalar("double"),
86        FieldType::Boolean => ProtoFieldType::scalar("bool"),
87        FieldType::Id | FieldType::Uuid => ProtoFieldType::scalar("string"),
88        FieldType::DateTime => ProtoFieldType::scalar("google.protobuf.Timestamp"),
89        FieldType::Date | FieldType::Time | FieldType::Decimal => ProtoFieldType::scalar("string"),
90        FieldType::Json => ProtoFieldType::scalar("google.protobuf.Struct"),
91        FieldType::Vector => ProtoFieldType::repeated("double"),
92        FieldType::Scalar(_) => ProtoFieldType::scalar("string"),
93        FieldType::Enum(name) => ProtoFieldType::scalar(name),
94        FieldType::Object(name) | FieldType::Interface(name) | FieldType::Union(name) => {
95            ProtoFieldType::scalar(name)
96        },
97        FieldType::Input(name) => ProtoFieldType::scalar(name),
98        FieldType::List(inner) => {
99            let inner_proto = field_type_to_proto(inner);
100            ProtoFieldType::repeated(&inner_proto.type_name)
101        },
102        _ => ProtoFieldType::scalar("string"),
103    }
104}
105
106/// Intermediate representation of a protobuf field type.
107struct ProtoFieldType {
108    type_name: String,
109    repeated:  bool,
110}
111
112impl ProtoFieldType {
113    fn scalar(name: &str) -> Self {
114        Self {
115            type_name: name.to_string(),
116            repeated:  false,
117        }
118    }
119
120    fn repeated(name: &str) -> Self {
121        Self {
122            type_name: name.to_string(),
123            repeated:  true,
124        }
125    }
126}
127
128/// Generate a complete `.proto` file from a compiled schema.
129///
130/// Produces a proto3 service definition with:
131/// - One message per GraphQL type (fields sorted alphabetically for stable numbering)
132/// - One RPC per query (Get for single, List for list queries)
133/// - One RPC per mutation (returns `MutationResponse`)
134/// - Enum definitions from the schema
135/// - Request/response wrapper messages
136///
137/// # Errors
138///
139/// Returns an error if the schema contains no types to expose.
140pub fn generate_proto_file(
141    schema: &CompiledSchema,
142    package: &str,
143    include_types: &[String],
144    exclude_types: &[String],
145) -> String {
146    let mut out = String::new();
147    let mut imports = BTreeSet::new();
148
149    // Collect which types to expose
150    let types: Vec<_> = schema
151        .types
152        .iter()
153        .filter(|t| should_include_type(t.name.as_ref(), include_types, exclude_types))
154        .collect();
155
156    // Pre-scan for needed imports
157    for td in &types {
158        for field in &td.fields {
159            let proto = field_type_to_proto(&field.field_type);
160            if needs_well_known_import(&proto.type_name) {
161                add_import_for_type(&proto.type_name, &mut imports);
162            }
163        }
164    }
165    // Scan query/mutation arguments too
166    for q in &schema.queries {
167        for arg in &q.arguments {
168            let proto = field_type_to_proto(&arg.arg_type);
169            if needs_well_known_import(&proto.type_name) {
170                add_import_for_type(&proto.type_name, &mut imports);
171            }
172        }
173    }
174
175    // Header
176    out.push_str("syntax = \"proto3\";\n\n");
177    out.push_str(&format!("package {package};\n\n"));
178
179    // Imports
180    for imp in &imports {
181        out.push_str(&format!("import \"{imp}\";\n"));
182    }
183    if !imports.is_empty() {
184        out.push('\n');
185    }
186
187    // Enum definitions
188    for enum_def in &schema.enums {
189        generate_enum(&mut out, &enum_def.name, &enum_def.values);
190    }
191
192    // Type messages
193    for td in &types {
194        generate_message(&mut out, td.name.as_ref(), &td.fields);
195    }
196
197    // MutationResponse message (if any mutations exist)
198    if !schema.mutations.is_empty() {
199        out.push_str("message MutationResponse {\n");
200        out.push_str("  bool success = 1;\n");
201        out.push_str("  optional string id = 2;\n");
202        out.push_str("  optional string error = 3;\n");
203        out.push_str("}\n\n");
204    }
205
206    // Request/response messages for queries
207    for q in &schema.queries {
208        if !types.iter().any(|t| t.name == q.return_type) {
209            continue;
210        }
211        generate_query_messages(&mut out, q);
212    }
213
214    // Request messages for mutations
215    for m in &schema.mutations {
216        generate_mutation_request_message(&mut out, m);
217    }
218
219    // Service definition
220    let service_name = package_to_service(package);
221    out.push_str(&format!("service {service_name} {{\n"));
222
223    for q in &schema.queries {
224        if !types.iter().any(|t| t.name == q.return_type) {
225            continue;
226        }
227        let rpc_name = to_pascal_case(&q.name);
228        let req = format!("{rpc_name}Request");
229        if q.returns_list {
230            // Server-streaming RPC: each response frame is a single entity message.
231            out.push_str(&format!("  rpc {rpc_name}({req}) returns (stream {});\n", q.return_type));
232        } else {
233            out.push_str(&format!("  rpc {rpc_name}({req}) returns ({});\n", q.return_type));
234        }
235    }
236
237    for m in &schema.mutations {
238        let rpc_name = to_pascal_case(&m.name);
239        let req = format!("{rpc_name}Request");
240        out.push_str(&format!("  rpc {rpc_name}({req}) returns (MutationResponse);\n"));
241    }
242
243    out.push_str("}\n");
244
245    out
246}
247
248/// Generate a protobuf message from a type's fields.
249///
250/// Fields are sorted alphabetically for deterministic field numbering.
251fn generate_message(out: &mut String, name: &str, fields: &[FieldDefinition]) {
252    out.push_str(&format!("message {name} {{\n"));
253
254    let mut sorted_fields: Vec<&FieldDefinition> = fields.iter().collect();
255    sorted_fields.sort_by(|a, b| a.name.as_ref().cmp(b.name.as_ref()));
256
257    for (i, field) in sorted_fields.iter().enumerate() {
258        let proto = field_type_to_proto(&field.field_type);
259        let field_num = i + 1;
260        let optional = if field.nullable && !proto.repeated {
261            "optional "
262        } else {
263            ""
264        };
265        let repeated = if proto.repeated { "repeated " } else { "" };
266        out.push_str(&format!(
267            "  {optional}{repeated}{} {} = {field_num};\n",
268            proto.type_name, field.name
269        ));
270    }
271
272    out.push_str("}\n\n");
273}
274
275/// Generate a protobuf enum definition.
276fn generate_enum(
277    out: &mut String,
278    name: &str,
279    values: &[fraiseql_core::schema::EnumValueDefinition],
280) {
281    out.push_str(&format!("enum {name} {{\n"));
282    out.push_str(&format!("  {}_UNSPECIFIED = 0;\n", to_screaming_snake(name)));
283
284    for (i, val) in values.iter().enumerate() {
285        out.push_str(&format!("  {} = {};\n", val.name, i + 1));
286    }
287
288    out.push_str("}\n\n");
289}
290
291/// Generate request/response messages for a query.
292fn generate_query_messages(out: &mut String, q: &fraiseql_core::schema::QueryDefinition) {
293    let rpc_name = to_pascal_case(&q.name);
294
295    // Request message
296    out.push_str(&format!("message {rpc_name}Request {{\n"));
297
298    let mut sorted_args: Vec<_> = q.arguments.iter().collect();
299    sorted_args.sort_by(|a, b| a.name.cmp(&b.name));
300
301    for (i, arg) in sorted_args.iter().enumerate() {
302        let proto = field_type_to_proto(&arg.arg_type);
303        let optional = if arg.nullable && !proto.repeated {
304            "optional "
305        } else {
306            ""
307        };
308        let repeated = if proto.repeated { "repeated " } else { "" };
309        out.push_str(&format!(
310            "  {optional}{repeated}{} {} = {};\n",
311            proto.type_name,
312            arg.name,
313            i + 1,
314        ));
315    }
316
317    // Add standard pagination fields for list queries
318    if q.returns_list {
319        let next_num = sorted_args.len() + 1;
320        out.push_str(&format!("  optional int32 limit = {next_num};\n"));
321        out.push_str(&format!("  optional int32 offset = {};\n", next_num + 1));
322    }
323
324    out.push_str("}\n\n");
325
326    // Note: list queries use server-streaming RPCs and do not need a
327    // response wrapper message — each streamed frame is the entity type directly.
328}
329
330/// Generate a request message for a mutation.
331fn generate_mutation_request_message(
332    out: &mut String,
333    m: &fraiseql_core::schema::MutationDefinition,
334) {
335    let rpc_name = to_pascal_case(&m.name);
336
337    out.push_str(&format!("message {rpc_name}Request {{\n"));
338
339    let mut sorted_args: Vec<_> = m.arguments.iter().collect();
340    sorted_args.sort_by(|a, b| a.name.cmp(&b.name));
341
342    for (i, arg) in sorted_args.iter().enumerate() {
343        let proto = field_type_to_proto(&arg.arg_type);
344        let optional = if arg.nullable && !proto.repeated {
345            "optional "
346        } else {
347            ""
348        };
349        let repeated = if proto.repeated { "repeated " } else { "" };
350        out.push_str(&format!(
351            "  {optional}{repeated}{} {} = {};\n",
352            proto.type_name,
353            arg.name,
354            i + 1,
355        ));
356    }
357
358    out.push_str("}\n\n");
359}
360
361/// Check if a type should be included based on include/exclude lists.
362fn should_include_type(name: &str, include_types: &[String], exclude_types: &[String]) -> bool {
363    if !include_types.is_empty() && !include_types.iter().any(|t| t == name) {
364        return false;
365    }
366    !exclude_types.iter().any(|t| t == name)
367}
368
369/// Convert a snake_case or camelCase name to PascalCase.
370fn to_pascal_case(name: &str) -> String {
371    name.split('_')
372        .map(|part| {
373            let mut chars = part.chars();
374            match chars.next() {
375                Some(c) => {
376                    let mut s = c.to_uppercase().to_string();
377                    s.push_str(&chars.collect::<String>());
378                    s
379                },
380                None => String::new(),
381            }
382        })
383        .collect()
384}
385
386/// Convert a PascalCase name to SCREAMING_SNAKE_CASE.
387fn to_screaming_snake(name: &str) -> String {
388    let mut result = String::new();
389    for (i, c) in name.chars().enumerate() {
390        if c.is_uppercase() && i > 0 {
391            result.push('_');
392        }
393        result.push(c.to_ascii_uppercase());
394    }
395    result
396}
397
398/// Extract service name from package (e.g., "fraiseql.v1" → "FraiseQLService").
399fn package_to_service(package: &str) -> String {
400    let parts: Vec<&str> = package.split('.').collect();
401    let base = parts.first().copied().unwrap_or("FraiseQL");
402    let mut service = to_pascal_case(base);
403    service.push_str("Service");
404    service
405}
406
407/// Add the import path for a well-known protobuf type.
408fn add_import_for_type(proto_type: &str, imports: &mut BTreeSet<String>) {
409    match proto_type {
410        "google.protobuf.Timestamp" => {
411            imports.insert("google/protobuf/timestamp.proto".to_string());
412        },
413        "google.protobuf.Struct" => {
414            imports.insert("google/protobuf/struct.proto".to_string());
415        },
416        _ => {},
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use fraiseql_core::schema::{
423        CompiledSchema, EnumDefinition, EnumValueDefinition, FieldDenyPolicy, FieldType,
424        TypeDefinition,
425    };
426
427    use super::*;
428
429    fn make_field(name: &str, ft: FieldType, nullable: bool) -> FieldDefinition {
430        FieldDefinition {
431            name: name.into(),
432            field_type: ft,
433            nullable,
434            description: None,
435            default_value: None,
436            vector_config: None,
437            alias: None,
438            deprecation: None,
439            requires_scope: None,
440            on_deny: FieldDenyPolicy::default(),
441            encryption: None,
442        }
443    }
444
445    fn make_type(name: &str, fields: Vec<FieldDefinition>) -> TypeDefinition {
446        TypeDefinition {
447            name: name.into(),
448            sql_source: String::new().into(),
449            jsonb_column: "data".to_string(),
450            fields,
451            description: None,
452            sql_projection_hint: None,
453            implements: vec![],
454            requires_role: None,
455            is_error: false,
456            relay: false,
457            relationships: Vec::new(),
458        }
459    }
460
461    /// Build a query via JSON deserialization to leverage `#[serde(default)]`.
462    fn make_query(
463        name: &str,
464        return_type: &str,
465        returns_list: bool,
466    ) -> fraiseql_core::schema::QueryDefinition {
467        let json = serde_json::json!({
468            "name": name,
469            "return_type": return_type,
470            "returns_list": returns_list,
471        });
472        serde_json::from_value(json).expect("test query definition")
473    }
474
475    /// Build a mutation via JSON deserialization.
476    fn make_mutation(
477        name: &str,
478        args: Vec<fraiseql_core::schema::ArgumentDefinition>,
479    ) -> fraiseql_core::schema::MutationDefinition {
480        let mut m: fraiseql_core::schema::MutationDefinition =
481            serde_json::from_value(serde_json::json!({
482                "name": name,
483                "return_type": "MutationResponse",
484            }))
485            .expect("test mutation definition");
486        m.arguments = args;
487        m
488    }
489
490    fn make_arg(
491        name: &str,
492        ft: FieldType,
493        nullable: bool,
494    ) -> fraiseql_core::schema::ArgumentDefinition {
495        fraiseql_core::schema::ArgumentDefinition {
496            name: name.to_string(),
497            arg_type: ft,
498            nullable,
499            default_value: None,
500            description: None,
501            deprecation: None,
502        }
503    }
504
505    // ── graphql_to_proto_type ───────────────────────────────────────────
506
507    #[test]
508    fn test_proto_type_string() {
509        assert_eq!(graphql_to_proto_type("String"), "string");
510    }
511
512    #[test]
513    fn test_proto_type_int() {
514        assert_eq!(graphql_to_proto_type("Int"), "int32");
515    }
516
517    #[test]
518    fn test_proto_type_float() {
519        assert_eq!(graphql_to_proto_type("Float"), "double");
520    }
521
522    #[test]
523    fn test_proto_type_boolean() {
524        assert_eq!(graphql_to_proto_type("Boolean"), "bool");
525    }
526
527    #[test]
528    fn test_proto_type_id() {
529        assert_eq!(graphql_to_proto_type("ID"), "string");
530    }
531
532    #[test]
533    fn test_proto_type_datetime() {
534        assert_eq!(graphql_to_proto_type("DateTime"), "google.protobuf.Timestamp");
535    }
536
537    #[test]
538    fn test_proto_type_date() {
539        assert_eq!(graphql_to_proto_type("Date"), "string");
540    }
541
542    #[test]
543    fn test_proto_type_bigint() {
544        assert_eq!(graphql_to_proto_type("BigInt"), "int64");
545    }
546
547    #[test]
548    fn test_proto_type_json() {
549        assert_eq!(graphql_to_proto_type("JSON"), "google.protobuf.Struct");
550    }
551
552    #[test]
553    fn test_proto_type_custom_scalar_fallback() {
554        assert_eq!(graphql_to_proto_type("Email"), "string");
555        assert_eq!(graphql_to_proto_type("PhoneNumber"), "string");
556    }
557
558    // ── graphql_to_row_view_type ────────────────────────────────────────
559
560    #[test]
561    fn test_row_view_type_string() {
562        assert_eq!(graphql_to_row_view_type("String"), RowViewColumnType::Text);
563    }
564
565    #[test]
566    fn test_row_view_type_int() {
567        assert_eq!(graphql_to_row_view_type("Int"), RowViewColumnType::Int32);
568    }
569
570    #[test]
571    fn test_row_view_type_bigint() {
572        assert_eq!(graphql_to_row_view_type("BigInt"), RowViewColumnType::Int64);
573    }
574
575    #[test]
576    fn test_row_view_type_float() {
577        assert_eq!(graphql_to_row_view_type("Float"), RowViewColumnType::Float64);
578    }
579
580    #[test]
581    fn test_row_view_type_boolean() {
582        assert_eq!(graphql_to_row_view_type("Boolean"), RowViewColumnType::Boolean);
583    }
584
585    #[test]
586    fn test_row_view_type_id() {
587        assert_eq!(graphql_to_row_view_type("ID"), RowViewColumnType::Uuid);
588    }
589
590    #[test]
591    fn test_row_view_type_datetime() {
592        assert_eq!(graphql_to_row_view_type("DateTime"), RowViewColumnType::Timestamptz);
593    }
594
595    #[test]
596    fn test_row_view_type_json() {
597        assert_eq!(graphql_to_row_view_type("JSON"), RowViewColumnType::Json);
598    }
599
600    #[test]
601    fn test_row_view_type_date() {
602        assert_eq!(graphql_to_row_view_type("Date"), RowViewColumnType::Date);
603    }
604
605    #[test]
606    fn test_row_view_type_custom_scalar_fallback() {
607        assert_eq!(graphql_to_row_view_type("Email"), RowViewColumnType::Text);
608    }
609
610    // ── needs_well_known_import ─────────────────────────────────────────
611
612    #[test]
613    fn test_needs_import_timestamp() {
614        assert!(needs_well_known_import("google.protobuf.Timestamp"));
615    }
616
617    #[test]
618    fn test_needs_import_struct() {
619        assert!(needs_well_known_import("google.protobuf.Struct"));
620    }
621
622    #[test]
623    fn test_no_import_for_scalars() {
624        assert!(!needs_well_known_import("string"));
625        assert!(!needs_well_known_import("int32"));
626        assert!(!needs_well_known_import("bool"));
627    }
628
629    // ── to_pascal_case ──────────────────────────────────────────────────
630
631    #[test]
632    fn test_pascal_case_snake() {
633        assert_eq!(to_pascal_case("get_user"), "GetUser");
634    }
635
636    #[test]
637    fn test_pascal_case_single() {
638        assert_eq!(to_pascal_case("users"), "Users");
639    }
640
641    #[test]
642    fn test_pascal_case_already() {
643        assert_eq!(to_pascal_case("User"), "User");
644    }
645
646    // ── to_screaming_snake ──────────────────────────────────────────────
647
648    #[test]
649    fn test_screaming_snake() {
650        assert_eq!(to_screaming_snake("OrderStatus"), "ORDER_STATUS");
651    }
652
653    // ── should_include_type ─────────────────────────────────────────────
654
655    #[test]
656    fn test_include_all_when_empty() {
657        assert!(should_include_type("User", &[], &[]));
658    }
659
660    #[test]
661    fn test_include_whitelist() {
662        assert!(should_include_type("User", &["User".to_string()], &[]));
663        assert!(!should_include_type("Post", &["User".to_string()], &[]));
664    }
665
666    #[test]
667    fn test_exclude_blacklist() {
668        assert!(!should_include_type("Secret", &[], &["Secret".to_string()]));
669        assert!(should_include_type("User", &[], &["Secret".to_string()]));
670    }
671
672    // ── generate_proto_file ─────────────────────────────────────────────
673
674    #[test]
675    fn test_generate_proto_basic_type() {
676        let mut schema = CompiledSchema::new();
677        schema.types.push(make_type(
678            "User",
679            vec![
680                make_field("id", FieldType::Id, false),
681                make_field("name", FieldType::String, false),
682                make_field("email", FieldType::String, true),
683            ],
684        ));
685        schema.queries.push(make_query("get_user", "User", false));
686        schema.queries.push(make_query("list_users", "User", true));
687
688        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
689
690        assert!(proto.contains("syntax = \"proto3\";"));
691        assert!(proto.contains("package fraiseql.v1;"));
692        assert!(proto.contains("message User {"));
693        // Fields sorted alphabetically: email=1, id=2, name=3
694        assert!(proto.contains("optional string email = 1;"));
695        assert!(proto.contains("string id = 2;"));
696        assert!(proto.contains("string name = 3;"));
697        // Service
698        assert!(proto.contains("service FraiseqlService {"));
699        assert!(proto.contains("rpc GetUser(GetUserRequest) returns (User);"));
700        assert!(proto.contains("rpc ListUsers(ListUsersRequest) returns (stream User);"));
701    }
702
703    #[test]
704    fn test_generate_proto_with_datetime_import() {
705        let mut schema = CompiledSchema::new();
706        schema.types.push(make_type(
707            "Post",
708            vec![
709                make_field("id", FieldType::Id, false),
710                make_field("created_at", FieldType::DateTime, false),
711            ],
712        ));
713        schema.queries.push(make_query("get_post", "Post", false));
714
715        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
716
717        assert!(proto.contains("import \"google/protobuf/timestamp.proto\";"));
718        assert!(proto.contains("google.protobuf.Timestamp created_at = 1;"));
719    }
720
721    #[test]
722    fn test_generate_proto_with_mutations() {
723        let mut schema = CompiledSchema::new();
724        schema
725            .types
726            .push(make_type("User", vec![make_field("id", FieldType::Id, false)]));
727        schema.mutations.push(make_mutation(
728            "create_user",
729            vec![
730                make_arg("name", FieldType::String, false),
731                make_arg("email", FieldType::String, false),
732            ],
733        ));
734
735        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
736
737        assert!(proto.contains("message MutationResponse {"));
738        assert!(proto.contains("message CreateUserRequest {"));
739        // Args sorted: email=1, name=2
740        assert!(proto.contains("string email = 1;"));
741        assert!(proto.contains("string name = 2;"));
742        assert!(proto.contains("rpc CreateUser(CreateUserRequest) returns (MutationResponse);"));
743    }
744
745    #[test]
746    fn test_generate_proto_with_enum() {
747        let mut schema = CompiledSchema::new();
748        schema.enums.push(EnumDefinition {
749            name:        "OrderStatus".to_string(),
750            values:      vec![
751                EnumValueDefinition {
752                    name:        "PENDING".to_string(),
753                    description: None,
754                    deprecation: None,
755                },
756                EnumValueDefinition {
757                    name:        "SHIPPED".to_string(),
758                    description: None,
759                    deprecation: None,
760                },
761            ],
762            description: None,
763        });
764        schema.types.push(make_type(
765            "Order",
766            vec![
767                make_field("id", FieldType::Id, false),
768                make_field("status", FieldType::Enum("OrderStatus".to_string()), false),
769            ],
770        ));
771        schema.queries.push(make_query("get_order", "Order", false));
772
773        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
774
775        assert!(proto.contains("enum OrderStatus {"));
776        assert!(proto.contains("ORDER_STATUS_UNSPECIFIED = 0;"));
777        assert!(proto.contains("PENDING = 1;"));
778        assert!(proto.contains("SHIPPED = 2;"));
779        assert!(proto.contains("OrderStatus status = 2;"));
780    }
781
782    #[test]
783    fn test_generate_proto_exclude_types() {
784        let mut schema = CompiledSchema::new();
785        schema
786            .types
787            .push(make_type("User", vec![make_field("id", FieldType::Id, false)]));
788        schema
789            .types
790            .push(make_type("Secret", vec![make_field("id", FieldType::Id, false)]));
791        schema.queries.push(make_query("get_user", "User", false));
792        schema.queries.push(make_query("get_secret", "Secret", false));
793
794        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &["Secret".to_string()]);
795
796        assert!(proto.contains("message User {"));
797        assert!(!proto.contains("message Secret {"));
798        assert!(proto.contains("rpc GetUser"));
799        assert!(!proto.contains("rpc GetSecret"));
800    }
801
802    #[test]
803    fn test_generate_proto_list_query_pagination() {
804        let mut schema = CompiledSchema::new();
805        schema
806            .types
807            .push(make_type("User", vec![make_field("id", FieldType::Id, false)]));
808        schema.queries.push(make_query("list_users", "User", true));
809
810        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
811
812        // Pagination fields added to list request
813        assert!(proto.contains("optional int32 limit = 1;"));
814        assert!(proto.contains("optional int32 offset = 2;"));
815        // Server-streaming: no ListUsersResponse wrapper, returns stream User
816        assert!(proto.contains("rpc ListUsers(ListUsersRequest) returns (stream User);"));
817        assert!(!proto.contains("ListUsersResponse"), "No response wrapper for streaming RPCs");
818    }
819
820    #[test]
821    fn test_generate_proto_nullable_field() {
822        let mut schema = CompiledSchema::new();
823        schema.types.push(make_type(
824            "User",
825            vec![
826                make_field("name", FieldType::String, false),
827                make_field("bio", FieldType::String, true),
828            ],
829        ));
830        schema.queries.push(make_query("get_user", "User", false));
831
832        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
833
834        assert!(proto.contains("optional string bio = 1;"));
835        assert!(proto.contains("string name = 2;"));
836    }
837
838    #[test]
839    fn test_generate_proto_list_field() {
840        let mut schema = CompiledSchema::new();
841        schema.types.push(make_type(
842            "User",
843            vec![make_field(
844                "tags",
845                FieldType::List(Box::new(FieldType::String)),
846                false,
847            )],
848        ));
849        schema.queries.push(make_query("get_user", "User", false));
850
851        let proto = generate_proto_file(&schema, "fraiseql.v1", &[], &[]);
852
853        assert!(proto.contains("repeated string tags = 1;"));
854    }
855}