Skip to main content

prax_cli/commands/
generate.rs

1//! `prax generate` command - Generate Rust client code from schema.
2
3use std::path::PathBuf;
4
5use crate::cli::GenerateArgs;
6use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
7use crate::error::{CliError, CliResult};
8use crate::output::{self, success};
9
10/// Run the generate command
11pub async fn run(args: GenerateArgs) -> CliResult<()> {
12    output::header("Generate Prax Client");
13
14    let cwd = std::env::current_dir()?;
15
16    // Load config
17    let config_path = cwd.join(CONFIG_FILE_NAME);
18    let config = if config_path.exists() {
19        Config::load(&config_path)?
20    } else {
21        Config::default()
22    };
23
24    // Resolve schema path
25    let schema_path = args
26        .schema
27        .clone()
28        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
29    if !schema_path.exists() {
30        return Err(
31            CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
32        );
33    }
34
35    // Resolve output directory
36    let output_dir = args
37        .output
38        .clone()
39        .unwrap_or_else(|| PathBuf::from(&config.generator.output));
40
41    output::kv("Schema", &schema_path.display().to_string());
42    output::kv("Output", &output_dir.display().to_string());
43    output::newline();
44
45    output::step(1, 4, "Reading schema...");
46
47    // Parse schema
48    let schema_content = std::fs::read_to_string(&schema_path)?;
49    let schema = parse_schema(&schema_content)?;
50
51    output::step(2, 4, "Validating schema...");
52
53    // Validate schema
54    validate_schema(&schema)?;
55
56    output::step(3, 4, "Generating code...");
57
58    // Create output directory
59    std::fs::create_dir_all(&output_dir)?;
60
61    // Generate code
62    let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
63
64    output::step(4, 4, "Writing files...");
65
66    // Print generated files
67    output::newline();
68    output::section("Generated files");
69
70    for file in &generated_files {
71        let relative_path = file
72            .strip_prefix(&cwd)
73            .unwrap_or(file)
74            .display()
75            .to_string();
76        output::list_item(&relative_path);
77    }
78
79    output::newline();
80    success(&format!(
81        "Generated {} files in {:.2}s",
82        generated_files.len(),
83        0.0 // TODO: Add timing
84    ));
85
86    Ok(())
87}
88
89/// Parse and validate the schema file
90fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
91    // Use validate_schema to ensure field types are properly resolved
92    // (e.g., FieldType::Model -> FieldType::Enum for enum references)
93    prax_schema::validate_schema(content)
94        .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
95}
96
97/// Validate the schema (now a no-op since parse_schema does validation)
98fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
99    // Validation is now done in parse_schema via validate_schema()
100    Ok(())
101}
102
103/// Generate code from the schema
104fn generate_code(
105    schema: &prax_schema::ast::Schema,
106    output_dir: &PathBuf,
107    args: &GenerateArgs,
108    config: &Config,
109) -> CliResult<Vec<PathBuf>> {
110    let mut generated_files = Vec::new();
111
112    // Determine which features to generate
113    let features = if !args.features.is_empty() {
114        args.features.clone()
115    } else {
116        config
117            .generator
118            .features
119            .clone()
120            .unwrap_or_else(|| vec!["client".to_string()])
121    };
122
123    // Generate main client module
124    let client_path = output_dir.join("mod.rs");
125    let client_code = generate_client_module(schema, &features)?;
126    std::fs::write(&client_path, client_code)?;
127    generated_files.push(client_path);
128
129    // Generate model modules
130    for model in schema.models.values() {
131        let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
132        let model_code = generate_model_module(model, &features)?;
133        std::fs::write(&model_path, model_code)?;
134        generated_files.push(model_path);
135    }
136
137    // Generate enum modules
138    for enum_def in schema.enums.values() {
139        let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
140        let enum_code = generate_enum_module(enum_def)?;
141        std::fs::write(&enum_path, enum_code)?;
142        generated_files.push(enum_path);
143    }
144
145    // Generate type definitions
146    let types_path = output_dir.join("types.rs");
147    let types_code = generate_types_module(schema)?;
148    std::fs::write(&types_path, types_code)?;
149    generated_files.push(types_path);
150
151    // Generate filters
152    let filters_path = output_dir.join("filters.rs");
153    let filters_code = generate_filters_module(schema)?;
154    std::fs::write(&filters_path, filters_code)?;
155    generated_files.push(filters_path);
156
157    Ok(generated_files)
158}
159
160/// Generate the main client module
161fn generate_client_module(
162    schema: &prax_schema::ast::Schema,
163    _features: &[String],
164) -> CliResult<String> {
165    let mut code = String::new();
166
167    code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
168    code.push_str("//!\n");
169    code.push_str("//! This module contains the generated Prax client.\n\n");
170
171    // Module declarations
172    code.push_str("pub mod types;\n");
173    code.push_str("pub mod filters;\n\n");
174
175    for model in schema.models.values() {
176        code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
177    }
178
179    for enum_def in schema.enums.values() {
180        code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
181    }
182
183    code.push_str("\n");
184
185    // Re-exports
186    code.push_str("pub use types::*;\n");
187    code.push_str("pub use filters::*;\n\n");
188
189    for model in schema.models.values() {
190        code.push_str(&format!(
191            "pub use {}::{};\n",
192            to_snake_case(model.name()),
193            model.name()
194        ));
195    }
196
197    for enum_def in schema.enums.values() {
198        code.push_str(&format!(
199            "pub use {}::{};\n",
200            to_snake_case(enum_def.name()),
201            enum_def.name()
202        ));
203    }
204
205    code.push_str("\n");
206
207    // Client struct
208    code.push_str("/// The Prax database client\n");
209    code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
210    code.push_str("    engine: E,\n");
211    code.push_str("}\n\n");
212
213    code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
214    code.push_str("    /// Create a new Prax client with the given query engine\n");
215    code.push_str("    pub fn new(engine: E) -> Self {\n");
216    code.push_str("        Self { engine }\n");
217    code.push_str("    }\n\n");
218
219    for model in schema.models.values() {
220        let snake_name = to_snake_case(model.name());
221        code.push_str(&format!("    /// Access {} operations\n", model.name()));
222        code.push_str(&format!(
223            "    pub fn {}(&self) -> {}::{}Operations<E> {{\n",
224            snake_name,
225            snake_name,
226            model.name()
227        ));
228        code.push_str(&format!(
229            "        {}::{}Operations::new(&self.engine)\n",
230            snake_name,
231            model.name()
232        ));
233        code.push_str("    }\n\n");
234    }
235
236    code.push_str("}\n");
237
238    Ok(code)
239}
240
241/// Generate a model module
242fn generate_model_module(
243    model: &prax_schema::ast::Model,
244    features: &[String],
245) -> CliResult<String> {
246    let mut code = String::new();
247
248    code.push_str(&format!(
249        "//! Auto-generated module for {} model\n\n",
250        model.name()
251    ));
252
253    // Derive macros based on features
254    let mut derives = vec!["Debug", "Clone"];
255    if features.contains(&"serde".to_string()) {
256        derives.push("serde::Serialize");
257        derives.push("serde::Deserialize");
258    }
259
260    // Model struct
261    code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
262    code.push_str(&format!("pub struct {} {{\n", model.name()));
263
264    for field in model.fields.values() {
265        let rust_type = field_type_to_rust(&field.field_type, field.modifier);
266        let field_name = to_snake_case(field.name());
267
268        // Add serde rename if mapped
269        if let Some(attr) = field.get_attribute("map") {
270            if features.contains(&"serde".to_string()) {
271                if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
272                    code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
273                }
274            }
275        }
276
277        code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
278    }
279
280    code.push_str("}\n\n");
281
282    // Operations struct
283    code.push_str(&format!("/// Operations for the {} model\n", model.name()));
284    code.push_str(&format!(
285        "pub struct {}Operations<'a, E: prax_query::QueryEngine> {{\n",
286        model.name()
287    ));
288    code.push_str("    engine: &'a E,\n");
289    code.push_str("}\n\n");
290
291    code.push_str(&format!(
292        "impl<'a, E: prax_query::QueryEngine> {}Operations<'a, E> {{\n",
293        model.name()
294    ));
295    code.push_str("    pub fn new(engine: &'a E) -> Self {\n");
296    code.push_str("        Self { engine }\n");
297    code.push_str("    }\n\n");
298
299    let table_name = model.table_name();
300
301    // CRUD methods
302    code.push_str("    /// Find many records\n");
303    code.push_str(&format!(
304        "    pub fn find_many(&self) -> prax_query::FindManyOperation<'a, E, {}> {{\n",
305        model.name()
306    ));
307    code.push_str(&format!(
308        "        prax_query::FindManyOperation::new(self.engine, \"{}\")\n",
309        table_name
310    ));
311    code.push_str("    }\n\n");
312
313    code.push_str("    /// Find a unique record\n");
314    code.push_str(&format!(
315        "    pub fn find_unique(&self) -> prax_query::FindUniqueOperation<'a, E, {}> {{\n",
316        model.name()
317    ));
318    code.push_str(&format!(
319        "        prax_query::FindUniqueOperation::new(self.engine, \"{}\")\n",
320        table_name
321    ));
322    code.push_str("    }\n\n");
323
324    code.push_str("    /// Find the first matching record\n");
325    code.push_str(&format!(
326        "    pub fn find_first(&self) -> prax_query::FindFirstOperation<'a, E, {}> {{\n",
327        model.name()
328    ));
329    code.push_str(&format!(
330        "        prax_query::FindFirstOperation::new(self.engine, \"{}\")\n",
331        table_name
332    ));
333    code.push_str("    }\n\n");
334
335    code.push_str("    /// Create a new record\n");
336    code.push_str(&format!(
337        "    pub fn create(&self) -> prax_query::CreateOperation<'a, E, {}> {{\n",
338        model.name()
339    ));
340    code.push_str(&format!(
341        "        prax_query::CreateOperation::new(self.engine, \"{}\")\n",
342        table_name
343    ));
344    code.push_str("    }\n\n");
345
346    code.push_str("    /// Update a record\n");
347    code.push_str(&format!(
348        "    pub fn update(&self) -> prax_query::UpdateOperation<'a, E, {}> {{\n",
349        model.name()
350    ));
351    code.push_str(&format!(
352        "        prax_query::UpdateOperation::new(self.engine, \"{}\")\n",
353        table_name
354    ));
355    code.push_str("    }\n\n");
356
357    code.push_str("    /// Delete a record\n");
358    code.push_str(&format!(
359        "    pub fn delete(&self) -> prax_query::DeleteOperation<'a, E, {}> {{\n",
360        model.name()
361    ));
362    code.push_str(&format!(
363        "        prax_query::DeleteOperation::new(self.engine, \"{}\")\n",
364        table_name
365    ));
366    code.push_str("    }\n\n");
367
368    code.push_str("    /// Count records\n");
369    code.push_str("    pub fn count(&self) -> prax_query::CountOperation<'a, E> {\n");
370    code.push_str(&format!(
371        "        prax_query::CountOperation::new(self.engine, \"{}\")\n",
372        table_name
373    ));
374    code.push_str("    }\n");
375
376    code.push_str("}\n");
377
378    Ok(code)
379}
380
381/// Generate an enum module
382fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
383    let mut code = String::new();
384
385    code.push_str(&format!(
386        "//! Auto-generated module for {} enum\n\n",
387        enum_def.name()
388    ));
389
390    code.push_str(
391        "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
392    );
393    code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
394
395    for variant in &enum_def.variants {
396        // Check for @map attribute
397        if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
398            if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
399                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
400            }
401        }
402        code.push_str(&format!("    {},\n", variant.name()));
403    }
404
405    code.push_str("}\n\n");
406
407    // Default implementation
408    if let Some(default_variant) = enum_def.variants.first() {
409        code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
410        code.push_str(&format!(
411            "    fn default() -> Self {{\n        Self::{}\n    }}\n",
412            default_variant.name()
413        ));
414        code.push_str("}\n");
415    }
416
417    Ok(code)
418}
419
420/// Generate types module
421fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
422    let mut code = String::new();
423
424    code.push_str("//! Common type definitions\n\n");
425    code.push_str("pub use chrono::{DateTime, Utc};\n");
426    code.push_str("pub use uuid::Uuid;\n");
427    code.push_str("pub use serde_json::Value as Json;\n");
428    code.push_str("\n");
429
430    // Add any custom types from composite types
431    for composite in schema.types.values() {
432        code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
433        code.push_str(&format!("pub struct {} {{\n", composite.name()));
434        for field in composite.fields.values() {
435            let rust_type = field_type_to_rust(&field.field_type, field.modifier);
436            let field_name = to_snake_case(field.name());
437            code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
438        }
439        code.push_str("}\n\n");
440    }
441
442    Ok(code)
443}
444
445/// Generate filters module
446fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
447    let mut code = String::new();
448
449    code.push_str("//! Filter types for queries\n\n");
450    code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n\n");
451
452    for model in schema.models.values() {
453        // Where input
454        code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
455        code.push_str("#[derive(Debug, Default, Clone)]\n");
456        code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
457
458        for field in model.fields.values() {
459            if !field.is_relation() {
460                let filter_type = field_to_filter_type(&field.field_type);
461                let field_name = to_snake_case(field.name());
462                code.push_str(&format!(
463                    "    pub {}: Option<{}>,\n",
464                    field_name, filter_type
465                ));
466            }
467        }
468
469        code.push_str("    pub and: Option<Vec<Self>>,\n");
470        code.push_str("    pub or: Option<Vec<Self>>,\n");
471        code.push_str("    pub not: Option<Box<Self>>,\n");
472        code.push_str("}\n\n");
473
474        // OrderBy input
475        code.push_str(&format!(
476            "/// Order by input for {} queries\n",
477            model.name()
478        ));
479        code.push_str("#[derive(Debug, Default, Clone)]\n");
480        code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
481
482        for field in model.fields.values() {
483            if !field.is_relation() {
484                let field_name = to_snake_case(field.name());
485                code.push_str(&format!(
486                    "    pub {}: Option<prax_query::SortOrder>,\n",
487                    field_name
488                ));
489            }
490        }
491
492        code.push_str("}\n\n");
493    }
494
495    Ok(code)
496}
497
498/// Convert a field type to Rust type
499fn field_type_to_rust(
500    field_type: &prax_schema::ast::FieldType,
501    modifier: prax_schema::ast::TypeModifier,
502) -> String {
503    use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
504
505    let base_type = match field_type {
506        FieldType::Scalar(scalar) => match scalar {
507            ScalarType::Int => "i32".to_string(),
508            ScalarType::BigInt => "i64".to_string(),
509            ScalarType::Float => "f64".to_string(),
510            ScalarType::String => "String".to_string(),
511            ScalarType::Boolean => "bool".to_string(),
512            ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
513            ScalarType::Date => "chrono::NaiveDate".to_string(),
514            ScalarType::Time => "chrono::NaiveTime".to_string(),
515            ScalarType::Json => "serde_json::Value".to_string(),
516            ScalarType::Bytes => "Vec<u8>".to_string(),
517            ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
518            ScalarType::Uuid => "uuid::Uuid".to_string(),
519            ScalarType::Cuid => "String".to_string(),
520            ScalarType::Cuid2 => "String".to_string(),
521            ScalarType::NanoId => "String".to_string(),
522            ScalarType::Ulid => "String".to_string(),
523            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
524            ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
525            ScalarType::Bit(_) => "Vec<u8>".to_string(),
526        },
527        FieldType::Model(name) => name.to_string(),
528        FieldType::Enum(name) => name.to_string(),
529        FieldType::Composite(name) => name.to_string(),
530        FieldType::Unsupported(_) => "serde_json::Value".to_string(),
531    };
532
533    match modifier {
534        TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
535        TypeModifier::List => format!("Vec<{}>", base_type),
536        TypeModifier::Required => base_type,
537    }
538}
539
540/// Convert a field type to filter type
541fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
542    use prax_schema::ast::{FieldType, ScalarType};
543
544    match field_type {
545        FieldType::Scalar(scalar) => match scalar {
546            ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
547            ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
548            ScalarType::String
549            | ScalarType::Uuid
550            | ScalarType::Cuid
551            | ScalarType::Cuid2
552            | ScalarType::NanoId
553            | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
554            ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
555            ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
556            ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
557            ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
558            ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
559            ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
560            // Vector types don't have standard scalar filters
561            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
562            ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
563            ScalarType::Bit(_) => "BitFilter".to_string(),
564        },
565        FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
566        _ => "Filter".to_string(),
567    }
568}
569
570/// Convert PascalCase to snake_case
571fn to_snake_case(name: &str) -> String {
572    let mut result = String::new();
573    for (i, c) in name.chars().enumerate() {
574        if c.is_uppercase() {
575            if i > 0 {
576                result.push('_');
577            }
578            result.push(c.to_lowercase().next().unwrap());
579        } else {
580            result.push(c);
581        }
582    }
583    result
584}