Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9use std::path::Path;
10
11use proc_macro2::TokenStream;
12
13use crate::cli::{DatabaseKind, TimeCrate};
14use crate::introspect::SchemaInfo;
15
16/// Rust reserved keywords that cannot be used as identifiers.
17const RUST_KEYWORDS: &[&str] = &[
18    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
19    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
20    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
21    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
22    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
23];
24
25/// Returns true if the given name is a Rust reserved keyword.
26pub fn is_rust_keyword(name: &str) -> bool {
27    RUST_KEYWORDS.contains(&name)
28}
29
30/// Returns the imports needed for well-known extra derives.
31pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
32    let mut imports = Vec::new();
33    let has = |name: &str| extra_derives.iter().any(|d| d == name);
34    if has("Serialize") || has("Deserialize") {
35        let mut parts = Vec::new();
36        if has("Serialize") {
37            parts.push("Serialize");
38        }
39        if has("Deserialize") {
40            parts.push("Deserialize");
41        }
42        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
43    }
44    imports
45}
46
47/// Normalize a table name for use as a Rust module/filename:
48/// replace multiple consecutive underscores with a single one.
49pub fn normalize_module_name(name: &str) -> String {
50    let mut result = String::with_capacity(name.len());
51    let mut prev_underscore = false;
52    for c in name.chars() {
53        if c == '_' {
54            if !prev_underscore {
55                result.push(c);
56            }
57            prev_underscore = true;
58        } else {
59            prev_underscore = false;
60            result.push(c);
61        }
62    }
63    result
64}
65
66/// Well-known default schemas that don't need a prefix in filenames.
67const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
68
69/// Returns true if the schema is a well-known default (public, main, dbo).
70pub fn is_default_schema(schema: &str) -> bool {
71    DEFAULT_SCHEMAS.contains(&schema)
72}
73
74/// Build a module name, prefixing with schema only when the name collides
75/// (same table name exists in multiple schemas).
76pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
77    if name_collides && !is_default_schema(schema_name) {
78        normalize_module_name(&format!("{}_{}", schema_name, table_name))
79    } else {
80        normalize_module_name(table_name)
81    }
82}
83
84/// Find table/view names that appear in more than one schema.
85fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
86    let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
87    for t in &schema_info.tables {
88        seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
89    }
90    for v in &schema_info.views {
91        seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
92    }
93    seen.into_iter()
94        .filter(|(_, schemas)| schemas.len() > 1)
95        .map(|(name, _)| name)
96        .collect()
97}
98
99/// A generated code file with its content and required imports.
100#[derive(Debug, Clone)]
101pub struct GeneratedFile {
102    pub filename: String,
103    /// Optional origin comment (e.g. "Table: schema.name")
104    pub origin: Option<String>,
105    pub code: String,
106}
107
108/// Generate all code for a given schema.
109pub fn generate(
110    schema_info: &SchemaInfo,
111    db_kind: DatabaseKind,
112    extra_derives: &[String],
113    type_overrides: &HashMap<String, String>,
114    single_file: bool,
115    time_crate: TimeCrate,
116) -> Vec<GeneratedFile> {
117    let mut files = Vec::new();
118
119    // Detect table/view names that appear in multiple schemas (collisions)
120    let colliding_names = find_colliding_names(schema_info);
121
122    // Generate struct files for each table
123    for table in &schema_info.tables {
124        let (tokens, imports) =
125            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false, time_crate);
126        let imports = filter_imports(&imports, single_file);
127        let code = format_tokens_with_imports(&tokens, &imports);
128        let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
129        files.push(GeneratedFile {
130            filename: format!("{}.rs", module_name),
131            origin: None,
132            code,
133        });
134    }
135
136    // Generate struct files for each view
137    for view in &schema_info.views {
138        let (tokens, imports) =
139            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true, time_crate);
140        let imports = filter_imports(&imports, single_file);
141        let code = format_tokens_with_imports(&tokens, &imports);
142        let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
143        files.push(GeneratedFile {
144            filename: format!("{}.rs", module_name),
145            origin: None,
146            code,
147        });
148    }
149
150    // Generate types file (enums, composites, domains)
151    // Each item is formatted individually so we can insert blank lines between them.
152    let mut types_blocks: Vec<String> = Vec::new();
153    let mut types_imports = BTreeSet::new();
154
155    // Enrich enums with default variants extracted from column defaults
156    let enum_defaults = extract_enum_defaults(schema_info);
157    for enum_info in &schema_info.enums {
158        let mut enriched = enum_info.clone();
159        if enriched.default_variant.is_none() {
160            if let Some(default) = enum_defaults.get(&enum_info.name) {
161                enriched.default_variant = Some(default.clone());
162            }
163        }
164        let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
165        types_blocks.push(format_tokens(&tokens));
166        types_imports.extend(imports);
167    }
168
169    for composite in &schema_info.composite_types {
170        let (tokens, imports) = composite_gen::generate_composite(
171            composite,
172            db_kind,
173            schema_info,
174            extra_derives,
175            type_overrides,
176            time_crate,
177        );
178        types_blocks.push(format_tokens(&tokens));
179        types_imports.extend(imports);
180    }
181
182    for domain in &schema_info.domains {
183        let (tokens, imports) =
184            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides, time_crate);
185        types_blocks.push(format_tokens(&tokens));
186        types_imports.extend(imports);
187    }
188
189    if !types_blocks.is_empty() {
190        let import_lines: String = types_imports
191            .iter()
192            .map(|i| format!("{}\n", i))
193            .collect();
194        let body = types_blocks.join("\n");
195        let code = if import_lines.is_empty() {
196            body
197        } else {
198            format!("{}\n\n{}", import_lines.trim_end(), body)
199        };
200        files.push(GeneratedFile {
201            filename: "types.rs".to_string(),
202            origin: None,
203            code,
204        });
205    }
206
207    files
208}
209
210/// Extract default variant values for enums by scanning column defaults across all tables and views.
211/// PostgreSQL column defaults look like `'idle'::task_status` or `'active'::public.task_status`.
212fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
213    let mut defaults: HashMap<String, String> = HashMap::new();
214
215    let all_columns = schema_info
216        .tables
217        .iter()
218        .chain(schema_info.views.iter())
219        .flat_map(|t| t.columns.iter());
220
221    for col in all_columns {
222        let default_expr = match &col.column_default {
223            Some(d) => d,
224            None => continue,
225        };
226
227        // Strip leading underscore for array types to get the base enum name
228        let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
229
230        // Check if this column references a known enum
231        let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
232        if enum_match.is_none() {
233            continue;
234        }
235
236        // Parse PG default: 'variant'::type_name
237        if let Some(variant) = parse_pg_enum_default(default_expr) {
238            defaults.entry(base_udt.to_string()).or_insert(variant);
239        }
240    }
241
242    defaults
243}
244
245/// Parse a PostgreSQL column default expression to extract the enum variant.
246/// Handles formats like `'idle'::task_status`, `'idle'::public.task_status`.
247fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
248    // Pattern: 'value'::some_type
249    let stripped = default_expr.trim();
250    if stripped.starts_with('\'') {
251        if let Some(end_quote) = stripped[1..].find('\'') {
252            let value = &stripped[1..1 + end_quote];
253            // Verify there's a :: cast after the closing quote
254            let rest = &stripped[2 + end_quote..];
255            if rest.starts_with("::") {
256                return Some(value.to_string());
257            }
258        }
259    }
260    None
261}
262
263/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
264fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
265    if single_file {
266        imports
267            .iter()
268            .filter(|i| !i.contains("super::types::"))
269            .cloned()
270            .collect()
271    } else {
272        imports.clone()
273    }
274}
275
276/// Detect `tab_spaces` from `rustfmt.toml` or `.rustfmt.toml` by walking up
277/// from `start_dir`. Returns 4 (rustfmt default) if no config is found.
278pub fn detect_tab_spaces(start_dir: &Path) -> usize {
279    let mut dir = if start_dir.is_file() {
280        start_dir.parent().unwrap_or(start_dir)
281    } else {
282        start_dir
283    };
284    loop {
285        for name in &["rustfmt.toml", ".rustfmt.toml"] {
286            let candidate = dir.join(name);
287            if let Ok(content) = std::fs::read_to_string(&candidate) {
288                for line in content.lines() {
289                    let line = line.trim();
290                    if let Some(rest) = line.strip_prefix("tab_spaces") {
291                        let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
292                        if let Ok(n) = rest.trim().parse::<usize>() {
293                            return n;
294                        }
295                    }
296                }
297                // Config found but no tab_spaces key → use rustfmt default
298                return 4;
299            }
300        }
301        match dir.parent() {
302            Some(parent) => dir = parent,
303            None => return 4,
304        }
305    }
306}
307
308/// Parse and format a TokenStream via prettyplease, then post-process spacing.
309/// `tab_spaces` controls how many spaces per indentation level for SQL inside raw strings.
310pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
311    parse_and_format_with_tab_spaces(tokens, 4)
312}
313
314pub(crate) fn parse_and_format_with_tab_spaces(tokens: &TokenStream, tab_spaces: usize) -> String {
315    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
316        log::error!("Failed to parse generated code: {}", e);
317        log::error!("This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
318        std::process::exit(1);
319    });
320    let raw = prettyplease::unparse(&file);
321    let raw = indent_multiline_raw_strings(&raw, tab_spaces);
322    add_blank_lines_between_items(&raw)
323}
324
325/// Format a single TokenStream block (no imports).
326pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
327    parse_and_format(tokens)
328}
329
330pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
331    format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
332}
333
334pub fn format_tokens_with_imports_and_tab_spaces(tokens: &TokenStream, imports: &BTreeSet<String>, tab_spaces: usize) -> String {
335    let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces);
336
337    let used_imports: Vec<&String> = imports
338        .iter()
339        .filter(|imp| is_import_used(imp, &formatted))
340        .collect();
341
342    if used_imports.is_empty() {
343        formatted
344    } else {
345        let import_lines: String = used_imports
346            .iter()
347            .map(|i| format!("{}\n", i))
348            .collect();
349        format!("{}\n\n{}", import_lines.trim_end(), formatted)
350    }
351}
352
353/// Check if an import is actually used in the generated code.
354/// Extracts the imported type names and checks if they appear in the code.
355fn is_import_used(import: &str, code: &str) -> bool {
356    // "use foo::bar::Baz;" → check for "Baz"
357    // "use foo::{A, B};" → check for "A" or "B"
358    // "use foo::bar::*;" → always keep
359    let trimmed = import.trim().trim_end_matches(';');
360    let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
361
362    if path.ends_with("::*") {
363        return true;
364    }
365
366    // Handle grouped imports: use foo::{A, B, C};
367    if let Some(start) = path.find('{') {
368        if let Some(end) = path.find('}') {
369            let names = &path[start + 1..end];
370            return names
371                .split(',')
372                .map(|n| n.trim())
373                .filter(|n| !n.is_empty())
374                .any(|name| code.contains(name));
375        }
376    }
377
378    // Simple import: use foo::Bar;
379    if let Some(name) = path.rsplit("::").next() {
380        return code.contains(name);
381    }
382
383    true
384}
385
386/// Post-process formatted code to:
387/// - Add blank lines between enum variants with `#[sqlx(rename`
388/// - Add blank lines between top-level items (structs, impls)
389/// - Add blank lines between logical blocks inside async methods
390/// Indent the content of multi-line raw string literals (`r#"..."#`) so SQL
391/// reads naturally in generated code. All SQL raw strings live inside `impl`
392/// methods, so content is indented at a fixed 2-level depth and relative
393/// indentation between SQL lines is preserved (e.g. SET items indented under
394/// the SET keyword).
395fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
396    // Raw string content is NOT reformatted by external formatters, so we must
397    // bake the right indentation at generation time.
398    // The closing "# aligns with the r#" argument level (3 indent levels deep),
399    // and SQL content gets one extra level beyond that.
400    let close_indent = 4 + tab_spaces;   // impl(4) + fn_arg(tab)
401    let sql_indent = 4 + 2 * tab_spaces; // impl(4) + fn_arg(tab) + sql(tab)
402
403    let lines: Vec<&str> = code.lines().collect();
404    let mut result = Vec::with_capacity(lines.len());
405    let mut inside_raw = false;
406    let mut raw_lines: Vec<&str> = Vec::new();
407
408    for line in &lines {
409        if !inside_raw {
410            if let Some(pos) = line.find("r#\"") {
411                let after = &line[pos + 3..];
412                if !after.contains("\"#") {
413                    inside_raw = true;
414                    raw_lines.clear();
415                }
416            }
417            result.push(line.to_string());
418        } else if line.trim_start().starts_with("\"#") {
419            // Find minimum indentation among non-empty content lines
420            let min_indent = raw_lines
421                .iter()
422                .filter(|l| !l.trim().is_empty())
423                .map(|l| l.len() - l.trim_start().len())
424                .min()
425                .unwrap_or(0);
426            for raw_line in &raw_lines {
427                let trimmed = raw_line.trim();
428                if trimmed.is_empty() {
429                    result.push(String::new());
430                } else {
431                    let original_indent = raw_line.len() - raw_line.trim_start().len();
432                    let relative = original_indent.saturating_sub(min_indent);
433                    result.push(format!(
434                        "{}{}{}",
435                        " ".repeat(sql_indent),
436                        " ".repeat(relative),
437                        trimmed
438                    ));
439                }
440            }
441            // Closing "# at method body level
442            let trimmed = line.trim();
443            result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
444            inside_raw = false;
445        } else {
446            raw_lines.push(line);
447        }
448    }
449
450    result.join("\n")
451}
452
453fn add_blank_lines_between_items(code: &str) -> String {
454    let lines: Vec<&str> = code.lines().collect();
455    let mut result = Vec::with_capacity(lines.len());
456
457    for (i, line) in lines.iter().enumerate() {
458        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
459        // but not for the first variant in the enum.
460        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
461            let prev = lines[i - 1].trim();
462            if prev.ends_with(',') {
463                result.push("");
464            }
465        }
466
467        // Insert a blank line before top-level items (pub struct, impl, #[derive)
468        // and before methods inside impl blocks, when preceded by a closing brace `}`
469        if i > 0 {
470            let trimmed = line.trim();
471            let prev = lines[i - 1].trim();
472            if prev == "}"
473                && (trimmed.starts_with("pub struct")
474                    || trimmed.starts_with("impl ")
475                    || trimmed.starts_with("#[derive")
476                    || trimmed.starts_with("pub async fn")
477                    || trimmed.starts_with("pub fn"))
478            {
479                result.push("");
480            }
481        }
482
483        // Insert a blank line before a new logical block inside methods:
484        // - before `let` or `Ok(` when preceded by `.await?;` or `.unwrap_or(…);`
485        // - before `let … = sqlx::` when preceded by a simple `let … = …;` (not sqlx)
486        if i > 0 {
487            let trimmed = line.trim();
488            let prev = lines[i - 1].trim();
489            let prev_is_await_end = prev.ends_with(".await?;")
490                || prev.ends_with(".await?")
491                || (prev.ends_with(';') && prev.contains(".unwrap_or("));
492            if prev_is_await_end
493                && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
494            {
495                result.push("");
496            }
497            // Separate a sqlx query `let` from preceding simple `let` assignments
498            if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
499                && prev.starts_with("let ") && !prev.contains("sqlx::")
500            {
501                result.push("");
502            }
503        }
504
505        result.push(line);
506    }
507
508    result.join("\n")
509}
510
511#[cfg(test)]
512mod tests {
513    use super::*;
514    use crate::introspect::{
515        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
516    };
517    use std::collections::HashMap;
518
519    // ========== is_rust_keyword ==========
520
521    #[test]
522    fn test_keyword_type() {
523        assert!(is_rust_keyword("type"));
524    }
525
526    #[test]
527    fn test_keyword_fn() {
528        assert!(is_rust_keyword("fn"));
529    }
530
531    #[test]
532    fn test_keyword_let() {
533        assert!(is_rust_keyword("let"));
534    }
535
536    #[test]
537    fn test_keyword_match() {
538        assert!(is_rust_keyword("match"));
539    }
540
541    #[test]
542    fn test_keyword_async() {
543        assert!(is_rust_keyword("async"));
544    }
545
546    #[test]
547    fn test_keyword_await() {
548        assert!(is_rust_keyword("await"));
549    }
550
551    #[test]
552    fn test_keyword_yield() {
553        assert!(is_rust_keyword("yield"));
554    }
555
556    #[test]
557    fn test_keyword_abstract() {
558        assert!(is_rust_keyword("abstract"));
559    }
560
561    #[test]
562    fn test_keyword_try() {
563        assert!(is_rust_keyword("try"));
564    }
565
566    #[test]
567    fn test_not_keyword_name() {
568        assert!(!is_rust_keyword("name"));
569    }
570
571    #[test]
572    fn test_not_keyword_id() {
573        assert!(!is_rust_keyword("id"));
574    }
575
576    #[test]
577    fn test_not_keyword_uppercase_type() {
578        assert!(!is_rust_keyword("Type"));
579    }
580
581    // ========== normalize_module_name ==========
582
583    #[test]
584    fn test_normalize_no_underscores() {
585        assert_eq!(normalize_module_name("users"), "users");
586    }
587
588    #[test]
589    fn test_normalize_single_underscore() {
590        assert_eq!(normalize_module_name("user_roles"), "user_roles");
591    }
592
593    #[test]
594    fn test_normalize_double_underscore() {
595        assert_eq!(normalize_module_name("user__roles"), "user_roles");
596    }
597
598    #[test]
599    fn test_normalize_triple_underscore() {
600        assert_eq!(normalize_module_name("a___b"), "a_b");
601    }
602
603    #[test]
604    fn test_normalize_leading_underscore() {
605        assert_eq!(normalize_module_name("_private"), "_private");
606    }
607
608    #[test]
609    fn test_normalize_trailing_underscore() {
610        assert_eq!(normalize_module_name("name_"), "name_");
611    }
612
613    #[test]
614    fn test_normalize_double_leading() {
615        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
616    }
617
618    #[test]
619    fn test_normalize_multiple_groups() {
620        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
621    }
622
623    // ========== build_module_name ==========
624
625    #[test]
626    fn test_build_no_collision_no_prefix() {
627        assert_eq!(build_module_name("public", "users", false), "users");
628    }
629
630    #[test]
631    fn test_build_no_collision_non_default_no_prefix() {
632        assert_eq!(build_module_name("billing", "invoices", false), "invoices");
633    }
634
635    #[test]
636    fn test_build_collision_prefixed() {
637        assert_eq!(build_module_name("billing", "users", true), "billing_users");
638    }
639
640    #[test]
641    fn test_build_collision_default_schema_no_prefix() {
642        assert_eq!(build_module_name("public", "users", true), "users");
643    }
644
645    #[test]
646    fn test_build_collision_normalizes_double_underscore() {
647        assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
648    }
649
650    // ========== is_default_schema ==========
651
652    #[test]
653    fn test_default_schema_public() {
654        assert!(is_default_schema("public"));
655    }
656
657    #[test]
658    fn test_default_schema_main() {
659        assert!(is_default_schema("main"));
660    }
661
662    #[test]
663    fn test_non_default_schema() {
664        assert!(!is_default_schema("billing"));
665    }
666
667    // ========== imports_for_derives ==========
668
669    #[test]
670    fn test_imports_empty() {
671        let result = imports_for_derives(&[]);
672        assert!(result.is_empty());
673    }
674
675    #[test]
676    fn test_imports_serialize_only() {
677        let derives = vec!["Serialize".to_string()];
678        let result = imports_for_derives(&derives);
679        assert_eq!(result, vec!["use serde::{Serialize};"]);
680    }
681
682    #[test]
683    fn test_imports_deserialize_only() {
684        let derives = vec!["Deserialize".to_string()];
685        let result = imports_for_derives(&derives);
686        assert_eq!(result, vec!["use serde::{Deserialize};"]);
687    }
688
689    #[test]
690    fn test_imports_both_serde() {
691        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
692        let result = imports_for_derives(&derives);
693        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
694    }
695
696    #[test]
697    fn test_imports_non_serde() {
698        let derives = vec!["Hash".to_string()];
699        let result = imports_for_derives(&derives);
700        assert!(result.is_empty());
701    }
702
703    #[test]
704    fn test_imports_non_serde_multiple() {
705        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
706        let result = imports_for_derives(&derives);
707        assert!(result.is_empty());
708    }
709
710    #[test]
711    fn test_imports_mixed_serde_and_others() {
712        let derives = vec![
713            "Serialize".to_string(),
714            "Hash".to_string(),
715            "Deserialize".to_string(),
716        ];
717        let result = imports_for_derives(&derives);
718        assert_eq!(result.len(), 1);
719        assert!(result[0].contains("Serialize"));
720        assert!(result[0].contains("Deserialize"));
721    }
722
723    // ========== add_blank_lines_between_items ==========
724
725    #[test]
726    fn test_blank_lines_between_renamed_variants() {
727        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
728        let result = add_blank_lines_between_items(input);
729        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
730    }
731
732    #[test]
733    fn test_no_blank_line_for_first_variant() {
734        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
735        let result = add_blank_lines_between_items(input);
736        // No blank line before first #[sqlx(rename because previous line is `{`
737        assert!(!result.contains("{\n\n"));
738    }
739
740    #[test]
741    fn test_no_change_without_rename() {
742        let input = "pub enum Foo {\n    A,\n    B,\n}";
743        let result = add_blank_lines_between_items(input);
744        assert_eq!(result, input);
745    }
746
747    #[test]
748    fn test_no_change_for_struct() {
749        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
750        let result = add_blank_lines_between_items(input);
751        assert_eq!(result, input);
752    }
753
754    // ========== filter_imports ==========
755
756    #[test]
757    fn test_filter_single_file_strips_super_types() {
758        let mut imports = BTreeSet::new();
759        imports.insert("use super::types::Foo;".to_string());
760        imports.insert("use chrono::NaiveDateTime;".to_string());
761        let result = filter_imports(&imports, true);
762        assert!(!result.contains("use super::types::Foo;"));
763        assert!(result.contains("use chrono::NaiveDateTime;"));
764    }
765
766    #[test]
767    fn test_filter_single_file_keeps_other_imports() {
768        let mut imports = BTreeSet::new();
769        imports.insert("use chrono::NaiveDateTime;".to_string());
770        let result = filter_imports(&imports, true);
771        assert!(result.contains("use chrono::NaiveDateTime;"));
772    }
773
774    #[test]
775    fn test_filter_multi_file_keeps_all() {
776        let mut imports = BTreeSet::new();
777        imports.insert("use super::types::Foo;".to_string());
778        imports.insert("use chrono::NaiveDateTime;".to_string());
779        let result = filter_imports(&imports, false);
780        assert_eq!(result.len(), 2);
781    }
782
783    #[test]
784    fn test_filter_empty_set() {
785        let imports = BTreeSet::new();
786        let result = filter_imports(&imports, true);
787        assert!(result.is_empty());
788    }
789
790    // ========== generate() orchestrator ==========
791
792    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
793        TableInfo {
794            schema_name: "public".to_string(),
795            name: name.to_string(),
796            columns,
797        }
798    }
799
800    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
801        ColumnInfo {
802            name: name.to_string(),
803            data_type: udt_name.to_string(),
804            udt_name: udt_name.to_string(),
805            is_nullable: false,
806            is_primary_key: false,
807            ordinal_position: 0,
808            schema_name: "public".to_string(),
809            column_default: None,
810        }
811    }
812
813    #[test]
814    fn test_generate_empty_schema() {
815        let schema = SchemaInfo::default();
816        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
817        assert!(files.is_empty());
818    }
819
820    #[test]
821    fn test_generate_one_table() {
822        let schema = SchemaInfo {
823            tables: vec![make_table("users", vec![make_col("id", "int4")])],
824            ..Default::default()
825        };
826        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
827        assert_eq!(files.len(), 1);
828        assert_eq!(files[0].filename, "users.rs");
829    }
830
831    #[test]
832    fn test_generate_two_tables() {
833        let schema = SchemaInfo {
834            tables: vec![
835                make_table("users", vec![make_col("id", "int4")]),
836                make_table("posts", vec![make_col("id", "int4")]),
837            ],
838            ..Default::default()
839        };
840        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
841        assert_eq!(files.len(), 2);
842    }
843
844    #[test]
845    fn test_generate_enum_creates_types_file() {
846        let schema = SchemaInfo {
847            enums: vec![EnumInfo {
848                schema_name: "public".to_string(),
849                name: "status".to_string(),
850                variants: vec!["active".to_string(), "inactive".to_string()],
851                default_variant: None,
852            }],
853            ..Default::default()
854        };
855        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
856        assert_eq!(files.len(), 1);
857        assert_eq!(files[0].filename, "types.rs");
858    }
859
860    #[test]
861    fn test_generate_enums_composites_domains_single_types_file() {
862        let schema = SchemaInfo {
863            enums: vec![EnumInfo {
864                schema_name: "public".to_string(),
865                name: "status".to_string(),
866                variants: vec!["active".to_string()],
867                default_variant: None,
868            }],
869            composite_types: vec![CompositeTypeInfo {
870                schema_name: "public".to_string(),
871                name: "address".to_string(),
872                fields: vec![make_col("street", "text")],
873            }],
874            domains: vec![DomainInfo {
875                schema_name: "public".to_string(),
876                name: "email".to_string(),
877                base_type: "text".to_string(),
878            }],
879            ..Default::default()
880        };
881        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
882        // Should produce exactly 1 types.rs
883        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
884        assert_eq!(types_files.len(), 1);
885    }
886
887    #[test]
888    fn test_generate_tables_and_enums() {
889        let schema = SchemaInfo {
890            tables: vec![make_table("users", vec![make_col("id", "int4")])],
891            enums: vec![EnumInfo {
892                schema_name: "public".to_string(),
893                name: "status".to_string(),
894                variants: vec!["active".to_string()],
895                default_variant: None,
896            }],
897            ..Default::default()
898        };
899        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
900        assert_eq!(files.len(), 2); // users.rs + types.rs
901    }
902
903    #[test]
904    fn test_generate_filename_normalized() {
905        let schema = SchemaInfo {
906            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
907            ..Default::default()
908        };
909        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
910        assert_eq!(files[0].filename, "user_data.rs");
911    }
912
913    #[test]
914    fn test_generate_no_origin_for_tables() {
915        let schema = SchemaInfo {
916            tables: vec![make_table("users", vec![make_col("id", "int4")])],
917            ..Default::default()
918        };
919        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
920        assert_eq!(files[0].origin, None);
921    }
922
923    #[test]
924    fn test_generate_types_no_origin() {
925        let schema = SchemaInfo {
926            enums: vec![EnumInfo {
927                schema_name: "public".to_string(),
928                name: "status".to_string(),
929                variants: vec!["active".to_string()],
930                default_variant: None,
931            }],
932            ..Default::default()
933        };
934        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
935        assert_eq!(files[0].origin, None);
936    }
937
938    #[test]
939    fn test_generate_single_file_filters_super_types_imports() {
940        let schema = SchemaInfo {
941            tables: vec![make_table("users", vec![make_col("id", "int4")])],
942            enums: vec![EnumInfo {
943                schema_name: "public".to_string(),
944                name: "status".to_string(),
945                variants: vec!["active".to_string()],
946                default_variant: None,
947            }],
948            ..Default::default()
949        };
950        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono);
951        // struct file should not have super::types:: imports
952        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
953        assert!(!struct_file.code.contains("super::types::"));
954    }
955
956    #[test]
957    fn test_generate_multi_file_keeps_super_types_imports() {
958        // Table with a column referencing an enum
959        let schema = SchemaInfo {
960            tables: vec![make_table("users", vec![make_col("status", "status")])],
961            enums: vec![EnumInfo {
962                schema_name: "public".to_string(),
963                name: "status".to_string(),
964                variants: vec!["active".to_string()],
965                default_variant: None,
966            }],
967            ..Default::default()
968        };
969        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
970        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
971        assert!(struct_file.code.contains("super::types::"));
972    }
973
974    #[test]
975    fn test_generate_extra_derives_in_struct() {
976        let schema = SchemaInfo {
977            tables: vec![make_table("users", vec![make_col("id", "int4")])],
978            ..Default::default()
979        };
980        let derives = vec!["Serialize".to_string()];
981        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono);
982        assert!(files[0].code.contains("Serialize"));
983    }
984
985    #[test]
986    fn test_generate_extra_derives_in_enum() {
987        let schema = SchemaInfo {
988            enums: vec![EnumInfo {
989                schema_name: "public".to_string(),
990                name: "status".to_string(),
991                variants: vec!["active".to_string()],
992                default_variant: None,
993            }],
994            ..Default::default()
995        };
996        let derives = vec!["Serialize".to_string()];
997        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono);
998        assert!(files[0].code.contains("Serialize"));
999    }
1000
1001    #[test]
1002    fn test_generate_type_overrides_in_struct() {
1003        let mut overrides = HashMap::new();
1004        overrides.insert("jsonb".to_string(), "MyJson".to_string());
1005        let schema = SchemaInfo {
1006            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1007            ..Default::default()
1008        };
1009        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false, TimeCrate::Chrono);
1010        assert!(files[0].code.contains("MyJson"));
1011    }
1012
1013    #[test]
1014    fn test_generate_valid_rust_syntax() {
1015        let schema = SchemaInfo {
1016            tables: vec![make_table("users", vec![
1017                make_col("id", "int4"),
1018                make_col("name", "text"),
1019            ])],
1020            enums: vec![EnumInfo {
1021                schema_name: "public".to_string(),
1022                name: "status".to_string(),
1023                variants: vec!["active".to_string(), "inactive".to_string()],
1024                default_variant: None,
1025            }],
1026            ..Default::default()
1027        };
1028        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1029        for f in &files {
1030            // Should be parseable as valid Rust
1031            let parse_result = syn::parse_file(&f.code);
1032            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
1033        }
1034    }
1035
1036    // ========== generate() — views ==========
1037
1038    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1039        TableInfo {
1040            schema_name: "public".to_string(),
1041            name: name.to_string(),
1042            columns,
1043        }
1044    }
1045
1046    #[test]
1047    fn test_generate_one_view() {
1048        let schema = SchemaInfo {
1049            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1050            ..Default::default()
1051        };
1052        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1053        assert_eq!(files.len(), 1);
1054        assert_eq!(files[0].filename, "active_users.rs");
1055    }
1056
1057    #[test]
1058    fn test_generate_no_origin_for_views() {
1059        let schema = SchemaInfo {
1060            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1061            ..Default::default()
1062        };
1063        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1064        assert_eq!(files[0].origin, None);
1065    }
1066
1067    #[test]
1068    fn test_generate_tables_and_views() {
1069        let schema = SchemaInfo {
1070            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1071            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1072            ..Default::default()
1073        };
1074        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1075        assert_eq!(files.len(), 2);
1076    }
1077
1078    #[test]
1079    fn test_generate_view_valid_rust() {
1080        let schema = SchemaInfo {
1081            views: vec![make_view("active_users", vec![
1082                make_col("id", "int4"),
1083                make_col("name", "text"),
1084            ])],
1085            ..Default::default()
1086        };
1087        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1088        let parse_result = syn::parse_file(&files[0].code);
1089        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
1090    }
1091
1092    #[test]
1093    fn test_generate_view_nullable_column() {
1094        let schema = SchemaInfo {
1095            views: vec![make_view("v", vec![ColumnInfo {
1096                name: "email".to_string(),
1097                data_type: "text".to_string(),
1098                udt_name: "text".to_string(),
1099                is_nullable: true,
1100                is_primary_key: false,
1101                ordinal_position: 0,
1102                schema_name: "public".to_string(),
1103                column_default: None,
1104            }])],
1105            ..Default::default()
1106        };
1107        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1108        assert!(files[0].code.contains("Option<String>"));
1109    }
1110
1111    #[test]
1112    fn test_generate_collision_both_prefixed() {
1113        let schema = SchemaInfo {
1114            tables: vec![
1115                make_table("users", vec![make_col("id", "int4")]),
1116                TableInfo {
1117                    schema_name: "billing".to_string(),
1118                    name: "users".to_string(),
1119                    columns: vec![make_col("id", "int4")],
1120                },
1121            ],
1122            ..Default::default()
1123        };
1124        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1125        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1126        assert!(filenames.contains(&"users.rs"));
1127        assert!(filenames.contains(&"billing_users.rs"));
1128    }
1129
1130    #[test]
1131    fn test_generate_no_collision_no_prefix() {
1132        let schema = SchemaInfo {
1133            tables: vec![
1134                make_table("users", vec![make_col("id", "int4")]),
1135                TableInfo {
1136                    schema_name: "billing".to_string(),
1137                    name: "invoices".to_string(),
1138                    columns: vec![make_col("id", "int4")],
1139                },
1140            ],
1141            ..Default::default()
1142        };
1143        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1144        let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1145        assert!(filenames.contains(&"users.rs"));
1146        assert!(filenames.contains(&"invoices.rs"));
1147    }
1148
1149    #[test]
1150    fn test_generate_single_schema_no_prefix() {
1151        let schema = SchemaInfo {
1152            tables: vec![
1153                make_table("users", vec![make_col("id", "int4")]),
1154                make_table("posts", vec![make_col("id", "int4")]),
1155            ],
1156            ..Default::default()
1157        };
1158        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1159        assert_eq!(files[0].filename, "users.rs");
1160        assert_eq!(files[1].filename, "posts.rs");
1161    }
1162
1163    #[test]
1164    fn test_generate_view_single_file_mode() {
1165        let schema = SchemaInfo {
1166            tables: vec![make_table("users", vec![make_col("id", "int4")])],
1167            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1168            ..Default::default()
1169        };
1170        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono);
1171        assert_eq!(files.len(), 2);
1172    }
1173
1174    // ========== parse_pg_enum_default ==========
1175
1176    #[test]
1177    fn test_parse_pg_enum_default_simple() {
1178        assert_eq!(
1179            parse_pg_enum_default("'idle'::task_status"),
1180            Some("idle".to_string())
1181        );
1182    }
1183
1184    #[test]
1185    fn test_parse_pg_enum_default_schema_qualified() {
1186        assert_eq!(
1187            parse_pg_enum_default("'active'::public.task_status"),
1188            Some("active".to_string())
1189        );
1190    }
1191
1192    #[test]
1193    fn test_parse_pg_enum_default_not_enum() {
1194        // No single-quote pattern
1195        assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1196    }
1197
1198    #[test]
1199    fn test_parse_pg_enum_default_no_cast() {
1200        assert_eq!(parse_pg_enum_default("'hello'"), None);
1201    }
1202
1203    #[test]
1204    fn test_parse_pg_enum_default_empty() {
1205        assert_eq!(parse_pg_enum_default(""), None);
1206    }
1207
1208    // ========== extract_enum_defaults ==========
1209
1210    #[test]
1211    fn test_extract_enum_defaults_from_column() {
1212        let schema = SchemaInfo {
1213            tables: vec![TableInfo {
1214                schema_name: "public".to_string(),
1215                name: "tasks".to_string(),
1216                columns: vec![ColumnInfo {
1217                    name: "status".to_string(),
1218                    data_type: "USER-DEFINED".to_string(),
1219                    udt_name: "task_status".to_string(),
1220                    is_nullable: false,
1221                    is_primary_key: false,
1222                    ordinal_position: 0,
1223                    schema_name: "public".to_string(),
1224                    column_default: Some("'idle'::task_status".to_string()),
1225                }],
1226            }],
1227            enums: vec![EnumInfo {
1228                schema_name: "public".to_string(),
1229                name: "task_status".to_string(),
1230                variants: vec!["idle".to_string(), "running".to_string()],
1231                default_variant: None,
1232            }],
1233            ..Default::default()
1234        };
1235        let defaults = extract_enum_defaults(&schema);
1236        assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1237    }
1238
1239    #[test]
1240    fn test_extract_enum_defaults_no_default() {
1241        let schema = SchemaInfo {
1242            tables: vec![TableInfo {
1243                schema_name: "public".to_string(),
1244                name: "tasks".to_string(),
1245                columns: vec![ColumnInfo {
1246                    name: "status".to_string(),
1247                    data_type: "USER-DEFINED".to_string(),
1248                    udt_name: "task_status".to_string(),
1249                    is_nullable: false,
1250                    is_primary_key: false,
1251                    ordinal_position: 0,
1252                    schema_name: "public".to_string(),
1253                    column_default: None,
1254                }],
1255            }],
1256            enums: vec![EnumInfo {
1257                schema_name: "public".to_string(),
1258                name: "task_status".to_string(),
1259                variants: vec!["idle".to_string()],
1260                default_variant: None,
1261            }],
1262            ..Default::default()
1263        };
1264        let defaults = extract_enum_defaults(&schema);
1265        assert!(defaults.is_empty());
1266    }
1267
1268    #[test]
1269    fn test_extract_enum_defaults_non_enum_column_ignored() {
1270        let schema = SchemaInfo {
1271            tables: vec![TableInfo {
1272                schema_name: "public".to_string(),
1273                name: "users".to_string(),
1274                columns: vec![ColumnInfo {
1275                    name: "name".to_string(),
1276                    data_type: "character varying".to_string(),
1277                    udt_name: "varchar".to_string(),
1278                    is_nullable: false,
1279                    is_primary_key: false,
1280                    ordinal_position: 0,
1281                    schema_name: "public".to_string(),
1282                    column_default: Some("'hello'::character varying".to_string()),
1283                }],
1284            }],
1285            enums: vec![],
1286            ..Default::default()
1287        };
1288        let defaults = extract_enum_defaults(&schema);
1289        assert!(defaults.is_empty());
1290    }
1291
1292    #[test]
1293    fn test_generate_enum_with_default() {
1294        let schema = SchemaInfo {
1295            tables: vec![TableInfo {
1296                schema_name: "public".to_string(),
1297                name: "tasks".to_string(),
1298                columns: vec![ColumnInfo {
1299                    name: "status".to_string(),
1300                    data_type: "USER-DEFINED".to_string(),
1301                    udt_name: "task_status".to_string(),
1302                    is_nullable: false,
1303                    is_primary_key: false,
1304                    ordinal_position: 0,
1305                    schema_name: "public".to_string(),
1306                    column_default: Some("'idle'::task_status".to_string()),
1307                }],
1308            }],
1309            enums: vec![EnumInfo {
1310                schema_name: "public".to_string(),
1311                name: "task_status".to_string(),
1312                variants: vec!["idle".to_string(), "running".to_string()],
1313                default_variant: None,
1314            }],
1315            ..Default::default()
1316        };
1317        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1318        let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1319        assert!(types_file.code.contains("impl Default for TaskStatus"));
1320        assert!(types_file.code.contains("Self::Idle"));
1321    }
1322}