Skip to main content

prax_cli/commands/
generate.rs

1//! `prax generate` command - Generate Rust client code from schema.
2
3use std::collections::{HashMap, HashSet};
4use std::path::PathBuf;
5
6use crate::cli::GenerateArgs;
7use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success};
10
11/// Run the generate command
12pub async fn run(args: GenerateArgs) -> CliResult<()> {
13    output::header("Generate Prax Client");
14
15    let cwd = std::env::current_dir()?;
16
17    // Load config
18    let config_path = cwd.join(CONFIG_FILE_NAME);
19    let config = if config_path.exists() {
20        Config::load(&config_path)?
21    } else {
22        Config::default()
23    };
24
25    // Resolve schema path
26    let schema_path = args
27        .schema
28        .clone()
29        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
30    if !schema_path.exists() {
31        return Err(
32            CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
33        );
34    }
35
36    // Resolve output directory
37    let output_dir = args
38        .output
39        .clone()
40        .unwrap_or_else(|| PathBuf::from(&config.generator.output));
41
42    output::kv("Schema", &schema_path.display().to_string());
43    output::kv("Output", &output_dir.display().to_string());
44    output::newline();
45
46    output::step(1, 4, "Reading schema...");
47
48    // Parse schema
49    let schema_content = std::fs::read_to_string(&schema_path)?;
50    let schema = parse_schema(&schema_content)?;
51
52    output::step(2, 4, "Validating schema...");
53
54    // Validate schema
55    validate_schema(&schema)?;
56
57    output::step(3, 4, "Generating code...");
58
59    // Create output directory
60    std::fs::create_dir_all(&output_dir)?;
61
62    // Generate code
63    let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
64
65    output::step(4, 4, "Writing files...");
66
67    // Print generated files
68    output::newline();
69    output::section("Generated files");
70
71    for file in &generated_files {
72        let relative_path = file
73            .strip_prefix(&cwd)
74            .unwrap_or(file)
75            .display()
76            .to_string();
77        output::list_item(&relative_path);
78    }
79
80    output::newline();
81    success(&format!(
82        "Generated {} files in {:.2}s",
83        generated_files.len(),
84        0.0 // TODO: Add timing
85    ));
86
87    Ok(())
88}
89
90/// Parse and validate the schema file
91fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
92    // Use validate_schema to ensure field types are properly resolved
93    // (e.g., FieldType::Model -> FieldType::Enum for enum references)
94    prax_schema::validate_schema(content)
95        .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
96}
97
98/// Validate the schema (now a no-op since parse_schema does validation)
99fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
100    // Validation is now done in parse_schema via validate_schema()
101    Ok(())
102}
103
104/// Generate code from the schema
105fn generate_code(
106    schema: &prax_schema::ast::Schema,
107    output_dir: &PathBuf,
108    args: &GenerateArgs,
109    config: &Config,
110) -> CliResult<Vec<PathBuf>> {
111    let mut generated_files = Vec::new();
112
113    // Determine which features to generate
114    let features = if !args.features.is_empty() {
115        args.features.clone()
116    } else {
117        config
118            .generator
119            .features
120            .clone()
121            .unwrap_or_else(|| vec!["client".to_string()])
122    };
123
124    // Build relation graph for cycle detection
125    let relation_graph = build_relation_graph(schema);
126
127    // Generate main client module
128    let client_path = output_dir.join("mod.rs");
129    let client_code = generate_client_module(schema, &features)?;
130    std::fs::write(&client_path, client_code)?;
131    generated_files.push(client_path);
132
133    // Generate model modules
134    for model in schema.models.values() {
135        let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
136        let model_code = generate_model_module(model, &features, &relation_graph)?;
137        std::fs::write(&model_path, model_code)?;
138        generated_files.push(model_path);
139    }
140
141    // Generate enum modules
142    for enum_def in schema.enums.values() {
143        let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
144        let enum_code = generate_enum_module(enum_def)?;
145        std::fs::write(&enum_path, enum_code)?;
146        generated_files.push(enum_path);
147    }
148
149    // Generate type definitions
150    let types_path = output_dir.join("types.rs");
151    let types_code = generate_types_module(schema)?;
152    std::fs::write(&types_path, types_code)?;
153    generated_files.push(types_path);
154
155    // Generate filters
156    let filters_path = output_dir.join("filters.rs");
157    let filters_code = generate_filters_module(schema)?;
158    std::fs::write(&filters_path, filters_code)?;
159    generated_files.push(filters_path);
160
161    Ok(generated_files)
162}
163
164/// Build a graph of model relations for cycle detection.
165/// Returns a map from model name to the set of model names it references
166/// (non-list relations only, since Vec<T> doesn't cause infinite size).
167fn build_relation_graph(
168    schema: &prax_schema::ast::Schema,
169) -> HashMap<String, HashSet<String>> {
170    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
171
172    for model in schema.models.values() {
173        let entry = graph.entry(model.name().to_string()).or_default();
174        for field in model.fields.values() {
175            if let prax_schema::ast::FieldType::Model(ref target) = field.field_type {
176                if !field.is_list() {
177                    entry.insert(target.to_string());
178                }
179            }
180        }
181    }
182
183    graph
184}
185
186/// Check if a non-list relation field from `source_model` to `target_model`
187/// participates in a cycle (i.e. target_model can reach source_model through
188/// non-list relations). If so, the field must be wrapped in Box<T>.
189fn needs_boxing(
190    source_model: &str,
191    target_model: &str,
192    graph: &HashMap<String, HashSet<String>>,
193) -> bool {
194    let mut visited = HashSet::new();
195    let mut stack = vec![target_model.to_string()];
196
197    while let Some(current) = stack.pop() {
198        if current == source_model {
199            return true;
200        }
201        if !visited.insert(current.clone()) {
202            continue;
203        }
204        if let Some(neighbors) = graph.get(&current) {
205            for neighbor in neighbors {
206                stack.push(neighbor.clone());
207            }
208        }
209    }
210
211    false
212}
213
214/// Generate the main client module
215fn generate_client_module(
216    schema: &prax_schema::ast::Schema,
217    _features: &[String],
218) -> CliResult<String> {
219    let mut code = String::new();
220
221    code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
222    code.push_str("//!\n");
223    code.push_str("//! This module contains the generated Prax client.\n\n");
224
225    // Module declarations
226    code.push_str("pub mod types;\n");
227    code.push_str("pub mod filters;\n\n");
228
229    for model in schema.models.values() {
230        code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
231    }
232
233    for enum_def in schema.enums.values() {
234        code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
235    }
236
237    code.push_str("\n");
238
239    // Re-exports
240    code.push_str("#[allow(unused_imports)]\npub use types::*;\n");
241    code.push_str("#[allow(unused_imports)]\npub use filters::*;\n\n");
242
243    for model in schema.models.values() {
244        code.push_str(&format!(
245            "#[allow(unused_imports)]\npub use {}::{};\n",
246            to_snake_case(model.name()),
247            model.name()
248        ));
249    }
250
251    for enum_def in schema.enums.values() {
252        code.push_str(&format!(
253            "#[allow(unused_imports)]\npub use {}::{};\n",
254            to_snake_case(enum_def.name()),
255            enum_def.name()
256        ));
257    }
258
259    code.push_str("\n");
260
261    // Client struct with Clone bound and derive
262    code.push_str("#[allow(dead_code)]\n");
263    code.push_str("/// The Prax database client\n");
264    code.push_str("#[derive(Clone)]\n");
265    code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
266    code.push_str("    engine: E,\n");
267    code.push_str("}\n\n");
268
269    code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
270    code.push_str("    /// Create a new Prax client with the given query engine\n");
271    code.push_str("    pub fn new(engine: E) -> Self {\n");
272    code.push_str("        Self { engine }\n");
273    code.push_str("    }\n\n");
274
275    for model in schema.models.values() {
276        let snake_name = to_snake_case(model.name());
277        code.push_str(&format!("    /// Access {} operations\n", model.name()));
278        code.push_str(&format!(
279            "    pub fn {}(&self) -> {}::{}Operations<E> {{\n",
280            snake_name,
281            snake_name,
282            model.name()
283        ));
284        code.push_str(&format!(
285            "        {}::{}Operations::new(self.engine.clone())\n",
286            snake_name,
287            model.name()
288        ));
289        code.push_str("    }\n\n");
290    }
291
292    code.push_str("}\n");
293
294    Ok(code)
295}
296
297/// Generate a model module
298fn generate_model_module(
299    model: &prax_schema::ast::Model,
300    features: &[String],
301    relation_graph: &HashMap<String, HashSet<String>>,
302) -> CliResult<String> {
303    let mut code = String::new();
304
305    code.push_str(&format!(
306        "//! Auto-generated module for {} model\n\n",
307        model.name()
308    ));
309
310    // Import sibling types for relation fields
311    code.push_str("#[allow(unused_imports)]\n");
312    code.push_str("use super::*;\n");
313    code.push_str("#[allow(unused_imports)]\n");
314    code.push_str("use prax_query::traits::Model;\n\n");
315
316    // Derive macros based on features
317    let mut derives = vec!["Debug", "Clone"];
318    if features.contains(&"serde".to_string()) {
319        derives.push("serde::Serialize");
320        derives.push("serde::Deserialize");
321    }
322
323    // Model struct
324    code.push_str("#[allow(dead_code)]\n");
325    code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
326    code.push_str(&format!("pub struct {} {{\n", model.name()));
327
328    for field in model.fields.values() {
329        let field_name = to_snake_case(field.name());
330
331        // Add serde rename if mapped
332        if let Some(attr) = field.get_attribute("map") {
333            if features.contains(&"serde".to_string()) {
334                if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
335                    code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
336                }
337            }
338        }
339
340        let rust_type = field_type_to_rust_with_boxing(
341            &field.field_type,
342            field.modifier,
343            model.name(),
344            relation_graph,
345        );
346        code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
347    }
348
349    code.push_str("}\n\n");
350
351    // Model trait implementation
352    let table_name = model.table_name();
353    let id_fields: Vec<&str> = model.id_fields().iter().map(|f| f.name()).collect();
354    let scalar_columns: Vec<String> = model
355        .scalar_fields()
356        .iter()
357        .map(|f| {
358            // Use @map name if present, otherwise snake_case the field name
359            f.get_attribute("map")
360                .and_then(|a| a.first_arg())
361                .and_then(|v| v.as_string())
362                .map(|s| s.to_string())
363                .unwrap_or_else(|| to_snake_case(f.name()))
364        })
365        .collect();
366
367    code.push_str(&format!("impl Model for {} {{\n", model.name()));
368    code.push_str(&format!(
369        "    const MODEL_NAME: &'static str = \"{}\";\n",
370        model.name()
371    ));
372    code.push_str(&format!(
373        "    const TABLE_NAME: &'static str = \"{}\";\n",
374        table_name
375    ));
376    code.push_str(&format!(
377        "    const PRIMARY_KEY: &'static [&'static str] = &[{}];\n",
378        id_fields
379            .iter()
380            .map(|f| format!("\"{}\"", to_snake_case(f)))
381            .collect::<Vec<_>>()
382            .join(", ")
383    ));
384    code.push_str(&format!(
385        "    const COLUMNS: &'static [&'static str] = &[{}];\n",
386        scalar_columns
387            .iter()
388            .map(|c| format!("\"{}\"", c))
389            .collect::<Vec<_>>()
390            .join(", ")
391    ));
392    code.push_str("}\n\n");
393
394    // Operations struct (owned engine, no lifetime)
395    code.push_str("#[allow(dead_code)]\n");
396    code.push_str(&format!("/// Operations for the {} model\n", model.name()));
397    code.push_str(&format!(
398        "pub struct {}Operations<E: prax_query::QueryEngine> {{\n",
399        model.name()
400    ));
401    code.push_str("    engine: E,\n");
402    code.push_str("}\n\n");
403
404    code.push_str(&format!(
405        "impl<E: prax_query::QueryEngine> {}Operations<E> {{\n",
406        model.name()
407    ));
408    code.push_str("    pub fn new(engine: E) -> Self {\n");
409    code.push_str("        Self { engine }\n");
410    code.push_str("    }\n\n");
411
412    // CRUD methods (1-arg constructors, no lifetime on return types)
413    code.push_str("    /// Find many records\n");
414    code.push_str(&format!(
415        "    pub fn find_many(&self) -> prax_query::FindManyOperation<E, {}> {{\n",
416        model.name()
417    ));
418    code.push_str("        prax_query::FindManyOperation::new(self.engine.clone())\n");
419    code.push_str("    }\n\n");
420
421    code.push_str("    /// Find a unique record\n");
422    code.push_str(&format!(
423        "    pub fn find_unique(&self) -> prax_query::FindUniqueOperation<E, {}> {{\n",
424        model.name()
425    ));
426    code.push_str("        prax_query::FindUniqueOperation::new(self.engine.clone())\n");
427    code.push_str("    }\n\n");
428
429    code.push_str("    /// Find the first matching record\n");
430    code.push_str(&format!(
431        "    pub fn find_first(&self) -> prax_query::FindFirstOperation<E, {}> {{\n",
432        model.name()
433    ));
434    code.push_str("        prax_query::FindFirstOperation::new(self.engine.clone())\n");
435    code.push_str("    }\n\n");
436
437    code.push_str("    /// Create a new record\n");
438    code.push_str(&format!(
439        "    pub fn create(&self) -> prax_query::CreateOperation<E, {}> {{\n",
440        model.name()
441    ));
442    code.push_str("        prax_query::CreateOperation::new(self.engine.clone())\n");
443    code.push_str("    }\n\n");
444
445    code.push_str("    /// Update a record\n");
446    code.push_str(&format!(
447        "    pub fn update(&self) -> prax_query::UpdateOperation<E, {}> {{\n",
448        model.name()
449    ));
450    code.push_str("        prax_query::UpdateOperation::new(self.engine.clone())\n");
451    code.push_str("    }\n\n");
452
453    code.push_str("    /// Delete a record\n");
454    code.push_str(&format!(
455        "    pub fn delete(&self) -> prax_query::DeleteOperation<E, {}> {{\n",
456        model.name()
457    ));
458    code.push_str("        prax_query::DeleteOperation::new(self.engine.clone())\n");
459    code.push_str("    }\n\n");
460
461    code.push_str("    /// Count records\n");
462    code.push_str(&format!(
463        "    pub fn count(&self) -> prax_query::CountOperation<E, {}> {{\n",
464        model.name()
465    ));
466    code.push_str("        prax_query::CountOperation::new(self.engine.clone())\n");
467    code.push_str("    }\n");
468
469    code.push_str("}\n");
470
471    Ok(code)
472}
473
474/// Generate an enum module
475fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
476    let mut code = String::new();
477
478    code.push_str(&format!(
479        "//! Auto-generated module for {} enum\n\n",
480        enum_def.name()
481    ));
482
483    code.push_str("#[allow(dead_code)]\n");
484    code.push_str(
485        "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
486    );
487    code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
488
489    for variant in &enum_def.variants {
490        let raw_name = variant.name();
491        let pascal_name = to_pascal_case(raw_name);
492
493        // Check for explicit @map attribute first
494        if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
495            if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
496                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
497                code.push_str(&format!("    {},\n", pascal_name));
498                continue;
499            }
500        }
501
502        // If variant name differs from PascalCase form, add serde rename
503        if raw_name != pascal_name {
504            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", raw_name));
505        }
506        code.push_str(&format!("    {},\n", pascal_name));
507    }
508
509    code.push_str("}\n\n");
510
511    // Display implementation for SQL serialization
512    code.push_str(&format!("impl std::fmt::Display for {} {{\n", enum_def.name()));
513    code.push_str("    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n");
514    code.push_str("        match self {\n");
515    for variant in &enum_def.variants {
516        let raw_name = variant.name();
517        let pascal_name = to_pascal_case(raw_name);
518        let db_value = variant.db_value();
519        code.push_str(&format!(
520            "            Self::{} => write!(f, \"{}\"),\n",
521            pascal_name, db_value
522        ));
523    }
524    code.push_str("        }\n");
525    code.push_str("    }\n");
526    code.push_str("}\n\n");
527
528    // Default implementation
529    if let Some(default_variant) = enum_def.variants.first() {
530        let pascal_name = to_pascal_case(default_variant.name());
531        code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
532        code.push_str(&format!(
533            "    fn default() -> Self {{\n        Self::{}\n    }}\n",
534            pascal_name
535        ));
536        code.push_str("}\n");
537    }
538
539    Ok(code)
540}
541
542/// Generate types module
543fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
544    let mut code = String::new();
545
546    code.push_str("//! Common type definitions\n\n");
547    code.push_str("#[allow(unused_imports)]\npub use chrono::{DateTime, Utc};\n");
548    code.push_str("#[allow(unused_imports)]\npub use uuid::Uuid;\n");
549    code.push_str("#[allow(unused_imports)]\npub use serde_json::Value as Json;\n");
550    code.push_str("\n");
551
552    // Add any custom types from composite types
553    for composite in schema.types.values() {
554        code.push_str("#[allow(dead_code)]\n");
555        code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
556        code.push_str(&format!("pub struct {} {{\n", composite.name()));
557        for field in composite.fields.values() {
558            let rust_type = field_type_to_rust(&field.field_type, field.modifier);
559            let field_name = to_snake_case(field.name());
560            code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
561        }
562        code.push_str("}\n\n");
563    }
564
565    Ok(code)
566}
567
568/// Generate filters module
569fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
570    let mut code = String::new();
571
572    code.push_str("//! Filter types for queries\n\n");
573    code.push_str("#[allow(unused_imports)]\n");
574    code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n");
575
576    // Collect all enum types referenced by model scalar fields
577    let mut referenced_enums = HashSet::new();
578    for model in schema.models.values() {
579        for field in model.fields.values() {
580            if !field.is_relation() {
581                if let prax_schema::ast::FieldType::Enum(ref name) = field.field_type {
582                    referenced_enums.insert(name.to_string());
583                }
584            }
585        }
586    }
587
588    // Import enum types
589    for enum_name in &referenced_enums {
590        code.push_str(&format!(
591            "#[allow(unused_imports)]\nuse super::{}::{};\n",
592            to_snake_case(enum_name),
593            enum_name
594        ));
595    }
596
597    code.push_str("\n");
598
599    for model in schema.models.values() {
600        // Where input
601        code.push_str("#[allow(dead_code)]\n");
602        code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
603        code.push_str("#[derive(Debug, Default, Clone)]\n");
604        code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
605
606        for field in model.fields.values() {
607            if !field.is_relation() {
608                let filter_type = field_to_filter_type(&field.field_type);
609                let field_name = to_snake_case(field.name());
610                code.push_str(&format!(
611                    "    pub {}: Option<{}>,\n",
612                    field_name, filter_type
613                ));
614            }
615        }
616
617        code.push_str("    pub and: Option<Vec<Self>>,\n");
618        code.push_str("    pub or: Option<Vec<Self>>,\n");
619        code.push_str("    pub not: Option<Box<Self>>,\n");
620        code.push_str("}\n\n");
621
622        // OrderBy input
623        code.push_str("#[allow(dead_code)]\n");
624        code.push_str(&format!(
625            "/// Order by input for {} queries\n",
626            model.name()
627        ));
628        code.push_str("#[derive(Debug, Default, Clone)]\n");
629        code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
630
631        for field in model.fields.values() {
632            if !field.is_relation() {
633                let field_name = to_snake_case(field.name());
634                code.push_str(&format!(
635                    "    pub {}: Option<prax_query::SortOrder>,\n",
636                    field_name
637                ));
638            }
639        }
640
641        code.push_str("}\n\n");
642    }
643
644    Ok(code)
645}
646
647/// Convert a field type to Rust type (basic, without boxing)
648fn field_type_to_rust(
649    field_type: &prax_schema::ast::FieldType,
650    modifier: prax_schema::ast::TypeModifier,
651) -> String {
652    use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
653
654    let base_type = match field_type {
655        FieldType::Scalar(scalar) => match scalar {
656            ScalarType::Int => "i32".to_string(),
657            ScalarType::BigInt => "i64".to_string(),
658            ScalarType::Float => "f64".to_string(),
659            ScalarType::String => "String".to_string(),
660            ScalarType::Boolean => "bool".to_string(),
661            ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
662            ScalarType::Date => "chrono::NaiveDate".to_string(),
663            ScalarType::Time => "chrono::NaiveTime".to_string(),
664            ScalarType::Json => "serde_json::Value".to_string(),
665            ScalarType::Bytes => "Vec<u8>".to_string(),
666            ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
667            ScalarType::Uuid => "uuid::Uuid".to_string(),
668            ScalarType::Cuid => "String".to_string(),
669            ScalarType::Cuid2 => "String".to_string(),
670            ScalarType::NanoId => "String".to_string(),
671            ScalarType::Ulid => "String".to_string(),
672            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
673            ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
674            ScalarType::Bit(_) => "Vec<u8>".to_string(),
675        },
676        FieldType::Model(name) => name.to_string(),
677        FieldType::Enum(name) => name.to_string(),
678        FieldType::Composite(name) => name.to_string(),
679        FieldType::Unsupported(_) => "serde_json::Value".to_string(),
680    };
681
682    match modifier {
683        TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
684        TypeModifier::List => format!("Vec<{}>", base_type),
685        TypeModifier::Required => base_type,
686    }
687}
688
689/// Convert a field type to Rust type with Box<T> wrapping for cyclic relations.
690fn field_type_to_rust_with_boxing(
691    field_type: &prax_schema::ast::FieldType,
692    modifier: prax_schema::ast::TypeModifier,
693    source_model: &str,
694    relation_graph: &HashMap<String, HashSet<String>>,
695) -> String {
696    use prax_schema::ast::{FieldType, TypeModifier};
697
698    // For model references (non-list), check if boxing is needed to break cycles
699    if let FieldType::Model(target) = field_type {
700        if !matches!(modifier, TypeModifier::List) {
701            let should_box = needs_boxing(source_model, target, relation_graph);
702            let base = target.to_string();
703            return match modifier {
704                TypeModifier::Optional | TypeModifier::OptionalList => {
705                    if should_box {
706                        format!("Option<Box<{}>>", base)
707                    } else {
708                        format!("Option<{}>", base)
709                    }
710                }
711                TypeModifier::Required => {
712                    if should_box {
713                        format!("Box<{}>", base)
714                    } else {
715                        base
716                    }
717                }
718                TypeModifier::List => unreachable!(),
719            };
720        }
721    }
722
723    // Fallback to basic conversion for non-cyclic fields
724    field_type_to_rust(field_type, modifier)
725}
726
727/// Convert a field type to filter type
728fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
729    use prax_schema::ast::{FieldType, ScalarType};
730
731    match field_type {
732        FieldType::Scalar(scalar) => match scalar {
733            ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
734            ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
735            ScalarType::String
736            | ScalarType::Uuid
737            | ScalarType::Cuid
738            | ScalarType::Cuid2
739            | ScalarType::NanoId
740            | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
741            ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
742            ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
743            ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
744            ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
745            ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
746            ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
747            // Vector types don't have standard scalar filters
748            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
749            ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
750            ScalarType::Bit(_) => "BitFilter".to_string(),
751        },
752        FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
753        _ => "Filter".to_string(),
754    }
755}
756
757/// Convert PascalCase to snake_case
758fn to_snake_case(name: &str) -> String {
759    let mut result = String::new();
760    for (i, c) in name.chars().enumerate() {
761        if c.is_uppercase() {
762            if i > 0 {
763                result.push('_');
764            }
765            result.push(c.to_lowercase().next().unwrap());
766        } else {
767            result.push(c);
768        }
769    }
770    result
771}
772
773/// Convert snake_case, SCREAMING_SNAKE_CASE, or any other casing to PascalCase.
774fn to_pascal_case(name: &str) -> String {
775    if name.is_empty() {
776        return String::new();
777    }
778
779    // If already PascalCase (starts with uppercase, contains lowercase), return as-is
780    let first = name.chars().next().unwrap();
781    if first.is_uppercase() && name.chars().any(|c| c.is_lowercase()) && !name.contains('_') {
782        return name.to_string();
783    }
784
785    // Split on underscores and capitalize each segment
786    name.split('_')
787        .filter(|s| !s.is_empty())
788        .map(|segment| {
789            let mut chars = segment.chars();
790            match chars.next() {
791                None => String::new(),
792                Some(first) => {
793                    let rest: String = chars.collect();
794                    format!("{}{}", first.to_uppercase(), rest.to_lowercase())
795                }
796            }
797        })
798        .collect()
799}
800
801#[cfg(test)]
802mod tests {
803    use super::*;
804
805    #[test]
806    fn test_to_snake_case() {
807        assert_eq!(to_snake_case("BoardMember"), "board_member");
808        assert_eq!(to_snake_case("User"), "user");
809        assert_eq!(to_snake_case("JiraImportConfig"), "jira_import_config");
810    }
811
812    #[test]
813    fn test_to_pascal_case_from_snake() {
814        assert_eq!(to_pascal_case("card_created"), "CardCreated");
815        assert_eq!(to_pascal_case("branch_deleted"), "BranchDeleted");
816        assert_eq!(to_pascal_case("pr_merged"), "PrMerged");
817    }
818
819    #[test]
820    fn test_to_pascal_case_from_screaming() {
821        assert_eq!(to_pascal_case("CARD_CREATED"), "CardCreated");
822        assert_eq!(to_pascal_case("PR_MERGED"), "PrMerged");
823    }
824
825    #[test]
826    fn test_to_pascal_case_already_pascal() {
827        assert_eq!(to_pascal_case("Admin"), "Admin");
828        assert_eq!(to_pascal_case("SuperAdmin"), "SuperAdmin");
829        assert_eq!(to_pascal_case("Low"), "Low");
830    }
831
832    #[test]
833    fn test_to_pascal_case_single_word() {
834        assert_eq!(to_pascal_case("active"), "Active");
835        assert_eq!(to_pascal_case("ACTIVE"), "Active");
836    }
837
838    #[test]
839    fn test_needs_boxing_direct_cycle() {
840        let mut graph = HashMap::new();
841        graph.insert(
842            "Board".to_string(),
843            HashSet::from(["JiraConfig".to_string()]),
844        );
845        graph.insert(
846            "JiraConfig".to_string(),
847            HashSet::from(["Board".to_string()]),
848        );
849
850        assert!(needs_boxing("Board", "JiraConfig", &graph));
851        assert!(needs_boxing("JiraConfig", "Board", &graph));
852    }
853
854    #[test]
855    fn test_needs_boxing_no_cycle() {
856        let mut graph = HashMap::new();
857        graph.insert("Post".to_string(), HashSet::from(["User".to_string()]));
858        graph.insert("User".to_string(), HashSet::new());
859
860        assert!(!needs_boxing("Post", "User", &graph));
861    }
862
863    #[test]
864    fn test_needs_boxing_indirect_cycle() {
865        let mut graph = HashMap::new();
866        graph.insert("A".to_string(), HashSet::from(["B".to_string()]));
867        graph.insert("B".to_string(), HashSet::from(["C".to_string()]));
868        graph.insert("C".to_string(), HashSet::from(["A".to_string()]));
869
870        assert!(needs_boxing("A", "B", &graph));
871        assert!(needs_boxing("B", "C", &graph));
872        assert!(needs_boxing("C", "A", &graph));
873    }
874}