1use std::{fs, path::Path};
4
5use anyhow::Context;
6use fraiseql_core::{
7 db::dialect::{MySqlDialect, PostgresDialect, SqlDialect, SqlServerDialect, SqliteDialect},
8 schema::CompiledSchema,
9};
10
11use crate::{
12 codegen::{proto_gen, row_views},
13 output::OutputFormatter,
14};
15
16fn resolve_dialect(name: &str) -> anyhow::Result<Box<dyn SqlDialect>> {
22 match name {
23 "postgres" | "postgresql" => Ok(Box::new(PostgresDialect)),
24 "mysql" => Ok(Box::new(MySqlDialect)),
25 "sqlite" => Ok(Box::new(SqliteDialect)),
26 "sqlserver" => Ok(Box::new(SqlServerDialect)),
27 other => Err(anyhow::anyhow!(
28 "Unknown dialect '{other}'. Expected: postgres, mysql, sqlite, sqlserver"
29 )),
30 }
31}
32
33fn build_file_descriptor_set(proto_source: &str, package: &str) -> anyhow::Result<Vec<u8>> {
43 use prost::Message;
44 use prost_types::{FileDescriptorProto, FileDescriptorSet};
45
46 let mut file = FileDescriptorProto {
47 name: Some("service.proto".to_string()),
48 package: Some(package.to_string()),
49 syntax: Some("proto3".to_string()),
50 ..FileDescriptorProto::default()
51 };
52
53 if proto_source.contains("google/protobuf/timestamp.proto") {
55 file.dependency.push("google/protobuf/timestamp.proto".to_string());
56 }
57 if proto_source.contains("google/protobuf/struct.proto") {
58 file.dependency.push("google/protobuf/struct.proto".to_string());
59 }
60
61 let fds = FileDescriptorSet { file: vec![file] };
62
63 let mut buf = Vec::with_capacity(fds.encoded_len());
64 fds.encode(&mut buf).context("Failed to encode FileDescriptorSet")?;
65 Ok(buf)
66}
67
68pub fn run(
80 schema_path: &str,
81 output_dir: &str,
82 package: &str,
83 dialect_name: &str,
84 formatter: &OutputFormatter,
85) -> anyhow::Result<()> {
86 formatter.progress("Loading compiled schema...");
87
88 let content = fs::read_to_string(schema_path).context("Failed to read compiled schema file")?;
89 let schema: CompiledSchema =
90 serde_json::from_str(&content).context("Failed to parse compiled schema JSON")?;
91
92 let dialect = resolve_dialect(dialect_name)?;
93
94 let (include_types, exclude_types) = schema
96 .grpc_config
97 .as_ref()
98 .map(|g| (g.include_types.clone(), g.exclude_types.clone()))
99 .unwrap_or_default();
100
101 formatter.progress("Generating service.proto...");
103 let proto_source =
104 proto_gen::generate_proto_file(&schema, package, &include_types, &exclude_types);
105
106 formatter.progress("Generating vr_migrations.sql...");
108 let row_view_ddl = row_views::generate_all_row_views(
109 dialect.as_ref(),
110 &schema.types,
111 &include_types,
112 &exclude_types,
113 );
114
115 formatter.progress("Building descriptor.binpb...");
117 let descriptor_bytes = build_file_descriptor_set(&proto_source, package)?;
118
119 let out_path = Path::new(output_dir);
121 fs::create_dir_all(out_path).context("Failed to create output directory")?;
122
123 let proto_path = out_path.join("service.proto");
124 fs::write(&proto_path, &proto_source)
125 .with_context(|| format!("Failed to write {}", proto_path.display()))?;
126
127 let sql_path = out_path.join("vr_migrations.sql");
128 fs::write(&sql_path, &row_view_ddl)
129 .with_context(|| format!("Failed to write {}", sql_path.display()))?;
130
131 let desc_path = out_path.join("descriptor.binpb");
132 fs::write(&desc_path, &descriptor_bytes)
133 .with_context(|| format!("Failed to write {}", desc_path.display()))?;
134
135 formatter.section("Generated files");
136 formatter.progress(&format!(" {}", proto_path.display()));
137 formatter.progress(&format!(" {}", sql_path.display()));
138 formatter.progress(&format!(" {}", desc_path.display()));
139
140 Ok(())
141}
142
143#[cfg(test)]
144mod tests {
145 use std::io::Write as _;
146
147 use fraiseql_core::schema::{
148 CompiledSchema, EnumDefinition, EnumValueDefinition, FieldDefinition, FieldDenyPolicy,
149 FieldType, TypeDefinition,
150 };
151 use tempfile::TempDir;
152
153 use super::*;
154
155 fn make_field(name: &str, ft: FieldType, nullable: bool) -> FieldDefinition {
156 FieldDefinition {
157 name: name.into(),
158 field_type: ft,
159 nullable,
160 description: None,
161 default_value: None,
162 vector_config: None,
163 alias: None,
164 deprecation: None,
165 requires_scope: None,
166 on_deny: FieldDenyPolicy::default(),
167 encryption: None,
168 }
169 }
170
171 fn make_type(name: &str, fields: Vec<FieldDefinition>) -> TypeDefinition {
172 TypeDefinition {
173 name: name.into(),
174 sql_source: name.to_lowercase().into(),
175 jsonb_column: "data".to_string(),
176 fields,
177 description: None,
178 sql_projection_hint: None,
179 implements: vec![],
180 requires_role: None,
181 is_error: false,
182 relay: false,
183 relationships: Vec::new(),
184 }
185 }
186
187 fn make_query(
188 name: &str,
189 return_type: &str,
190 returns_list: bool,
191 ) -> fraiseql_core::schema::QueryDefinition {
192 serde_json::from_value(serde_json::json!({
193 "name": name,
194 "return_type": return_type,
195 "returns_list": returns_list,
196 }))
197 .expect("test query definition")
198 }
199
200 fn write_schema_file(dir: &Path, schema: &CompiledSchema) -> String {
201 let json = serde_json::to_string_pretty(schema).expect("serialize schema");
202 let path = dir.join("schema.compiled.json");
203 let mut f = fs::File::create(&path).expect("create schema file");
204 f.write_all(json.as_bytes()).expect("write schema file");
205 path.to_string_lossy().into_owned()
206 }
207
208 #[test]
211 fn test_resolve_dialect_postgres() {
212 assert!(resolve_dialect("postgres").is_ok());
213 assert!(resolve_dialect("postgresql").is_ok());
214 }
215
216 #[test]
217 fn test_resolve_dialect_mysql() {
218 assert!(resolve_dialect("mysql").is_ok());
219 }
220
221 #[test]
222 fn test_resolve_dialect_sqlite() {
223 assert!(resolve_dialect("sqlite").is_ok());
224 }
225
226 #[test]
227 fn test_resolve_dialect_sqlserver() {
228 assert!(resolve_dialect("sqlserver").is_ok());
229 }
230
231 #[test]
232 fn test_resolve_dialect_unknown() {
233 match resolve_dialect("oracle") {
234 Ok(_) => panic!("expected error for oracle"),
235 Err(e) => assert!(e.to_string().contains("Unknown dialect")),
236 }
237 }
238
239 #[test]
242 fn test_descriptor_bytes_non_empty() {
243 let proto = "syntax = \"proto3\";\npackage test.v1;\n";
244 let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
245 assert!(!bytes.is_empty());
246 }
247
248 #[test]
249 fn test_descriptor_includes_timestamp_dep() {
250 let proto = "import \"google/protobuf/timestamp.proto\";\n";
251 let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
252 let as_str = String::from_utf8_lossy(&bytes);
253 assert!(as_str.contains("google/protobuf/timestamp.proto"));
254 }
255
256 #[test]
257 fn test_descriptor_includes_struct_dep() {
258 let proto = "import \"google/protobuf/struct.proto\";\n";
259 let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
260 let as_str = String::from_utf8_lossy(&bytes);
261 assert!(as_str.contains("google/protobuf/struct.proto"));
262 }
263
264 #[test]
265 fn test_descriptor_no_deps_when_absent() {
266 let proto = "syntax = \"proto3\";\n";
267 let bytes = build_file_descriptor_set(proto, "test.v1").expect("encode");
268 let as_str = String::from_utf8_lossy(&bytes);
269 assert!(!as_str.contains("google/protobuf/timestamp.proto"));
270 }
271
272 #[test]
275 fn test_run_generates_three_files() {
276 let tmp = TempDir::new().expect("temp dir");
277 let mut schema = CompiledSchema::new();
278 schema.types.push(make_type(
279 "User",
280 vec![
281 make_field("id", FieldType::Id, false),
282 make_field("name", FieldType::String, false),
283 ],
284 ));
285 schema.queries.push(make_query("get_user", "User", false));
286
287 let schema_path = write_schema_file(tmp.path(), &schema);
288 let out_dir = tmp.path().join("out");
289 let formatter = OutputFormatter::new(false, true);
290
291 run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "postgres", &formatter)
292 .expect("run should succeed");
293
294 assert!(out_dir.join("service.proto").exists());
295 assert!(out_dir.join("vr_migrations.sql").exists());
296 assert!(out_dir.join("descriptor.binpb").exists());
297
298 let proto = fs::read_to_string(out_dir.join("service.proto")).expect("read proto");
300 assert!(proto.contains("package test.v1;"));
301 assert!(proto.contains("message User {"));
302 assert!(proto.contains("service TestService {"));
303 }
304
305 #[test]
306 fn test_run_with_enum_and_datetime() {
307 let tmp = TempDir::new().expect("temp dir");
308 let mut schema = CompiledSchema::new();
309 schema.enums.push(EnumDefinition {
310 name: "Status".to_string(),
311 values: vec![EnumValueDefinition {
312 name: "ACTIVE".to_string(),
313 description: None,
314 deprecation: None,
315 }],
316 description: None,
317 });
318 schema.types.push(make_type(
319 "Event",
320 vec![
321 make_field("id", FieldType::Id, false),
322 make_field("created_at", FieldType::DateTime, false),
323 make_field("status", FieldType::Enum("Status".to_string()), false),
324 ],
325 ));
326 schema.queries.push(make_query("get_event", "Event", false));
327
328 let schema_path = write_schema_file(tmp.path(), &schema);
329 let out_dir = tmp.path().join("out");
330 let formatter = OutputFormatter::new(false, true);
331
332 run(&schema_path, &out_dir.to_string_lossy(), "fraiseql.v1", "postgres", &formatter)
333 .expect("run should succeed");
334
335 let proto = fs::read_to_string(out_dir.join("service.proto")).expect("read proto");
336 assert!(proto.contains("import \"google/protobuf/timestamp.proto\""));
337 assert!(proto.contains("enum Status {"));
338
339 let desc = fs::read(out_dir.join("descriptor.binpb")).expect("read descriptor");
341 let desc_str = String::from_utf8_lossy(&desc);
342 assert!(desc_str.contains("google/protobuf/timestamp.proto"));
343 }
344
345 #[test]
346 fn test_run_mysql_dialect() {
347 let tmp = TempDir::new().expect("temp dir");
348 let mut schema = CompiledSchema::new();
349 schema.types.push(make_type(
350 "User",
351 vec![
352 make_field("id", FieldType::Id, false),
353 make_field("name", FieldType::String, false),
354 ],
355 ));
356 schema.queries.push(make_query("get_user", "User", false));
357
358 let schema_path = write_schema_file(tmp.path(), &schema);
359 let out_dir = tmp.path().join("out");
360 let formatter = OutputFormatter::new(false, true);
361
362 run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "mysql", &formatter)
363 .expect("run with mysql should succeed");
364
365 let sql = fs::read_to_string(out_dir.join("vr_migrations.sql")).expect("read sql");
366 assert!(sql.contains("JSON_EXTRACT"));
367 }
368
369 #[test]
370 fn test_run_bad_schema_path() {
371 let tmp = TempDir::new().expect("temp dir");
372 let out_dir = tmp.path().join("out");
373 let formatter = OutputFormatter::new(false, true);
374
375 let result = run(
376 "/nonexistent/schema.compiled.json",
377 &out_dir.to_string_lossy(),
378 "test.v1",
379 "postgres",
380 &formatter,
381 );
382 assert!(result.is_err());
383 }
384
385 #[test]
386 fn test_run_bad_dialect() {
387 let tmp = TempDir::new().expect("temp dir");
388 let mut schema = CompiledSchema::new();
389 schema
390 .types
391 .push(make_type("User", vec![make_field("id", FieldType::Id, false)]));
392
393 let schema_path = write_schema_file(tmp.path(), &schema);
394 let out_dir = tmp.path().join("out");
395 let formatter = OutputFormatter::new(false, true);
396
397 let result = run(&schema_path, &out_dir.to_string_lossy(), "test.v1", "oracle", &formatter);
398 assert!(result.is_err());
399 match result {
400 Ok(()) => panic!("expected error for oracle dialect"),
401 Err(e) => assert!(e.to_string().contains("Unknown dialect")),
402 }
403 }
404}