1use std::path::PathBuf;
4
5use crate::cli::GenerateArgs;
6use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
7use crate::error::{CliError, CliResult};
8use crate::output::{self, success};
9
10pub async fn run(args: GenerateArgs) -> CliResult<()> {
12 output::header("Generate Prax Client");
13
14 let cwd = std::env::current_dir()?;
15
16 let config_path = cwd.join(CONFIG_FILE_NAME);
18 let config = if config_path.exists() {
19 Config::load(&config_path)?
20 } else {
21 Config::default()
22 };
23
24 let schema_path = args
26 .schema
27 .clone()
28 .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
29 if !schema_path.exists() {
30 return Err(
31 CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
32 );
33 }
34
35 let output_dir = args
37 .output
38 .clone()
39 .unwrap_or_else(|| PathBuf::from(&config.generator.output));
40
41 output::kv("Schema", &schema_path.display().to_string());
42 output::kv("Output", &output_dir.display().to_string());
43 output::newline();
44
45 output::step(1, 4, "Reading schema...");
46
47 let schema_content = std::fs::read_to_string(&schema_path)?;
49 let schema = parse_schema(&schema_content)?;
50
51 output::step(2, 4, "Validating schema...");
52
53 validate_schema(&schema)?;
55
56 output::step(3, 4, "Generating code...");
57
58 std::fs::create_dir_all(&output_dir)?;
60
61 let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
63
64 output::step(4, 4, "Writing files...");
65
66 output::newline();
68 output::section("Generated files");
69
70 for file in &generated_files {
71 let relative_path = file
72 .strip_prefix(&cwd)
73 .unwrap_or(file)
74 .display()
75 .to_string();
76 output::list_item(&relative_path);
77 }
78
79 output::newline();
80 success(&format!(
81 "Generated {} files in {:.2}s",
82 generated_files.len(),
83 0.0 ));
85
86 Ok(())
87}
88
89fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
91 prax_schema::validate_schema(content)
94 .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
95}
96
97fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
99 Ok(())
101}
102
103fn generate_code(
105 schema: &prax_schema::ast::Schema,
106 output_dir: &PathBuf,
107 args: &GenerateArgs,
108 config: &Config,
109) -> CliResult<Vec<PathBuf>> {
110 let mut generated_files = Vec::new();
111
112 let features = if !args.features.is_empty() {
114 args.features.clone()
115 } else {
116 config
117 .generator
118 .features
119 .clone()
120 .unwrap_or_else(|| vec!["client".to_string()])
121 };
122
123 let client_path = output_dir.join("mod.rs");
125 let client_code = generate_client_module(schema, &features)?;
126 std::fs::write(&client_path, client_code)?;
127 generated_files.push(client_path);
128
129 for model in schema.models.values() {
131 let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
132 let model_code = generate_model_module(model, &features)?;
133 std::fs::write(&model_path, model_code)?;
134 generated_files.push(model_path);
135 }
136
137 for enum_def in schema.enums.values() {
139 let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
140 let enum_code = generate_enum_module(enum_def)?;
141 std::fs::write(&enum_path, enum_code)?;
142 generated_files.push(enum_path);
143 }
144
145 let types_path = output_dir.join("types.rs");
147 let types_code = generate_types_module(schema)?;
148 std::fs::write(&types_path, types_code)?;
149 generated_files.push(types_path);
150
151 let filters_path = output_dir.join("filters.rs");
153 let filters_code = generate_filters_module(schema)?;
154 std::fs::write(&filters_path, filters_code)?;
155 generated_files.push(filters_path);
156
157 Ok(generated_files)
158}
159
160fn generate_client_module(
162 schema: &prax_schema::ast::Schema,
163 _features: &[String],
164) -> CliResult<String> {
165 let mut code = String::new();
166
167 code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
168 code.push_str("//!\n");
169 code.push_str("//! This module contains the generated Prax client.\n\n");
170
171 code.push_str("pub mod types;\n");
173 code.push_str("pub mod filters;\n\n");
174
175 for model in schema.models.values() {
176 code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
177 }
178
179 for enum_def in schema.enums.values() {
180 code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
181 }
182
183 code.push_str("\n");
184
185 code.push_str("pub use types::*;\n");
187 code.push_str("pub use filters::*;\n\n");
188
189 for model in schema.models.values() {
190 code.push_str(&format!(
191 "pub use {}::{};\n",
192 to_snake_case(model.name()),
193 model.name()
194 ));
195 }
196
197 for enum_def in schema.enums.values() {
198 code.push_str(&format!(
199 "pub use {}::{};\n",
200 to_snake_case(enum_def.name()),
201 enum_def.name()
202 ));
203 }
204
205 code.push_str("\n");
206
207 code.push_str("/// The Prax database client\n");
209 code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
210 code.push_str(" engine: E,\n");
211 code.push_str("}\n\n");
212
213 code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
214 code.push_str(" /// Create a new Prax client with the given query engine\n");
215 code.push_str(" pub fn new(engine: E) -> Self {\n");
216 code.push_str(" Self { engine }\n");
217 code.push_str(" }\n\n");
218
219 for model in schema.models.values() {
220 let snake_name = to_snake_case(model.name());
221 code.push_str(&format!(" /// Access {} operations\n", model.name()));
222 code.push_str(&format!(
223 " pub fn {}(&self) -> {}::{}Operations<E> {{\n",
224 snake_name,
225 snake_name,
226 model.name()
227 ));
228 code.push_str(&format!(
229 " {}::{}Operations::new(&self.engine)\n",
230 snake_name,
231 model.name()
232 ));
233 code.push_str(" }\n\n");
234 }
235
236 code.push_str("}\n");
237
238 Ok(code)
239}
240
241fn generate_model_module(
243 model: &prax_schema::ast::Model,
244 features: &[String],
245) -> CliResult<String> {
246 let mut code = String::new();
247
248 code.push_str(&format!(
249 "//! Auto-generated module for {} model\n\n",
250 model.name()
251 ));
252
253 let mut derives = vec!["Debug", "Clone"];
255 if features.contains(&"serde".to_string()) {
256 derives.push("serde::Serialize");
257 derives.push("serde::Deserialize");
258 }
259
260 code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
262 code.push_str(&format!("pub struct {} {{\n", model.name()));
263
264 for field in model.fields.values() {
265 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
266 let field_name = to_snake_case(field.name());
267
268 if let Some(attr) = field.get_attribute("map") {
270 if features.contains(&"serde".to_string()) {
271 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
272 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
273 }
274 }
275 }
276
277 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
278 }
279
280 code.push_str("}\n\n");
281
282 code.push_str(&format!("/// Operations for the {} model\n", model.name()));
284 code.push_str(&format!(
285 "pub struct {}Operations<'a, E: prax_query::QueryEngine> {{\n",
286 model.name()
287 ));
288 code.push_str(" engine: &'a E,\n");
289 code.push_str("}\n\n");
290
291 code.push_str(&format!(
292 "impl<'a, E: prax_query::QueryEngine> {}Operations<'a, E> {{\n",
293 model.name()
294 ));
295 code.push_str(" pub fn new(engine: &'a E) -> Self {\n");
296 code.push_str(" Self { engine }\n");
297 code.push_str(" }\n\n");
298
299 let table_name = model.table_name();
300
301 code.push_str(" /// Find many records\n");
303 code.push_str(&format!(
304 " pub fn find_many(&self) -> prax_query::FindManyOperation<'a, E, {}> {{\n",
305 model.name()
306 ));
307 code.push_str(&format!(
308 " prax_query::FindManyOperation::new(self.engine, \"{}\")\n",
309 table_name
310 ));
311 code.push_str(" }\n\n");
312
313 code.push_str(" /// Find a unique record\n");
314 code.push_str(&format!(
315 " pub fn find_unique(&self) -> prax_query::FindUniqueOperation<'a, E, {}> {{\n",
316 model.name()
317 ));
318 code.push_str(&format!(
319 " prax_query::FindUniqueOperation::new(self.engine, \"{}\")\n",
320 table_name
321 ));
322 code.push_str(" }\n\n");
323
324 code.push_str(" /// Find the first matching record\n");
325 code.push_str(&format!(
326 " pub fn find_first(&self) -> prax_query::FindFirstOperation<'a, E, {}> {{\n",
327 model.name()
328 ));
329 code.push_str(&format!(
330 " prax_query::FindFirstOperation::new(self.engine, \"{}\")\n",
331 table_name
332 ));
333 code.push_str(" }\n\n");
334
335 code.push_str(" /// Create a new record\n");
336 code.push_str(&format!(
337 " pub fn create(&self) -> prax_query::CreateOperation<'a, E, {}> {{\n",
338 model.name()
339 ));
340 code.push_str(&format!(
341 " prax_query::CreateOperation::new(self.engine, \"{}\")\n",
342 table_name
343 ));
344 code.push_str(" }\n\n");
345
346 code.push_str(" /// Update a record\n");
347 code.push_str(&format!(
348 " pub fn update(&self) -> prax_query::UpdateOperation<'a, E, {}> {{\n",
349 model.name()
350 ));
351 code.push_str(&format!(
352 " prax_query::UpdateOperation::new(self.engine, \"{}\")\n",
353 table_name
354 ));
355 code.push_str(" }\n\n");
356
357 code.push_str(" /// Delete a record\n");
358 code.push_str(&format!(
359 " pub fn delete(&self) -> prax_query::DeleteOperation<'a, E, {}> {{\n",
360 model.name()
361 ));
362 code.push_str(&format!(
363 " prax_query::DeleteOperation::new(self.engine, \"{}\")\n",
364 table_name
365 ));
366 code.push_str(" }\n\n");
367
368 code.push_str(" /// Count records\n");
369 code.push_str(" pub fn count(&self) -> prax_query::CountOperation<'a, E> {\n");
370 code.push_str(&format!(
371 " prax_query::CountOperation::new(self.engine, \"{}\")\n",
372 table_name
373 ));
374 code.push_str(" }\n");
375
376 code.push_str("}\n");
377
378 Ok(code)
379}
380
381fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
383 let mut code = String::new();
384
385 code.push_str(&format!(
386 "//! Auto-generated module for {} enum\n\n",
387 enum_def.name()
388 ));
389
390 code.push_str(
391 "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
392 );
393 code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
394
395 for variant in &enum_def.variants {
396 if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
398 if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
399 code.push_str(&format!(" #[serde(rename = \"{}\")]\n", value));
400 }
401 }
402 code.push_str(&format!(" {},\n", variant.name()));
403 }
404
405 code.push_str("}\n\n");
406
407 if let Some(default_variant) = enum_def.variants.first() {
409 code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
410 code.push_str(&format!(
411 " fn default() -> Self {{\n Self::{}\n }}\n",
412 default_variant.name()
413 ));
414 code.push_str("}\n");
415 }
416
417 Ok(code)
418}
419
420fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
422 let mut code = String::new();
423
424 code.push_str("//! Common type definitions\n\n");
425 code.push_str("pub use chrono::{DateTime, Utc};\n");
426 code.push_str("pub use uuid::Uuid;\n");
427 code.push_str("pub use serde_json::Value as Json;\n");
428 code.push_str("\n");
429
430 for composite in schema.types.values() {
432 code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
433 code.push_str(&format!("pub struct {} {{\n", composite.name()));
434 for field in composite.fields.values() {
435 let rust_type = field_type_to_rust(&field.field_type, field.modifier);
436 let field_name = to_snake_case(field.name());
437 code.push_str(&format!(" pub {}: {},\n", field_name, rust_type));
438 }
439 code.push_str("}\n\n");
440 }
441
442 Ok(code)
443}
444
445fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
447 let mut code = String::new();
448
449 code.push_str("//! Filter types for queries\n\n");
450 code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n\n");
451
452 for model in schema.models.values() {
453 code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
455 code.push_str("#[derive(Debug, Default, Clone)]\n");
456 code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
457
458 for field in model.fields.values() {
459 if !field.is_relation() {
460 let filter_type = field_to_filter_type(&field.field_type);
461 let field_name = to_snake_case(field.name());
462 code.push_str(&format!(
463 " pub {}: Option<{}>,\n",
464 field_name, filter_type
465 ));
466 }
467 }
468
469 code.push_str(" pub and: Option<Vec<Self>>,\n");
470 code.push_str(" pub or: Option<Vec<Self>>,\n");
471 code.push_str(" pub not: Option<Box<Self>>,\n");
472 code.push_str("}\n\n");
473
474 code.push_str(&format!(
476 "/// Order by input for {} queries\n",
477 model.name()
478 ));
479 code.push_str("#[derive(Debug, Default, Clone)]\n");
480 code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
481
482 for field in model.fields.values() {
483 if !field.is_relation() {
484 let field_name = to_snake_case(field.name());
485 code.push_str(&format!(
486 " pub {}: Option<prax_query::SortOrder>,\n",
487 field_name
488 ));
489 }
490 }
491
492 code.push_str("}\n\n");
493 }
494
495 Ok(code)
496}
497
498fn field_type_to_rust(
500 field_type: &prax_schema::ast::FieldType,
501 modifier: prax_schema::ast::TypeModifier,
502) -> String {
503 use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
504
505 let base_type = match field_type {
506 FieldType::Scalar(scalar) => match scalar {
507 ScalarType::Int => "i32".to_string(),
508 ScalarType::BigInt => "i64".to_string(),
509 ScalarType::Float => "f64".to_string(),
510 ScalarType::String => "String".to_string(),
511 ScalarType::Boolean => "bool".to_string(),
512 ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
513 ScalarType::Date => "chrono::NaiveDate".to_string(),
514 ScalarType::Time => "chrono::NaiveTime".to_string(),
515 ScalarType::Json => "serde_json::Value".to_string(),
516 ScalarType::Bytes => "Vec<u8>".to_string(),
517 ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
518 ScalarType::Uuid => "uuid::Uuid".to_string(),
519 ScalarType::Cuid => "String".to_string(),
520 ScalarType::Cuid2 => "String".to_string(),
521 ScalarType::NanoId => "String".to_string(),
522 ScalarType::Ulid => "String".to_string(),
523 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
524 ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
525 ScalarType::Bit(_) => "Vec<u8>".to_string(),
526 },
527 FieldType::Model(name) => name.to_string(),
528 FieldType::Enum(name) => name.to_string(),
529 FieldType::Composite(name) => name.to_string(),
530 FieldType::Unsupported(_) => "serde_json::Value".to_string(),
531 };
532
533 match modifier {
534 TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
535 TypeModifier::List => format!("Vec<{}>", base_type),
536 TypeModifier::Required => base_type,
537 }
538}
539
540fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
542 use prax_schema::ast::{FieldType, ScalarType};
543
544 match field_type {
545 FieldType::Scalar(scalar) => match scalar {
546 ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
547 ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
548 ScalarType::String
549 | ScalarType::Uuid
550 | ScalarType::Cuid
551 | ScalarType::Cuid2
552 | ScalarType::NanoId
553 | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
554 ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
555 ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
556 ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
557 ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
558 ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
559 ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
560 ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
562 ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
563 ScalarType::Bit(_) => "BitFilter".to_string(),
564 },
565 FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
566 _ => "Filter".to_string(),
567 }
568}
569
570fn to_snake_case(name: &str) -> String {
572 let mut result = String::new();
573 for (i, c) in name.chars().enumerate() {
574 if c.is_uppercase() {
575 if i > 0 {
576 result.push('_');
577 }
578 result.push(c.to_lowercase().next().unwrap());
579 } else {
580 result.push(c);
581 }
582 }
583 result
584}