1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9
10use proc_macro2::TokenStream;
11
12use crate::cli::DatabaseKind;
13use crate::introspect::SchemaInfo;
14
15const RUST_KEYWORDS: &[&str] = &[
17 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
18 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
19 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
20 "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
21 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
22];
23
24pub fn is_rust_keyword(name: &str) -> bool {
26 RUST_KEYWORDS.contains(&name)
27}
28
29pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
31 let mut imports = Vec::new();
32 let has = |name: &str| extra_derives.iter().any(|d| d == name);
33 if has("Serialize") || has("Deserialize") {
34 let mut parts = Vec::new();
35 if has("Serialize") {
36 parts.push("Serialize");
37 }
38 if has("Deserialize") {
39 parts.push("Deserialize");
40 }
41 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
42 }
43 imports
44}
45
46pub fn normalize_module_name(name: &str) -> String {
49 let mut result = String::with_capacity(name.len());
50 let mut prev_underscore = false;
51 for c in name.chars() {
52 if c == '_' {
53 if !prev_underscore {
54 result.push(c);
55 }
56 prev_underscore = true;
57 } else {
58 prev_underscore = false;
59 result.push(c);
60 }
61 }
62 result
63}
64
65const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
67
68pub fn is_default_schema(schema: &str) -> bool {
70 DEFAULT_SCHEMAS.contains(&schema)
71}
72
73pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
76 if name_collides && !is_default_schema(schema_name) {
77 normalize_module_name(&format!("{}_{}", schema_name, table_name))
78 } else {
79 normalize_module_name(table_name)
80 }
81}
82
83fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
85 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
86 for t in &schema_info.tables {
87 seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
88 }
89 for v in &schema_info.views {
90 seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
91 }
92 seen.into_iter()
93 .filter(|(_, schemas)| schemas.len() > 1)
94 .map(|(name, _)| name)
95 .collect()
96}
97
98#[derive(Debug, Clone)]
100pub struct GeneratedFile {
101 pub filename: String,
102 pub origin: Option<String>,
104 pub code: String,
105}
106
107pub fn generate(
109 schema_info: &SchemaInfo,
110 db_kind: DatabaseKind,
111 extra_derives: &[String],
112 type_overrides: &HashMap<String, String>,
113 single_file: bool,
114) -> Vec<GeneratedFile> {
115 let mut files = Vec::new();
116
117 let colliding_names = find_colliding_names(schema_info);
119
120 for table in &schema_info.tables {
122 let (tokens, imports) =
123 struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false);
124 let imports = filter_imports(&imports, single_file);
125 let code = format_tokens_with_imports(&tokens, &imports);
126 let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
127 files.push(GeneratedFile {
128 filename: format!("{}.rs", module_name),
129 origin: None,
130 code,
131 });
132 }
133
134 for view in &schema_info.views {
136 let (tokens, imports) =
137 struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true);
138 let imports = filter_imports(&imports, single_file);
139 let code = format_tokens_with_imports(&tokens, &imports);
140 let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
141 files.push(GeneratedFile {
142 filename: format!("{}.rs", module_name),
143 origin: None,
144 code,
145 });
146 }
147
148 let mut types_blocks: Vec<String> = Vec::new();
151 let mut types_imports = BTreeSet::new();
152
153 let enum_defaults = extract_enum_defaults(schema_info);
155 for enum_info in &schema_info.enums {
156 let mut enriched = enum_info.clone();
157 if enriched.default_variant.is_none() {
158 if let Some(default) = enum_defaults.get(&enum_info.name) {
159 enriched.default_variant = Some(default.clone());
160 }
161 }
162 let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
163 types_blocks.push(format_tokens(&tokens));
164 types_imports.extend(imports);
165 }
166
167 for composite in &schema_info.composite_types {
168 let (tokens, imports) = composite_gen::generate_composite(
169 composite,
170 db_kind,
171 schema_info,
172 extra_derives,
173 type_overrides,
174 );
175 types_blocks.push(format_tokens(&tokens));
176 types_imports.extend(imports);
177 }
178
179 for domain in &schema_info.domains {
180 let (tokens, imports) =
181 domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides);
182 types_blocks.push(format_tokens(&tokens));
183 types_imports.extend(imports);
184 }
185
186 if !types_blocks.is_empty() {
187 let import_lines: String = types_imports
188 .iter()
189 .map(|i| format!("{}\n", i))
190 .collect();
191 let body = types_blocks.join("\n");
192 let code = if import_lines.is_empty() {
193 body
194 } else {
195 format!("{}\n\n{}", import_lines.trim_end(), body)
196 };
197 files.push(GeneratedFile {
198 filename: "types.rs".to_string(),
199 origin: None,
200 code,
201 });
202 }
203
204 files
205}
206
207fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
210 let mut defaults: HashMap<String, String> = HashMap::new();
211
212 let all_columns = schema_info
213 .tables
214 .iter()
215 .chain(schema_info.views.iter())
216 .flat_map(|t| t.columns.iter());
217
218 for col in all_columns {
219 let default_expr = match &col.column_default {
220 Some(d) => d,
221 None => continue,
222 };
223
224 let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
226
227 let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
229 if enum_match.is_none() {
230 continue;
231 }
232
233 if let Some(variant) = parse_pg_enum_default(default_expr) {
235 defaults.entry(base_udt.to_string()).or_insert(variant);
236 }
237 }
238
239 defaults
240}
241
242fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
245 let stripped = default_expr.trim();
247 if stripped.starts_with('\'') {
248 if let Some(end_quote) = stripped[1..].find('\'') {
249 let value = &stripped[1..1 + end_quote];
250 let rest = &stripped[2 + end_quote..];
252 if rest.starts_with("::") {
253 return Some(value.to_string());
254 }
255 }
256 }
257 None
258}
259
260fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
262 if single_file {
263 imports
264 .iter()
265 .filter(|i| !i.contains("super::types::"))
266 .cloned()
267 .collect()
268 } else {
269 imports.clone()
270 }
271}
272
273pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
275 let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
276 log::error!("Failed to parse generated code: {}", e);
277 log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens);
278 std::process::exit(1);
279 });
280 let raw = prettyplease::unparse(&file);
281 add_blank_lines_between_items(&raw)
282}
283
284pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
286 parse_and_format(tokens)
287}
288
289pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
290 let formatted = parse_and_format(tokens);
291
292 let used_imports: Vec<&String> = imports
293 .iter()
294 .filter(|imp| is_import_used(imp, &formatted))
295 .collect();
296
297 if used_imports.is_empty() {
298 formatted
299 } else {
300 let import_lines: String = used_imports
301 .iter()
302 .map(|i| format!("{}\n", i))
303 .collect();
304 format!("{}\n\n{}", import_lines.trim_end(), formatted)
305 }
306}
307
308fn is_import_used(import: &str, code: &str) -> bool {
311 let trimmed = import.trim().trim_end_matches(';');
315 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
316
317 if path.ends_with("::*") {
318 return true;
319 }
320
321 if let Some(start) = path.find('{') {
323 if let Some(end) = path.find('}') {
324 let names = &path[start + 1..end];
325 return names
326 .split(',')
327 .map(|n| n.trim())
328 .filter(|n| !n.is_empty())
329 .any(|name| code.contains(name));
330 }
331 }
332
333 if let Some(name) = path.rsplit("::").next() {
335 return code.contains(name);
336 }
337
338 true
339}
340
341fn add_blank_lines_between_items(code: &str) -> String {
346 let lines: Vec<&str> = code.lines().collect();
347 let mut result = Vec::with_capacity(lines.len());
348
349 for (i, line) in lines.iter().enumerate() {
350 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
353 let prev = lines[i - 1].trim();
354 if prev.ends_with(',') {
355 result.push("");
356 }
357 }
358
359 if i > 0 {
362 let trimmed = line.trim();
363 let prev = lines[i - 1].trim();
364 if prev == "}"
365 && (trimmed.starts_with("pub struct")
366 || trimmed.starts_with("impl ")
367 || trimmed.starts_with("#[derive")
368 || trimmed.starts_with("pub async fn")
369 || trimmed.starts_with("pub fn"))
370 {
371 result.push("");
372 }
373 }
374
375 if i > 0 {
379 let trimmed = line.trim();
380 let prev = lines[i - 1].trim();
381 let prev_is_await_end = prev.ends_with(".await?;")
382 || prev.ends_with(".await?")
383 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
384 if prev_is_await_end
385 && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
386 {
387 result.push("");
388 }
389 if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
391 && prev.starts_with("let ") && !prev.contains("sqlx::")
392 {
393 result.push("");
394 }
395 }
396
397 result.push(line);
398 }
399
400 result.join("\n")
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::introspect::{
407 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
408 };
409 use std::collections::HashMap;
410
411 #[test]
414 fn test_keyword_type() {
415 assert!(is_rust_keyword("type"));
416 }
417
418 #[test]
419 fn test_keyword_fn() {
420 assert!(is_rust_keyword("fn"));
421 }
422
423 #[test]
424 fn test_keyword_let() {
425 assert!(is_rust_keyword("let"));
426 }
427
428 #[test]
429 fn test_keyword_match() {
430 assert!(is_rust_keyword("match"));
431 }
432
433 #[test]
434 fn test_keyword_async() {
435 assert!(is_rust_keyword("async"));
436 }
437
438 #[test]
439 fn test_keyword_await() {
440 assert!(is_rust_keyword("await"));
441 }
442
443 #[test]
444 fn test_keyword_yield() {
445 assert!(is_rust_keyword("yield"));
446 }
447
448 #[test]
449 fn test_keyword_abstract() {
450 assert!(is_rust_keyword("abstract"));
451 }
452
453 #[test]
454 fn test_keyword_try() {
455 assert!(is_rust_keyword("try"));
456 }
457
458 #[test]
459 fn test_not_keyword_name() {
460 assert!(!is_rust_keyword("name"));
461 }
462
463 #[test]
464 fn test_not_keyword_id() {
465 assert!(!is_rust_keyword("id"));
466 }
467
468 #[test]
469 fn test_not_keyword_uppercase_type() {
470 assert!(!is_rust_keyword("Type"));
471 }
472
473 #[test]
476 fn test_normalize_no_underscores() {
477 assert_eq!(normalize_module_name("users"), "users");
478 }
479
480 #[test]
481 fn test_normalize_single_underscore() {
482 assert_eq!(normalize_module_name("user_roles"), "user_roles");
483 }
484
485 #[test]
486 fn test_normalize_double_underscore() {
487 assert_eq!(normalize_module_name("user__roles"), "user_roles");
488 }
489
490 #[test]
491 fn test_normalize_triple_underscore() {
492 assert_eq!(normalize_module_name("a___b"), "a_b");
493 }
494
495 #[test]
496 fn test_normalize_leading_underscore() {
497 assert_eq!(normalize_module_name("_private"), "_private");
498 }
499
500 #[test]
501 fn test_normalize_trailing_underscore() {
502 assert_eq!(normalize_module_name("name_"), "name_");
503 }
504
505 #[test]
506 fn test_normalize_double_leading() {
507 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
508 }
509
510 #[test]
511 fn test_normalize_multiple_groups() {
512 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
513 }
514
515 #[test]
518 fn test_build_no_collision_no_prefix() {
519 assert_eq!(build_module_name("public", "users", false), "users");
520 }
521
522 #[test]
523 fn test_build_no_collision_non_default_no_prefix() {
524 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
525 }
526
527 #[test]
528 fn test_build_collision_prefixed() {
529 assert_eq!(build_module_name("billing", "users", true), "billing_users");
530 }
531
532 #[test]
533 fn test_build_collision_default_schema_no_prefix() {
534 assert_eq!(build_module_name("public", "users", true), "users");
535 }
536
537 #[test]
538 fn test_build_collision_normalizes_double_underscore() {
539 assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
540 }
541
542 #[test]
545 fn test_default_schema_public() {
546 assert!(is_default_schema("public"));
547 }
548
549 #[test]
550 fn test_default_schema_main() {
551 assert!(is_default_schema("main"));
552 }
553
554 #[test]
555 fn test_non_default_schema() {
556 assert!(!is_default_schema("billing"));
557 }
558
559 #[test]
562 fn test_imports_empty() {
563 let result = imports_for_derives(&[]);
564 assert!(result.is_empty());
565 }
566
567 #[test]
568 fn test_imports_serialize_only() {
569 let derives = vec!["Serialize".to_string()];
570 let result = imports_for_derives(&derives);
571 assert_eq!(result, vec!["use serde::{Serialize};"]);
572 }
573
574 #[test]
575 fn test_imports_deserialize_only() {
576 let derives = vec!["Deserialize".to_string()];
577 let result = imports_for_derives(&derives);
578 assert_eq!(result, vec!["use serde::{Deserialize};"]);
579 }
580
581 #[test]
582 fn test_imports_both_serde() {
583 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
584 let result = imports_for_derives(&derives);
585 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
586 }
587
588 #[test]
589 fn test_imports_non_serde() {
590 let derives = vec!["Hash".to_string()];
591 let result = imports_for_derives(&derives);
592 assert!(result.is_empty());
593 }
594
595 #[test]
596 fn test_imports_non_serde_multiple() {
597 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
598 let result = imports_for_derives(&derives);
599 assert!(result.is_empty());
600 }
601
602 #[test]
603 fn test_imports_mixed_serde_and_others() {
604 let derives = vec![
605 "Serialize".to_string(),
606 "Hash".to_string(),
607 "Deserialize".to_string(),
608 ];
609 let result = imports_for_derives(&derives);
610 assert_eq!(result.len(), 1);
611 assert!(result[0].contains("Serialize"));
612 assert!(result[0].contains("Deserialize"));
613 }
614
615 #[test]
618 fn test_blank_lines_between_renamed_variants() {
619 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
620 let result = add_blank_lines_between_items(input);
621 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
622 }
623
624 #[test]
625 fn test_no_blank_line_for_first_variant() {
626 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
627 let result = add_blank_lines_between_items(input);
628 assert!(!result.contains("{\n\n"));
630 }
631
632 #[test]
633 fn test_no_change_without_rename() {
634 let input = "pub enum Foo {\n A,\n B,\n}";
635 let result = add_blank_lines_between_items(input);
636 assert_eq!(result, input);
637 }
638
639 #[test]
640 fn test_no_change_for_struct() {
641 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
642 let result = add_blank_lines_between_items(input);
643 assert_eq!(result, input);
644 }
645
646 #[test]
649 fn test_filter_single_file_strips_super_types() {
650 let mut imports = BTreeSet::new();
651 imports.insert("use super::types::Foo;".to_string());
652 imports.insert("use chrono::NaiveDateTime;".to_string());
653 let result = filter_imports(&imports, true);
654 assert!(!result.contains("use super::types::Foo;"));
655 assert!(result.contains("use chrono::NaiveDateTime;"));
656 }
657
658 #[test]
659 fn test_filter_single_file_keeps_other_imports() {
660 let mut imports = BTreeSet::new();
661 imports.insert("use chrono::NaiveDateTime;".to_string());
662 let result = filter_imports(&imports, true);
663 assert!(result.contains("use chrono::NaiveDateTime;"));
664 }
665
666 #[test]
667 fn test_filter_multi_file_keeps_all() {
668 let mut imports = BTreeSet::new();
669 imports.insert("use super::types::Foo;".to_string());
670 imports.insert("use chrono::NaiveDateTime;".to_string());
671 let result = filter_imports(&imports, false);
672 assert_eq!(result.len(), 2);
673 }
674
675 #[test]
676 fn test_filter_empty_set() {
677 let imports = BTreeSet::new();
678 let result = filter_imports(&imports, true);
679 assert!(result.is_empty());
680 }
681
682 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
685 TableInfo {
686 schema_name: "public".to_string(),
687 name: name.to_string(),
688 columns,
689 }
690 }
691
692 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
693 ColumnInfo {
694 name: name.to_string(),
695 data_type: udt_name.to_string(),
696 udt_name: udt_name.to_string(),
697 is_nullable: false,
698 is_primary_key: false,
699 ordinal_position: 0,
700 schema_name: "public".to_string(),
701 column_default: None,
702 }
703 }
704
705 #[test]
706 fn test_generate_empty_schema() {
707 let schema = SchemaInfo::default();
708 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
709 assert!(files.is_empty());
710 }
711
712 #[test]
713 fn test_generate_one_table() {
714 let schema = SchemaInfo {
715 tables: vec![make_table("users", vec![make_col("id", "int4")])],
716 ..Default::default()
717 };
718 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
719 assert_eq!(files.len(), 1);
720 assert_eq!(files[0].filename, "users.rs");
721 }
722
723 #[test]
724 fn test_generate_two_tables() {
725 let schema = SchemaInfo {
726 tables: vec![
727 make_table("users", vec![make_col("id", "int4")]),
728 make_table("posts", vec![make_col("id", "int4")]),
729 ],
730 ..Default::default()
731 };
732 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
733 assert_eq!(files.len(), 2);
734 }
735
736 #[test]
737 fn test_generate_enum_creates_types_file() {
738 let schema = SchemaInfo {
739 enums: vec![EnumInfo {
740 schema_name: "public".to_string(),
741 name: "status".to_string(),
742 variants: vec!["active".to_string(), "inactive".to_string()],
743 default_variant: None,
744 }],
745 ..Default::default()
746 };
747 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
748 assert_eq!(files.len(), 1);
749 assert_eq!(files[0].filename, "types.rs");
750 }
751
752 #[test]
753 fn test_generate_enums_composites_domains_single_types_file() {
754 let schema = SchemaInfo {
755 enums: vec![EnumInfo {
756 schema_name: "public".to_string(),
757 name: "status".to_string(),
758 variants: vec!["active".to_string()],
759 default_variant: None,
760 }],
761 composite_types: vec![CompositeTypeInfo {
762 schema_name: "public".to_string(),
763 name: "address".to_string(),
764 fields: vec![make_col("street", "text")],
765 }],
766 domains: vec![DomainInfo {
767 schema_name: "public".to_string(),
768 name: "email".to_string(),
769 base_type: "text".to_string(),
770 }],
771 ..Default::default()
772 };
773 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
774 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
776 assert_eq!(types_files.len(), 1);
777 }
778
779 #[test]
780 fn test_generate_tables_and_enums() {
781 let schema = SchemaInfo {
782 tables: vec![make_table("users", vec![make_col("id", "int4")])],
783 enums: vec![EnumInfo {
784 schema_name: "public".to_string(),
785 name: "status".to_string(),
786 variants: vec!["active".to_string()],
787 default_variant: None,
788 }],
789 ..Default::default()
790 };
791 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
792 assert_eq!(files.len(), 2); }
794
795 #[test]
796 fn test_generate_filename_normalized() {
797 let schema = SchemaInfo {
798 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
799 ..Default::default()
800 };
801 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
802 assert_eq!(files[0].filename, "user_data.rs");
803 }
804
805 #[test]
806 fn test_generate_no_origin_for_tables() {
807 let schema = SchemaInfo {
808 tables: vec![make_table("users", vec![make_col("id", "int4")])],
809 ..Default::default()
810 };
811 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
812 assert_eq!(files[0].origin, None);
813 }
814
815 #[test]
816 fn test_generate_types_no_origin() {
817 let schema = SchemaInfo {
818 enums: vec![EnumInfo {
819 schema_name: "public".to_string(),
820 name: "status".to_string(),
821 variants: vec!["active".to_string()],
822 default_variant: None,
823 }],
824 ..Default::default()
825 };
826 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
827 assert_eq!(files[0].origin, None);
828 }
829
830 #[test]
831 fn test_generate_single_file_filters_super_types_imports() {
832 let schema = SchemaInfo {
833 tables: vec![make_table("users", vec![make_col("id", "int4")])],
834 enums: vec![EnumInfo {
835 schema_name: "public".to_string(),
836 name: "status".to_string(),
837 variants: vec!["active".to_string()],
838 default_variant: None,
839 }],
840 ..Default::default()
841 };
842 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
843 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
845 assert!(!struct_file.code.contains("super::types::"));
846 }
847
848 #[test]
849 fn test_generate_multi_file_keeps_super_types_imports() {
850 let schema = SchemaInfo {
852 tables: vec![make_table("users", vec![make_col("status", "status")])],
853 enums: vec![EnumInfo {
854 schema_name: "public".to_string(),
855 name: "status".to_string(),
856 variants: vec!["active".to_string()],
857 default_variant: None,
858 }],
859 ..Default::default()
860 };
861 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
862 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
863 assert!(struct_file.code.contains("super::types::"));
864 }
865
866 #[test]
867 fn test_generate_extra_derives_in_struct() {
868 let schema = SchemaInfo {
869 tables: vec![make_table("users", vec![make_col("id", "int4")])],
870 ..Default::default()
871 };
872 let derives = vec!["Serialize".to_string()];
873 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
874 assert!(files[0].code.contains("Serialize"));
875 }
876
877 #[test]
878 fn test_generate_extra_derives_in_enum() {
879 let schema = SchemaInfo {
880 enums: vec![EnumInfo {
881 schema_name: "public".to_string(),
882 name: "status".to_string(),
883 variants: vec!["active".to_string()],
884 default_variant: None,
885 }],
886 ..Default::default()
887 };
888 let derives = vec!["Serialize".to_string()];
889 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false);
890 assert!(files[0].code.contains("Serialize"));
891 }
892
893 #[test]
894 fn test_generate_type_overrides_in_struct() {
895 let mut overrides = HashMap::new();
896 overrides.insert("jsonb".to_string(), "MyJson".to_string());
897 let schema = SchemaInfo {
898 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
899 ..Default::default()
900 };
901 let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false);
902 assert!(files[0].code.contains("MyJson"));
903 }
904
905 #[test]
906 fn test_generate_valid_rust_syntax() {
907 let schema = SchemaInfo {
908 tables: vec![make_table("users", vec![
909 make_col("id", "int4"),
910 make_col("name", "text"),
911 ])],
912 enums: vec![EnumInfo {
913 schema_name: "public".to_string(),
914 name: "status".to_string(),
915 variants: vec!["active".to_string(), "inactive".to_string()],
916 default_variant: None,
917 }],
918 ..Default::default()
919 };
920 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
921 for f in &files {
922 let parse_result = syn::parse_file(&f.code);
924 assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
925 }
926 }
927
928 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
931 TableInfo {
932 schema_name: "public".to_string(),
933 name: name.to_string(),
934 columns,
935 }
936 }
937
938 #[test]
939 fn test_generate_one_view() {
940 let schema = SchemaInfo {
941 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
942 ..Default::default()
943 };
944 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
945 assert_eq!(files.len(), 1);
946 assert_eq!(files[0].filename, "active_users.rs");
947 }
948
949 #[test]
950 fn test_generate_no_origin_for_views() {
951 let schema = SchemaInfo {
952 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
953 ..Default::default()
954 };
955 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
956 assert_eq!(files[0].origin, None);
957 }
958
959 #[test]
960 fn test_generate_tables_and_views() {
961 let schema = SchemaInfo {
962 tables: vec![make_table("users", vec![make_col("id", "int4")])],
963 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
964 ..Default::default()
965 };
966 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
967 assert_eq!(files.len(), 2);
968 }
969
970 #[test]
971 fn test_generate_view_valid_rust() {
972 let schema = SchemaInfo {
973 views: vec![make_view("active_users", vec![
974 make_col("id", "int4"),
975 make_col("name", "text"),
976 ])],
977 ..Default::default()
978 };
979 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
980 let parse_result = syn::parse_file(&files[0].code);
981 assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
982 }
983
984 #[test]
985 fn test_generate_view_nullable_column() {
986 let schema = SchemaInfo {
987 views: vec![make_view("v", vec![ColumnInfo {
988 name: "email".to_string(),
989 data_type: "text".to_string(),
990 udt_name: "text".to_string(),
991 is_nullable: true,
992 is_primary_key: false,
993 ordinal_position: 0,
994 schema_name: "public".to_string(),
995 column_default: None,
996 }])],
997 ..Default::default()
998 };
999 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1000 assert!(files[0].code.contains("Option<String>"));
1001 }
1002
1003 #[test]
1004 fn test_generate_collision_both_prefixed() {
1005 let schema = SchemaInfo {
1006 tables: vec![
1007 make_table("users", vec![make_col("id", "int4")]),
1008 TableInfo {
1009 schema_name: "billing".to_string(),
1010 name: "users".to_string(),
1011 columns: vec![make_col("id", "int4")],
1012 },
1013 ],
1014 ..Default::default()
1015 };
1016 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1017 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1018 assert!(filenames.contains(&"users.rs"));
1019 assert!(filenames.contains(&"billing_users.rs"));
1020 }
1021
1022 #[test]
1023 fn test_generate_no_collision_no_prefix() {
1024 let schema = SchemaInfo {
1025 tables: vec![
1026 make_table("users", vec![make_col("id", "int4")]),
1027 TableInfo {
1028 schema_name: "billing".to_string(),
1029 name: "invoices".to_string(),
1030 columns: vec![make_col("id", "int4")],
1031 },
1032 ],
1033 ..Default::default()
1034 };
1035 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1036 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1037 assert!(filenames.contains(&"users.rs"));
1038 assert!(filenames.contains(&"invoices.rs"));
1039 }
1040
1041 #[test]
1042 fn test_generate_single_schema_no_prefix() {
1043 let schema = SchemaInfo {
1044 tables: vec![
1045 make_table("users", vec![make_col("id", "int4")]),
1046 make_table("posts", vec![make_col("id", "int4")]),
1047 ],
1048 ..Default::default()
1049 };
1050 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1051 assert_eq!(files[0].filename, "users.rs");
1052 assert_eq!(files[1].filename, "posts.rs");
1053 }
1054
1055 #[test]
1056 fn test_generate_view_single_file_mode() {
1057 let schema = SchemaInfo {
1058 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1059 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1060 ..Default::default()
1061 };
1062 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true);
1063 assert_eq!(files.len(), 2);
1064 }
1065
1066 #[test]
1069 fn test_parse_pg_enum_default_simple() {
1070 assert_eq!(
1071 parse_pg_enum_default("'idle'::task_status"),
1072 Some("idle".to_string())
1073 );
1074 }
1075
1076 #[test]
1077 fn test_parse_pg_enum_default_schema_qualified() {
1078 assert_eq!(
1079 parse_pg_enum_default("'active'::public.task_status"),
1080 Some("active".to_string())
1081 );
1082 }
1083
1084 #[test]
1085 fn test_parse_pg_enum_default_not_enum() {
1086 assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1088 }
1089
1090 #[test]
1091 fn test_parse_pg_enum_default_no_cast() {
1092 assert_eq!(parse_pg_enum_default("'hello'"), None);
1093 }
1094
1095 #[test]
1096 fn test_parse_pg_enum_default_empty() {
1097 assert_eq!(parse_pg_enum_default(""), None);
1098 }
1099
1100 #[test]
1103 fn test_extract_enum_defaults_from_column() {
1104 let schema = SchemaInfo {
1105 tables: vec![TableInfo {
1106 schema_name: "public".to_string(),
1107 name: "tasks".to_string(),
1108 columns: vec![ColumnInfo {
1109 name: "status".to_string(),
1110 data_type: "USER-DEFINED".to_string(),
1111 udt_name: "task_status".to_string(),
1112 is_nullable: false,
1113 is_primary_key: false,
1114 ordinal_position: 0,
1115 schema_name: "public".to_string(),
1116 column_default: Some("'idle'::task_status".to_string()),
1117 }],
1118 }],
1119 enums: vec![EnumInfo {
1120 schema_name: "public".to_string(),
1121 name: "task_status".to_string(),
1122 variants: vec!["idle".to_string(), "running".to_string()],
1123 default_variant: None,
1124 }],
1125 ..Default::default()
1126 };
1127 let defaults = extract_enum_defaults(&schema);
1128 assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1129 }
1130
1131 #[test]
1132 fn test_extract_enum_defaults_no_default() {
1133 let schema = SchemaInfo {
1134 tables: vec![TableInfo {
1135 schema_name: "public".to_string(),
1136 name: "tasks".to_string(),
1137 columns: vec![ColumnInfo {
1138 name: "status".to_string(),
1139 data_type: "USER-DEFINED".to_string(),
1140 udt_name: "task_status".to_string(),
1141 is_nullable: false,
1142 is_primary_key: false,
1143 ordinal_position: 0,
1144 schema_name: "public".to_string(),
1145 column_default: None,
1146 }],
1147 }],
1148 enums: vec![EnumInfo {
1149 schema_name: "public".to_string(),
1150 name: "task_status".to_string(),
1151 variants: vec!["idle".to_string()],
1152 default_variant: None,
1153 }],
1154 ..Default::default()
1155 };
1156 let defaults = extract_enum_defaults(&schema);
1157 assert!(defaults.is_empty());
1158 }
1159
1160 #[test]
1161 fn test_extract_enum_defaults_non_enum_column_ignored() {
1162 let schema = SchemaInfo {
1163 tables: vec![TableInfo {
1164 schema_name: "public".to_string(),
1165 name: "users".to_string(),
1166 columns: vec![ColumnInfo {
1167 name: "name".to_string(),
1168 data_type: "character varying".to_string(),
1169 udt_name: "varchar".to_string(),
1170 is_nullable: false,
1171 is_primary_key: false,
1172 ordinal_position: 0,
1173 schema_name: "public".to_string(),
1174 column_default: Some("'hello'::character varying".to_string()),
1175 }],
1176 }],
1177 enums: vec![],
1178 ..Default::default()
1179 };
1180 let defaults = extract_enum_defaults(&schema);
1181 assert!(defaults.is_empty());
1182 }
1183
1184 #[test]
1185 fn test_generate_enum_with_default() {
1186 let schema = SchemaInfo {
1187 tables: vec![TableInfo {
1188 schema_name: "public".to_string(),
1189 name: "tasks".to_string(),
1190 columns: vec![ColumnInfo {
1191 name: "status".to_string(),
1192 data_type: "USER-DEFINED".to_string(),
1193 udt_name: "task_status".to_string(),
1194 is_nullable: false,
1195 is_primary_key: false,
1196 ordinal_position: 0,
1197 schema_name: "public".to_string(),
1198 column_default: Some("'idle'::task_status".to_string()),
1199 }],
1200 }],
1201 enums: vec![EnumInfo {
1202 schema_name: "public".to_string(),
1203 name: "task_status".to_string(),
1204 variants: vec!["idle".to_string(), "running".to_string()],
1205 default_variant: None,
1206 }],
1207 ..Default::default()
1208 };
1209 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false);
1210 let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1211 assert!(types_file.code.contains("impl Default for TaskStatus"));
1212 assert!(types_file.code.contains("Self::Idle"));
1213 }
1214}