1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod struct_gen;
7
8use std::collections::{BTreeSet, HashMap};
9use std::path::Path;
10
11use proc_macro2::TokenStream;
12
13use crate::cli::{DatabaseKind, TimeCrate};
14use crate::introspect::SchemaInfo;
15
16const RUST_KEYWORDS: &[&str] = &[
18 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
19 "extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
20 "mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true",
21 "type", "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do",
22 "final", "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
23];
24
25pub fn is_rust_keyword(name: &str) -> bool {
27 RUST_KEYWORDS.contains(&name)
28}
29
30pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
32 let mut imports = Vec::new();
33 let has = |name: &str| extra_derives.iter().any(|d| d == name);
34 if has("Serialize") || has("Deserialize") {
35 let mut parts = Vec::new();
36 if has("Serialize") {
37 parts.push("Serialize");
38 }
39 if has("Deserialize") {
40 parts.push("Deserialize");
41 }
42 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
43 }
44 imports
45}
46
47pub fn normalize_module_name(name: &str) -> String {
50 let mut result = String::with_capacity(name.len());
51 let mut prev_underscore = false;
52 for c in name.chars() {
53 if c == '_' {
54 if !prev_underscore {
55 result.push(c);
56 }
57 prev_underscore = true;
58 } else {
59 prev_underscore = false;
60 result.push(c);
61 }
62 }
63 result
64}
65
66const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
68
69pub fn is_default_schema(schema: &str) -> bool {
71 DEFAULT_SCHEMAS.contains(&schema)
72}
73
74pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
77 if name_collides && !is_default_schema(schema_name) {
78 normalize_module_name(&format!("{}_{}", schema_name, table_name))
79 } else {
80 normalize_module_name(table_name)
81 }
82}
83
84fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
86 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
87 for t in &schema_info.tables {
88 seen.entry(t.name.as_str()).or_default().insert(t.schema_name.as_str());
89 }
90 for v in &schema_info.views {
91 seen.entry(v.name.as_str()).or_default().insert(v.schema_name.as_str());
92 }
93 seen.into_iter()
94 .filter(|(_, schemas)| schemas.len() > 1)
95 .map(|(name, _)| name)
96 .collect()
97}
98
99#[derive(Debug, Clone)]
101pub struct GeneratedFile {
102 pub filename: String,
103 pub origin: Option<String>,
105 pub code: String,
106}
107
108pub fn generate(
110 schema_info: &SchemaInfo,
111 db_kind: DatabaseKind,
112 extra_derives: &[String],
113 type_overrides: &HashMap<String, String>,
114 single_file: bool,
115 time_crate: TimeCrate,
116) -> Vec<GeneratedFile> {
117 let mut files = Vec::new();
118
119 let colliding_names = find_colliding_names(schema_info);
121
122 for table in &schema_info.tables {
124 let (tokens, imports) =
125 struct_gen::generate_struct(table, db_kind, schema_info, extra_derives, type_overrides, false, time_crate);
126 let imports = filter_imports(&imports, single_file);
127 let code = format_tokens_with_imports(&tokens, &imports);
128 let module_name = build_module_name(&table.schema_name, &table.name, colliding_names.contains(table.name.as_str()));
129 files.push(GeneratedFile {
130 filename: format!("{}.rs", module_name),
131 origin: None,
132 code,
133 });
134 }
135
136 for view in &schema_info.views {
138 let (tokens, imports) =
139 struct_gen::generate_struct(view, db_kind, schema_info, extra_derives, type_overrides, true, time_crate);
140 let imports = filter_imports(&imports, single_file);
141 let code = format_tokens_with_imports(&tokens, &imports);
142 let module_name = build_module_name(&view.schema_name, &view.name, colliding_names.contains(view.name.as_str()));
143 files.push(GeneratedFile {
144 filename: format!("{}.rs", module_name),
145 origin: None,
146 code,
147 });
148 }
149
150 let mut types_blocks: Vec<String> = Vec::new();
153 let mut types_imports = BTreeSet::new();
154
155 let enum_defaults = extract_enum_defaults(schema_info);
157 for enum_info in &schema_info.enums {
158 let mut enriched = enum_info.clone();
159 if enriched.default_variant.is_none() {
160 if let Some(default) = enum_defaults.get(&enum_info.name) {
161 enriched.default_variant = Some(default.clone());
162 }
163 }
164 let (tokens, imports) = enum_gen::generate_enum(&enriched, db_kind, extra_derives);
165 types_blocks.push(format_tokens(&tokens));
166 types_imports.extend(imports);
167 }
168
169 for composite in &schema_info.composite_types {
170 let (tokens, imports) = composite_gen::generate_composite(
171 composite,
172 db_kind,
173 schema_info,
174 extra_derives,
175 type_overrides,
176 time_crate,
177 );
178 types_blocks.push(format_tokens(&tokens));
179 types_imports.extend(imports);
180 }
181
182 for domain in &schema_info.domains {
183 let (tokens, imports) =
184 domain_gen::generate_domain(domain, db_kind, schema_info, type_overrides, time_crate);
185 types_blocks.push(format_tokens(&tokens));
186 types_imports.extend(imports);
187 }
188
189 if !types_blocks.is_empty() {
190 let import_lines: String = types_imports
191 .iter()
192 .map(|i| format!("{}\n", i))
193 .collect();
194 let body = types_blocks.join("\n");
195 let code = if import_lines.is_empty() {
196 body
197 } else {
198 format!("{}\n\n{}", import_lines.trim_end(), body)
199 };
200 files.push(GeneratedFile {
201 filename: "types.rs".to_string(),
202 origin: None,
203 code,
204 });
205 }
206
207 files
208}
209
210fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
213 let mut defaults: HashMap<String, String> = HashMap::new();
214
215 let all_columns = schema_info
216 .tables
217 .iter()
218 .chain(schema_info.views.iter())
219 .flat_map(|t| t.columns.iter());
220
221 for col in all_columns {
222 let default_expr = match &col.column_default {
223 Some(d) => d,
224 None => continue,
225 };
226
227 let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
229
230 let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
232 if enum_match.is_none() {
233 continue;
234 }
235
236 if let Some(variant) = parse_pg_enum_default(default_expr) {
238 defaults.entry(base_udt.to_string()).or_insert(variant);
239 }
240 }
241
242 defaults
243}
244
245fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
248 let stripped = default_expr.trim();
250 if stripped.starts_with('\'') {
251 if let Some(end_quote) = stripped[1..].find('\'') {
252 let value = &stripped[1..1 + end_quote];
253 let rest = &stripped[2 + end_quote..];
255 if rest.starts_with("::") {
256 return Some(value.to_string());
257 }
258 }
259 }
260 None
261}
262
263fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
265 if single_file {
266 imports
267 .iter()
268 .filter(|i| !i.contains("super::types::"))
269 .cloned()
270 .collect()
271 } else {
272 imports.clone()
273 }
274}
275
276pub fn detect_tab_spaces(start_dir: &Path) -> usize {
279 let mut dir = if start_dir.is_file() {
280 start_dir.parent().unwrap_or(start_dir)
281 } else {
282 start_dir
283 };
284 loop {
285 for name in &["rustfmt.toml", ".rustfmt.toml"] {
286 let candidate = dir.join(name);
287 if let Ok(content) = std::fs::read_to_string(&candidate) {
288 for line in content.lines() {
289 let line = line.trim();
290 if let Some(rest) = line.strip_prefix("tab_spaces") {
291 let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
292 if let Ok(n) = rest.trim().parse::<usize>() {
293 return n;
294 }
295 }
296 }
297 return 4;
299 }
300 }
301 match dir.parent() {
302 Some(parent) => dir = parent,
303 None => return 4,
304 }
305 }
306}
307
308pub(crate) fn parse_and_format(tokens: &TokenStream) -> String {
311 parse_and_format_with_tab_spaces(tokens, 4)
312}
313
314pub(crate) fn parse_and_format_with_tab_spaces(tokens: &TokenStream, tab_spaces: usize) -> String {
315 let file = syn::parse2::<syn::File>(tokens.clone()).unwrap_or_else(|e| {
316 log::error!("Failed to parse generated code: {}", e);
317 log::error!("This is a bug in sqlx-gen. Raw tokens:\n {}", tokens);
318 std::process::exit(1);
319 });
320 let raw = prettyplease::unparse(&file);
321 let raw = indent_multiline_raw_strings(&raw, tab_spaces);
322 add_blank_lines_between_items(&raw)
323}
324
325pub(crate) fn format_tokens(tokens: &TokenStream) -> String {
327 parse_and_format(tokens)
328}
329
330pub fn format_tokens_with_imports(tokens: &TokenStream, imports: &BTreeSet<String>) -> String {
331 format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
332}
333
334pub fn format_tokens_with_imports_and_tab_spaces(tokens: &TokenStream, imports: &BTreeSet<String>, tab_spaces: usize) -> String {
335 let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces);
336
337 let used_imports: Vec<&String> = imports
338 .iter()
339 .filter(|imp| is_import_used(imp, &formatted))
340 .collect();
341
342 if used_imports.is_empty() {
343 formatted
344 } else {
345 let import_lines: String = used_imports
346 .iter()
347 .map(|i| format!("{}\n", i))
348 .collect();
349 format!("{}\n\n{}", import_lines.trim_end(), formatted)
350 }
351}
352
353fn is_import_used(import: &str, code: &str) -> bool {
356 let trimmed = import.trim().trim_end_matches(';');
360 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
361
362 if path.ends_with("::*") {
363 return true;
364 }
365
366 if let Some(start) = path.find('{') {
368 if let Some(end) = path.find('}') {
369 let names = &path[start + 1..end];
370 return names
371 .split(',')
372 .map(|n| n.trim())
373 .filter(|n| !n.is_empty())
374 .any(|name| code.contains(name));
375 }
376 }
377
378 if let Some(name) = path.rsplit("::").next() {
380 return code.contains(name);
381 }
382
383 true
384}
385
386fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
396 let close_indent = 4 + tab_spaces; let sql_indent = 4 + 2 * tab_spaces; let lines: Vec<&str> = code.lines().collect();
404 let mut result = Vec::with_capacity(lines.len());
405 let mut inside_raw = false;
406 let mut raw_lines: Vec<&str> = Vec::new();
407
408 for line in &lines {
409 if !inside_raw {
410 if let Some(pos) = line.find("r#\"") {
411 let after = &line[pos + 3..];
412 if !after.contains("\"#") {
413 inside_raw = true;
414 raw_lines.clear();
415 }
416 }
417 result.push(line.to_string());
418 } else if line.trim_start().starts_with("\"#") {
419 let min_indent = raw_lines
421 .iter()
422 .filter(|l| !l.trim().is_empty())
423 .map(|l| l.len() - l.trim_start().len())
424 .min()
425 .unwrap_or(0);
426 for raw_line in &raw_lines {
427 let trimmed = raw_line.trim();
428 if trimmed.is_empty() {
429 result.push(String::new());
430 } else {
431 let original_indent = raw_line.len() - raw_line.trim_start().len();
432 let relative = original_indent.saturating_sub(min_indent);
433 result.push(format!(
434 "{}{}{}",
435 " ".repeat(sql_indent),
436 " ".repeat(relative),
437 trimmed
438 ));
439 }
440 }
441 let trimmed = line.trim();
443 result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
444 inside_raw = false;
445 } else {
446 raw_lines.push(line);
447 }
448 }
449
450 result.join("\n")
451}
452
453fn add_blank_lines_between_items(code: &str) -> String {
454 let lines: Vec<&str> = code.lines().collect();
455 let mut result = Vec::with_capacity(lines.len());
456
457 for (i, line) in lines.iter().enumerate() {
458 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
461 let prev = lines[i - 1].trim();
462 if prev.ends_with(',') {
463 result.push("");
464 }
465 }
466
467 if i > 0 {
470 let trimmed = line.trim();
471 let prev = lines[i - 1].trim();
472 if prev == "}"
473 && (trimmed.starts_with("pub struct")
474 || trimmed.starts_with("impl ")
475 || trimmed.starts_with("#[derive")
476 || trimmed.starts_with("pub async fn")
477 || trimmed.starts_with("pub fn"))
478 {
479 result.push("");
480 }
481 }
482
483 if i > 0 {
487 let trimmed = line.trim();
488 let prev = lines[i - 1].trim();
489 let prev_is_await_end = prev.ends_with(".await?;")
490 || prev.ends_with(".await?")
491 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
492 if prev_is_await_end
493 && (trimmed.starts_with("let ") || trimmed.starts_with("Ok("))
494 {
495 result.push("");
496 }
497 if trimmed.starts_with("let ") && trimmed.contains("sqlx::")
499 && prev.starts_with("let ") && !prev.contains("sqlx::")
500 {
501 result.push("");
502 }
503 }
504
505 result.push(line);
506 }
507
508 result.join("\n")
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use crate::introspect::{
515 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
516 };
517 use std::collections::HashMap;
518
519 #[test]
522 fn test_keyword_type() {
523 assert!(is_rust_keyword("type"));
524 }
525
526 #[test]
527 fn test_keyword_fn() {
528 assert!(is_rust_keyword("fn"));
529 }
530
531 #[test]
532 fn test_keyword_let() {
533 assert!(is_rust_keyword("let"));
534 }
535
536 #[test]
537 fn test_keyword_match() {
538 assert!(is_rust_keyword("match"));
539 }
540
541 #[test]
542 fn test_keyword_async() {
543 assert!(is_rust_keyword("async"));
544 }
545
546 #[test]
547 fn test_keyword_await() {
548 assert!(is_rust_keyword("await"));
549 }
550
551 #[test]
552 fn test_keyword_yield() {
553 assert!(is_rust_keyword("yield"));
554 }
555
556 #[test]
557 fn test_keyword_abstract() {
558 assert!(is_rust_keyword("abstract"));
559 }
560
561 #[test]
562 fn test_keyword_try() {
563 assert!(is_rust_keyword("try"));
564 }
565
566 #[test]
567 fn test_not_keyword_name() {
568 assert!(!is_rust_keyword("name"));
569 }
570
571 #[test]
572 fn test_not_keyword_id() {
573 assert!(!is_rust_keyword("id"));
574 }
575
576 #[test]
577 fn test_not_keyword_uppercase_type() {
578 assert!(!is_rust_keyword("Type"));
579 }
580
581 #[test]
584 fn test_normalize_no_underscores() {
585 assert_eq!(normalize_module_name("users"), "users");
586 }
587
588 #[test]
589 fn test_normalize_single_underscore() {
590 assert_eq!(normalize_module_name("user_roles"), "user_roles");
591 }
592
593 #[test]
594 fn test_normalize_double_underscore() {
595 assert_eq!(normalize_module_name("user__roles"), "user_roles");
596 }
597
598 #[test]
599 fn test_normalize_triple_underscore() {
600 assert_eq!(normalize_module_name("a___b"), "a_b");
601 }
602
603 #[test]
604 fn test_normalize_leading_underscore() {
605 assert_eq!(normalize_module_name("_private"), "_private");
606 }
607
608 #[test]
609 fn test_normalize_trailing_underscore() {
610 assert_eq!(normalize_module_name("name_"), "name_");
611 }
612
613 #[test]
614 fn test_normalize_double_leading() {
615 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
616 }
617
618 #[test]
619 fn test_normalize_multiple_groups() {
620 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
621 }
622
623 #[test]
626 fn test_build_no_collision_no_prefix() {
627 assert_eq!(build_module_name("public", "users", false), "users");
628 }
629
630 #[test]
631 fn test_build_no_collision_non_default_no_prefix() {
632 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
633 }
634
635 #[test]
636 fn test_build_collision_prefixed() {
637 assert_eq!(build_module_name("billing", "users", true), "billing_users");
638 }
639
640 #[test]
641 fn test_build_collision_default_schema_no_prefix() {
642 assert_eq!(build_module_name("public", "users", true), "users");
643 }
644
645 #[test]
646 fn test_build_collision_normalizes_double_underscore() {
647 assert_eq!(build_module_name("billing", "agent__connector", true), "billing_agent_connector");
648 }
649
650 #[test]
653 fn test_default_schema_public() {
654 assert!(is_default_schema("public"));
655 }
656
657 #[test]
658 fn test_default_schema_main() {
659 assert!(is_default_schema("main"));
660 }
661
662 #[test]
663 fn test_non_default_schema() {
664 assert!(!is_default_schema("billing"));
665 }
666
667 #[test]
670 fn test_imports_empty() {
671 let result = imports_for_derives(&[]);
672 assert!(result.is_empty());
673 }
674
675 #[test]
676 fn test_imports_serialize_only() {
677 let derives = vec!["Serialize".to_string()];
678 let result = imports_for_derives(&derives);
679 assert_eq!(result, vec!["use serde::{Serialize};"]);
680 }
681
682 #[test]
683 fn test_imports_deserialize_only() {
684 let derives = vec!["Deserialize".to_string()];
685 let result = imports_for_derives(&derives);
686 assert_eq!(result, vec!["use serde::{Deserialize};"]);
687 }
688
689 #[test]
690 fn test_imports_both_serde() {
691 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
692 let result = imports_for_derives(&derives);
693 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
694 }
695
696 #[test]
697 fn test_imports_non_serde() {
698 let derives = vec!["Hash".to_string()];
699 let result = imports_for_derives(&derives);
700 assert!(result.is_empty());
701 }
702
703 #[test]
704 fn test_imports_non_serde_multiple() {
705 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
706 let result = imports_for_derives(&derives);
707 assert!(result.is_empty());
708 }
709
710 #[test]
711 fn test_imports_mixed_serde_and_others() {
712 let derives = vec![
713 "Serialize".to_string(),
714 "Hash".to_string(),
715 "Deserialize".to_string(),
716 ];
717 let result = imports_for_derives(&derives);
718 assert_eq!(result.len(), 1);
719 assert!(result[0].contains("Serialize"));
720 assert!(result[0].contains("Deserialize"));
721 }
722
723 #[test]
726 fn test_blank_lines_between_renamed_variants() {
727 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
728 let result = add_blank_lines_between_items(input);
729 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
730 }
731
732 #[test]
733 fn test_no_blank_line_for_first_variant() {
734 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
735 let result = add_blank_lines_between_items(input);
736 assert!(!result.contains("{\n\n"));
738 }
739
740 #[test]
741 fn test_no_change_without_rename() {
742 let input = "pub enum Foo {\n A,\n B,\n}";
743 let result = add_blank_lines_between_items(input);
744 assert_eq!(result, input);
745 }
746
747 #[test]
748 fn test_no_change_for_struct() {
749 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
750 let result = add_blank_lines_between_items(input);
751 assert_eq!(result, input);
752 }
753
754 #[test]
757 fn test_filter_single_file_strips_super_types() {
758 let mut imports = BTreeSet::new();
759 imports.insert("use super::types::Foo;".to_string());
760 imports.insert("use chrono::NaiveDateTime;".to_string());
761 let result = filter_imports(&imports, true);
762 assert!(!result.contains("use super::types::Foo;"));
763 assert!(result.contains("use chrono::NaiveDateTime;"));
764 }
765
766 #[test]
767 fn test_filter_single_file_keeps_other_imports() {
768 let mut imports = BTreeSet::new();
769 imports.insert("use chrono::NaiveDateTime;".to_string());
770 let result = filter_imports(&imports, true);
771 assert!(result.contains("use chrono::NaiveDateTime;"));
772 }
773
774 #[test]
775 fn test_filter_multi_file_keeps_all() {
776 let mut imports = BTreeSet::new();
777 imports.insert("use super::types::Foo;".to_string());
778 imports.insert("use chrono::NaiveDateTime;".to_string());
779 let result = filter_imports(&imports, false);
780 assert_eq!(result.len(), 2);
781 }
782
783 #[test]
784 fn test_filter_empty_set() {
785 let imports = BTreeSet::new();
786 let result = filter_imports(&imports, true);
787 assert!(result.is_empty());
788 }
789
790 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
793 TableInfo {
794 schema_name: "public".to_string(),
795 name: name.to_string(),
796 columns,
797 }
798 }
799
800 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
801 ColumnInfo {
802 name: name.to_string(),
803 data_type: udt_name.to_string(),
804 udt_name: udt_name.to_string(),
805 is_nullable: false,
806 is_primary_key: false,
807 ordinal_position: 0,
808 schema_name: "public".to_string(),
809 column_default: None,
810 }
811 }
812
813 #[test]
814 fn test_generate_empty_schema() {
815 let schema = SchemaInfo::default();
816 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
817 assert!(files.is_empty());
818 }
819
820 #[test]
821 fn test_generate_one_table() {
822 let schema = SchemaInfo {
823 tables: vec![make_table("users", vec![make_col("id", "int4")])],
824 ..Default::default()
825 };
826 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
827 assert_eq!(files.len(), 1);
828 assert_eq!(files[0].filename, "users.rs");
829 }
830
831 #[test]
832 fn test_generate_two_tables() {
833 let schema = SchemaInfo {
834 tables: vec![
835 make_table("users", vec![make_col("id", "int4")]),
836 make_table("posts", vec![make_col("id", "int4")]),
837 ],
838 ..Default::default()
839 };
840 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
841 assert_eq!(files.len(), 2);
842 }
843
844 #[test]
845 fn test_generate_enum_creates_types_file() {
846 let schema = SchemaInfo {
847 enums: vec![EnumInfo {
848 schema_name: "public".to_string(),
849 name: "status".to_string(),
850 variants: vec!["active".to_string(), "inactive".to_string()],
851 default_variant: None,
852 }],
853 ..Default::default()
854 };
855 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
856 assert_eq!(files.len(), 1);
857 assert_eq!(files[0].filename, "types.rs");
858 }
859
860 #[test]
861 fn test_generate_enums_composites_domains_single_types_file() {
862 let schema = SchemaInfo {
863 enums: vec![EnumInfo {
864 schema_name: "public".to_string(),
865 name: "status".to_string(),
866 variants: vec!["active".to_string()],
867 default_variant: None,
868 }],
869 composite_types: vec![CompositeTypeInfo {
870 schema_name: "public".to_string(),
871 name: "address".to_string(),
872 fields: vec![make_col("street", "text")],
873 }],
874 domains: vec![DomainInfo {
875 schema_name: "public".to_string(),
876 name: "email".to_string(),
877 base_type: "text".to_string(),
878 }],
879 ..Default::default()
880 };
881 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
882 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
884 assert_eq!(types_files.len(), 1);
885 }
886
887 #[test]
888 fn test_generate_tables_and_enums() {
889 let schema = SchemaInfo {
890 tables: vec![make_table("users", vec![make_col("id", "int4")])],
891 enums: vec![EnumInfo {
892 schema_name: "public".to_string(),
893 name: "status".to_string(),
894 variants: vec!["active".to_string()],
895 default_variant: None,
896 }],
897 ..Default::default()
898 };
899 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
900 assert_eq!(files.len(), 2); }
902
903 #[test]
904 fn test_generate_filename_normalized() {
905 let schema = SchemaInfo {
906 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
907 ..Default::default()
908 };
909 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
910 assert_eq!(files[0].filename, "user_data.rs");
911 }
912
913 #[test]
914 fn test_generate_no_origin_for_tables() {
915 let schema = SchemaInfo {
916 tables: vec![make_table("users", vec![make_col("id", "int4")])],
917 ..Default::default()
918 };
919 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
920 assert_eq!(files[0].origin, None);
921 }
922
923 #[test]
924 fn test_generate_types_no_origin() {
925 let schema = SchemaInfo {
926 enums: vec![EnumInfo {
927 schema_name: "public".to_string(),
928 name: "status".to_string(),
929 variants: vec!["active".to_string()],
930 default_variant: None,
931 }],
932 ..Default::default()
933 };
934 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
935 assert_eq!(files[0].origin, None);
936 }
937
938 #[test]
939 fn test_generate_single_file_filters_super_types_imports() {
940 let schema = SchemaInfo {
941 tables: vec![make_table("users", vec![make_col("id", "int4")])],
942 enums: vec![EnumInfo {
943 schema_name: "public".to_string(),
944 name: "status".to_string(),
945 variants: vec!["active".to_string()],
946 default_variant: None,
947 }],
948 ..Default::default()
949 };
950 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono);
951 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
953 assert!(!struct_file.code.contains("super::types::"));
954 }
955
956 #[test]
957 fn test_generate_multi_file_keeps_super_types_imports() {
958 let schema = SchemaInfo {
960 tables: vec![make_table("users", vec![make_col("status", "status")])],
961 enums: vec![EnumInfo {
962 schema_name: "public".to_string(),
963 name: "status".to_string(),
964 variants: vec!["active".to_string()],
965 default_variant: None,
966 }],
967 ..Default::default()
968 };
969 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
970 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
971 assert!(struct_file.code.contains("super::types::"));
972 }
973
974 #[test]
975 fn test_generate_extra_derives_in_struct() {
976 let schema = SchemaInfo {
977 tables: vec![make_table("users", vec![make_col("id", "int4")])],
978 ..Default::default()
979 };
980 let derives = vec!["Serialize".to_string()];
981 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono);
982 assert!(files[0].code.contains("Serialize"));
983 }
984
985 #[test]
986 fn test_generate_extra_derives_in_enum() {
987 let schema = SchemaInfo {
988 enums: vec![EnumInfo {
989 schema_name: "public".to_string(),
990 name: "status".to_string(),
991 variants: vec!["active".to_string()],
992 default_variant: None,
993 }],
994 ..Default::default()
995 };
996 let derives = vec!["Serialize".to_string()];
997 let files = generate(&schema, DatabaseKind::Postgres, &derives, &HashMap::new(), false, TimeCrate::Chrono);
998 assert!(files[0].code.contains("Serialize"));
999 }
1000
1001 #[test]
1002 fn test_generate_type_overrides_in_struct() {
1003 let mut overrides = HashMap::new();
1004 overrides.insert("jsonb".to_string(), "MyJson".to_string());
1005 let schema = SchemaInfo {
1006 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1007 ..Default::default()
1008 };
1009 let files = generate(&schema, DatabaseKind::Postgres, &[], &overrides, false, TimeCrate::Chrono);
1010 assert!(files[0].code.contains("MyJson"));
1011 }
1012
1013 #[test]
1014 fn test_generate_valid_rust_syntax() {
1015 let schema = SchemaInfo {
1016 tables: vec![make_table("users", vec![
1017 make_col("id", "int4"),
1018 make_col("name", "text"),
1019 ])],
1020 enums: vec![EnumInfo {
1021 schema_name: "public".to_string(),
1022 name: "status".to_string(),
1023 variants: vec!["active".to_string(), "inactive".to_string()],
1024 default_variant: None,
1025 }],
1026 ..Default::default()
1027 };
1028 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1029 for f in &files {
1030 let parse_result = syn::parse_file(&f.code);
1032 assert!(parse_result.is_ok(), "Failed to parse {}: {:?}", f.filename, parse_result.err());
1033 }
1034 }
1035
1036 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1039 TableInfo {
1040 schema_name: "public".to_string(),
1041 name: name.to_string(),
1042 columns,
1043 }
1044 }
1045
1046 #[test]
1047 fn test_generate_one_view() {
1048 let schema = SchemaInfo {
1049 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1050 ..Default::default()
1051 };
1052 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1053 assert_eq!(files.len(), 1);
1054 assert_eq!(files[0].filename, "active_users.rs");
1055 }
1056
1057 #[test]
1058 fn test_generate_no_origin_for_views() {
1059 let schema = SchemaInfo {
1060 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1061 ..Default::default()
1062 };
1063 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1064 assert_eq!(files[0].origin, None);
1065 }
1066
1067 #[test]
1068 fn test_generate_tables_and_views() {
1069 let schema = SchemaInfo {
1070 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1071 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1072 ..Default::default()
1073 };
1074 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1075 assert_eq!(files.len(), 2);
1076 }
1077
1078 #[test]
1079 fn test_generate_view_valid_rust() {
1080 let schema = SchemaInfo {
1081 views: vec![make_view("active_users", vec![
1082 make_col("id", "int4"),
1083 make_col("name", "text"),
1084 ])],
1085 ..Default::default()
1086 };
1087 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1088 let parse_result = syn::parse_file(&files[0].code);
1089 assert!(parse_result.is_ok(), "Failed to parse: {:?}", parse_result.err());
1090 }
1091
1092 #[test]
1093 fn test_generate_view_nullable_column() {
1094 let schema = SchemaInfo {
1095 views: vec![make_view("v", vec![ColumnInfo {
1096 name: "email".to_string(),
1097 data_type: "text".to_string(),
1098 udt_name: "text".to_string(),
1099 is_nullable: true,
1100 is_primary_key: false,
1101 ordinal_position: 0,
1102 schema_name: "public".to_string(),
1103 column_default: None,
1104 }])],
1105 ..Default::default()
1106 };
1107 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1108 assert!(files[0].code.contains("Option<String>"));
1109 }
1110
1111 #[test]
1112 fn test_generate_collision_both_prefixed() {
1113 let schema = SchemaInfo {
1114 tables: vec![
1115 make_table("users", vec![make_col("id", "int4")]),
1116 TableInfo {
1117 schema_name: "billing".to_string(),
1118 name: "users".to_string(),
1119 columns: vec![make_col("id", "int4")],
1120 },
1121 ],
1122 ..Default::default()
1123 };
1124 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1125 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1126 assert!(filenames.contains(&"users.rs"));
1127 assert!(filenames.contains(&"billing_users.rs"));
1128 }
1129
1130 #[test]
1131 fn test_generate_no_collision_no_prefix() {
1132 let schema = SchemaInfo {
1133 tables: vec![
1134 make_table("users", vec![make_col("id", "int4")]),
1135 TableInfo {
1136 schema_name: "billing".to_string(),
1137 name: "invoices".to_string(),
1138 columns: vec![make_col("id", "int4")],
1139 },
1140 ],
1141 ..Default::default()
1142 };
1143 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1144 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1145 assert!(filenames.contains(&"users.rs"));
1146 assert!(filenames.contains(&"invoices.rs"));
1147 }
1148
1149 #[test]
1150 fn test_generate_single_schema_no_prefix() {
1151 let schema = SchemaInfo {
1152 tables: vec![
1153 make_table("users", vec![make_col("id", "int4")]),
1154 make_table("posts", vec![make_col("id", "int4")]),
1155 ],
1156 ..Default::default()
1157 };
1158 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1159 assert_eq!(files[0].filename, "users.rs");
1160 assert_eq!(files[1].filename, "posts.rs");
1161 }
1162
1163 #[test]
1164 fn test_generate_view_single_file_mode() {
1165 let schema = SchemaInfo {
1166 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1167 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1168 ..Default::default()
1169 };
1170 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), true, TimeCrate::Chrono);
1171 assert_eq!(files.len(), 2);
1172 }
1173
1174 #[test]
1177 fn test_parse_pg_enum_default_simple() {
1178 assert_eq!(
1179 parse_pg_enum_default("'idle'::task_status"),
1180 Some("idle".to_string())
1181 );
1182 }
1183
1184 #[test]
1185 fn test_parse_pg_enum_default_schema_qualified() {
1186 assert_eq!(
1187 parse_pg_enum_default("'active'::public.task_status"),
1188 Some("active".to_string())
1189 );
1190 }
1191
1192 #[test]
1193 fn test_parse_pg_enum_default_not_enum() {
1194 assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1196 }
1197
1198 #[test]
1199 fn test_parse_pg_enum_default_no_cast() {
1200 assert_eq!(parse_pg_enum_default("'hello'"), None);
1201 }
1202
1203 #[test]
1204 fn test_parse_pg_enum_default_empty() {
1205 assert_eq!(parse_pg_enum_default(""), None);
1206 }
1207
1208 #[test]
1211 fn test_extract_enum_defaults_from_column() {
1212 let schema = SchemaInfo {
1213 tables: vec![TableInfo {
1214 schema_name: "public".to_string(),
1215 name: "tasks".to_string(),
1216 columns: vec![ColumnInfo {
1217 name: "status".to_string(),
1218 data_type: "USER-DEFINED".to_string(),
1219 udt_name: "task_status".to_string(),
1220 is_nullable: false,
1221 is_primary_key: false,
1222 ordinal_position: 0,
1223 schema_name: "public".to_string(),
1224 column_default: Some("'idle'::task_status".to_string()),
1225 }],
1226 }],
1227 enums: vec![EnumInfo {
1228 schema_name: "public".to_string(),
1229 name: "task_status".to_string(),
1230 variants: vec!["idle".to_string(), "running".to_string()],
1231 default_variant: None,
1232 }],
1233 ..Default::default()
1234 };
1235 let defaults = extract_enum_defaults(&schema);
1236 assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1237 }
1238
1239 #[test]
1240 fn test_extract_enum_defaults_no_default() {
1241 let schema = SchemaInfo {
1242 tables: vec![TableInfo {
1243 schema_name: "public".to_string(),
1244 name: "tasks".to_string(),
1245 columns: vec![ColumnInfo {
1246 name: "status".to_string(),
1247 data_type: "USER-DEFINED".to_string(),
1248 udt_name: "task_status".to_string(),
1249 is_nullable: false,
1250 is_primary_key: false,
1251 ordinal_position: 0,
1252 schema_name: "public".to_string(),
1253 column_default: None,
1254 }],
1255 }],
1256 enums: vec![EnumInfo {
1257 schema_name: "public".to_string(),
1258 name: "task_status".to_string(),
1259 variants: vec!["idle".to_string()],
1260 default_variant: None,
1261 }],
1262 ..Default::default()
1263 };
1264 let defaults = extract_enum_defaults(&schema);
1265 assert!(defaults.is_empty());
1266 }
1267
1268 #[test]
1269 fn test_extract_enum_defaults_non_enum_column_ignored() {
1270 let schema = SchemaInfo {
1271 tables: vec![TableInfo {
1272 schema_name: "public".to_string(),
1273 name: "users".to_string(),
1274 columns: vec![ColumnInfo {
1275 name: "name".to_string(),
1276 data_type: "character varying".to_string(),
1277 udt_name: "varchar".to_string(),
1278 is_nullable: false,
1279 is_primary_key: false,
1280 ordinal_position: 0,
1281 schema_name: "public".to_string(),
1282 column_default: Some("'hello'::character varying".to_string()),
1283 }],
1284 }],
1285 enums: vec![],
1286 ..Default::default()
1287 };
1288 let defaults = extract_enum_defaults(&schema);
1289 assert!(defaults.is_empty());
1290 }
1291
1292 #[test]
1293 fn test_generate_enum_with_default() {
1294 let schema = SchemaInfo {
1295 tables: vec![TableInfo {
1296 schema_name: "public".to_string(),
1297 name: "tasks".to_string(),
1298 columns: vec![ColumnInfo {
1299 name: "status".to_string(),
1300 data_type: "USER-DEFINED".to_string(),
1301 udt_name: "task_status".to_string(),
1302 is_nullable: false,
1303 is_primary_key: false,
1304 ordinal_position: 0,
1305 schema_name: "public".to_string(),
1306 column_default: Some("'idle'::task_status".to_string()),
1307 }],
1308 }],
1309 enums: vec![EnumInfo {
1310 schema_name: "public".to_string(),
1311 name: "task_status".to_string(),
1312 variants: vec!["idle".to_string(), "running".to_string()],
1313 default_variant: None,
1314 }],
1315 ..Default::default()
1316 };
1317 let files = generate(&schema, DatabaseKind::Postgres, &[], &HashMap::new(), false, TimeCrate::Chrono);
1318 let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1319 assert!(types_file.code.contains("impl Default for TaskStatus"));
1320 assert!(types_file.code.contains("Self::Idle"));
1321 }
1322}