1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15const RUST_KEYWORDS: &[&str] = &[
17 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20 "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24pub fn is_rust_keyword(name: &str) -> bool {
26 RUST_KEYWORDS.contains(&name)
27}
28
29pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31 let mut imports = Vec::new();
32 let has = |name: &str| extra_derives.iter().any(|d| d == name);
33 if has("Serialize") || has("Deserialize") {
34 let mut parts = Vec::new();
35 if has("Serialize") {
36 parts.push("Serialize");
37 }
38 if has("Deserialize") {
39 parts.push("Deserialize");
40 }
41 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42 }
43 imports
44}
45
46pub fn normalize_module_name(name: &str) -> String {
49 let mut result = String::with_capacity(name.len());
50 let mut prev_underscore = false;
51 for c in name.chars() {
52 if c == '_' {
53 if !prev_underscore {
54 result.push(c);
55 }
56 prev_underscore = true;
57 } else {
58 prev_underscore = false;
59 result.push(c);
60 }
61 }
62 result
63}
64
65const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68pub fn is_default_schema(schema: &str) -> bool {
70 DEFAULT_SCHEMAS.contains(&schema)
71}
72
73pub fn build_module_name(schema_name: &str, table_name: &str, has_multiple_schemas: bool) -> String {
76 if !has_multiple_schemas || DEFAULT_SCHEMAS.contains(&schema_name) {
77 normalize_module_name(table_name)
78 } else {
79 normalize_module_name(&format!("{}_{}", schema_name, table_name))
80 }
81}
82
83#[derive(Debug, Clone)]
85pub struct GeneratedFile {
86 pub filename: String,
87 pub origin: Option<String>,
89 pub code: String,
90}
91
92pub fn generate(
94 schema_info: &SchemaInfo,
95 db_kind: DatabaseKind,
96 extra_derives: &[String],
97 type_overrides: &HashMap<String, String>,
98 single_file: bool,
99) -> Vec<GeneratedFile> {
100 let mut files = Vec::new();
101
102 let mut schemas = BTreeSet::new();
104 for t in &schema_info.tables {
105 schemas.insert(t.schema_name.as_str());
106 }
107 for v in &schema_info.views {
108 schemas.insert(v.schema_name.as_str());
109 }
110 let has_multiple_schemas = schemas.len() > 1;
111
112 for table in &schema_info.tables {
114 let (tokens, imports) =
115 struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
116 let imports = filter_imports(&imports, single_file);
117 let code = format_tokens_with_imports(&tokens, &imports);
118 let module_name = build_module_name(&table.schema_name, &table.name, has_multiple_schemas);
119 let origin = format!("Table: {}.{}", table.schema_name, table.name);
120 files.push(GeneratedFile {
121 filename: format!("{}.rs", module_name),
122 origin: Some(origin),
123 code,
124 });
125 }
126
127 for view in &schema_info.views {
129 let (tokens, imports) =
130 struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
131 let imports = filter_imports(&imports, single_file);
132 let code = format_tokens_with_imports(&tokens, &imports);
133 let module_name = build_module_name(&view.schema_name, &view.name, has_multiple_schemas);
134 let origin = format!("View: {}.{}", view.schema_name, view.name);
135 files.push(GeneratedFile {
136 filename: format!("{}.rs", module_name),
137 origin: Some(origin),
138 code,
139 });
140 }
141
142 let mut types_blocks: Vec<String> = Vec::new();
145 let mut types_imports = BTreeSet::new();
146
147 for enum_info in &schema_info.enums {
148 let (tokens, imports) = enum_gen::generate_enum(enum_info, db_kind, extra_derives);
149 types_blocks.push(format_tokens(&tokens));
150 types_imports.extend(imports);
151 }
152
153 for composite in &schema_info.composite_types {
154 let (tokens, imports) = composite_gen::generate_composite(
155 composite,
156 db_kind,
157 schema_info,
158 extra_derives,
159 type_overrides,
160 );
161 types_blocks.push(format_tokens(&tokens));
162 types_imports.extend(imports);
163 }
164
165 for domain in &schema_info.domains {
166 let (tokens, imports) =
167 domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
168 types_blocks.push(format_tokens(&tokens));
169 types_imports.extend(imports);
170 }
171
172 if !types_blocks.is_empty() {
173 let import_lines: String = types_imports
174 .iter()
175 .map(|i| format!("{}\n", i))
176 .collect();
177 let body = types_blocks.join("\n");
178 let code = if import_lines.is_empty() {
179 body
180 } else {
181 format!("{}\n\n{}", import_lines.trim_end(), body)
182 };
183 files.push(GeneratedFile {
184 filename: "types.rs".to_string(),
185 origin: None,
186 code,
187 });
188 }
189
190 files
191}
192
193fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
195 if single_file {
196 imports
197 .iter()
198 .filter(|i| !i.contains("super::types::"))
199 .cloned()
200 .collect()
201 } else {
202 imports.clone()
203 }
204}
205
206pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
208 let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
209 log::error!("Failed to parse generated code: {}", e);
210 log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens);
211 std::process::exit(1);
212 });
213 let raw = prettyplease::unparse(&file);
214 add_blank_lines_between_items(&raw)
215}
216
217pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
219 parse_and_format(tokens)
220}
221
222pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
223 let formatted = parse_and_format(tokens);
224
225 let used_imports: Vec<&String> = imports
226 .iter()
227 .filter(|imp| is_import_used(imp, &formatted))
228 .collect();
229
230 if used_imports.is_empty() {
231 formatted
232 } else {
233 let import_lines: String = used_imports
234 .iter()
235 .map(|i| format!("{}\n", i))
236 .collect();
237 format!("{}\n\n{}", import_lines.trim_end(), formatted)
238 }
239}
240
241fn is_import_used(import: &str, code: &str) -> bool {
244 let trimmed = import.trim().trim_end_matches(';');
248 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
249
250 if path.ends_with("::*") {
251 return true;
252 }
253
254 if let Some(start) = path.find('{') {
256 if let Some(end) = path.find('}') {
257 let names = &path[start + 1..end];
258 return names
259 .split(',')
260 .map(|n| n.trim())
261 .filter(|n| !n.is_empty())
262 .any(|name| code.contains(name));
263 }
264 }
265
266 if let Some(name) = path.rsplit("::").next() {
268 return code.contains(name);
269 }
270
271 true
272}
273
274fn add_blank_lines_between_items(code: &str) -> String {
279 let lines: Vec<&str> = code.lines().collect();
280 let mut result = Vec::with_capacity(lines.len());
281
282 for (i, line) in lines.iter().enumerate() {
283 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
286 let prev = lines[i - 1].trim();
287 if prev.ends_with(',') {
288 result.push("");
289 }
290 }
291
292 if i > 0 {
295 let trimmed = line.trim();
296 let prev = lines[i - 1].trim();
297 if prev == "}"
298 && (trimmed.starts_with("pub struct")
299 || trimmed.starts_with("impl ")
300 || trimmed.starts_with("#[derive")
301 || trimmed.starts_with("pub async fn")
302 || trimmed.starts_with("pub fn"))
303 {
304 result.push("");
305 }
306 }
307
308 if i > 0 {
312 let trimmed = line.trim();
313 let prev = lines[i - 1].trim();
314 let prev_is_await_end = prev.ends_with(".await?;")
315 || prev.ends_with(".await?")
316 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
317 if prev_is_await_end
318 && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
319 {
320 result.push("");
321 }
322 if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
324 && prev.starts_with("let ") && !prev.contains("sqlx::")
325 {
326 result.push("");
327 }
328 }
329
330 result.push(line);
331 }
332
333 result.join("\n")
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::introspect::{
340 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
341 };
342 use std::collections::HashMap;
343
344 #[test]
347 fn test_keyword_type() {
348 assert!(is_rust_keyword("type"));
349 }
350
351 #[test]
352 fn test_keyword_fn() {
353 assert!(is_rust_keyword("fn"));
354 }
355
356 #[test]
357 fn test_keyword_let() {
358 assert!(is_rust_keyword("let"));
359 }
360
361 #[test]
362 fn test_keyword_match() {
363 assert!(is_rust_keyword("match"));
364 }
365
366 #[test]
367 fn test_keyword_async() {
368 assert!(is_rust_keyword("async"));
369 }
370
371 #[test]
372 fn test_keyword_await() {
373 assert!(is_rust_keyword("await"));
374 }
375
376 #[test]
377 fn test_keyword_yield() {
378 assert!(is_rust_keyword("yield"));
379 }
380
381 #[test]
382 fn test_keyword_abstract() {
383 assert!(is_rust_keyword("abstract"));
384 }
385
386 #[test]
387 fn test_keyword_try() {
388 assert!(is_rust_keyword("try"));
389 }
390
391 #[test]
392 fn test_not_keyword_name() {
393 assert!(!is_rust_keyword("name"));
394 }
395
396 #[test]
397 fn test_not_keyword_id() {
398 assert!(!is_rust_keyword("id"));
399 }
400
401 #[test]
402 fn test_not_keyword_uppercase_type() {
403 assert!(!is_rust_keyword("Type"));
404 }
405
406 #[test]
409 fn test_normalize_no_underscores() {
410 assert_eq!(normalize_module_name("users"), "users");
411 }
412
413 #[test]
414 fn test_normalize_single_underscore() {
415 assert_eq!(normalize_module_name("user_roles"), "user_roles");
416 }
417
418 #[test]
419 fn test_normalize_double_underscore() {
420 assert_eq!(normalize_module_name("user__roles"), "user_roles");
421 }
422
423 #[test]
424 fn test_normalize_triple_underscore() {
425 assert_eq!(normalize_module_name("a___b"), "a_b");
426 }
427
428 #[test]
429 fn test_normalize_leading_underscore() {
430 assert_eq!(normalize_module_name("_private"), "_private");
431 }
432
433 #[test]
434 fn test_normalize_trailing_underscore() {
435 assert_eq!(normalize_module_name("name_"), "name_");
436 }
437
438 #[test]
439 fn test_normalize_double_leading() {
440 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
441 }
442
443 #[test]
444 fn test_normalize_multiple_groups() {
445 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
446 }
447
448 #[test]
451 fn test_build_single_schema_no_prefix() {
452 assert_eq!(build_module_name("public", "users", false), "users");
453 }
454
455 #[test]
456 fn test_build_multi_schema_default_no_prefix() {
457 assert_eq!(build_module_name("public", "users", true), "users");
458 }
459
460 #[test]
461 fn test_build_multi_schema_non_default_prefixed() {
462 assert_eq!(build_module_name("billing", "users", true), "billing_users");
463 }
464
465 #[test]
466 fn test_build_multi_schema_dbo_no_prefix() {
467 assert_eq!(build_module_name("dbo", "users", true), "users");
468 }
469
470 #[test]
471 fn test_build_multi_schema_main_no_prefix() {
472 assert_eq!(build_module_name("main", "users", true), "users");
473 }
474
475 #[test]
476 fn test_build_normalizes_double_underscore() {
477 assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
478 }
479
480 #[test]
483 fn test_default_schema_public() {
484 assert!(is_default_schema("public"));
485 }
486
487 #[test]
488 fn test_default_schema_main() {
489 assert!(is_default_schema("main"));
490 }
491
492 #[test]
493 fn test_non_default_schema() {
494 assert!(!is_default_schema("billing"));
495 }
496
497 #[test]
500 fn test_imports_empty() {
501 let result = imports_for_derives(&[]);
502 assert!(result.is_empty());
503 }
504
505 #[test]
506 fn test_imports_serialize_only() {
507 let derives = vec!["Serialize".to_string()];
508 let result = imports_for_derives(&derives);
509 assert_eq!(result, vec!["use serde::{Serialize};"]);
510 }
511
512 #[test]
513 fn test_imports_deserialize_only() {
514 let derives = vec!["Deserialize".to_string()];
515 let result = imports_for_derives(&derives);
516 assert_eq!(result, vec!["use serde::{Deserialize};"]);
517 }
518
519 #[test]
520 fn test_imports_both_serde() {
521 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
522 let result = imports_for_derives(&derives);
523 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
524 }
525
526 #[test]
527 fn test_imports_non_serde() {
528 let derives = vec!["Hash".to_string()];
529 let result = imports_for_derives(&derives);
530 assert!(result.is_empty());
531 }
532
533 #[test]
534 fn test_imports_non_serde_multiple() {
535 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
536 let result = imports_for_derives(&derives);
537 assert!(result.is_empty());
538 }
539
540 #[test]
541 fn test_imports_mixed_serde_and_others() {
542 let derives = vec![
543 "Serialize".to_string(),
544 "Hash".to_string(),
545 "Deserialize".to_string(),
546 ];
547 let result = imports_for_derives(&derives);
548 assert_eq!(result.len(), 1);
549 assert!(result[0].contains("Serialize"));
550 assert!(result[0].contains("Deserialize"));
551 }
552
553 #[test]
556 fn test_blank_lines_between_renamed_variants() {
557 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
558 let result = add_blank_lines_between_items(input);
559 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
560 }
561
562 #[test]
563 fn test_no_blank_line_for_first_variant() {
564 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
565 let result = add_blank_lines_between_items(input);
566 assert!(!result.contains("{\n\n"));
568 }
569
570 #[test]
571 fn test_no_change_without_rename() {
572 let input = "pub enum Foo {\n A,\n B,\n}";
573 let result = add_blank_lines_between_items(input);
574 assert_eq!(result, input);
575 }
576
577 #[test]
578 fn test_no_change_for_struct() {
579 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
580 let result = add_blank_lines_between_items(input);
581 assert_eq!(result, input);
582 }
583
584 #[test]
587 fn test_filter_single_file_strips_super_types() {
588 let mut imports = BTreeSet::new();
589 imports.insert("use super::types::Foo;".to_string());
590 imports.insert("use chrono::NaiveDateTime;".to_string());
591 let result = filter_imports(&imports, true);
592 assert!(!result.contains("use super::types::Foo;"));
593 assert!(result.contains("use chrono::NaiveDateTime;"));
594 }
595
596 #[test]
597 fn test_filter_single_file_keeps_other_imports() {
598 let mut imports = BTreeSet::new();
599 imports.insert("use chrono::NaiveDateTime;".to_string());
600 let result = filter_imports(&imports, true);
601 assert!(result.contains("use chrono::NaiveDateTime;"));
602 }
603
604 #[test]
605 fn test_filter_multi_file_keeps_all() {
606 let mut imports = BTreeSet::new();
607 imports.insert("use super::types::Foo;".to_string());
608 imports.insert("use chrono::NaiveDateTime;".to_string());
609 let result = filter_imports(&imports, false);
610 assert_eq!(result.len(), 2);
611 }
612
613 #[test]
614 fn test_filter_empty_set() {
615 let imports = BTreeSet::new();
616 let result = filter_imports(&imports, true);
617 assert!(result.is_empty());
618 }
619
620 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
623 TableInfo {
624 schema_name: "public".to_string(),
625 name: name.to_string(),
626 columns,
627 }
628 }
629
630 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
631 ColumnInfo {
632 name: name.to_string(),
633 data_type: udt_name.to_string(),
634 udt_name: udt_name.to_string(),
635 is_nullable: false,
636 is_primary_key: false,
637 ordinal_position: 0,
638 schema_name: "public".to_string(),
639 }
640 }
641
642 #[test]
643 fn test_generate_empty_schema() {
644 let schema = SchemaInfo::default();
645 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
646 assert!(files.is_empty());
647 }
648
649 #[test]
650 fn test_generate_one_table() {
651 let schema = SchemaInfo {
652 tables: vec![make_table("users", vec![make_col("id", "int4")])],
653 ..Default::default()
654 };
655 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
656 assert_eq!(files.len(), 1);
657 assert_eq!(files[0].filename, "users.rs");
658 }
659
660 #[test]
661 fn test_generate_two_tables() {
662 let schema = SchemaInfo {
663 tables: vec![
664 make_table("users", vec![make_col("id", "int4")]),
665 make_table("posts", vec![make_col("id", "int4")]),
666 ],
667 ..Default::default()
668 };
669 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
670 assert_eq!(files.len(), 2);
671 }
672
673 #[test]
674 fn test_generate_enum_creates_types_file() {
675 let schema = SchemaInfo {
676 enums: vec![EnumInfo {
677 schema_name: "public".to_string(),
678 name: "status".to_string(),
679 variants: vec!["active".to_string(), "inactive".to_string()],
680 }],
681 ..Default::default()
682 };
683 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
684 assert_eq!(files.len(), 1);
685 assert_eq!(files[0].filename, "types.rs");
686 }
687
688 #[test]
689 fn test_generate_enums_composites_domains_single_types_file() {
690 let schema = SchemaInfo {
691 enums: vec![EnumInfo {
692 schema_name: "public".to_string(),
693 name: "status".to_string(),
694 variants: vec!["active".to_string()],
695 }],
696 composite_types: vec![CompositeTypeInfo {
697 schema_name: "public".to_string(),
698 name: "address".to_string(),
699 fields: vec![make_col("street", "text")],
700 }],
701 domains: vec![DomainInfo {
702 schema_name: "public".to_string(),
703 name: "email".to_string(),
704 base_type: "text".to_string(),
705 }],
706 ..Default::default()
707 };
708 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
709 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
711 assert_eq!(types_files.len(), 1);
712 }
713
714 #[test]
715 fn test_generate_tables_and_enums() {
716 let schema = SchemaInfo {
717 tables: vec![make_table("users", vec![make_col("id", "int4")])],
718 enums: vec![EnumInfo {
719 schema_name: "public".to_string(),
720 name: "status".to_string(),
721 variants: vec!["active".to_string()],
722 }],
723 ..Default::default()
724 };
725 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
726 assert_eq!(files.len(), 2); }
728
729 #[test]
730 fn test_generate_filename_normalized() {
731 let schema = SchemaInfo {
732 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
733 ..Default::default()
734 };
735 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
736 assert_eq!(files[0].filename, "user_data.rs");
737 }
738
739 #[test]
740 fn test_generate_origin_correct() {
741 let schema = SchemaInfo {
742 tables: vec![make_table("users", vec![make_col("id", "int4")])],
743 ..Default::default()
744 };
745 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
746 assert_eq!(files[0].origin, Some("Table: public.users".to_string()));
747 }
748
749 #[test]
750 fn test_generate_types_no_origin() {
751 let schema = SchemaInfo {
752 enums: vec![EnumInfo {
753 schema_name: "public".to_string(),
754 name: "status".to_string(),
755 variants: vec!["active".to_string()],
756 }],
757 ..Default::default()
758 };
759 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
760 assert_eq!(files[0].origin, None);
761 }
762
763 #[test]
764 fn test_generate_single_file_filters_super_types_imports() {
765 let schema = SchemaInfo {
766 tables: vec![make_table("users", vec![make_col("id", "int4")])],
767 enums: vec![EnumInfo {
768 schema_name: "public".to_string(),
769 name: "status".to_string(),
770 variants: vec!["active".to_string()],
771 }],
772 ..Default::default()
773 };
774 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
775 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
777 assert!(!struct_file.code.contains("super::types::"));
778 }
779
780 #[test]
781 fn test_generate_multi_file_keeps_super_types_imports() {
782 let schema = SchemaInfo {
784 tables: vec![make_table("users", vec![make_col("status", "status")])],
785 enums: vec![EnumInfo {
786 schema_name: "public".to_string(),
787 name: "status".to_string(),
788 variants: vec!["active".to_string()],
789 }],
790 ..Default::default()
791 };
792 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
793 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
794 assert!(struct_file.code.contains("super::types::"));
795 }
796
797 #[test]
798 fn test_generate_extra_derives_in_struct() {
799 let schema = SchemaInfo {
800 tables: vec![make_table("users", vec![make_col("id", "int4")])],
801 ..Default::default()
802 };
803 let derives = vec!["Serialize".to_string()];
804 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
805 assert!(files[0].code.contains("Serialize"));
806 }
807
808 #[test]
809 fn test_generate_extra_derives_in_enum() {
810 let schema = SchemaInfo {
811 enums: vec![EnumInfo {
812 schema_name: "public".to_string(),
813 name: "status".to_string(),
814 variants: vec!["active".to_string()],
815 }],
816 ..Default::default()
817 };
818 let derives = vec!["Serialize".to_string()];
819 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
820 assert!(files[0].code.contains("Serialize"));
821 }
822
823 #[test]
824 fn test_generate_type_overrides_in_struct() {
825 let mut overrides = HashMap::new();
826 overrides.insert("jsonb".to_string(), "MyJson".to_string());
827 let schema = SchemaInfo {
828 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
829 ..Default::default()
830 };
831 let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
832 assert!(files[0].code.contains("MyJson"));
833 }
834
835 #[test]
836 fn test_generate_valid_rust_syntax() {
837 let schema = SchemaInfo {
838 tables: vec![make_table("users", vec![
839 make_col("id", "int4"),
840 make_col("name", "text"),
841 ])],
842 enums: vec![EnumInfo {
843 schema_name: "public".to_string(),
844 name: "status".to_string(),
845 variants: vec!["active".to_string(), "inactive".to_string()],
846 }],
847 ..Default::default()
848 };
849 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
850 for f in &files {
851 let parse_result = syn::parse_file(&f.code);
853 assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
854 }
855 }
856
857 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
860 TableInfo {
861 schema_name: "public".to_string(),
862 name: name.to_string(),
863 columns,
864 }
865 }
866
867 #[test]
868 fn test_generate_one_view() {
869 let schema = SchemaInfo {
870 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
871 ..Default::default()
872 };
873 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
874 assert_eq!(files.len(), 1);
875 assert_eq!(files[0].filename, "active_users.rs");
876 }
877
878 #[test]
879 fn test_generate_view_origin() {
880 let schema = SchemaInfo {
881 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
882 ..Default::default()
883 };
884 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
885 assert_eq!(files[0].origin, Some("View: public.active_users".to_string()));
886 }
887
888 #[test]
889 fn test_generate_tables_and_views() {
890 let schema = SchemaInfo {
891 tables: vec![make_table("users", vec![make_col("id", "int4")])],
892 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
893 ..Default::default()
894 };
895 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
896 assert_eq!(files.len(), 2);
897 }
898
899 #[test]
900 fn test_generate_view_valid_rust() {
901 let schema = SchemaInfo {
902 views: vec![make_view("active_users", vec![
903 make_col("id", "int4"),
904 make_col("name", "text"),
905 ])],
906 ..Default::default()
907 };
908 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
909 let parse_result = syn::parse_file(&files[0].code);
910 assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
911 }
912
913 #[test]
914 fn test_generate_view_nullable_column() {
915 let schema = SchemaInfo {
916 views: vec![make_view("v", vec![ColumnInfo {
917 name: "email".to_string(),
918 data_type: "text".to_string(),
919 udt_name: "text".to_string(),
920 is_nullable: true,
921 is_primary_key: false,
922 ordinal_position: 0,
923 schema_name: "public".to_string(),
924 }])],
925 ..Default::default()
926 };
927 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
928 assert!(files[0].code.contains("Option<String>"));
929 }
930
931 #[test]
932 fn test_generate_multi_schema_prefixes_non_default() {
933 let schema = SchemaInfo {
934 tables: vec![
935 make_table("users", vec![make_col("id", "int4")]),
936 TableInfo {
937 schema_name: "billing".to_string(),
938 name: "users".to_string(),
939 columns: vec![make_col("id", "int4")],
940 },
941 ],
942 ..Default::default()
943 };
944 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
945 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
946 assert!(filenames.contains(&"users.rs"));
947 assert!(filenames.contains(&"billing_users.rs"));
948 }
949
950 #[test]
951 fn test_generate_single_schema_no_prefix() {
952 let schema = SchemaInfo {
953 tables: vec![
954 make_table("users", vec![make_col("id", "int4")]),
955 make_table("posts", vec![make_col("id", "int4")]),
956 ],
957 ..Default::default()
958 };
959 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
960 assert_eq!(files[0].filename, "users.rs");
961 assert_eq!(files[1].filename, "posts.rs");
962 }
963
964 #[test]
965 fn test_generate_view_single_file_mode() {
966 let schema = SchemaInfo {
967 tables: vec![make_table("users", vec![make_col("id", "int4")])],
968 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
969 ..Default::default()
970 };
971 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
972 assert_eq!(files.len(), 2);
973 }
974}