1use std::path::PathBuf;
4
5use crate::cli::GenerateArgs;
6use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_NAME};
7use crate::error::{CliError, CliResult};
8use crate::output::{self, success};
9
10pub async fn run(args: GenerateArgs) -> CliResult<()> {
12 output::header("Generate Prax Client");
13
14 let cwd = std::env::current_dir()?;
15
16 let config_path = cwd.join(CONFIG_FILE_NAME);
18 let config = if config_path.exists() {
19 Config::load(&config_path)?
20 } else {
21 Config::default()
22 };
23
24 let schema_path = args
26 .schema
27 .clone()
28 .unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
29 if !schema_path.exists() {
30 return Err(
31 CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
32 );
33 }
34
35 let output_dir = args
37 .output
38 .clone()
39 .unwrap_or_else(|| PathBuf::from(&config.generator.output));
40
41 output::kv("Schema", &schema_path.display().to_string());
42 output::kv("Output", &output_dir.display().to_string());
43 output::newline();
44
45 output::step(1, 4, "Reading schema...");
46
47 let schema_content = std::fs::read_to_string(&schema_path)?;
49 let schema = parse_schema(&schema_content)?;
50
51 output::step(2, 4, "Validating schema...");
52
53 validate_schema(&schema)?;
55
56 output::step(3, 4, "Generating code...");
57
58 std::fs::create_dir_all(&output_dir)?;
60
61 let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
63
64 output::step(4, 4, "Writing files...");
65
66 output::newline();
68 output::section("Generated files");
69
70 for file in &generated_files {
71 let relative_path = file
72 .strip_prefix(&cwd)
73 .unwrap_or(file)
74 .display()
75 .to_string();
76 output::list_item(&relative_path);
77 }
78
79 output::newline();
80 success(&format!(
81 "Generated {} files in {:.2}s",
82 generated_files.len(),
83 0.0 ));
85
86 Ok(())
87}
88
89fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
91 prax_schema::parse_schema(content)
92 .map_err(|e| CliError::Schema(format!("Failed to parse schema: {}", e)))
93}
94
95fn validate_schema(schema: &prax_schema::Schema) -> CliResult<()> {
97 let mut validator = prax_schema::Validator::new();
99 if let Err(e) = validator.validate(schema.clone()) {
101 return Err(CliError::Validation(format!(
102 "Schema validation failed: {}",
103 e
104 )));
105 }
106 Ok(())
107}
108
109fn generate_code(
111 schema: &prax_schema::ast::Schema,
112 output_dir: &PathBuf,
113 args: &GenerateArgs,
114 config: &Config,
115) -> CliResult<Vec<PathBuf>> {
116 let mut generated_files = Vec::new();
117
118 let features = if !args.features.is_empty() {
120 args.features.clone()
121 } else {
122 config
123 .generator
124 .features
125 .clone()
126 .unwrap_or_else(|| vec!["client".to_string()])
127 };
128
129 let client_path = output_dir.join("mod.rs");
131 let client_code = generate_client_module(schema, &features)?;
132 std::fs::write(&client_path, client_code)?;
133 generated_files.push(client_path);
134
135 for model in schema.models.values() {
137 let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
138 let model_code = generate_model_module(model, &features)?;
139 std::fs::write(&model_path, model_code)?;
140 generated_files.push(model_path);
141 }
142
143 for enum_def in schema.enums.values() {
145 let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
146 let enum_code = generate_enum_module(enum_def)?;
147 std::fs::write(&enum_path, enum_code)?;
148 generated_files.push(enum_path);
149 }
150
151 let types_path = output_dir.join("types.rs");
153 let types_code = generate_types_module(schema)?;
154 std::fs::write(&types_path, types_code)?;
155 generated_files.push(types_path);
156
157 let filters_path = output_dir.join("filters.rs");
159 let filters_code = generate_filters_module(schema)?;
160 std::fs::write(&filters_path, filters_code)?;
161 generated_files.push(filters_path);
162
163 Ok(generated_files)
164}
165
166fn generate_client_module(
168 schema: &prax_schema::ast::Schema,
169 _features: &[String],
170) -> CliResult<String> {
171 let mut code = String::new();
172
173 code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
174 code.push_str("//!\n");
175 code.push_str("//! This module contains the generated Prax client.\n\n");
176
177 code.push_str("pub mod types;\n");
179 code.push_str("pub mod filters;\n\n");
180
181 for model in schema.models.values() {
182 code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
183 }
184
185 for enum_def in schema.enums.values() {
186 code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
187 }
188
189 code.push_str("\n");
190
191 code.push_str("pub use types::*;\n");
193 code.push_str("pub use filters::*;\n\n");
194
195 for model in schema.models.values() {
196 code.push_str(&format!(
197 "pub use {}::{};\n",
198 to_snake_case(model.name()),
199 model.name()
200 ));
201 }
202
203 for enum_def in schema.enums.values() {
204 code.push_str(&format!(
205 "pub use {}::{};\n",
206 to_snake_case(enum_def.name()),
207 enum_def.name()
208 ));
209 }
210
211 code.push_str("\n");
212
213 code.push_str("/// The Prax database client\n");
215 code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
216 code.push_str(" engine: E,\n");
217 code.push_str("}\n\n");
218
219 code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
220 code.push_str(" /// Create a new Prax client with the given query engine\n");
221 code.push_str(" pub fn new(engine: E) -> Self {\n");
222 code.push_str(" Self { engine }\n");
223 code.push_str(" }\n\n");
224
225 for model in schema.models.values() {
226 let snake_name = to_snake_case(model.name());
227 code.push_str(&format!(" /// Access {} operations\n", model.name()));
228 code.push_str(&format!(
229 " pub fn {}(&self) -> {}::{}Operations<E> {{\n",
230 snake_name,
231 snake_name,
232 model.name()
233 ));
234 code.push_str(&format!(
235 " {}::{}Operations::new(&self.engine)\n",
236 snake_name,
237 model.name()
238 ));
239 code.push_str(" }\n\n");
240 }
241
242 code.push_str("}\n");
243
244 Ok(code)
245}
246
247fn generate_model_module(
249 model: &prax_schema::ast::Model,
250 features: &[String],
251) -> CliResult<String> {
252 let mut code = String::new();
253
254 code.push_str(&format!(
255 "//! Auto-generated module for {} model\n\n",
256 model.name()
257 ));
258
259 let mut derives = vec!["Debug", "Clone"];
261 if features.contains(&"serde".to_string()) {
262 derives.push("serde::Serialize");
263 derives.push("serde::Deserialize");
264 }
265
266 code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
268 code.push_str(&format!("pub struct {} {{\n", model.name()));
269
270 for field in model.fields.values() {
271 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
272 let field_name = to_snake_case(field.name());
273
274 if let Some(attr) = field.get_attribute("map") {
276 if features.contains(&"serde".to_string()) {
277 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
278 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
279 }
280 }
281 }
282
283 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
284 }
285
286 code.push_str("}\n\n");
287
288 code.push_str(&format!("/// Operations for the {} model\n", model.name()));
290 code.push_str(&format!(
291 "pub struct {}Operations<'a, E: prax_query::QueryEngine> {{\n",
292 model.name()
293 ));
294 code.push_str(" engine: &'a E,\n");
295 code.push_str("}\n\n");
296
297 code.push_str(&format!(
298 "impl<'a, E: prax_query::QueryEngine> {}Operations<'a, E> {{\n",
299 model.name()
300 ));
301 code.push_str(" pub fn new(engine: &'a E) -> Self {\n");
302 code.push_str(" Self { engine }\n");
303 code.push_str(" }\n\n");
304
305 let table_name = model.table_name();
306
307 code.push_str(" /// Find many records\n");
309 code.push_str(&format!(
310 " pub fn find_many(&self) -> prax_query::FindManyOperation<'a, E, {}> {{\n",
311 model.name()
312 ));
313 code.push_str(&format!(
314 " prax_query::FindManyOperation::new(self.engine, \"{}\")\n",
315 table_name
316 ));
317 code.push_str(" }\n\n");
318
319 code.push_str(" /// Find a unique record\n");
320 code.push_str(&format!(
321 " pub fn find_unique(&self) -> prax_query::FindUniqueOperation<'a, E, {}> {{\n",
322 model.name()
323 ));
324 code.push_str(&format!(
325 " prax_query::FindUniqueOperation::new(self.engine, \"{}\")\n",
326 table_name
327 ));
328 code.push_str(" }\n\n");
329
330 code.push_str(" /// Find the first matching record\n");
331 code.push_str(&format!(
332 " pub fn find_first(&self) -> prax_query::FindFirstOperation<'a, E, {}> {{\n",
333 model.name()
334 ));
335 code.push_str(&format!(
336 " prax_query::FindFirstOperation::new(self.engine, \"{}\")\n",
337 table_name
338 ));
339 code.push_str(" }\n\n");
340
341 code.push_str(" /// Create a new record\n");
342 code.push_str(&format!(
343 " pub fn create(&self) -> prax_query::CreateOperation<'a, E, {}> {{\n",
344 model.name()
345 ));
346 code.push_str(&format!(
347 " prax_query::CreateOperation::new(self.engine, \"{}\")\n",
348 table_name
349 ));
350 code.push_str(" }\n\n");
351
352 code.push_str(" /// Update a record\n");
353 code.push_str(&format!(
354 " pub fn update(&self) -> prax_query::UpdateOperation<'a, E, {}> {{\n",
355 model.name()
356 ));
357 code.push_str(&format!(
358 " prax_query::UpdateOperation::new(self.engine, \"{}\")\n",
359 table_name
360 ));
361 code.push_str(" }\n\n");
362
363 code.push_str(" /// Delete a record\n");
364 code.push_str(&format!(
365 " pub fn delete(&self) -> prax_query::DeleteOperation<'a, E, {}> {{\n",
366 model.name()
367 ));
368 code.push_str(&format!(
369 " prax_query::DeleteOperation::new(self.engine, \"{}\")\n",
370 table_name
371 ));
372 code.push_str(" }\n\n");
373
374 code.push_str(" /// Count records\n");
375 code.push_str(" pub fn count(&self) -> prax_query::CountOperation<'a, E> {\n");
376 code.push_str(&format!(
377 " prax_query::CountOperation::new(self.engine, \"{}\")\n",
378 table_name
379 ));
380 code.push_str(" }\n");
381
382 code.push_str("}\n");
383
384 Ok(code)
385}
386
387fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
389 let mut code = String::new();
390
391 code.push_str(&format!(
392 "//! Auto-generated module for {} enum\n\n",
393 enum_def.name()
394 ));
395
396 code.push_str(
397 "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
398 );
399 code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
400
401 for variant in &enum_def.variants {
402 if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
404 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
405 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
406 }
407 }
408 code.push_str(&format!(" {},\n", variant.name()));
409 }
410
411 code.push_str("}\n\n");
412
413 if let Some(default_variant) = enum_def.variants.first() {
415 code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
416 code.push_str(&format!(
417 " fn default() -> Self {{\n Self::{}\n }}\n",
418 default_variant.name()
419 ));
420 code.push_str("}\n");
421 }
422
423 Ok(code)
424}
425
426fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
428 let mut code = String::new();
429
430 code.push_str("//! Common type definitions\n\n");
431 code.push_str("pub use chrono::{DateTime, Utc};\n");
432 code.push_str("pub use uuid::Uuid;\n");
433 code.push_str("pub use serde_json::Value as Json;\n");
434 code.push_str("\n");
435
436 for composite in schema.types.values() {
438 code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
439 code.push_str(&format!("pub struct {} {{\n", composite.name()));
440 for field in composite.fields.values() {
441 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
442 let field_name = to_snake_case(field.name());
443 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
444 }
445 code.push_str("}\n\n");
446 }
447
448 Ok(code)
449}
450
451fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
453 let mut code = String::new();
454
455 code.push_str("//! Filter types for queries\n\n");
456 code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n\n");
457
458 for model in schema.models.values() {
459 code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
461 code.push_str("#[derive(Debug, Default, Clone)]\n");
462 code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
463
464 for field in model.fields.values() {
465 if !field.is_relation() {
466 let filter_type = field_to_filter_type(&field.field_type);
467 let field_name = to_snake_case(field.name());
468 code.push_str(&format!(
469 " pub {}: Option<{}>,\n",
470 field_name, filter_type
471 ));
472 }
473 }
474
475 code.push_str(" pub and: Option<Vec<Self>>,\n");
476 code.push_str(" pub or: Option<Vec<Self>>,\n");
477 code.push_str(" pub not: Option<Box<Self>>,\n");
478 code.push_str("}\n\n");
479
480 code.push_str(&format!(
482 "/// Order by input for {} queries\n",
483 model.name()
484 ));
485 code.push_str("#[derive(Debug, Default, Clone)]\n");
486 code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
487
488 for field in model.fields.values() {
489 if !field.is_relation() {
490 let field_name = to_snake_case(field.name());
491 code.push_str(&format!(
492 " pub {}: Option<prax_query::SortOrder>,\n",
493 field_name
494 ));
495 }
496 }
497
498 code.push_str("}\n\n");
499 }
500
501 Ok(code)
502}
503
504fn field_type_to_rust(
506 field_type: &prax_schema::ast::FieldType,
507 modifier: prax_schema::ast::TypeModifier,
508) -> String {
509 use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
510
511 let base_type = match field_type {
512 FieldType::Scalar(scalar) => match scalar {
513 ScalarType::Int => "i32".to_string(),
514 ScalarType::BigInt => "i64".to_string(),
515 ScalarType::Float => "f64".to_string(),
516 ScalarType::String => "String".to_string(),
517 ScalarType::Boolean => "bool".to_string(),
518 ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
519 ScalarType::Date => "chrono::NaiveDate".to_string(),
520 ScalarType::Time => "chrono::NaiveTime".to_string(),
521 ScalarType::Json => "serde_json::Value".to_string(),
522 ScalarType::Bytes => "Vec<u8>".to_string(),
523 ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
524 ScalarType::Uuid => "uuid::Uuid".to_string(),
525 ScalarType::Cuid => "String".to_string(),
526 ScalarType::Cuid2 => "String".to_string(),
527 ScalarType::NanoId => "String".to_string(),
528 ScalarType::Ulid => "String".to_string(),
529 },
530 FieldType::Model(name) => name.to_string(),
531 FieldType::Enum(name) => name.to_string(),
532 FieldType::Composite(name) => name.to_string(),
533 FieldType::Unsupported(_) => "serde_json::Value".to_string(),
534 };
535
536 match modifier {
537 TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
538 TypeModifier::List => format!("Vec<{}>", base_type),
539 TypeModifier::Required => base_type,
540 }
541}
542
543fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
545 use prax_schema::ast::{FieldType, ScalarType};
546
547 match field_type {
548 FieldType::Scalar(scalar) => match scalar {
549 ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
550 ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
551 ScalarType::String
552 | ScalarType::Uuid
553 | ScalarType::Cuid
554 | ScalarType::Cuid2
555 | ScalarType::NanoId
556 | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
557 ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
558 ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
559 ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
560 ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
561 ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
562 ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
563 },
564 FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
565 _ => "Filter".to_string(),
566 }
567}
568
569fn to_snake_case(name: &str) -> String {
571 let mut result = String::new();
572 for (i, c) in name.chars().enumerate() {
573 if c.is_uppercase() {
574 if i > 0 {
575 result.push('_');
576 }
577 result.push(c.to_lowercase().next().unwrap());
578 } else {
579 result.push(c);
580 }
581 }
582 result
583}