Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15/// Rust reserved keywords that cannot be used as identifiers.
16const RUST_KEYWORDS: &[&str] = &[
17    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24/// Returns true if the given name is a Rust reserved keyword.
25pub fn is_rust_keyword(name: &str) -> bool {
26    RUST_KEYWORDS.contains(&name)
27}
28
29/// Returns the imports needed for well-known extra derives.
30pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31    let mut imports = Vec::new();
32    let has = |name: &str| extra_derives.iter().any(|d| d == name);
33    if has("Serialize") || has("Deserialize") {
34        let mut parts = Vec::new();
35        if has("Serialize") {
36            parts.push("Serialize");
37        }
38        if has("Deserialize") {
39            parts.push("Deserialize");
40        }
41        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42    }
43    imports
44}
45
46/// Normalize a table name for use as a Rust module/filename:
47/// replace multiple consecutive underscores with a single one.
48pub fn normalize_module_name(name: &str) -> String {
49    let mut result = String::with_capacity(name.len());
50    let mut prev_underscore = false;
51    for c in name.chars() {
52        if c == '_' {
53            if !prev_underscore {
54                result.push(c);
55            }
56            prev_underscore = true;
57        } else {
58            prev_underscore = false;
59            result.push(c);
60        }
61    }
62    result
63}
64
65/// Well-known default schemas that don't need a prefix in filenames.
66const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68/// Returns true if the schema is a well-known default (public, main, dbo).
69pub fn is_default_schema(schema: &str) -> bool {
70    DEFAULT_SCHEMAS.contains(&schema)
71}
72
73/// Build a module name, prefixing with schema only when the name collides
74/// (same table name exists in multiple schemas).
75pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76    if name_collides && !is_default_schema(schema_name) {
77        normalize_module_name(&format!("{}_{}", schema_name, table_name))
78    } else {
79        normalize_module_name(table_name)
80    }
81}
82
83/// Find table/view names that appear in more than one schema.
84fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86    for t in &schema_info.tables {
87        seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88    }
89    for v in &schema_info.views {
90        seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91    }
92    seen.into_iter()
93        .filter(|(_, schemas)| schemas.len() > 1)
94        .map(|(name, _)| name)
95        .collect()
96}
97
98/// A generated code file with its content and required imports.
99#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101    pub filename: String,
102    /// Optional origin comment (e.g. "Table: schema.name")
103    pub origin: Option<String>,
104    pub code: String,
105}
106
107/// Generate all code for a given schema.
108pub fn generate(
109    schema_info: &SchemaInfo,
110    db_kind: DatabaseKind,
111    extra_derives: &[String],
112    type_overrides: &HashMap<String, String>,
113    single_file: bool,
114) -> Vec<GeneratedFile> {
115    let mut files = Vec::new();
116
117    // Detect table/view names that appear in multiple schemas (collisions)
118    let colliding_names = find_colliding_names(schema_info);
119
120    // Generate struct files for each table
121    for table in &schema_info.tables {
122        let (tokens, imports) =
123            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124        let imports = filter_imports(&imports, single_file);
125        let code = format_tokens_with_imports(&tokens, &imports);
126        let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127        files.push(GeneratedFile {
128            filename: format!("{}.rs", module_name),
129            origin: None,
130            code,
131        });
132    }
133
134    // Generate struct files for each view
135    for view in &schema_info.views {
136        let (tokens, imports) =
137            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
138        let imports = filter_imports(&imports, single_file);
139        let code = format_tokens_with_imports(&tokens, &imports);
140        let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
141        files.push(GeneratedFile {
142            filename: format!("{}.rs", module_name),
143            origin: None,
144            code,
145        });
146    }
147
148    // Generate types file (enums, composites, domains)
149    // Each item is formatted individually so we can insert blank lines between them.
150    let mut types_blocks: Vec<String> = Vec::new();
151    let mut types_imports = BTreeSet::new();
152
153    // Enrich enums with default variants extracted from column defaults
154    let enum_defaults = extract_enum_defaults(schema_info);
155    for enum_info in &schema_info.enums {
156        let mut enriched = enum_info.clone();
157        if enriched.default_variant.is_none() {
158            if let Some(default) = enum_defaults.get(&enum_info.name) {
159                enriched.default_variant = Some(default.clone());
160            }
161        }
162        let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
163        types_blocks.push(format_tokens(&tokens));
164        types_imports.extend(imports);
165    }
166
167    for composite in &schema_info.composite_types {
168        let (tokens, imports) = composite_gen::generate_composite(
169            composite,
170            db_kind,
171            schema_info,
172            extra_derives,
173            type_overrides,
174        );
175        types_blocks.push(format_tokens(&tokens));
176        types_imports.extend(imports);
177    }
178
179    for domain in &schema_info.domains {
180        let (tokens, imports) =
181            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
182        types_blocks.push(format_tokens(&tokens));
183        types_imports.extend(imports);
184    }
185
186    if !types_blocks.is_empty() {
187        let import_lines: String = types_imports
188            .iter()
189            .map(|i| format!("{}\n", i))
190            .collect();
191        let body = types_blocks.join("\n");
192        let code = if import_lines.is_empty() {
193            body
194        } else {
195            format!("{}\n\n{}", import_lines.trim_end(), body)
196        };
197        files.push(GeneratedFile {
198            filename: "types.rs".to_string(),
199            origin: None,
200            code,
201        });
202    }
203
204    files
205}
206
207/// Extract default variant values for enums by scanning column defaults across all tables and views.
208/// PostgreSQL column defaults look like `'idle'::task_status` or `'active'::public.task_status`.
209fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
210    let mut defaults: HashMap<String, String> = HashMap::new();
211
212    let all_columns = schema_info
213        .tables
214        .iter()
215        .chain(schema_info.views.iter())
216        .flat_map(|t| t.columns.iter());
217
218    for col in all_columns {
219        let default_expr = match &col.column_default {
220            Some(d) => d,
221            None => continue,
222        };
223
224        // Strip leading underscore for array types to get the base enum name
225        let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
226
227        // Check if this column references a known enum
228        let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
229        if enum_match.is_none() {
230            continue;
231        }
232
233        // Parse PG default: 'variant'::type_name
234        if let Some(variant) = parse_pg_enum_default(default_expr) {
235            defaults.entry(base_udt.to_string()).or_insert(variant);
236        }
237    }
238
239    defaults
240}
241
242/// Parse a PostgreSQL column default expression to extract the enum variant.
243/// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`.
244fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
245    // Pattern: 'value'::some_type
246    let stripped = default_expr.trim();
247    if stripped.starts_with('\'') {
248        if let Some(end_quote) = stripped[1..].find('\'') {
249            let value = &stripped[1..1 + end_quote];
250            // Verify there's a :: cast after the closing quote
251            let rest = &stripped[2 + end_quote..];
252            if rest.starts_with("::") {
253                return Some(value.to_string());
254            }
255        }
256    }
257    None
258}
259
260/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
261fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
262    if single_file {
263        imports
264            .iter()
265            .filter(|i| !i.contains("super::types::"))
266            .cloned()
267            .collect()
268    } else {
269        imports.clone()
270    }
271}
272
273/// Parse and format a TokenStream via prettyplease, then post-process spacing.
274pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
275    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
276        log::error!("Failed to parse generated code: {}", e);
277        log::error!("This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
278        std::process::exit(1);
279    });
280    let raw = prettyplease::unparse(&file);
281    add_blank_lines_between_items(&raw)
282}
283
284/// Format a single TokenStream block (no imports).
285pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
286    parse_and_format(tokens)
287}
288
289pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
290    let formatted = parse_and_format(tokens);
291
292    let used_imports: Vec<&String> = imports
293        .iter()
294        .filter(|imp| is_import_used(imp, &formatted))
295        .collect();
296
297    if used_imports.is_empty() {
298        formatted
299    } else {
300        let import_lines: String = used_imports
301            .iter()
302            .map(|i| format!("{}\n", i))
303            .collect();
304        format!("{}\n\n{}", import_lines.trim_end(), formatted)
305    }
306}
307
308/// Check if an import is actually used in the generated code.
309/// Extracts the imported type names and checks if they appear in the code.
310fn is_import_used(import: &str, code: &str) -> bool {
311    // "use foo::bar::Baz;" → check for "Baz"
312    // "use foo::{A, B};" → check for "A" or "B"
313    // "use foo::bar::*;" → always keep
314    let trimmed = import.trim().trim_end_matches(';');
315    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
316
317    if path.ends_with("::*") {
318        return true;
319    }
320
321    // Handle grouped imports: use foo::{A, B, C};
322    if let Some(start) = path.find('{') {
323        if let Some(end) = path.find('}') {
324            let names = &path[start + 1..end];
325            return names
326                .split(',')
327                .map(|n| n.trim())
328                .filter(|n| !n.is_empty())
329                .any(|name| code.contains(name));
330        }
331    }
332
333    // Simple import: use foo::Bar;
334    if let Some(name) = path.rsplit("::").next() {
335        return code.contains(name);
336    }
337
338    true
339}
340
341/// Post-process formatted code to:
342/// - Add blank lines between enum variants with `#[sqlx(rename`
343/// - Add blank lines between top-level items (structs, impls)
344/// - Add blank lines between logical blocks inside async methods
345fn add_blank_lines_between_items(code: &str) -> String {
346    let lines: Vec<&str> = code.lines().collect();
347    let mut result = Vec::with_capacity(lines.len());
348
349    for (i, line) in lines.iter().enumerate() {
350        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
351        // but not for the first variant in the enum.
352        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
353            let prev = lines[i - 1].trim();
354            if prev.ends_with(',') {
355                result.push("");
356            }
357        }
358
359        // Insert a blank line before top-level items (pub struct, impl, #[derive)
360        // and before methods inside impl blocks, when preceded by a closing brace `}`
361        if i > 0 {
362            let trimmed = line.trim();
363            let prev = lines[i - 1].trim();
364            if prev == "}"
365                && (trimmed.starts_with("pub struct")
366                    || trimmed.starts_with("impl ")
367                    || trimmed.starts_with("#[derive")
368                    || trimmed.starts_with("pub async fn")
369                    || trimmed.starts_with("pub fn"))
370            {
371                result.push("");
372            }
373        }
374
375        // Insert a blank line before a new logical block inside methods:
376        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
377        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
378        if i > 0 {
379            let trimmed = line.trim();
380            let prev = lines[i - 1].trim();
381            let prev_is_await_end = prev.ends_with(".await?;")
382                || prev.ends_with(".await?")
383                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
384            if prev_is_await_end
385                && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
386            {
387                result.push("");
388            }
389            // Separate a sqlx query `let` from preceding simple `let` assignments
390            if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
391                && prev.starts_with("let ") && !prev.contains("sqlx::")
392            {
393                result.push("");
394            }
395        }
396
397        result.push(line);
398    }
399
400    result.join("\n")
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::introspect::{
407        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
408    };
409    use std::collections::HashMap;
410
411    // ========== is_rust_keyword ==========
412
413    #[test]
414    fn test_keyword_type() {
415        assert!(is_rust_keyword("type"));
416    }
417
418    #[test]
419    fn test_keyword_fn() {
420        assert!(is_rust_keyword("fn"));
421    }
422
423    #[test]
424    fn test_keyword_let() {
425        assert!(is_rust_keyword("let"));
426    }
427
428    #[test]
429    fn test_keyword_match() {
430        assert!(is_rust_keyword("match"));
431    }
432
433    #[test]
434    fn test_keyword_async() {
435        assert!(is_rust_keyword("async"));
436    }
437
438    #[test]
439    fn test_keyword_await() {
440        assert!(is_rust_keyword("await"));
441    }
442
443    #[test]
444    fn test_keyword_yield() {
445        assert!(is_rust_keyword("yield"));
446    }
447
448    #[test]
449    fn test_keyword_abstract() {
450        assert!(is_rust_keyword("abstract"));
451    }
452
453    #[test]
454    fn test_keyword_try() {
455        assert!(is_rust_keyword("try"));
456    }
457
458    #[test]
459    fn test_not_keyword_name() {
460        assert!(!is_rust_keyword("name"));
461    }
462
463    #[test]
464    fn test_not_keyword_id() {
465        assert!(!is_rust_keyword("id"));
466    }
467
468    #[test]
469    fn test_not_keyword_uppercase_type() {
470        assert!(!is_rust_keyword("Type"));
471    }
472
473    // ========== normalize_module_name ==========
474
475    #[test]
476    fn test_normalize_no_underscores() {
477        assert_eq!(normalize_module_name("users"), "users");
478    }
479
480    #[test]
481    fn test_normalize_single_underscore() {
482        assert_eq!(normalize_module_name("user_roles"), "user_roles");
483    }
484
485    #[test]
486    fn test_normalize_double_underscore() {
487        assert_eq!(normalize_module_name("user__roles"), "user_roles");
488    }
489
490    #[test]
491    fn test_normalize_triple_underscore() {
492        assert_eq!(normalize_module_name("a___b"), "a_b");
493    }
494
495    #[test]
496    fn test_normalize_leading_underscore() {
497        assert_eq!(normalize_module_name("_private"), "_private");
498    }
499
500    #[test]
501    fn test_normalize_trailing_underscore() {
502        assert_eq!(normalize_module_name("name_"), "name_");
503    }
504
505    #[test]
506    fn test_normalize_double_leading() {
507        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
508    }
509
510    #[test]
511    fn test_normalize_multiple_groups() {
512        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
513    }
514
515    // ========== build_module_name ==========
516
517    #[test]
518    fn test_build_no_collision_no_prefix() {
519        assert_eq!(build_module_name("public", "users", false), "users");
520    }
521
522    #[test]
523    fn test_build_no_collision_non_default_no_prefix() {
524        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
525    }
526
527    #[test]
528    fn test_build_collision_prefixed() {
529        assert_eq!(build_module_name("billing", "users", true), "billing_users");
530    }
531
532    #[test]
533    fn test_build_collision_default_schema_no_prefix() {
534        assert_eq!(build_module_name("public", "users", true), "users");
535    }
536
537    #[test]
538    fn test_build_collision_normalizes_double_underscore() {
539        assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
540    }
541
542    // ========== is_default_schema ==========
543
544    #[test]
545    fn test_default_schema_public() {
546        assert!(is_default_schema("public"));
547    }
548
549    #[test]
550    fn test_default_schema_main() {
551        assert!(is_default_schema("main"));
552    }
553
554    #[test]
555    fn test_non_default_schema() {
556        assert!(!is_default_schema("billing"));
557    }
558
559    // ========== imports_for_derives ==========
560
561    #[test]
562    fn test_imports_empty() {
563        let result = imports_for_derives(&[]);
564        assert!(result.is_empty());
565    }
566
567    #[test]
568    fn test_imports_serialize_only() {
569        let derives = vec!["Serialize".to_string()];
570        let result = imports_for_derives(&derives);
571        assert_eq!(result, vec!["use serde::{Serialize};"]);
572    }
573
574    #[test]
575    fn test_imports_deserialize_only() {
576        let derives = vec!["Deserialize".to_string()];
577        let result = imports_for_derives(&derives);
578        assert_eq!(result, vec!["use serde::{Deserialize};"]);
579    }
580
581    #[test]
582    fn test_imports_both_serde() {
583        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
584        let result = imports_for_derives(&derives);
585        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
586    }
587
588    #[test]
589    fn test_imports_non_serde() {
590        let derives = vec!["Hash".to_string()];
591        let result = imports_for_derives(&derives);
592        assert!(result.is_empty());
593    }
594
595    #[test]
596    fn test_imports_non_serde_multiple() {
597        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
598        let result = imports_for_derives(&derives);
599        assert!(result.is_empty());
600    }
601
602    #[test]
603    fn test_imports_mixed_serde_and_others() {
604        let derives = vec![
605            "Serialize".to_string(),
606            "Hash".to_string(),
607            "Deserialize".to_string(),
608        ];
609        let result = imports_for_derives(&derives);
610        assert_eq!(result.len(), 1);
611        assert!(result[0].contains("Serialize"));
612        assert!(result[0].contains("Deserialize"));
613    }
614
615    // ========== add_blank_lines_between_items ==========
616
617    #[test]
618    fn test_blank_lines_between_renamed_variants() {
619        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
620        let result = add_blank_lines_between_items(input);
621        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
622    }
623
624    #[test]
625    fn test_no_blank_line_for_first_variant() {
626        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
627        let result = add_blank_lines_between_items(input);
628        // No blank line before first #[sqlx(rename because previous line is `{`
629        assert!(!result.contains("{\n\n"));
630    }
631
632    #[test]
633    fn test_no_change_without_rename() {
634        let input = "pub enum Foo {\n    A,\n    B,\n}";
635        let result = add_blank_lines_between_items(input);
636        assert_eq!(result, input);
637    }
638
639    #[test]
640    fn test_no_change_for_struct() {
641        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
642        let result = add_blank_lines_between_items(input);
643        assert_eq!(result, input);
644    }
645
646    // ========== filter_imports ==========
647
648    #[test]
649    fn test_filter_single_file_strips_super_types() {
650        let mut imports = BTreeSet::new();
651        imports.insert("use super::types::Foo;".to_string());
652        imports.insert("use chrono::NaiveDateTime;".to_string());
653        let result = filter_imports(&imports, true);
654        assert!(!result.contains("use super::types::Foo;"));
655        assert!(result.contains("use chrono::NaiveDateTime;"));
656    }
657
658    #[test]
659    fn test_filter_single_file_keeps_other_imports() {
660        let mut imports = BTreeSet::new();
661        imports.insert("use chrono::NaiveDateTime;".to_string());
662        let result = filter_imports(&imports, true);
663        assert!(result.contains("use chrono::NaiveDateTime;"));
664    }
665
666    #[test]
667    fn test_filter_multi_file_keeps_all() {
668        let mut imports = BTreeSet::new();
669        imports.insert("use super::types::Foo;".to_string());
670        imports.insert("use chrono::NaiveDateTime;".to_string());
671        let result = filter_imports(&imports, false);
672        assert_eq!(result.len(), 2);
673    }
674
675    #[test]
676    fn test_filter_empty_set() {
677        let imports = BTreeSet::new();
678        let result = filter_imports(&imports, true);
679        assert!(result.is_empty());
680    }
681
682    // ========== generate() orchestrator ==========
683
684    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
685        TableInfo {
686            schema_name: "public".to_string(),
687            name: name.to_string(),
688            columns,
689        }
690    }
691
692    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
693        ColumnInfo {
694            name: name.to_string(),
695            data_type: udt_name.to_string(),
696            udt_name: udt_name.to_string(),
697            is_nullable: false,
698            is_primary_key: false,
699            ordinal_position: 0,
700            schema_name: "public".to_string(),
701            column_default: None,
702        }
703    }
704
705    #[test]
706    fn test_generate_empty_schema() {
707        let schema = SchemaInfo::default();
708        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
709        assert!(files.is_empty());
710    }
711
712    #[test]
713    fn test_generate_one_table() {
714        let schema = SchemaInfo {
715            tables: vec![make_table("users", vec![make_col("id", "int4")])],
716            ..Default::default()
717        };
718        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
719        assert_eq!(files.len(), 1);
720        assert_eq!(files[0].filename, "users.rs");
721    }
722
723    #[test]
724    fn test_generate_two_tables() {
725        let schema = SchemaInfo {
726            tables: vec![
727                make_table("users", vec![make_col("id", "int4")]),
728                make_table("posts", vec![make_col("id", "int4")]),
729            ],
730            ..Default::default()
731        };
732        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
733        assert_eq!(files.len(), 2);
734    }
735
736    #[test]
737    fn test_generate_enum_creates_types_file() {
738        let schema = SchemaInfo {
739            enums: vec![EnumInfo {
740                schema_name: "public".to_string(),
741                name: "status".to_string(),
742                variants: vec!["active".to_string(), "inactive".to_string()],
743                default_variant: None,
744            }],
745            ..Default::default()
746        };
747        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
748        assert_eq!(files.len(), 1);
749        assert_eq!(files[0].filename, "types.rs");
750    }
751
752    #[test]
753    fn test_generate_enums_composites_domains_single_types_file() {
754        let schema = SchemaInfo {
755            enums: vec![EnumInfo {
756                schema_name: "public".to_string(),
757                name: "status".to_string(),
758                variants: vec!["active".to_string()],
759                default_variant: None,
760            }],
761            composite_types: vec![CompositeTypeInfo {
762                schema_name: "public".to_string(),
763                name: "address".to_string(),
764                fields: vec![make_col("street", "text")],
765            }],
766            domains: vec![DomainInfo {
767                schema_name: "public".to_string(),
768                name: "email".to_string(),
769                base_type: "text".to_string(),
770            }],
771            ..Default::default()
772        };
773        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
774        // Should produce exactly 1 types.rs
775        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
776        assert_eq!(types_files.len(), 1);
777    }
778
779    #[test]
780    fn test_generate_tables_and_enums() {
781        let schema = SchemaInfo {
782            tables: vec![make_table("users", vec![make_col("id", "int4")])],
783            enums: vec![EnumInfo {
784                schema_name: "public".to_string(),
785                name: "status".to_string(),
786                variants: vec!["active".to_string()],
787                default_variant: None,
788            }],
789            ..Default::default()
790        };
791        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
792        assert_eq!(files.len(), 2); // users.rs + types.rs
793    }
794
795    #[test]
796    fn test_generate_filename_normalized() {
797        let schema = SchemaInfo {
798            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
799            ..Default::default()
800        };
801        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
802        assert_eq!(files[0].filename, "user_data.rs");
803    }
804
805    #[test]
806    fn test_generate_no_origin_for_tables() {
807        let schema = SchemaInfo {
808            tables: vec![make_table("users", vec![make_col("id", "int4")])],
809            ..Default::default()
810        };
811        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
812        assert_eq!(files[0].origin, None);
813    }
814
815    #[test]
816    fn test_generate_types_no_origin() {
817        let schema = SchemaInfo {
818            enums: vec![EnumInfo {
819                schema_name: "public".to_string(),
820                name: "status".to_string(),
821                variants: vec!["active".to_string()],
822                default_variant: None,
823            }],
824            ..Default::default()
825        };
826        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
827        assert_eq!(files[0].origin, None);
828    }
829
830    #[test]
831    fn test_generate_single_file_filters_super_types_imports() {
832        let schema = SchemaInfo {
833            tables: vec![make_table("users", vec![make_col("id", "int4")])],
834            enums: vec![EnumInfo {
835                schema_name: "public".to_string(),
836                name: "status".to_string(),
837                variants: vec!["active".to_string()],
838                default_variant: None,
839            }],
840            ..Default::default()
841        };
842        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
843        // struct file should not have super::types:: imports
844        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
845        assert!(!struct_file.code.contains("super::types::"));
846    }
847
848    #[test]
849    fn test_generate_multi_file_keeps_super_types_imports() {
850        // Table with a column referencing an enum
851        let schema = SchemaInfo {
852            tables: vec![make_table("users", vec![make_col("status", "status")])],
853            enums: vec![EnumInfo {
854                schema_name: "public".to_string(),
855                name: "status".to_string(),
856                variants: vec!["active".to_string()],
857                default_variant: None,
858            }],
859            ..Default::default()
860        };
861        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
862        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
863        assert!(struct_file.code.contains("super::types::"));
864    }
865
866    #[test]
867    fn test_generate_extra_derives_in_struct() {
868        let schema = SchemaInfo {
869            tables: vec![make_table("users", vec![make_col("id", "int4")])],
870            ..Default::default()
871        };
872        let derives = vec!["Serialize".to_string()];
873        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
874        assert!(files[0].code.contains("Serialize"));
875    }
876
877    #[test]
878    fn test_generate_extra_derives_in_enum() {
879        let schema = SchemaInfo {
880            enums: vec![EnumInfo {
881                schema_name: "public".to_string(),
882                name: "status".to_string(),
883                variants: vec!["active".to_string()],
884                default_variant: None,
885            }],
886            ..Default::default()
887        };
888        let derives = vec!["Serialize".to_string()];
889        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
890        assert!(files[0].code.contains("Serialize"));
891    }
892
893    #[test]
894    fn test_generate_type_overrides_in_struct() {
895        let mut overrides = HashMap::new();
896        overrides.insert("jsonb".to_string(), "MyJson".to_string());
897        let schema = SchemaInfo {
898            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
899            ..Default::default()
900        };
901        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
902        assert!(files[0].code.contains("MyJson"));
903    }
904
905    #[test]
906    fn test_generate_valid_rust_syntax() {
907        let schema = SchemaInfo {
908            tables: vec![make_table("users", vec![
909                make_col("id", "int4"),
910                make_col("name", "text"),
911            ])],
912            enums: vec![EnumInfo {
913                schema_name: "public".to_string(),
914                name: "status".to_string(),
915                variants: vec!["active".to_string(), "inactive".to_string()],
916                default_variant: None,
917            }],
918            ..Default::default()
919        };
920        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
921        for f in &files {
922            // Should be parseable as valid Rust
923            let parse_result = syn::parse_file(&f.code);
924            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
925        }
926    }
927
928    // ========== generate() — views ==========
929
930    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
931        TableInfo {
932            schema_name: "public".to_string(),
933            name: name.to_string(),
934            columns,
935        }
936    }
937
938    #[test]
939    fn test_generate_one_view() {
940        let schema = SchemaInfo {
941            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
942            ..Default::default()
943        };
944        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
945        assert_eq!(files.len(), 1);
946        assert_eq!(files[0].filename, "active_users.rs");
947    }
948
949    #[test]
950    fn test_generate_no_origin_for_views() {
951        let schema = SchemaInfo {
952            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
953            ..Default::default()
954        };
955        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
956        assert_eq!(files[0].origin, None);
957    }
958
959    #[test]
960    fn test_generate_tables_and_views() {
961        let schema = SchemaInfo {
962            tables: vec![make_table("users", vec![make_col("id", "int4")])],
963            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
964            ..Default::default()
965        };
966        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
967        assert_eq!(files.len(), 2);
968    }
969
970    #[test]
971    fn test_generate_view_valid_rust() {
972        let schema = SchemaInfo {
973            views: vec![make_view("active_users", vec![
974                make_col("id", "int4"),
975                make_col("name", "text"),
976            ])],
977            ..Default::default()
978        };
979        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
980        let parse_result = syn::parse_file(&files[0].code);
981        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
982    }
983
984    #[test]
985    fn test_generate_view_nullable_column() {
986        let schema = SchemaInfo {
987            views: vec![make_view("v", vec![ColumnInfo {
988                name: "email".to_string(),
989                data_type: "text".to_string(),
990                udt_name: "text".to_string(),
991                is_nullable: true,
992                is_primary_key: false,
993                ordinal_position: 0,
994                schema_name: "public".to_string(),
995                column_default: None,
996            }])],
997            ..Default::default()
998        };
999        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1000        assert!(files[0].code.contains("Option<String>"));
1001    }
1002
1003    #[test]
1004    fn test_generate_collision_both_prefixed() {
1005        let schema = SchemaInfo {
1006            tables: vec![
1007                make_table("users", vec![make_col("id", "int4")]),
1008                TableInfo {
1009                    schema_name: "billing".to_string(),
1010                    name: "users".to_string(),
1011                    columns: vec![make_col("id", "int4")],
1012                },
1013            ],
1014            ..Default::default()
1015        };
1016        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1017        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1018        assert!(filenames.contains(&"users.rs"));
1019        assert!(filenames.contains(&"billing_users.rs"));
1020    }
1021
1022    #[test]
1023    fn test_generate_no_collision_no_prefix() {
1024        let schema = SchemaInfo {
1025            tables: vec![
1026                make_table("users", vec![make_col("id", "int4")]),
1027                TableInfo {
1028                    schema_name: "billing".to_string(),
1029                    name: "invoices".to_string(),
1030                    columns: vec![make_col("id", "int4")],
1031                },
1032            ],
1033            ..Default::default()
1034        };
1035        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1036        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1037        assert!(filenames.contains(&"users.rs"));
1038        assert!(filenames.contains(&"invoices.rs"));
1039    }
1040
1041    #[test]
1042    fn test_generate_single_schema_no_prefix() {
1043        let schema = SchemaInfo {
1044            tables: vec![
1045                make_table("users", vec![make_col("id", "int4")]),
1046                make_table("posts", vec![make_col("id", "int4")]),
1047            ],
1048            ..Default::default()
1049        };
1050        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1051        assert_eq!(files[0].filename, "users.rs");
1052        assert_eq!(files[1].filename, "posts.rs");
1053    }
1054
1055    #[test]
1056    fn test_generate_view_single_file_mode() {
1057        let schema = SchemaInfo {
1058            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1059            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1060            ..Default::default()
1061        };
1062        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
1063        assert_eq!(files.len(), 2);
1064    }
1065
1066    // ========== parse_pg_enum_default ==========
1067
1068    #[test]
1069    fn test_parse_pg_enum_default_simple() {
1070        assert_eq!(
1071            parse_pg_enum_default("'idle'::task_status"),
1072            Some("idle".to_string())
1073        );
1074    }
1075
1076    #[test]
1077    fn test_parse_pg_enum_default_schema_qualified() {
1078        assert_eq!(
1079            parse_pg_enum_default("'active'::public.task_status"),
1080            Some("active".to_string())
1081        );
1082    }
1083
1084    #[test]
1085    fn test_parse_pg_enum_default_not_enum() {
1086        // No single-quote pattern
1087        assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1088    }
1089
1090    #[test]
1091    fn test_parse_pg_enum_default_no_cast() {
1092        assert_eq!(parse_pg_enum_default("'hello'"), None);
1093    }
1094
1095    #[test]
1096    fn test_parse_pg_enum_default_empty() {
1097        assert_eq!(parse_pg_enum_default(""), None);
1098    }
1099
1100    // ========== extract_enum_defaults ==========
1101
1102    #[test]
1103    fn test_extract_enum_defaults_from_column() {
1104        let schema = SchemaInfo {
1105            tables: vec![TableInfo {
1106                schema_name: "public".to_string(),
1107                name: "tasks".to_string(),
1108                columns: vec![ColumnInfo {
1109                    name: "status".to_string(),
1110                    data_type: "USER-DEFINED".to_string(),
1111                    udt_name: "task_status".to_string(),
1112                    is_nullable: false,
1113                    is_primary_key: false,
1114                    ordinal_position: 0,
1115                    schema_name: "public".to_string(),
1116                    column_default: Some("'idle'::task_status".to_string()),
1117                }],
1118            }],
1119            enums: vec![EnumInfo {
1120                schema_name: "public".to_string(),
1121                name: "task_status".to_string(),
1122                variants: vec!["idle".to_string(), "running".to_string()],
1123                default_variant: None,
1124            }],
1125            ..Default::default()
1126        };
1127        let defaults = extract_enum_defaults(&schema);
1128        assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1129    }
1130
1131    #[test]
1132    fn test_extract_enum_defaults_no_default() {
1133        let schema = SchemaInfo {
1134            tables: vec![TableInfo {
1135                schema_name: "public".to_string(),
1136                name: "tasks".to_string(),
1137                columns: vec![ColumnInfo {
1138                    name: "status".to_string(),
1139                    data_type: "USER-DEFINED".to_string(),
1140                    udt_name: "task_status".to_string(),
1141                    is_nullable: false,
1142                    is_primary_key: false,
1143                    ordinal_position: 0,
1144                    schema_name: "public".to_string(),
1145                    column_default: None,
1146                }],
1147            }],
1148            enums: vec![EnumInfo {
1149                schema_name: "public".to_string(),
1150                name: "task_status".to_string(),
1151                variants: vec!["idle".to_string()],
1152                default_variant: None,
1153            }],
1154            ..Default::default()
1155        };
1156        let defaults = extract_enum_defaults(&schema);
1157        assert!(defaults.is_empty());
1158    }
1159
1160    #[test]
1161    fn test_extract_enum_defaults_non_enum_column_ignored() {
1162        let schema = SchemaInfo {
1163            tables: vec![TableInfo {
1164                schema_name: "public".to_string(),
1165                name: "users".to_string(),
1166                columns: vec![ColumnInfo {
1167                    name: "name".to_string(),
1168                    data_type: "character varying".to_string(),
1169                    udt_name: "varchar".to_string(),
1170                    is_nullable: false,
1171                    is_primary_key: false,
1172                    ordinal_position: 0,
1173                    schema_name: "public".to_string(),
1174                    column_default: Some("'hello'::character varying".to_string()),
1175                }],
1176            }],
1177            enums: vec![],
1178            ..Default::default()
1179        };
1180        let defaults = extract_enum_defaults(&schema);
1181        assert!(defaults.is_empty());
1182    }
1183
1184    #[test]
1185    fn test_generate_enum_with_default() {
1186        let schema = SchemaInfo {
1187            tables: vec![TableInfo {
1188                schema_name: "public".to_string(),
1189                name: "tasks".to_string(),
1190                columns: vec![ColumnInfo {
1191                    name: "status".to_string(),
1192                    data_type: "USER-DEFINED".to_string(),
1193                    udt_name: "task_status".to_string(),
1194                    is_nullable: false,
1195                    is_primary_key: false,
1196                    ordinal_position: 0,
1197                    schema_name: "public".to_string(),
1198                    column_default: Some("'idle'::task_status".to_string()),
1199                }],
1200            }],
1201            enums: vec![EnumInfo {
1202                schema_name: "public".to_string(),
1203                name: "task_status".to_string(),
1204                variants: vec!["idle".to_string(), "running".to_string()],
1205                default_variant: None,
1206            }],
1207            ..Default::default()
1208        };
1209        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1210        let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1211        assert!(types_file.code.contains("impl Default for TaskStatus"));
1212        assert!(types_file.code.contains("Self::Idle"));
1213    }
1214}