prax_cli/commands/
generate.rs

1//! `prax generate` command - Generate Rust client code from schema.
2
3use std::path::PathBuf;
4
5use crate::cli::GenerateArgs;
6use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_NAME};
7use crate::error::{CliError, CliResult};
8use crate::output::{self, success};
9
10/// Run the generate command
11pub async fn run(args: GenerateArgs) -> CliResult<()> {
12    output::header("Generate Prax Client");
13
14    let cwd = std::env::current_dir()?;
15
16    // Load config
17    let config_path = cwd.join(CONFIG_FILE_NAME);
18    let config = if config_path.exists() {
19        Config::load(&config_path)?
20    } else {
21        Config::default()
22    };
23
24    // Resolve schema path
25    let schema_path = args
26        .schema
27        .clone()
28        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_NAME));
29    if !schema_path.exists() {
30        return Err(
31            CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
32        );
33    }
34
35    // Resolve output directory
36    let output_dir = args
37        .output
38        .clone()
39        .unwrap_or_else(|| PathBuf::from(&config.generator.output));
40
41    output::kv("Schema", &schema_path.display().to_string());
42    output::kv("Output", &output_dir.display().to_string());
43    output::newline();
44
45    output::step(1, 4, "Reading schema...");
46
47    // Parse schema
48    let schema_content = std::fs::read_to_string(&schema_path)?;
49    let schema = parse_schema(&schema_content)?;
50
51    output::step(2, 4, "Validating schema...");
52
53    // Validate schema
54    validate_schema(&schema)?;
55
56    output::step(3, 4, "Generating code...");
57
58    // Create output directory
59    std::fs::create_dir_all(&output_dir)?;
60
61    // Generate code
62    let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
63
64    output::step(4, 4, "Writing files...");
65
66    // Print generated files
67    output::newline();
68    output::section("Generated files");
69
70    for file in &generated_files {
71        let relative_path = file
72            .strip_prefix(&cwd)
73            .unwrap_or(file)
74            .display()
75            .to_string();
76        output::list_item(&relative_path);
77    }
78
79    output::newline();
80    success(&format!(
81        "Generated {} files in {:.2}s",
82        generated_files.len(),
83        0.0 // TODO: Add timing
84    ));
85
86    Ok(())
87}
88
89/// Parse the schema file
90fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
91    prax_schema::parse_schema(content)
92        .map_err(|e| CliError::Schema(format!("Failed to parse schema: {}", e)))
93}
94
95/// Validate the schema
96fn validate_schema(schema: &prax_schema::Schema) -> CliResult<()> {
97    // Run schema validation using the Validator
98    let mut validator = prax_schema::Validator::new();
99    // Re-validate by cloning (validator takes ownership)
100    if let Err(e) = validator.validate(schema.clone()) {
101        return Err(CliError::Validation(format!(
102            "Schema validation failed: {}",
103            e
104        )));
105    }
106    Ok(())
107}
108
109/// Generate code from the schema
110fn generate_code(
111    schema: &prax_schema::ast::Schema,
112    output_dir: &PathBuf,
113    args: &GenerateArgs,
114    config: &Config,
115) -> CliResult<Vec<PathBuf>> {
116    let mut generated_files = Vec::new();
117
118    // Determine which features to generate
119    let features = if !args.features.is_empty() {
120        args.features.clone()
121    } else {
122        config
123            .generator
124            .features
125            .clone()
126            .unwrap_or_else(|| vec!["client".to_string()])
127    };
128
129    // Generate main client module
130    let client_path = output_dir.join("mod.rs");
131    let client_code = generate_client_module(schema, &features)?;
132    std::fs::write(&client_path, client_code)?;
133    generated_files.push(client_path);
134
135    // Generate model modules
136    for model in schema.models.values() {
137        let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
138        let model_code = generate_model_module(model, &features)?;
139        std::fs::write(&model_path, model_code)?;
140        generated_files.push(model_path);
141    }
142
143    // Generate enum modules
144    for enum_def in schema.enums.values() {
145        let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
146        let enum_code = generate_enum_module(enum_def)?;
147        std::fs::write(&enum_path, enum_code)?;
148        generated_files.push(enum_path);
149    }
150
151    // Generate type definitions
152    let types_path = output_dir.join("types.rs");
153    let types_code = generate_types_module(schema)?;
154    std::fs::write(&types_path, types_code)?;
155    generated_files.push(types_path);
156
157    // Generate filters
158    let filters_path = output_dir.join("filters.rs");
159    let filters_code = generate_filters_module(schema)?;
160    std::fs::write(&filters_path, filters_code)?;
161    generated_files.push(filters_path);
162
163    Ok(generated_files)
164}
165
166/// Generate the main client module
167fn generate_client_module(
168    schema: &prax_schema::ast::Schema,
169    _features: &[String],
170) -> CliResult<String> {
171    let mut code = String::new();
172
173    code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
174    code.push_str("//!\n");
175    code.push_str("//! This module contains the generated Prax client.\n\n");
176
177    // Module declarations
178    code.push_str("pub mod types;\n");
179    code.push_str("pub mod filters;\n\n");
180
181    for model in schema.models.values() {
182        code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
183    }
184
185    for enum_def in schema.enums.values() {
186        code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
187    }
188
189    code.push_str("\n");
190
191    // Re-exports
192    code.push_str("pub use types::*;\n");
193    code.push_str("pub use filters::*;\n\n");
194
195    for model in schema.models.values() {
196        code.push_str(&format!(
197            "pub use {}::{};\n",
198            to_snake_case(model.name()),
199            model.name()
200        ));
201    }
202
203    for enum_def in schema.enums.values() {
204        code.push_str(&format!(
205            "pub use {}::{};\n",
206            to_snake_case(enum_def.name()),
207            enum_def.name()
208        ));
209    }
210
211    code.push_str("\n");
212
213    // Client struct
214    code.push_str("/// The Prax database client\n");
215    code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
216    code.push_str("    engine: E,\n");
217    code.push_str("}\n\n");
218
219    code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
220    code.push_str("    /// Create a new Prax client with the given query engine\n");
221    code.push_str("    pub fn new(engine: E) -> Self {\n");
222    code.push_str("        Self { engine }\n");
223    code.push_str("    }\n\n");
224
225    for model in schema.models.values() {
226        let snake_name = to_snake_case(model.name());
227        code.push_str(&format!("    /// Access {} operations\n", model.name()));
228        code.push_str(&format!(
229            "    pub fn {}(&self) -> {}::{}Operations<E> {{\n",
230            snake_name,
231            snake_name,
232            model.name()
233        ));
234        code.push_str(&format!(
235            "        {}::{}Operations::new(&self.engine)\n",
236            snake_name,
237            model.name()
238        ));
239        code.push_str("    }\n\n");
240    }
241
242    code.push_str("}\n");
243
244    Ok(code)
245}
246
247/// Generate a model module
248fn generate_model_module(
249    model: &prax_schema::ast::Model,
250    features: &[String],
251) -> CliResult<String> {
252    let mut code = String::new();
253
254    code.push_str(&format!(
255        "//! Auto-generated module for {} model\n\n",
256        model.name()
257    ));
258
259    // Derive macros based on features
260    let mut derives = vec!["Debug", "Clone"];
261    if features.contains(&"serde".to_string()) {
262        derives.push("serde::Serialize");
263        derives.push("serde::Deserialize");
264    }
265
266    // Model struct
267    code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
268    code.push_str(&format!("pub struct {} {{\n", model.name()));
269
270    for field in model.fields.values() {
271        let rust_type = field_type_to_rust(&field.field_type, field.modifier);
272        let field_name = to_snake_case(field.name());
273
274        // Add serde rename if mapped
275        if let Some(attr) = field.get_attribute("map") {
276            if features.contains(&"serde".to_string()) {
277                if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
278                    code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
279                }
280            }
281        }
282
283        code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
284    }
285
286    code.push_str("}\n\n");
287
288    // Operations struct
289    code.push_str(&format!("/// Operations for the {} model\n", model.name()));
290    code.push_str(&format!(
291        "pub struct {}Operations<'a, E: prax_query::QueryEngine> {{\n",
292        model.name()
293    ));
294    code.push_str("    engine: &'a E,\n");
295    code.push_str("}\n\n");
296
297    code.push_str(&format!(
298        "impl<'a, E: prax_query::QueryEngine> {}Operations<'a, E> {{\n",
299        model.name()
300    ));
301    code.push_str("    pub fn new(engine: &'a E) -> Self {\n");
302    code.push_str("        Self { engine }\n");
303    code.push_str("    }\n\n");
304
305    let table_name = model.table_name();
306
307    // CRUD methods
308    code.push_str("    /// Find many records\n");
309    code.push_str(&format!(
310        "    pub fn find_many(&self) -> prax_query::FindManyOperation<'a, E, {}> {{\n",
311        model.name()
312    ));
313    code.push_str(&format!(
314        "        prax_query::FindManyOperation::new(self.engine, \"{}\")\n",
315        table_name
316    ));
317    code.push_str("    }\n\n");
318
319    code.push_str("    /// Find a unique record\n");
320    code.push_str(&format!(
321        "    pub fn find_unique(&self) -> prax_query::FindUniqueOperation<'a, E, {}> {{\n",
322        model.name()
323    ));
324    code.push_str(&format!(
325        "        prax_query::FindUniqueOperation::new(self.engine, \"{}\")\n",
326        table_name
327    ));
328    code.push_str("    }\n\n");
329
330    code.push_str("    /// Find the first matching record\n");
331    code.push_str(&format!(
332        "    pub fn find_first(&self) -> prax_query::FindFirstOperation<'a, E, {}> {{\n",
333        model.name()
334    ));
335    code.push_str(&format!(
336        "        prax_query::FindFirstOperation::new(self.engine, \"{}\")\n",
337        table_name
338    ));
339    code.push_str("    }\n\n");
340
341    code.push_str("    /// Create a new record\n");
342    code.push_str(&format!(
343        "    pub fn create(&self) -> prax_query::CreateOperation<'a, E, {}> {{\n",
344        model.name()
345    ));
346    code.push_str(&format!(
347        "        prax_query::CreateOperation::new(self.engine, \"{}\")\n",
348        table_name
349    ));
350    code.push_str("    }\n\n");
351
352    code.push_str("    /// Update a record\n");
353    code.push_str(&format!(
354        "    pub fn update(&self) -> prax_query::UpdateOperation<'a, E, {}> {{\n",
355        model.name()
356    ));
357    code.push_str(&format!(
358        "        prax_query::UpdateOperation::new(self.engine, \"{}\")\n",
359        table_name
360    ));
361    code.push_str("    }\n\n");
362
363    code.push_str("    /// Delete a record\n");
364    code.push_str(&format!(
365        "    pub fn delete(&self) -> prax_query::DeleteOperation<'a, E, {}> {{\n",
366        model.name()
367    ));
368    code.push_str(&format!(
369        "        prax_query::DeleteOperation::new(self.engine, \"{}\")\n",
370        table_name
371    ));
372    code.push_str("    }\n\n");
373
374    code.push_str("    /// Count records\n");
375    code.push_str("    pub fn count(&self) -> prax_query::CountOperation<'a, E> {\n");
376    code.push_str(&format!(
377        "        prax_query::CountOperation::new(self.engine, \"{}\")\n",
378        table_name
379    ));
380    code.push_str("    }\n");
381
382    code.push_str("}\n");
383
384    Ok(code)
385}
386
387/// Generate an enum module
388fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
389    let mut code = String::new();
390
391    code.push_str(&format!(
392        "//! Auto-generated module for {} enum\n\n",
393        enum_def.name()
394    ));
395
396    code.push_str(
397        "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
398    );
399    code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
400
401    for variant in &enum_def.variants {
402        // Check for @map attribute
403        if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
404            if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
405                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
406            }
407        }
408        code.push_str(&format!("    {},\n", variant.name()));
409    }
410
411    code.push_str("}\n\n");
412
413    // Default implementation
414    if let Some(default_variant) = enum_def.variants.first() {
415        code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
416        code.push_str(&format!(
417            "    fn default() -> Self {{\n        Self::{}\n    }}\n",
418            default_variant.name()
419        ));
420        code.push_str("}\n");
421    }
422
423    Ok(code)
424}
425
426/// Generate types module
427fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
428    let mut code = String::new();
429
430    code.push_str("//! Common type definitions\n\n");
431    code.push_str("pub use chrono::{DateTime, Utc};\n");
432    code.push_str("pub use uuid::Uuid;\n");
433    code.push_str("pub use serde_json::Value as Json;\n");
434    code.push_str("\n");
435
436    // Add any custom types from composite types
437    for composite in schema.types.values() {
438        code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
439        code.push_str(&format!("pub struct {} {{\n", composite.name()));
440        for field in composite.fields.values() {
441            let rust_type = field_type_to_rust(&field.field_type, field.modifier);
442            let field_name = to_snake_case(field.name());
443            code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
444        }
445        code.push_str("}\n\n");
446    }
447
448    Ok(code)
449}
450
451/// Generate filters module
452fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
453    let mut code = String::new();
454
455    code.push_str("//! Filter types for queries\n\n");
456    code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n\n");
457
458    for model in schema.models.values() {
459        // Where input
460        code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
461        code.push_str("#[derive(Debug, Default, Clone)]\n");
462        code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
463
464        for field in model.fields.values() {
465            if !field.is_relation() {
466                let filter_type = field_to_filter_type(&field.field_type);
467                let field_name = to_snake_case(field.name());
468                code.push_str(&format!(
469                    "    pub {}: Option<{}>,\n",
470                    field_name, filter_type
471                ));
472            }
473        }
474
475        code.push_str("    pub and: Option<Vec<Self>>,\n");
476        code.push_str("    pub or: Option<Vec<Self>>,\n");
477        code.push_str("    pub not: Option<Box<Self>>,\n");
478        code.push_str("}\n\n");
479
480        // OrderBy input
481        code.push_str(&format!(
482            "/// Order by input for {} queries\n",
483            model.name()
484        ));
485        code.push_str("#[derive(Debug, Default, Clone)]\n");
486        code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
487
488        for field in model.fields.values() {
489            if !field.is_relation() {
490                let field_name = to_snake_case(field.name());
491                code.push_str(&format!(
492                    "    pub {}: Option<prax_query::SortOrder>,\n",
493                    field_name
494                ));
495            }
496        }
497
498        code.push_str("}\n\n");
499    }
500
501    Ok(code)
502}
503
504/// Convert a field type to Rust type
505fn field_type_to_rust(
506    field_type: &prax_schema::ast::FieldType,
507    modifier: prax_schema::ast::TypeModifier,
508) -> String {
509    use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
510
511    let base_type = match field_type {
512        FieldType::Scalar(scalar) => match scalar {
513            ScalarType::Int => "i32".to_string(),
514            ScalarType::BigInt => "i64".to_string(),
515            ScalarType::Float => "f64".to_string(),
516            ScalarType::String => "String".to_string(),
517            ScalarType::Boolean => "bool".to_string(),
518            ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
519            ScalarType::Date => "chrono::NaiveDate".to_string(),
520            ScalarType::Time => "chrono::NaiveTime".to_string(),
521            ScalarType::Json => "serde_json::Value".to_string(),
522            ScalarType::Bytes => "Vec<u8>".to_string(),
523            ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
524            ScalarType::Uuid => "uuid::Uuid".to_string(),
525            ScalarType::Cuid => "String".to_string(),
526            ScalarType::Cuid2 => "String".to_string(),
527            ScalarType::NanoId => "String".to_string(),
528            ScalarType::Ulid => "String".to_string(),
529        },
530        FieldType::Model(name) => name.to_string(),
531        FieldType::Enum(name) => name.to_string(),
532        FieldType::Composite(name) => name.to_string(),
533        FieldType::Unsupported(_) => "serde_json::Value".to_string(),
534    };
535
536    match modifier {
537        TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
538        TypeModifier::List => format!("Vec<{}>", base_type),
539        TypeModifier::Required => base_type,
540    }
541}
542
543/// Convert a field type to filter type
544fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
545    use prax_schema::ast::{FieldType, ScalarType};
546
547    match field_type {
548        FieldType::Scalar(scalar) => match scalar {
549            ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
550            ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
551            ScalarType::String
552            | ScalarType::Uuid
553            | ScalarType::Cuid
554            | ScalarType::Cuid2
555            | ScalarType::NanoId
556            | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
557            ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
558            ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
559            ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
560            ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
561            ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
562            ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
563        },
564        FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
565        _ => "Filter".to_string(),
566    }
567}
568
569/// Convert PascalCase to snake_case
570fn to_snake_case(name: &str) -> String {
571    let mut result = String::new();
572    for (i, c) in name.chars().enumerate() {
573        if c.is_uppercase() {
574            if i > 0 {
575                result.push('_');
576            }
577            result.push(c.to_lowercase().next().unwrap());
578        } else {
579            result.push(c);
580        }
581    }
582    result
583}