Skip to main content

prax_cli/commands/
generate.rs

1//! `prax generate` command - Generate Rust client code from schema.
2
3use std::collections::{HashMap, HashSet};
4use std::path::PathBuf;
5
6use crate::cli::GenerateArgs;
7use crate::config::{CONFIG_FILE_NAME, Config, SCHEMA_FILE_PATH};
8use crate::error::{CliError, CliResult};
9use crate::output::{self, success};
10
11/// Run the generate command
12pub async fn run(args: GenerateArgs) -> CliResult<()> {
13    output::header("Generate Prax Client");
14
15    let cwd = std::env::current_dir()?;
16
17    // Load config
18    let config_path = cwd.join(CONFIG_FILE_NAME);
19    let config = if config_path.exists() {
20        Config::load(&config_path)?
21    } else {
22        Config::default()
23    };
24
25    // Resolve schema path
26    let schema_path = args
27        .schema
28        .clone()
29        .unwrap_or_else(|| cwd.join(SCHEMA_FILE_PATH));
30    if !schema_path.exists() {
31        return Err(
32            CliError::Config(format!("Schema file not found: {}", schema_path.display())).into(),
33        );
34    }
35
36    // Resolve output directory
37    let output_dir = args
38        .output
39        .clone()
40        .unwrap_or_else(|| PathBuf::from(&config.generator.output));
41
42    output::kv("Schema", &schema_path.display().to_string());
43    output::kv("Output", &output_dir.display().to_string());
44    output::newline();
45
46    output::step(1, 4, "Reading schema...");
47
48    // Parse schema
49    let schema_content = std::fs::read_to_string(&schema_path)?;
50    let schema = parse_schema(&schema_content)?;
51
52    output::step(2, 4, "Validating schema...");
53
54    // Validate schema
55    validate_schema(&schema)?;
56
57    output::step(3, 4, "Generating code...");
58
59    // Create output directory
60    std::fs::create_dir_all(&output_dir)?;
61
62    // Generate code
63    let generated_files = generate_code(&schema, &output_dir, &args, &config)?;
64
65    output::step(4, 4, "Writing files...");
66
67    // Print generated files
68    output::newline();
69    output::section("Generated files");
70
71    for file in &generated_files {
72        let relative_path = file
73            .strip_prefix(&cwd)
74            .unwrap_or(file)
75            .display()
76            .to_string();
77        output::list_item(&relative_path);
78    }
79
80    output::newline();
81    success(&format!(
82        "Generated {} files in {:.2}s",
83        generated_files.len(),
84        0.0 // TODO: Add timing
85    ));
86
87    Ok(())
88}
89
90/// Parse and validate the schema file
91fn parse_schema(content: &str) -> CliResult<prax_schema::Schema> {
92    // Use validate_schema to ensure field types are properly resolved
93    // (e.g., FieldType::Model -> FieldType::Enum for enum references)
94    prax_schema::validate_schema(content)
95        .map_err(|e| CliError::Schema(format!("Failed to parse/validate schema: {}", e)))
96}
97
98/// Validate the schema (now a no-op since parse_schema does validation)
99fn validate_schema(_schema: &prax_schema::Schema) -> CliResult<()> {
100    // Validation is now done in parse_schema via validate_schema()
101    Ok(())
102}
103
104/// Generate code from the schema
105fn generate_code(
106    schema: &prax_schema::ast::Schema,
107    output_dir: &PathBuf,
108    args: &GenerateArgs,
109    config: &Config,
110) -> CliResult<Vec<PathBuf>> {
111    let mut generated_files = Vec::new();
112
113    // Determine which features to generate
114    let features = if !args.features.is_empty() {
115        args.features.clone()
116    } else {
117        config
118            .generator
119            .features
120            .clone()
121            .unwrap_or_else(|| vec!["client".to_string()])
122    };
123
124    // Build relation graph for cycle detection
125    let relation_graph = build_relation_graph(schema);
126
127    // Generate main client module
128    let client_path = output_dir.join("mod.rs");
129    let client_code = generate_client_module(schema, &features)?;
130    std::fs::write(&client_path, client_code)?;
131    generated_files.push(client_path);
132
133    // Generate model modules
134    for model in schema.models.values() {
135        let model_path = output_dir.join(format!("{}.rs", to_snake_case(model.name())));
136        let model_code = generate_model_module(model, &features, &relation_graph)?;
137        std::fs::write(&model_path, model_code)?;
138        generated_files.push(model_path);
139    }
140
141    // Generate enum modules
142    for enum_def in schema.enums.values() {
143        let enum_path = output_dir.join(format!("{}.rs", to_snake_case(enum_def.name())));
144        let enum_code = generate_enum_module(enum_def)?;
145        std::fs::write(&enum_path, enum_code)?;
146        generated_files.push(enum_path);
147    }
148
149    // Generate type definitions
150    let types_path = output_dir.join("types.rs");
151    let types_code = generate_types_module(schema)?;
152    std::fs::write(&types_path, types_code)?;
153    generated_files.push(types_path);
154
155    // Generate filters
156    let filters_path = output_dir.join("filters.rs");
157    let filters_code = generate_filters_module(schema)?;
158    std::fs::write(&filters_path, filters_code)?;
159    generated_files.push(filters_path);
160
161    Ok(generated_files)
162}
163
164/// Build a graph of model relations for cycle detection.
165/// Returns a map from model name to the set of model names it references
166/// (non-list relations only, since Vec<T> doesn't cause infinite size).
167fn build_relation_graph(
168    schema: &prax_schema::ast::Schema,
169) -> HashMap<String, HashSet<String>> {
170    let mut graph: HashMap<String, HashSet<String>> = HashMap::new();
171
172    for model in schema.models.values() {
173        let entry = graph.entry(model.name().to_string()).or_default();
174        for field in model.fields.values() {
175            if let prax_schema::ast::FieldType::Model(ref target) = field.field_type {
176                if !field.is_list() {
177                    entry.insert(target.to_string());
178                }
179            }
180        }
181    }
182
183    graph
184}
185
186/// Check if a non-list relation field from `source_model` to `target_model`
187/// participates in a cycle (i.e. target_model can reach source_model through
188/// non-list relations). If so, the field must be wrapped in Box<T>.
189fn needs_boxing(
190    source_model: &str,
191    target_model: &str,
192    graph: &HashMap<String, HashSet<String>>,
193) -> bool {
194    let mut visited = HashSet::new();
195    let mut stack = vec![target_model.to_string()];
196
197    while let Some(current) = stack.pop() {
198        if current == source_model {
199            return true;
200        }
201        if !visited.insert(current.clone()) {
202            continue;
203        }
204        if let Some(neighbors) = graph.get(&current) {
205            for neighbor in neighbors {
206                stack.push(neighbor.clone());
207            }
208        }
209    }
210
211    false
212}
213
214/// Generate the main client module
215fn generate_client_module(
216    schema: &prax_schema::ast::Schema,
217    _features: &[String],
218) -> CliResult<String> {
219    let mut code = String::new();
220
221    code.push_str("//! Auto-generated by Prax - DO NOT EDIT\n");
222    code.push_str("//!\n");
223    code.push_str("//! This module contains the generated Prax client.\n\n");
224
225    // Module declarations
226    code.push_str("pub mod types;\n");
227    code.push_str("pub mod filters;\n\n");
228
229    for model in schema.models.values() {
230        code.push_str(&format!("pub mod {};\n", to_snake_case(model.name())));
231    }
232
233    for enum_def in schema.enums.values() {
234        code.push_str(&format!("pub mod {};\n", to_snake_case(enum_def.name())));
235    }
236
237    code.push_str("\n");
238
239    // Re-exports
240    code.push_str("pub use types::*;\n");
241    code.push_str("pub use filters::*;\n\n");
242
243    for model in schema.models.values() {
244        code.push_str(&format!(
245            "pub use {}::{};\n",
246            to_snake_case(model.name()),
247            model.name()
248        ));
249    }
250
251    for enum_def in schema.enums.values() {
252        code.push_str(&format!(
253            "pub use {}::{};\n",
254            to_snake_case(enum_def.name()),
255            enum_def.name()
256        ));
257    }
258
259    code.push_str("\n");
260
261    // Client struct with Clone bound and derive
262    code.push_str("/// The Prax database client\n");
263    code.push_str("#[derive(Clone)]\n");
264    code.push_str("pub struct PraxClient<E: prax_query::QueryEngine> {\n");
265    code.push_str("    engine: E,\n");
266    code.push_str("}\n\n");
267
268    code.push_str("impl<E: prax_query::QueryEngine> PraxClient<E> {\n");
269    code.push_str("    /// Create a new Prax client with the given query engine\n");
270    code.push_str("    pub fn new(engine: E) -> Self {\n");
271    code.push_str("        Self { engine }\n");
272    code.push_str("    }\n\n");
273
274    for model in schema.models.values() {
275        let snake_name = to_snake_case(model.name());
276        code.push_str(&format!("    /// Access {} operations\n", model.name()));
277        code.push_str(&format!(
278            "    pub fn {}(&self) -> {}::{}Operations<E> {{\n",
279            snake_name,
280            snake_name,
281            model.name()
282        ));
283        code.push_str(&format!(
284            "        {}::{}Operations::new(self.engine.clone())\n",
285            snake_name,
286            model.name()
287        ));
288        code.push_str("    }\n\n");
289    }
290
291    code.push_str("}\n");
292
293    Ok(code)
294}
295
296/// Generate a model module
297fn generate_model_module(
298    model: &prax_schema::ast::Model,
299    features: &[String],
300    relation_graph: &HashMap<String, HashSet<String>>,
301) -> CliResult<String> {
302    let mut code = String::new();
303
304    code.push_str(&format!(
305        "//! Auto-generated module for {} model\n\n",
306        model.name()
307    ));
308
309    // Import sibling types for relation fields
310    code.push_str("use super::*;\n");
311    code.push_str("use prax_query::traits::Model;\n\n");
312
313    // Derive macros based on features
314    let mut derives = vec!["Debug", "Clone"];
315    if features.contains(&"serde".to_string()) {
316        derives.push("serde::Serialize");
317        derives.push("serde::Deserialize");
318    }
319
320    // Model struct
321    code.push_str(&format!("#[derive({})]\n", derives.join(", ")));
322    code.push_str(&format!("pub struct {} {{\n", model.name()));
323
324    for field in model.fields.values() {
325        let field_name = to_snake_case(field.name());
326
327        // Add serde rename if mapped
328        if let Some(attr) = field.get_attribute("map") {
329            if features.contains(&"serde".to_string()) {
330                if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
331                    code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
332                }
333            }
334        }
335
336        let rust_type = field_type_to_rust_with_boxing(
337            &field.field_type,
338            field.modifier,
339            model.name(),
340            relation_graph,
341        );
342        code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
343    }
344
345    code.push_str("}\n\n");
346
347    // Model trait implementation
348    let table_name = model.table_name();
349    let id_fields: Vec<&str> = model.id_fields().iter().map(|f| f.name()).collect();
350    let scalar_columns: Vec<String> = model
351        .scalar_fields()
352        .iter()
353        .map(|f| {
354            // Use @map name if present, otherwise snake_case the field name
355            f.get_attribute("map")
356                .and_then(|a| a.first_arg())
357                .and_then(|v| v.as_string())
358                .map(|s| s.to_string())
359                .unwrap_or_else(|| to_snake_case(f.name()))
360        })
361        .collect();
362
363    code.push_str(&format!("impl Model for {} {{\n", model.name()));
364    code.push_str(&format!(
365        "    const MODEL_NAME: &'static str = \"{}\";\n",
366        model.name()
367    ));
368    code.push_str(&format!(
369        "    const TABLE_NAME: &'static str = \"{}\";\n",
370        table_name
371    ));
372    code.push_str(&format!(
373        "    const PRIMARY_KEY: &'static [&'static str] = &[{}];\n",
374        id_fields
375            .iter()
376            .map(|f| format!("\"{}\"", to_snake_case(f)))
377            .collect::<Vec<_>>()
378            .join(", ")
379    ));
380    code.push_str(&format!(
381        "    const COLUMNS: &'static [&'static str] = &[{}];\n",
382        scalar_columns
383            .iter()
384            .map(|c| format!("\"{}\"", c))
385            .collect::<Vec<_>>()
386            .join(", ")
387    ));
388    code.push_str("}\n\n");
389
390    // Operations struct (owned engine, no lifetime)
391    code.push_str(&format!("/// Operations for the {} model\n", model.name()));
392    code.push_str(&format!(
393        "pub struct {}Operations<E: prax_query::QueryEngine> {{\n",
394        model.name()
395    ));
396    code.push_str("    engine: E,\n");
397    code.push_str("}\n\n");
398
399    code.push_str(&format!(
400        "impl<E: prax_query::QueryEngine> {}Operations<E> {{\n",
401        model.name()
402    ));
403    code.push_str("    pub fn new(engine: E) -> Self {\n");
404    code.push_str("        Self { engine }\n");
405    code.push_str("    }\n\n");
406
407    // CRUD methods (1-arg constructors, no lifetime on return types)
408    code.push_str("    /// Find many records\n");
409    code.push_str(&format!(
410        "    pub fn find_many(&self) -> prax_query::FindManyOperation<E, {}> {{\n",
411        model.name()
412    ));
413    code.push_str("        prax_query::FindManyOperation::new(self.engine.clone())\n");
414    code.push_str("    }\n\n");
415
416    code.push_str("    /// Find a unique record\n");
417    code.push_str(&format!(
418        "    pub fn find_unique(&self) -> prax_query::FindUniqueOperation<E, {}> {{\n",
419        model.name()
420    ));
421    code.push_str("        prax_query::FindUniqueOperation::new(self.engine.clone())\n");
422    code.push_str("    }\n\n");
423
424    code.push_str("    /// Find the first matching record\n");
425    code.push_str(&format!(
426        "    pub fn find_first(&self) -> prax_query::FindFirstOperation<E, {}> {{\n",
427        model.name()
428    ));
429    code.push_str("        prax_query::FindFirstOperation::new(self.engine.clone())\n");
430    code.push_str("    }\n\n");
431
432    code.push_str("    /// Create a new record\n");
433    code.push_str(&format!(
434        "    pub fn create(&self) -> prax_query::CreateOperation<E, {}> {{\n",
435        model.name()
436    ));
437    code.push_str("        prax_query::CreateOperation::new(self.engine.clone())\n");
438    code.push_str("    }\n\n");
439
440    code.push_str("    /// Update a record\n");
441    code.push_str(&format!(
442        "    pub fn update(&self) -> prax_query::UpdateOperation<E, {}> {{\n",
443        model.name()
444    ));
445    code.push_str("        prax_query::UpdateOperation::new(self.engine.clone())\n");
446    code.push_str("    }\n\n");
447
448    code.push_str("    /// Delete a record\n");
449    code.push_str(&format!(
450        "    pub fn delete(&self) -> prax_query::DeleteOperation<E, {}> {{\n",
451        model.name()
452    ));
453    code.push_str("        prax_query::DeleteOperation::new(self.engine.clone())\n");
454    code.push_str("    }\n\n");
455
456    code.push_str("    /// Count records\n");
457    code.push_str(&format!(
458        "    pub fn count(&self) -> prax_query::CountOperation<E, {}> {{\n",
459        model.name()
460    ));
461    code.push_str("        prax_query::CountOperation::new(self.engine.clone())\n");
462    code.push_str("    }\n");
463
464    code.push_str("}\n");
465
466    Ok(code)
467}
468
469/// Generate an enum module
470fn generate_enum_module(enum_def: &prax_schema::ast::Enum) -> CliResult<String> {
471    let mut code = String::new();
472
473    code.push_str(&format!(
474        "//! Auto-generated module for {} enum\n\n",
475        enum_def.name()
476    ));
477
478    code.push_str(
479        "#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]\n",
480    );
481    code.push_str(&format!("pub enum {} {{\n", enum_def.name()));
482
483    for variant in &enum_def.variants {
484        let raw_name = variant.name();
485        let pascal_name = to_pascal_case(raw_name);
486
487        // Check for explicit @map attribute first
488        if let Some(attr) = variant.attributes.iter().find(|a| a.is("map")) {
489            if let Some(value) = attr.first_arg().and_then(|v| v.as_string()) {
490                code.push_str(&format!("    #[serde(rename = \"{}\")]\n", value));
491                code.push_str(&format!("    {},\n", pascal_name));
492                continue;
493            }
494        }
495
496        // If variant name differs from PascalCase form, add serde rename
497        if raw_name != pascal_name {
498            code.push_str(&format!("    #[serde(rename = \"{}\")]\n", raw_name));
499        }
500        code.push_str(&format!("    {},\n", pascal_name));
501    }
502
503    code.push_str("}\n\n");
504
505    // Display implementation for SQL serialization
506    code.push_str(&format!("impl std::fmt::Display for {} {{\n", enum_def.name()));
507    code.push_str("    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {\n");
508    code.push_str("        match self {\n");
509    for variant in &enum_def.variants {
510        let raw_name = variant.name();
511        let pascal_name = to_pascal_case(raw_name);
512        let db_value = variant.db_value();
513        code.push_str(&format!(
514            "            Self::{} => write!(f, \"{}\"),\n",
515            pascal_name, db_value
516        ));
517    }
518    code.push_str("        }\n");
519    code.push_str("    }\n");
520    code.push_str("}\n\n");
521
522    // Default implementation
523    if let Some(default_variant) = enum_def.variants.first() {
524        let pascal_name = to_pascal_case(default_variant.name());
525        code.push_str(&format!("impl Default for {} {{\n", enum_def.name()));
526        code.push_str(&format!(
527            "    fn default() -> Self {{\n        Self::{}\n    }}\n",
528            pascal_name
529        ));
530        code.push_str("}\n");
531    }
532
533    Ok(code)
534}
535
536/// Generate types module
537fn generate_types_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
538    let mut code = String::new();
539
540    code.push_str("//! Common type definitions\n\n");
541    code.push_str("pub use chrono::{DateTime, Utc};\n");
542    code.push_str("pub use uuid::Uuid;\n");
543    code.push_str("pub use serde_json::Value as Json;\n");
544    code.push_str("\n");
545
546    // Add any custom types from composite types
547    for composite in schema.types.values() {
548        code.push_str("#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]\n");
549        code.push_str(&format!("pub struct {} {{\n", composite.name()));
550        for field in composite.fields.values() {
551            let rust_type = field_type_to_rust(&field.field_type, field.modifier);
552            let field_name = to_snake_case(field.name());
553            code.push_str(&format!("    pub {}: {},\n", field_name, rust_type));
554        }
555        code.push_str("}\n\n");
556    }
557
558    Ok(code)
559}
560
561/// Generate filters module
562fn generate_filters_module(schema: &prax_schema::ast::Schema) -> CliResult<String> {
563    let mut code = String::new();
564
565    code.push_str("//! Filter types for queries\n\n");
566    code.push_str("use prax_query::filter::{Filter, ScalarFilter};\n");
567
568    // Collect all enum types referenced by model scalar fields
569    let mut referenced_enums = HashSet::new();
570    for model in schema.models.values() {
571        for field in model.fields.values() {
572            if !field.is_relation() {
573                if let prax_schema::ast::FieldType::Enum(ref name) = field.field_type {
574                    referenced_enums.insert(name.to_string());
575                }
576            }
577        }
578    }
579
580    // Import enum types
581    for enum_name in &referenced_enums {
582        code.push_str(&format!(
583            "use super::{}::{};\n",
584            to_snake_case(enum_name),
585            enum_name
586        ));
587    }
588
589    code.push_str("\n");
590
591    for model in schema.models.values() {
592        // Where input
593        code.push_str(&format!("/// Filter input for {} queries\n", model.name()));
594        code.push_str("#[derive(Debug, Default, Clone)]\n");
595        code.push_str(&format!("pub struct {}WhereInput {{\n", model.name()));
596
597        for field in model.fields.values() {
598            if !field.is_relation() {
599                let filter_type = field_to_filter_type(&field.field_type);
600                let field_name = to_snake_case(field.name());
601                code.push_str(&format!(
602                    "    pub {}: Option<{}>,\n",
603                    field_name, filter_type
604                ));
605            }
606        }
607
608        code.push_str("    pub and: Option<Vec<Self>>,\n");
609        code.push_str("    pub or: Option<Vec<Self>>,\n");
610        code.push_str("    pub not: Option<Box<Self>>,\n");
611        code.push_str("}\n\n");
612
613        // OrderBy input
614        code.push_str(&format!(
615            "/// Order by input for {} queries\n",
616            model.name()
617        ));
618        code.push_str("#[derive(Debug, Default, Clone)]\n");
619        code.push_str(&format!("pub struct {}OrderByInput {{\n", model.name()));
620
621        for field in model.fields.values() {
622            if !field.is_relation() {
623                let field_name = to_snake_case(field.name());
624                code.push_str(&format!(
625                    "    pub {}: Option<prax_query::SortOrder>,\n",
626                    field_name
627                ));
628            }
629        }
630
631        code.push_str("}\n\n");
632    }
633
634    Ok(code)
635}
636
637/// Convert a field type to Rust type (basic, without boxing)
638fn field_type_to_rust(
639    field_type: &prax_schema::ast::FieldType,
640    modifier: prax_schema::ast::TypeModifier,
641) -> String {
642    use prax_schema::ast::{FieldType, ScalarType, TypeModifier};
643
644    let base_type = match field_type {
645        FieldType::Scalar(scalar) => match scalar {
646            ScalarType::Int => "i32".to_string(),
647            ScalarType::BigInt => "i64".to_string(),
648            ScalarType::Float => "f64".to_string(),
649            ScalarType::String => "String".to_string(),
650            ScalarType::Boolean => "bool".to_string(),
651            ScalarType::DateTime => "chrono::DateTime<chrono::Utc>".to_string(),
652            ScalarType::Date => "chrono::NaiveDate".to_string(),
653            ScalarType::Time => "chrono::NaiveTime".to_string(),
654            ScalarType::Json => "serde_json::Value".to_string(),
655            ScalarType::Bytes => "Vec<u8>".to_string(),
656            ScalarType::Decimal => "rust_decimal::Decimal".to_string(),
657            ScalarType::Uuid => "uuid::Uuid".to_string(),
658            ScalarType::Cuid => "String".to_string(),
659            ScalarType::Cuid2 => "String".to_string(),
660            ScalarType::NanoId => "String".to_string(),
661            ScalarType::Ulid => "String".to_string(),
662            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "Vec<f32>".to_string(),
663            ScalarType::SparseVector(_) => "Vec<(u32, f32)>".to_string(),
664            ScalarType::Bit(_) => "Vec<u8>".to_string(),
665        },
666        FieldType::Model(name) => name.to_string(),
667        FieldType::Enum(name) => name.to_string(),
668        FieldType::Composite(name) => name.to_string(),
669        FieldType::Unsupported(_) => "serde_json::Value".to_string(),
670    };
671
672    match modifier {
673        TypeModifier::Optional | TypeModifier::OptionalList => format!("Option<{}>", base_type),
674        TypeModifier::List => format!("Vec<{}>", base_type),
675        TypeModifier::Required => base_type,
676    }
677}
678
679/// Convert a field type to Rust type with Box<T> wrapping for cyclic relations.
680fn field_type_to_rust_with_boxing(
681    field_type: &prax_schema::ast::FieldType,
682    modifier: prax_schema::ast::TypeModifier,
683    source_model: &str,
684    relation_graph: &HashMap<String, HashSet<String>>,
685) -> String {
686    use prax_schema::ast::{FieldType, TypeModifier};
687
688    // For model references (non-list), check if boxing is needed to break cycles
689    if let FieldType::Model(target) = field_type {
690        if !matches!(modifier, TypeModifier::List) {
691            let should_box = needs_boxing(source_model, target, relation_graph);
692            let base = target.to_string();
693            return match modifier {
694                TypeModifier::Optional | TypeModifier::OptionalList => {
695                    if should_box {
696                        format!("Option<Box<{}>>", base)
697                    } else {
698                        format!("Option<{}>", base)
699                    }
700                }
701                TypeModifier::Required => {
702                    if should_box {
703                        format!("Box<{}>", base)
704                    } else {
705                        base
706                    }
707                }
708                TypeModifier::List => unreachable!(),
709            };
710        }
711    }
712
713    // Fallback to basic conversion for non-cyclic fields
714    field_type_to_rust(field_type, modifier)
715}
716
717/// Convert a field type to filter type
718fn field_to_filter_type(field_type: &prax_schema::ast::FieldType) -> String {
719    use prax_schema::ast::{FieldType, ScalarType};
720
721    match field_type {
722        FieldType::Scalar(scalar) => match scalar {
723            ScalarType::Int | ScalarType::BigInt => "ScalarFilter<i64>".to_string(),
724            ScalarType::Float | ScalarType::Decimal => "ScalarFilter<f64>".to_string(),
725            ScalarType::String
726            | ScalarType::Uuid
727            | ScalarType::Cuid
728            | ScalarType::Cuid2
729            | ScalarType::NanoId
730            | ScalarType::Ulid => "ScalarFilter<String>".to_string(),
731            ScalarType::Boolean => "ScalarFilter<bool>".to_string(),
732            ScalarType::DateTime => "ScalarFilter<chrono::DateTime<chrono::Utc>>".to_string(),
733            ScalarType::Date => "ScalarFilter<chrono::NaiveDate>".to_string(),
734            ScalarType::Time => "ScalarFilter<chrono::NaiveTime>".to_string(),
735            ScalarType::Json => "ScalarFilter<serde_json::Value>".to_string(),
736            ScalarType::Bytes => "ScalarFilter<Vec<u8>>".to_string(),
737            // Vector types don't have standard scalar filters
738            ScalarType::Vector(_) | ScalarType::HalfVector(_) => "VectorFilter".to_string(),
739            ScalarType::SparseVector(_) => "SparseVectorFilter".to_string(),
740            ScalarType::Bit(_) => "BitFilter".to_string(),
741        },
742        FieldType::Enum(name) => format!("ScalarFilter<{}>", name),
743        _ => "Filter".to_string(),
744    }
745}
746
747/// Convert PascalCase to snake_case
748fn to_snake_case(name: &str) -> String {
749    let mut result = String::new();
750    for (i, c) in name.chars().enumerate() {
751        if c.is_uppercase() {
752            if i > 0 {
753                result.push('_');
754            }
755            result.push(c.to_lowercase().next().unwrap());
756        } else {
757            result.push(c);
758        }
759    }
760    result
761}
762
763/// Convert snake_case, SCREAMING_SNAKE_CASE, or any other casing to PascalCase.
764fn to_pascal_case(name: &str) -> String {
765    if name.is_empty() {
766        return String::new();
767    }
768
769    // If already PascalCase (starts with uppercase, contains lowercase), return as-is
770    let first = name.chars().next().unwrap();
771    if first.is_uppercase() && name.chars().any(|c| c.is_lowercase()) && !name.contains('_') {
772        return name.to_string();
773    }
774
775    // Split on underscores and capitalize each segment
776    name.split('_')
777        .filter(|s| !s.is_empty())
778        .map(|segment| {
779            let mut chars = segment.chars();
780            match chars.next() {
781                None => String::new(),
782                Some(first) => {
783                    let rest: String = chars.collect();
784                    format!("{}{}", first.to_uppercase(), rest.to_lowercase())
785                }
786            }
787        })
788        .collect()
789}
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794
795    #[test]
796    fn test_to_snake_case() {
797        assert_eq!(to_snake_case("BoardMember"), "board_member");
798        assert_eq!(to_snake_case("User"), "user");
799        assert_eq!(to_snake_case("JiraImportConfig"), "jira_import_config");
800    }
801
802    #[test]
803    fn test_to_pascal_case_from_snake() {
804        assert_eq!(to_pascal_case("card_created"), "CardCreated");
805        assert_eq!(to_pascal_case("branch_deleted"), "BranchDeleted");
806        assert_eq!(to_pascal_case("pr_merged"), "PrMerged");
807    }
808
809    #[test]
810    fn test_to_pascal_case_from_screaming() {
811        assert_eq!(to_pascal_case("CARD_CREATED"), "CardCreated");
812        assert_eq!(to_pascal_case("PR_MERGED"), "PrMerged");
813    }
814
815    #[test]
816    fn test_to_pascal_case_already_pascal() {
817        assert_eq!(to_pascal_case("Admin"), "Admin");
818        assert_eq!(to_pascal_case("SuperAdmin"), "SuperAdmin");
819        assert_eq!(to_pascal_case("Low"), "Low");
820    }
821
822    #[test]
823    fn test_to_pascal_case_single_word() {
824        assert_eq!(to_pascal_case("active"), "Active");
825        assert_eq!(to_pascal_case("ACTIVE"), "Active");
826    }
827
828    #[test]
829    fn test_needs_boxing_direct_cycle() {
830        let mut graph = HashMap::new();
831        graph.insert(
832            "Board".to_string(),
833            HashSet::from(["JiraConfig".to_string()]),
834        );
835        graph.insert(
836            "JiraConfig".to_string(),
837            HashSet::from(["Board".to_string()]),
838        );
839
840        assert!(needs_boxing("Board", "JiraConfig", &graph));
841        assert!(needs_boxing("JiraConfig", "Board", &graph));
842    }
843
844    #[test]
845    fn test_needs_boxing_no_cycle() {
846        let mut graph = HashMap::new();
847        graph.insert("Post".to_string(), HashSet::from(["User".to_string()]));
848        graph.insert("User".to_string(), HashSet::new());
849
850        assert!(!needs_boxing("Post", "User", &graph));
851    }
852
853    #[test]
854    fn test_needs_boxing_indirect_cycle() {
855        let mut graph = HashMap::new();
856        graph.insert("A".to_string(), HashSet::from(["B".to_string()]));
857        graph.insert("B".to_string(), HashSet::from(["C".to_string()]));
858        graph.insert("C".to_string(), HashSet::from(["A".to_string()]));
859
860        assert!(needs_boxing("A", "B", &graph));
861        assert!(needs_boxing("B", "C", &graph));
862        assert!(needs_boxing("C", "A", &graph));
863    }
864}