Skip to main content

sqlx_gen/codegen/
mod.rs

1pub mod composite_gen;
2pub mod domain_gen;
3pub mod enum_gen;
4pub mod struct_gen;
5
6use std::collections::{BTreeSet, HashMap};
7
8use proc_macro2::TokenStream;
9
10use crate::cli::DatabaseKind;
11use crate::introspect::SchemaInfo;
12
13/// Rust reserved keywords that cannot be used as identifiers.
14const RUST_KEYWORDS: &[&str] = &[
15    "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
16    "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
17    "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
18    "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
19    "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
20];
21
22/// Returns true if the given name is a Rust reserved keyword.
23pub fn is_rust_keyword(name: &str) -> bool {
24    RUST_KEYWORDS.contains(&name)
25}
26
27/// Returns the imports needed for well-known extra derives.
28pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
29    let mut imports = Vec::new();
30    let has = |name: &str| extra_derives.iter().any(|d| d == name);
31    if has("Serialize") || has("Deserialize") {
32        let mut parts = Vec::new();
33        if has("Serialize") {
34            parts.push("Serialize");
35        }
36        if has("Deserialize") {
37            parts.push("Deserialize");
38        }
39        imports.push(format!("use serde::{{{}}};", parts.join(", ")));
40    }
41    imports
42}
43
44/// Normalize a table name for use as a Rust module/filename:
45/// replace multiple consecutive underscores with a single one.
46pub fn normalize_module_name(name: &str) -> String {
47    let mut result = String::with_capacity(name.len());
48    let mut prev_underscore = false;
49    for c in name.chars() {
50        if c == '_' {
51            if !prev_underscore {
52                result.push(c);
53            }
54            prev_underscore = true;
55        } else {
56            prev_underscore = false;
57            result.push(c);
58        }
59    }
60    result
61}
62
63/// A generated code file with its content and required imports.
64#[derive(Debug, Clone)]
65pub struct GeneratedFile {
66    pub filename: String,
67    /// Optional origin comment (e.g. "Table: schema.name")
68    pub origin: Option<String>,
69    pub code: String,
70}
71
72/// Generate all code for a given schema.
73pub fn generate(
74    schema_info: &SchemaInfo,
75    db_kind: DatabaseKind,
76    extra_derives: &[String],
77    type_overrides: &HashMap<String, String>,
78    single_file: bool,
79) -> Vec<GeneratedFile> {
80    let mut files = Vec::new();
81
82    // Generate struct files for each table
83    for table in &schema_info.tables {
84        let (tokens, imports) =
85            struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides);
86        let imports = filter_imports(&imports, single_file);
87        let code = format_tokens_with_imports(&tokens, &imports);
88        let module_name = normalize_module_name(&table.name);
89        let origin = format!("Table: {}.{}", table.schema_name, table.name);
90        files.push(GeneratedFile {
91            filename: format!("{}.rs", module_name),
92            origin: Some(origin),
93            code,
94        });
95    }
96
97    // Generate struct files for each view
98    for view in &schema_info.views {
99        let (tokens, imports) =
100            struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides);
101        let imports = filter_imports(&imports, single_file);
102        let code = format_tokens_with_imports(&tokens, &imports);
103        let module_name = normalize_module_name(&view.name);
104        let origin = format!("View: {}.{}", view.schema_name, view.name);
105        files.push(GeneratedFile {
106            filename: format!("{}.rs", module_name),
107            origin: Some(origin),
108            code,
109        });
110    }
111
112    // Generate types file (enums, composites, domains)
113    // Each item is formatted individually so we can insert blank lines between them.
114    let mut types_blocks: Vec<String> = Vec::new();
115    let mut types_imports = BTreeSet::new();
116
117    for enum_info in &schema_info.enums {
118        let (tokens, imports) = enum_gen::generate_enum(enum_info, db_kind, extra_derives);
119        types_blocks.push(format_tokens(&tokens));
120        types_imports.extend(imports);
121    }
122
123    for composite in &schema_info.composite_types {
124        let (tokens, imports) = composite_gen::generate_composite(
125            composite,
126            db_kind,
127            schema_info,
128            extra_derives,
129            type_overrides,
130        );
131        types_blocks.push(format_tokens(&tokens));
132        types_imports.extend(imports);
133    }
134
135    for domain in &schema_info.domains {
136        let (tokens, imports) =
137            domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
138        types_blocks.push(format_tokens(&tokens));
139        types_imports.extend(imports);
140    }
141
142    if !types_blocks.is_empty() {
143        let import_lines: String = types_imports
144            .iter()
145            .map(|i| format!("{}\n", i))
146            .collect();
147        let body = types_blocks.join("\n");
148        let code = if import_lines.is_empty() {
149            body
150        } else {
151            format!("{}\n\n{}", import_lines.trim_end(), body)
152        };
153        files.push(GeneratedFile {
154            filename: "types.rs".to_string(),
155            origin: None,
156            code,
157        });
158    }
159
160    files
161}
162
163/// In single-file mode, strip `use super::types::` imports since everything is in the same file.
164fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
165    if single_file {
166        imports
167            .iter()
168            .filter(|i| !i.contains("super::types::"))
169            .cloned()
170            .collect()
171    } else {
172        imports.clone()
173    }
174}
175
176/// Parse and format a TokenStream via prettyplease, then post-process spacing.
177pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
178    let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
179        eprintln!("ERROR: failed to parse generated code: {}", e);
180        eprintln!("  This is a bug in sqlx-gen. Raw tokens:\n  {}", tokens);
181        std::process::exit(1);
182    });
183    let raw = prettyplease::unparse(&file);
184    add_blank_lines_between_variants(&raw)
185}
186
187/// Format a single TokenStream block (no imports).
188pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
189    parse_and_format(tokens)
190}
191
192pub(crate) fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
193    let import_lines: String = imports
194        .iter()
195        .map(|i| format!("{}\n", i))
196        .collect();
197
198    let formatted = parse_and_format(tokens);
199
200    if import_lines.is_empty() {
201        formatted
202    } else {
203        format!("{}\n\n{}", import_lines.trim_end(), formatted)
204    }
205}
206
207/// Post-process formatted code to add blank lines between enum variants
208/// and between struct fields. prettyplease doesn't insert them.
209fn add_blank_lines_between_variants(code: &str) -> String {
210    let lines: Vec<&str> = code.lines().collect();
211    let mut result = Vec::with_capacity(lines.len());
212
213    for (i, line) in lines.iter().enumerate() {
214        // Insert a blank line before `#[sqlx(rename` that follows a variant line (ending with `,`)
215        // but not for the first variant in the enum.
216        if i > 0 && line.trim().starts_with("#[sqlx(rename") {
217            let prev = lines[i - 1].trim();
218            if prev.ends_with(',') {
219                result.push("");
220            }
221        }
222        result.push(line);
223    }
224
225    result.join("\n")
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::introspect::{
232        ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
233    };
234    use std::collections::HashMap;
235
236    // ========== is_rust_keyword ==========
237
238    #[test]
239    fn test_keyword_type() {
240        assert!(is_rust_keyword("type"));
241    }
242
243    #[test]
244    fn test_keyword_fn() {
245        assert!(is_rust_keyword("fn"));
246    }
247
248    #[test]
249    fn test_keyword_let() {
250        assert!(is_rust_keyword("let"));
251    }
252
253    #[test]
254    fn test_keyword_match() {
255        assert!(is_rust_keyword("match"));
256    }
257
258    #[test]
259    fn test_keyword_async() {
260        assert!(is_rust_keyword("async"));
261    }
262
263    #[test]
264    fn test_keyword_await() {
265        assert!(is_rust_keyword("await"));
266    }
267
268    #[test]
269    fn test_keyword_yield() {
270        assert!(is_rust_keyword("yield"));
271    }
272
273    #[test]
274    fn test_keyword_abstract() {
275        assert!(is_rust_keyword("abstract"));
276    }
277
278    #[test]
279    fn test_keyword_try() {
280        assert!(is_rust_keyword("try"));
281    }
282
283    #[test]
284    fn test_not_keyword_name() {
285        assert!(!is_rust_keyword("name"));
286    }
287
288    #[test]
289    fn test_not_keyword_id() {
290        assert!(!is_rust_keyword("id"));
291    }
292
293    #[test]
294    fn test_not_keyword_uppercase_type() {
295        assert!(!is_rust_keyword("Type"));
296    }
297
298    // ========== normalize_module_name ==========
299
300    #[test]
301    fn test_normalize_no_underscores() {
302        assert_eq!(normalize_module_name("users"), "users");
303    }
304
305    #[test]
306    fn test_normalize_single_underscore() {
307        assert_eq!(normalize_module_name("user_roles"), "user_roles");
308    }
309
310    #[test]
311    fn test_normalize_double_underscore() {
312        assert_eq!(normalize_module_name("user__roles"), "user_roles");
313    }
314
315    #[test]
316    fn test_normalize_triple_underscore() {
317        assert_eq!(normalize_module_name("a___b"), "a_b");
318    }
319
320    #[test]
321    fn test_normalize_leading_underscore() {
322        assert_eq!(normalize_module_name("_private"), "_private");
323    }
324
325    #[test]
326    fn test_normalize_trailing_underscore() {
327        assert_eq!(normalize_module_name("name_"), "name_");
328    }
329
330    #[test]
331    fn test_normalize_double_leading() {
332        assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
333    }
334
335    #[test]
336    fn test_normalize_multiple_groups() {
337        assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
338    }
339
340    // ========== imports_for_derives ==========
341
342    #[test]
343    fn test_imports_empty() {
344        let result = imports_for_derives(&[]);
345        assert!(result.is_empty());
346    }
347
348    #[test]
349    fn test_imports_serialize_only() {
350        let derives = vec!["Serialize".to_string()];
351        let result = imports_for_derives(&derives);
352        assert_eq!(result, vec!["use serde::{Serialize};"]);
353    }
354
355    #[test]
356    fn test_imports_deserialize_only() {
357        let derives = vec!["Deserialize".to_string()];
358        let result = imports_for_derives(&derives);
359        assert_eq!(result, vec!["use serde::{Deserialize};"]);
360    }
361
362    #[test]
363    fn test_imports_both_serde() {
364        let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
365        let result = imports_for_derives(&derives);
366        assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
367    }
368
369    #[test]
370    fn test_imports_non_serde() {
371        let derives = vec!["Hash".to_string()];
372        let result = imports_for_derives(&derives);
373        assert!(result.is_empty());
374    }
375
376    #[test]
377    fn test_imports_non_serde_multiple() {
378        let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
379        let result = imports_for_derives(&derives);
380        assert!(result.is_empty());
381    }
382
383    #[test]
384    fn test_imports_mixed_serde_and_others() {
385        let derives = vec![
386            "Serialize".to_string(),
387            "Hash".to_string(),
388            "Deserialize".to_string(),
389        ];
390        let result = imports_for_derives(&derives);
391        assert_eq!(result.len(), 1);
392        assert!(result[0].contains("Serialize"));
393        assert!(result[0].contains("Deserialize"));
394    }
395
396    // ========== add_blank_lines_between_variants ==========
397
398    #[test]
399    fn test_blank_lines_between_renamed_variants() {
400        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n    #[sqlx(rename = \"b\")]\n    B,\n}";
401        let result = add_blank_lines_between_variants(input);
402        assert!(result.contains("A,\n\n    #[sqlx(rename = \"b\")]"));
403    }
404
405    #[test]
406    fn test_no_blank_line_for_first_variant() {
407        let input = "pub enum Foo {\n    #[sqlx(rename = \"a\")]\n    A,\n}";
408        let result = add_blank_lines_between_variants(input);
409        // No blank line before first #[sqlx(rename because previous line is `{`
410        assert!(!result.contains("{\n\n"));
411    }
412
413    #[test]
414    fn test_no_change_without_rename() {
415        let input = "pub enum Foo {\n    A,\n    B,\n}";
416        let result = add_blank_lines_between_variants(input);
417        assert_eq!(result, input);
418    }
419
420    #[test]
421    fn test_no_change_for_struct() {
422        let input = "pub struct Foo {\n    pub a: i32,\n    pub b: String,\n}";
423        let result = add_blank_lines_between_variants(input);
424        assert_eq!(result, input);
425    }
426
427    // ========== filter_imports ==========
428
429    #[test]
430    fn test_filter_single_file_strips_super_types() {
431        let mut imports = BTreeSet::new();
432        imports.insert("use super::types::Foo;".to_string());
433        imports.insert("use chrono::NaiveDateTime;".to_string());
434        let result = filter_imports(&imports, true);
435        assert!(!result.contains("use super::types::Foo;"));
436        assert!(result.contains("use chrono::NaiveDateTime;"));
437    }
438
439    #[test]
440    fn test_filter_single_file_keeps_other_imports() {
441        let mut imports = BTreeSet::new();
442        imports.insert("use chrono::NaiveDateTime;".to_string());
443        let result = filter_imports(&imports, true);
444        assert!(result.contains("use chrono::NaiveDateTime;"));
445    }
446
447    #[test]
448    fn test_filter_multi_file_keeps_all() {
449        let mut imports = BTreeSet::new();
450        imports.insert("use super::types::Foo;".to_string());
451        imports.insert("use chrono::NaiveDateTime;".to_string());
452        let result = filter_imports(&imports, false);
453        assert_eq!(result.len(), 2);
454    }
455
456    #[test]
457    fn test_filter_empty_set() {
458        let imports = BTreeSet::new();
459        let result = filter_imports(&imports, true);
460        assert!(result.is_empty());
461    }
462
463    // ========== generate() orchestrator ==========
464
465    fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
466        TableInfo {
467            schema_name: "public".to_string(),
468            name: name.to_string(),
469            columns,
470        }
471    }
472
473    fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
474        ColumnInfo {
475            name: name.to_string(),
476            data_type: udt_name.to_string(),
477            udt_name: udt_name.to_string(),
478            is_nullable: false,
479            ordinal_position: 0,
480            schema_name: "public".to_string(),
481        }
482    }
483
484    #[test]
485    fn test_generate_empty_schema() {
486        let schema = SchemaInfo::default();
487        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
488        assert!(files.is_empty());
489    }
490
491    #[test]
492    fn test_generate_one_table() {
493        let schema = SchemaInfo {
494            tables: vec![make_table("users", vec![make_col("id", "int4")])],
495            ..Default::default()
496        };
497        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
498        assert_eq!(files.len(), 1);
499        assert_eq!(files[0].filename, "users.rs");
500    }
501
502    #[test]
503    fn test_generate_two_tables() {
504        let schema = SchemaInfo {
505            tables: vec![
506                make_table("users", vec![make_col("id", "int4")]),
507                make_table("posts", vec![make_col("id", "int4")]),
508            ],
509            ..Default::default()
510        };
511        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
512        assert_eq!(files.len(), 2);
513    }
514
515    #[test]
516    fn test_generate_enum_creates_types_file() {
517        let schema = SchemaInfo {
518            enums: vec![EnumInfo {
519                schema_name: "public".to_string(),
520                name: "status".to_string(),
521                variants: vec!["active".to_string(), "inactive".to_string()],
522            }],
523            ..Default::default()
524        };
525        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
526        assert_eq!(files.len(), 1);
527        assert_eq!(files[0].filename, "types.rs");
528    }
529
530    #[test]
531    fn test_generate_enums_composites_domains_single_types_file() {
532        let schema = SchemaInfo {
533            enums: vec![EnumInfo {
534                schema_name: "public".to_string(),
535                name: "status".to_string(),
536                variants: vec!["active".to_string()],
537            }],
538            composite_types: vec![CompositeTypeInfo {
539                schema_name: "public".to_string(),
540                name: "address".to_string(),
541                fields: vec![make_col("street", "text")],
542            }],
543            domains: vec![DomainInfo {
544                schema_name: "public".to_string(),
545                name: "email".to_string(),
546                base_type: "text".to_string(),
547            }],
548            ..Default::default()
549        };
550        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
551        // Should produce exactly 1 types.rs
552        let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
553        assert_eq!(types_files.len(), 1);
554    }
555
556    #[test]
557    fn test_generate_tables_and_enums() {
558        let schema = SchemaInfo {
559            tables: vec![make_table("users", vec![make_col("id", "int4")])],
560            enums: vec![EnumInfo {
561                schema_name: "public".to_string(),
562                name: "status".to_string(),
563                variants: vec!["active".to_string()],
564            }],
565            ..Default::default()
566        };
567        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
568        assert_eq!(files.len(), 2); // users.rs + types.rs
569    }
570
571    #[test]
572    fn test_generate_filename_normalized() {
573        let schema = SchemaInfo {
574            tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
575            ..Default::default()
576        };
577        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
578        assert_eq!(files[0].filename, "user_data.rs");
579    }
580
581    #[test]
582    fn test_generate_origin_correct() {
583        let schema = SchemaInfo {
584            tables: vec![make_table("users", vec![make_col("id", "int4")])],
585            ..Default::default()
586        };
587        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
588        assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
589    }
590
591    #[test]
592    fn test_generate_types_no_origin() {
593        let schema = SchemaInfo {
594            enums: vec![EnumInfo {
595                schema_name: "public".to_string(),
596                name: "status".to_string(),
597                variants: vec!["active".to_string()],
598            }],
599            ..Default::default()
600        };
601        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
602        assert_eq!(files[0].origin, None);
603    }
604
605    #[test]
606    fn test_generate_single_file_filters_super_types_imports() {
607        let schema = SchemaInfo {
608            tables: vec![make_table("users", vec![make_col("id", "int4")])],
609            enums: vec![EnumInfo {
610                schema_name: "public".to_string(),
611                name: "status".to_string(),
612                variants: vec!["active".to_string()],
613            }],
614            ..Default::default()
615        };
616        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
617        // struct file should not have super::types:: imports
618        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
619        assert!(!struct_file.code.contains("super::types::"));
620    }
621
622    #[test]
623    fn test_generate_multi_file_keeps_super_types_imports() {
624        // Table with a column referencing an enum
625        let schema = SchemaInfo {
626            tables: vec![make_table("users", vec![make_col("status", "status")])],
627            enums: vec![EnumInfo {
628                schema_name: "public".to_string(),
629                name: "status".to_string(),
630                variants: vec!["active".to_string()],
631            }],
632            ..Default::default()
633        };
634        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
635        let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
636        assert!(struct_file.code.contains("super::types::"));
637    }
638
639    #[test]
640    fn test_generate_extra_derives_in_struct() {
641        let schema = SchemaInfo {
642            tables: vec![make_table("users", vec![make_col("id", "int4")])],
643            ..Default::default()
644        };
645        let derives = vec!["Serialize".to_string()];
646        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
647        assert!(files[0].code.contains("Serialize"));
648    }
649
650    #[test]
651    fn test_generate_extra_derives_in_enum() {
652        let schema = SchemaInfo {
653            enums: vec![EnumInfo {
654                schema_name: "public".to_string(),
655                name: "status".to_string(),
656                variants: vec!["active".to_string()],
657            }],
658            ..Default::default()
659        };
660        let derives = vec!["Serialize".to_string()];
661        let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
662        assert!(files[0].code.contains("Serialize"));
663    }
664
665    #[test]
666    fn test_generate_type_overrides_in_struct() {
667        let mut overrides = HashMap::new();
668        overrides.insert("jsonb".to_string(), "MyJson".to_string());
669        let schema = SchemaInfo {
670            tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
671            ..Default::default()
672        };
673        let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
674        assert!(files[0].code.contains("MyJson"));
675    }
676
677    #[test]
678    fn test_generate_valid_rust_syntax() {
679        let schema = SchemaInfo {
680            tables: vec![make_table("users", vec![
681                make_col("id", "int4"),
682                make_col("name", "text"),
683            ])],
684            enums: vec![EnumInfo {
685                schema_name: "public".to_string(),
686                name: "status".to_string(),
687                variants: vec!["active".to_string(), "inactive".to_string()],
688            }],
689            ..Default::default()
690        };
691        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
692        for f in &files {
693            // Should be parseable as valid Rust
694            let parse_result = syn::parse_file(&f.code);
695            assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
696        }
697    }
698
699    // ========== generate() — views ==========
700
701    fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
702        TableInfo {
703            schema_name: "public".to_string(),
704            name: name.to_string(),
705            columns,
706        }
707    }
708
709    #[test]
710    fn test_generate_one_view() {
711        let schema = SchemaInfo {
712            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
713            ..Default::default()
714        };
715        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
716        assert_eq!(files.len(), 1);
717        assert_eq!(files[0].filename, "active_users.rs");
718    }
719
720    #[test]
721    fn test_generate_view_origin() {
722        let schema = SchemaInfo {
723            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
724            ..Default::default()
725        };
726        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
727        assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
728    }
729
730    #[test]
731    fn test_generate_tables_and_views() {
732        let schema = SchemaInfo {
733            tables: vec![make_table("users", vec![make_col("id", "int4")])],
734            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
735            ..Default::default()
736        };
737        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
738        assert_eq!(files.len(), 2);
739    }
740
741    #[test]
742    fn test_generate_view_valid_rust() {
743        let schema = SchemaInfo {
744            views: vec![make_view("active_users", vec![
745                make_col("id", "int4"),
746                make_col("name", "text"),
747            ])],
748            ..Default::default()
749        };
750        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
751        let parse_result = syn::parse_file(&files[0].code);
752        assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
753    }
754
755    #[test]
756    fn test_generate_view_nullable_column() {
757        let schema = SchemaInfo {
758            views: vec![make_view("v", vec![ColumnInfo {
759                name: "email".to_string(),
760                data_type: "text".to_string(),
761                udt_name: "text".to_string(),
762                is_nullable: true,
763                ordinal_position: 0,
764                schema_name: "public".to_string(),
765            }])],
766            ..Default::default()
767        };
768        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
769        assert!(files[0].code.contains("Option<String>"));
770    }
771
772    #[test]
773    fn test_generate_view_single_file_mode() {
774        let schema = SchemaInfo {
775            tables: vec![make_table("users", vec![make_col("id", "int4")])],
776            views: vec![make_view("active_users", vec![make_col("id", "int4")])],
777            ..Default::default()
778        };
779        let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
780        assert_eq!(files.len(), 2);
781    }
782}