Skip to main content

fraiseql_cli/commands/
generate_proto.rs

1//! `generate-proto` command: produce service.proto, vr_migrations.sql, and descriptor.binpb.
2
3use std::{fs, path::Path};
4
5use anyhow::Context;
6use fraiseql_core::{
7    db::dialect::{MySqlDialect, PostgresDialect, SqlDialect, SqlServerDialect, SqliteDialect},
8    schema::CompiledSchema,
9};
10
11use crate::{
12    codegen::{proto_gen, row_views},
13    output::OutputFormatter,
14};
15
16/// Resolve a SQL dialect from its CLI string name.
17///
18/// # Errors
19///
20/// Returns an error if the dialect name is not recognised.
21fn resolve_dialect(name: &str) -> anyhow::Result<Box<dyn SqlDialect>> {
22    match name {
23        "postgres" | "postgresql" => Ok(Box::new(PostgresDialect)),
24        "mysql" => Ok(Box::new(MySqlDialect)),
25        "sqlite" => Ok(Box::new(SqliteDialect)),
26        "sqlserver" => Ok(Box::new(SqlServerDialect)),
27        other => Err(anyhow::anyhow!(
28            "Unknown dialect '{other}'. Expected: postgres, mysql, sqlite, sqlserver"
29        )),
30    }
31}
32
33/// Build a serialized `FileDescriptorSet` from the generated proto source.
34///
35/// Constructs a [`prost_types::FileDescriptorProto`] with package, syntax,
36/// and dependency information, then encodes it into a binary protobuf that
37/// gRPC reflection servers can serve at runtime.
38///
39/// # Errors
40///
41/// Returns an error if protobuf encoding fails.
42fn build_file_descriptor_set(proto_source: &str, package: &str) -> anyhow::Result<Vec<u8>> {
43    use prost::Message;
44    use prost_types::{FileDescriptorProto, FileDescriptorSet};
45
46    let mut file = FileDescriptorProto {
47        name: Some("service.proto".to_string()),
48        package: Some(package.to_string()),
49        syntax: Some("proto3".to_string()),
50        ..FileDescriptorProto::default()
51    };
52
53    // Add well-known type dependencies detected in the proto source.
54    if proto_source.contains("google/protobuf/timestamp.proto") {
55        file.dependency.push("google/protobuf/timestamp.proto".to_string());
56    }
57    if proto_source.contains("google/protobuf/struct.proto") {
58        file.dependency.push("google/protobuf/struct.proto".to_string());
59    }
60
61    let fds = FileDescriptorSet { file: vec![file] };
62
63    let mut buf = Vec::with_capacity(fds.encoded_len());
64    fds.encode(&mut buf).context("Failed to encode FileDescriptorSet")?;
65    Ok(buf)
66}
67
68/// Run the `generate-proto` command.
69///
70/// Reads a compiled schema and writes three files to the output directory:
71/// - `service.proto` — proto3 service definition
72/// - `vr_migrations.sql` — row-shaped view DDL for the gRPC transport
73/// - `descriptor.binpb` — serialized `FileDescriptorSet` for gRPC reflection
74///
75/// # Errors
76///
77/// Returns an error if the schema cannot be loaded, the dialect is unknown,
78/// or the output files cannot be written.
79pub fn run(
80    schema_path: &str,
81    output_dir: &str,
82    package: &str,
83    dialect_name: &str,
84    formatter: &OutputFormatter,
85) -> anyhow::Result<()> {
86    formatter.progress("Loading compiled schema...");
87
88    let content = fs::read_to_string(schema_path).context("Failed to read compiled schema file")?;
89    let schema: CompiledSchema =
90        serde_json::from_str(&content).context("Failed to parse compiled schema JSON")?;
91
92    let dialect = resolve_dialect(dialect_name)?;
93
94    // Resolve include/exclude from grpc config if present
95    let (include_types, exclude_types) = schema
96        .grpc_config
97        .as_ref()
98        .map(|g| (g.include_types.clone(), g.exclude_types.clone()))
99        .unwrap_or_default();
100
101    // 1. Generate service.proto
102    formatter.progress("Generating service.proto...");
103    let proto_source =
104        proto_gen::generate_proto_file(&schema, package, &include_types, &exclude_types);
105
106    // 2. Generate vr_migrations.sql
107    formatter.progress("Generating vr_migrations.sql...");
108    let row_view_ddl = row_views::generate_all_row_views(
109        dialect.as_ref(),
110        &schema.types,
111        &include_types,
112        &exclude_types,
113    );
114
115    // 3. Build descriptor.binpb
116    formatter.progress("Building descriptor.binpb...");
117    let descriptor_bytes = build_file_descriptor_set(&proto_source, package)?;
118
119    // Write output files
120    let out_path = Path::new(output_dir);
121    fs::create_dir_all(out_path).context("Failed to create output directory")?;
122
123    let proto_path = out_path.join("service.proto");
124    fs::write(&proto_path, &proto_source)
125        .with_context(|| format!("Failed to write {}", proto_path.display()))?;
126
127    let sql_path = out_path.join("vr_migrations.sql");
128    fs::write(&sql_path, &row_view_ddl)
129        .with_context(|| format!("Failed to write {}", sql_path.display()))?;
130
131    let desc_path = out_path.join("descriptor.binpb");
132    fs::write(&desc_path, &descriptor_bytes)
133        .with_context(|| format!("Failed to write {}", desc_path.display()))?;
134
135    formatter.section("Generated files");
136    formatter.progress(&format!("  {}", proto_path.display()));
137    formatter.progress(&format!("  {}", sql_path.display()));
138    formatter.progress(&format!("  {}", desc_path.display()));
139
140    Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145    use std::io::Write as _;
146
147    use fraiseql_core::schema::{
148        CompiledSchema, EnumDefinition, EnumValueDefinition, FieldDefinition, FieldDenyPolicy,
149        FieldType, TypeDefinition,
150    };
151    use tempfile::TempDir;
152
153    use super::*;
154
155    fn make_field(name: &str, ft: FieldType, nullable: bool) -> FieldDefinition {
156        FieldDefinition {
157            name: name.into(),
158            field_type: ft,
159            nullable,
160            description: None,
161            default_value: None,
162            vector_config: None,
163            alias: None,
164            deprecation: None,
165            requires_scope: None,
166            on_deny: FieldDenyPolicy::default(),
167            encryption: None,
168        }
169    }
170
171    fn make_type(name: &str, fields: Vec<FieldDefinition>) -> TypeDefinition {
172        TypeDefinition {
173            name: name.into(),
174            sql_source: name.to_lowercase().into(),
175            jsonb_column: "data".to_string(),
176            fields,
177            description: None,
178            sql_projection_hint: None,
179            implements: vec![],
180            requires_role: None,
181            is_error: false,
182            relay: false,
183            relationships: Vec::new(),
184        }
185    }
186
187    fn make_query(
188        name: &str,
189        return_type: &str,
190        returns_list: bool,
191    ) -> fraiseql_core::schema::QueryDefinition {
192        serde_json::from_value(serde_json::json!({
193            "name": name,
194            "return_type": return_type,
195            "returns_list": returns_list,
196        }))
197        .expect("test query definition")
198    }
199
200    fn write_schema_file(dir: &Path, schema: &CompiledSchema) -> String {
201        let json = serde_json::to_string_pretty(schema).expect("serialize schema");
202        let path = dir.join("schema.compiled.json");
203        let mut f = fs::File::create(&path).expect("create schema file");
204        f.write_all(json.as_bytes()).expect("write schema file");
205        path.to_string_lossy().into_owned()
206    }
207
208    // ── resolve_dialect ──────────────────────────────────────────────────
209
210    #[test]
211    fn test_resolve_dialect_postgres() {
212        assert!(resolve_dialect("postgres").is_ok());
213        assert!(resolve_dialect("postgresql").is_ok());
214    }
215
216    #[test]
217    fn test_resolve_dialect_mysql() {
218        assert!(resolve_dialect("mysql").is_ok());
219    }
220
221    #[test]
222    fn test_resolve_dialect_sqlite() {
223        assert!(resolve_dialect("sqlite").is_ok());
224    }
225
226    #[test]
227    fn test_resolve_dialect_sqlserver() {
228        assert!(resolve_dialect("sqlserver").is_ok());
229    }
230
231    #[test]
232    fn test_resolve_dialect_unknown() {
233        match resolve_dialect("oracle") {
234            Ok(_) => panic!("expected error for oracle"),
235            Err(e) => assert!(e.to_string().contains("Unknown dialect")),
236        }
237    }
238
239    // ── build_file_descriptor_set ───────────────────────────────────────
240
241    #[test]
242    fn test_descriptor_bytes_non_empty() {
243        let proto = "syntax = \"proto3\";\npackage test.v1;\n";
244        let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
245        assert!(!bytes.is_empty());
246    }
247
248    #[test]
249    fn test_descriptor_includes_timestamp_dep() {
250        let proto = "import \"google/protobuf/timestamp.proto\";\n";
251        let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
252        let as_str = String::from_utf8_lossy(&bytes);
253        assert!(as_str.contains("google/protobuf/timestamp.proto"));
254    }
255
256    #[test]
257    fn test_descriptor_includes_struct_dep() {
258        let proto = "import \"google/protobuf/struct.proto\";\n";
259        let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
260        let as_str = String::from_utf8_lossy(&bytes);
261        assert!(as_str.contains("google/protobuf/struct.proto"));
262    }
263
264    #[test]
265    fn test_descriptor_no_deps_when_absent() {
266        let proto = "syntax = \"proto3\";\n";
267        let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
268        let as_str = String::from_utf8_lossy(&bytes);
269        assert!(!as_str.contains("google/protobuf/timestamp.proto"));
270    }
271
272    // ── run (integration) ────────────────────────────────────────────────
273
274    #[test]
275    fn test_run_generates_three_files() {
276        let tmp = TempDir::new().expect("temp dir");
277        let mut schema = CompiledSchema::new();
278        schema.types.push(make_type(
279            "User",
280            vec![
281                make_field("id", FieldType::Id, false),
282                make_field("name", FieldType::String, false),
283            ],
284        ));
285        schema.queries.push(make_query("get_user", "User", false));
286
287        let schema_path = write_schema_file(tmp.path(), &schema);
288        let out_dir = tmp.path().join("out");
289        let formatter = OutputFormatter::new(false, true);
290
291        run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "postgres", &formatter)
292            .expect("run should succeed");
293
294        assert!(out_dir.join("service.proto").exists());
295        assert!(out_dir.join("vr_migrations.sql").exists());
296        assert!(out_dir.join("descriptor.binpb").exists());
297
298        // Verify proto content
299        let proto = fs::read_to_string(out_dir.join("service.proto")).expect("read proto");
300        assert!(proto.contains("package test.v1;"));
301        assert!(proto.contains("message User {"));
302        assert!(proto.contains("service TestService {"));
303    }
304
305    #[test]
306    fn test_run_with_enum_and_datetime() {
307        let tmp = TempDir::new().expect("temp dir");
308        let mut schema = CompiledSchema::new();
309        schema.enums.push(EnumDefinition {
310            name:        "Status".to_string(),
311            values:      vec![EnumValueDefinition {
312                name:        "ACTIVE".to_string(),
313                description: None,
314                deprecation: None,
315            }],
316            description: None,
317        });
318        schema.types.push(make_type(
319            "Event",
320            vec![
321                make_field("id", FieldType::Id, false),
322                make_field("created_at", FieldType::DateTime, false),
323                make_field("status", FieldType::Enum("Status".to_string()), false),
324            ],
325        ));
326        schema.queries.push(make_query("get_event", "Event", false));
327
328        let schema_path = write_schema_file(tmp.path(), &schema);
329        let out_dir = tmp.path().join("out");
330        let formatter = OutputFormatter::new(false, true);
331
332        run(&schema_path, &out_dir.to_string_lossy(), "fraiseql.v1", "postgres", &formatter)
333            .expect("run should succeed");
334
335        let proto = fs::read_to_string(out_dir.join("service.proto")).expect("read proto");
336        assert!(proto.contains("import \"google/protobuf/timestamp.proto\""));
337        assert!(proto.contains("enum Status {"));
338
339        // Descriptor should include timestamp dependency
340        let desc = fs::read(out_dir.join("descriptor.binpb")).expect("read descriptor");
341        let desc_str = String::from_utf8_lossy(&desc);
342        assert!(desc_str.contains("google/protobuf/timestamp.proto"));
343    }
344
345    #[test]
346    fn test_run_mysql_dialect() {
347        let tmp = TempDir::new().expect("temp dir");
348        let mut schema = CompiledSchema::new();
349        schema.types.push(make_type(
350            "User",
351            vec![
352                make_field("id", FieldType::Id, false),
353                make_field("name", FieldType::String, false),
354            ],
355        ));
356        schema.queries.push(make_query("get_user", "User", false));
357
358        let schema_path = write_schema_file(tmp.path(), &schema);
359        let out_dir = tmp.path().join("out");
360        let formatter = OutputFormatter::new(false, true);
361
362        run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "mysql", &formatter)
363            .expect("run with mysql should succeed");
364
365        let sql = fs::read_to_string(out_dir.join("vr_migrations.sql")).expect("read sql");
366        assert!(sql.contains("JSON_EXTRACT"));
367    }
368
369    #[test]
370    fn test_run_bad_schema_path() {
371        let tmp = TempDir::new().expect("temp dir");
372        let out_dir = tmp.path().join("out");
373        let formatter = OutputFormatter::new(false, true);
374
375        let result = run(
376            "/nonexistent/schema.compiled.json",
377            &out_dir.to_string_lossy(),
378            "test.v1",
379            "postgres",
380            &formatter,
381        );
382        assert!(result.is_err());
383    }
384
385    #[test]
386    fn test_run_bad_dialect() {
387        let tmp = TempDir::new().expect("temp dir");
388        let mut schema = CompiledSchema::new();
389        schema
390            .types
391            .push(make_type("User", vec![make_field("id", FieldType::Id, false)]));
392
393        let schema_path = write_schema_file(tmp.path(), &schema);
394        let out_dir = tmp.path().join("out");
395        let formatter = OutputFormatter::new(false, true);
396
397        let result = run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "oracle", &formatter);
398        assert!(result.is_err());
399        match result {
400            Ok(()) => panic!("expected error for oracle dialect"),
401            Err(e) => assert!(e.to_string().contains("Unknown dialect")),
402        }
403    }
404}