1pub mod composite_gen;
2pub mod crud_gen;
3pub mod domain_gen;
4pub mod entity_parser;
5pub mod enum_gen;
6pub mod identifiers;
7pub mod struct_gen;
8
9use std::collections::{BTreeSet, HashMap};
10use std::path::Path;
11
12use proc_macro2::TokenStream;
13
14use crate::cli::{DatabaseKind, TimeCrate};
15use crate::introspect::SchemaInfo;
16
17const RUST_KEYWORDS: &[&str] = &[
19 "as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum", "extern",
20 "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move", "mut", "pub",
21 "ref", "return", "self", "Self", "static", "struct", "super", "trait", "true", "type",
22 "unsafe", "use", "where", "while", "yield", "abstract", "become", "box", "do", "final",
23 "macro", "override", "priv", "try", "typeof", "unsized", "virtual",
24];
25
26pub fn is_rust_keyword(name: &str) -> bool {
28 RUST_KEYWORDS.contains(&name)
29}
30
31pub fn imports_for_derives(extra_derives: &[String]) -> Vec<String> {
33 let mut imports = Vec::new();
34 let has = |name: &str| extra_derives.iter().any(|d| d == name);
35 if has("Serialize") || has("Deserialize") {
36 let mut parts = Vec::new();
37 if has("Serialize") {
38 parts.push("Serialize");
39 }
40 if has("Deserialize") {
41 parts.push("Deserialize");
42 }
43 imports.push(format!("use serde::{{{}}};", parts.join(", ")));
44 }
45 imports
46}
47
48pub fn normalize_module_name(name: &str) -> String {
51 let mut result = String::with_capacity(name.len());
52 let mut prev_underscore = false;
53 for c in name.chars() {
54 if c == '_' {
55 if !prev_underscore {
56 result.push(c);
57 }
58 prev_underscore = true;
59 } else {
60 prev_underscore = false;
61 result.push(c);
62 }
63 }
64 result
65}
66
67const DEFAULT_SCHEMAS: &[&str] = &["public", "main", "dbo"];
69
70pub fn is_default_schema(schema: &str) -> bool {
72 DEFAULT_SCHEMAS.contains(&schema)
73}
74
75pub fn rust_type_name_for(schema_info: &SchemaInfo, schema: &str, name: &str) -> String {
82 use heck::ToUpperCamelCase;
83 if type_name_has_cross_schema_collision(schema_info, name) && !is_default_schema(schema) {
84 format!(
85 "{}{}",
86 schema.to_upper_camel_case(),
87 name.to_upper_camel_case()
88 )
89 } else {
90 name.to_upper_camel_case()
91 }
92}
93
94pub fn required_pg_search_path(schema_info: &SchemaInfo) -> Vec<String> {
107 let mut schemas: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
108 for e in &schema_info.enums {
109 if !is_default_schema(&e.schema_name) {
110 schemas.insert(e.schema_name.clone());
111 }
112 }
113 for c in &schema_info.composite_types {
114 if !is_default_schema(&c.schema_name) {
115 schemas.insert(c.schema_name.clone());
116 }
117 }
118 for d in &schema_info.domains {
119 if !is_default_schema(&d.schema_name) {
120 schemas.insert(d.schema_name.clone());
121 }
122 }
123 schemas.into_iter().collect()
124}
125
126pub fn type_name_has_cross_schema_collision(schema_info: &SchemaInfo, name: &str) -> bool {
129 let mut schemas: std::collections::BTreeSet<&str> = std::collections::BTreeSet::new();
130 schemas.extend(
131 schema_info
132 .enums
133 .iter()
134 .filter(|e| e.name == name)
135 .map(|e| e.schema_name.as_str()),
136 );
137 schemas.extend(
138 schema_info
139 .composite_types
140 .iter()
141 .filter(|c| c.name == name)
142 .map(|c| c.schema_name.as_str()),
143 );
144 schemas.extend(
145 schema_info
146 .domains
147 .iter()
148 .filter(|d| d.name == name)
149 .map(|d| d.schema_name.as_str()),
150 );
151 schemas.len() > 1
152}
153
154pub fn build_module_name(schema_name: &str, table_name: &str, name_collides: bool) -> String {
157 if name_collides && !is_default_schema(schema_name) {
158 normalize_module_name(&format!("{}_{}", schema_name, table_name))
159 } else {
160 normalize_module_name(table_name)
161 }
162}
163
164fn find_colliding_names(schema_info: &SchemaInfo) -> BTreeSet<&str> {
166 let mut seen: HashMap<&str, BTreeSet<&str>> = HashMap::new();
167 for t in &schema_info.tables {
168 seen.entry(t.name.as_str())
169 .or_default()
170 .insert(t.schema_name.as_str());
171 }
172 for v in &schema_info.views {
173 seen.entry(v.name.as_str())
174 .or_default()
175 .insert(v.schema_name.as_str());
176 }
177 seen.into_iter()
178 .filter(|(_, schemas)| schemas.len() > 1)
179 .map(|(name, _)| name)
180 .collect()
181}
182
183#[derive(Debug, Clone)]
185pub struct GeneratedFile {
186 pub filename: String,
187 pub origin: Option<String>,
189 pub code: String,
190}
191
192pub fn generate(
194 schema_info: &SchemaInfo,
195 db_kind: DatabaseKind,
196 extra_derives: &[String],
197 type_overrides: &HashMap<String, String>,
198 single_file: bool,
199 time_crate: TimeCrate,
200) -> crate::error::Result<Vec<GeneratedFile>> {
201 generate_with_domain_style(
202 schema_info,
203 db_kind,
204 extra_derives,
205 type_overrides,
206 single_file,
207 time_crate,
208 crate::cli::DomainStyle::Alias,
209 )
210}
211
212pub fn generate_with_domain_style(
215 schema_info: &SchemaInfo,
216 db_kind: DatabaseKind,
217 extra_derives: &[String],
218 type_overrides: &HashMap<String, String>,
219 single_file: bool,
220 time_crate: TimeCrate,
221 domain_style: crate::cli::DomainStyle,
222) -> crate::error::Result<Vec<GeneratedFile>> {
223 let mut files = Vec::new();
224
225 let colliding_names = find_colliding_names(schema_info);
227
228 for table in &schema_info.tables {
230 let (tokens, imports) = struct_gen::generate_struct(
231 table,
232 db_kind,
233 schema_info,
234 extra_derives,
235 type_overrides,
236 false,
237 time_crate,
238 );
239 let imports = filter_imports(&imports, single_file);
240 let code = format_tokens_with_imports(&tokens, &imports)?;
241 let module_name = build_module_name(
242 &table.schema_name,
243 &table.name,
244 colliding_names.contains(table.name.as_str()),
245 );
246 files.push(GeneratedFile {
247 filename: format!("{}.rs", module_name),
248 origin: None,
249 code,
250 });
251 }
252
253 for view in &schema_info.views {
255 let (tokens, imports) = struct_gen::generate_struct(
256 view,
257 db_kind,
258 schema_info,
259 extra_derives,
260 type_overrides,
261 true,
262 time_crate,
263 );
264 let imports = filter_imports(&imports, single_file);
265 let code = format_tokens_with_imports(&tokens, &imports)?;
266 let module_name = build_module_name(
267 &view.schema_name,
268 &view.name,
269 colliding_names.contains(view.name.as_str()),
270 );
271 files.push(GeneratedFile {
272 filename: format!("{}.rs", module_name),
273 origin: None,
274 code,
275 });
276 }
277
278 let mut types_blocks: Vec<String> = Vec::new();
281 let mut types_imports = BTreeSet::new();
282
283 let enum_defaults = extract_enum_defaults(schema_info);
285 for enum_info in &schema_info.enums {
286 enum_gen::check_variant_collisions(enum_info)?;
287 let mut enriched = enum_info.clone();
288 if enriched.default_variant.is_none() {
289 if let Some(default) = enum_defaults.get(&enum_info.name) {
290 enriched.default_variant = Some(default.clone());
291 }
292 }
293 let (tokens, imports) =
294 enum_gen::generate_enum_with_schema(&enriched, db_kind, extra_derives, schema_info);
295 types_blocks.push(format_tokens(&tokens)?);
296 types_imports.extend(imports);
297 }
298
299 for composite in &schema_info.composite_types {
300 let (tokens, imports) = composite_gen::generate_composite(
301 composite,
302 db_kind,
303 schema_info,
304 extra_derives,
305 type_overrides,
306 time_crate,
307 );
308 types_blocks.push(format_tokens(&tokens)?);
309 types_imports.extend(imports);
310 }
311
312 for domain in &schema_info.domains {
313 let (tokens, imports) = domain_gen::generate_domain_with_style(
314 domain,
315 db_kind,
316 schema_info,
317 type_overrides,
318 time_crate,
319 domain_style,
320 );
321 types_blocks.push(format_tokens(&tokens)?);
322 types_imports.extend(imports);
323 }
324
325 if !types_blocks.is_empty() {
326 let import_lines: String = types_imports.iter().map(|i| format!("{}\n", i)).collect();
327 let body = types_blocks.join("\n");
328 let code = if import_lines.is_empty() {
329 body
330 } else {
331 format!("{}\n\n{}", import_lines.trim_end(), body)
332 };
333 files.push(GeneratedFile {
334 filename: "types.rs".to_string(),
335 origin: None,
336 code,
337 });
338 }
339
340 Ok(files)
341}
342
343fn extract_enum_defaults(schema_info: &SchemaInfo) -> HashMap<String, String> {
346 let mut defaults: HashMap<String, String> = HashMap::new();
347
348 let all_columns = schema_info
349 .tables
350 .iter()
351 .chain(schema_info.views.iter())
352 .flat_map(|t| t.columns.iter());
353
354 for col in all_columns {
355 let default_expr = match &col.column_default {
356 Some(d) => d,
357 None => continue,
358 };
359
360 let base_udt = col.udt_name.strip_prefix('_').unwrap_or(&col.udt_name);
362
363 let enum_match = schema_info.enums.iter().find(|e| e.name == base_udt);
365 if enum_match.is_none() {
366 continue;
367 }
368
369 if let Some(variant) = parse_pg_enum_default(default_expr) {
371 defaults.entry(base_udt.to_string()).or_insert(variant);
372 }
373 }
374
375 defaults
376}
377
378fn parse_pg_enum_default(default_expr: &str) -> Option<String> {
381 let after_opening = default_expr.trim().strip_prefix('\'')?;
383 let end_quote = after_opening.find('\'')?;
384 let value = &after_opening[..end_quote];
385 let rest = &after_opening[end_quote + 1..];
386 if rest.starts_with("::") {
387 return Some(value.to_string());
388 }
389 None
390}
391
392fn filter_imports(imports: &BTreeSet<String>, single_file: bool) -> BTreeSet<String> {
394 if single_file {
395 imports
396 .iter()
397 .filter(|i| !i.contains("super::types::"))
398 .cloned()
399 .collect()
400 } else {
401 imports.clone()
402 }
403}
404
405pub fn detect_tab_spaces(start_dir: &Path) -> usize {
408 let mut dir = if start_dir.is_file() {
409 start_dir.parent().unwrap_or(start_dir)
410 } else {
411 start_dir
412 };
413 loop {
414 for name in &["rustfmt.toml", ".rustfmt.toml"] {
415 let candidate = dir.join(name);
416 if let Ok(content) = std::fs::read_to_string(&candidate) {
417 for line in content.lines() {
418 let line = line.trim();
419 if let Some(rest) = line.strip_prefix("tab_spaces") {
420 let rest = rest.trim_start().strip_prefix('=').unwrap_or(rest);
421 if let Ok(n) = rest.trim().parse::<usize>() {
422 return n;
423 }
424 }
425 }
426 return 4;
428 }
429 }
430 match dir.parent() {
431 Some(parent) => dir = parent,
432 None => return 4,
433 }
434 }
435}
436
437pub(crate) fn parse_and_format(tokens: &TokenStream) -> crate::error::Result<String> {
440 parse_and_format_with_tab_spaces(tokens, 4)
441}
442
443pub(crate) fn parse_and_format_with_tab_spaces(
444 tokens: &TokenStream,
445 tab_spaces: usize,
446) -> crate::error::Result<String> {
447 let file = syn::parse2::<syn::File>(tokens.clone()).map_err(|e| {
448 crate::error::Error::Config(format!(
449 "Internal sqlx-gen bug: failed to parse generated code: {}. \
450 Raw tokens:\n {}\n\
451 Please report this with the input schema.",
452 e, tokens
453 ))
454 })?;
455 let raw = prettyplease::unparse(&file);
456 let raw = indent_multiline_raw_strings(&raw, tab_spaces);
457 Ok(add_blank_lines_between_items(&raw))
458}
459
460pub(crate) fn format_tokens(tokens: &TokenStream) -> crate::error::Result<String> {
462 parse_and_format(tokens)
463}
464
465pub fn format_tokens_with_imports(
466 tokens: &TokenStream,
467 imports: &BTreeSet<String>,
468) -> crate::error::Result<String> {
469 format_tokens_with_imports_and_tab_spaces(tokens, imports, 4)
470}
471
472pub fn format_tokens_with_imports_and_tab_spaces(
473 tokens: &TokenStream,
474 imports: &BTreeSet<String>,
475 tab_spaces: usize,
476) -> crate::error::Result<String> {
477 let formatted = parse_and_format_with_tab_spaces(tokens, tab_spaces)?;
478
479 let used_imports: Vec<&String> = imports
480 .iter()
481 .filter(|imp| is_import_used(imp, &formatted))
482 .collect();
483
484 if used_imports.is_empty() {
485 Ok(formatted)
486 } else {
487 let import_lines: String = used_imports.iter().map(|i| format!("{}\n", i)).collect();
488 Ok(format!("{}\n\n{}", import_lines.trim_end(), formatted))
489 }
490}
491
492fn is_import_used(import: &str, code: &str) -> bool {
495 let trimmed = import.trim().trim_end_matches(';');
499 let path = trimmed.strip_prefix("use ").unwrap_or(trimmed);
500
501 if path.ends_with("::*") {
502 return true;
503 }
504
505 if let Some(start) = path.find('{') {
507 if let Some(end) = path.find('}') {
508 let names = &path[start + 1..end];
509 return names
510 .split(',')
511 .map(|n| n.trim())
512 .filter(|n| !n.is_empty())
513 .any(|name| code.contains(name));
514 }
515 }
516
517 if let Some(name) = path.rsplit("::").next() {
519 return code.contains(name);
520 }
521
522 true
523}
524
525fn indent_multiline_raw_strings(code: &str, tab_spaces: usize) -> String {
531 let close_indent = 4 + tab_spaces; let sql_indent = 4 + 2 * tab_spaces; let lines: Vec<&str> = code.lines().collect();
539 let mut result = Vec::with_capacity(lines.len());
540 let mut inside_raw = false;
541 let mut raw_lines: Vec<&str> = Vec::new();
542
543 for line in &lines {
544 if !inside_raw {
545 if let Some(pos) = line.find("r#\"") {
546 let after = &line[pos + 3..];
547 if !after.contains("\"#") {
548 inside_raw = true;
549 raw_lines.clear();
550 }
551 }
552 result.push(line.to_string());
553 } else if line.trim_start().starts_with("\"#") {
554 let min_indent = raw_lines
556 .iter()
557 .filter(|l| !l.trim().is_empty())
558 .map(|l| l.len() - l.trim_start().len())
559 .min()
560 .unwrap_or(0);
561 for raw_line in &raw_lines {
562 let trimmed = raw_line.trim();
563 if trimmed.is_empty() {
564 result.push(String::new());
565 } else {
566 let original_indent = raw_line.len() - raw_line.trim_start().len();
567 let relative = original_indent.saturating_sub(min_indent);
568 result.push(format!(
569 "{}{}{}",
570 " ".repeat(sql_indent),
571 " ".repeat(relative),
572 trimmed
573 ));
574 }
575 }
576 let trimmed = line.trim();
578 result.push(format!("{}{}", " ".repeat(close_indent), trimmed));
579 inside_raw = false;
580 } else {
581 raw_lines.push(line);
582 }
583 }
584
585 result.join("\n")
586}
587
588fn add_blank_lines_between_items(code: &str) -> String {
589 let lines: Vec<&str> = code.lines().collect();
590 let mut result = Vec::with_capacity(lines.len());
591
592 for (i, line) in lines.iter().enumerate() {
593 if i > 0 && line.trim().starts_with("#[sqlx(rename") {
596 let prev = lines[i - 1].trim();
597 if prev.ends_with(',') {
598 result.push("");
599 }
600 }
601
602 if i > 0 {
605 let trimmed = line.trim();
606 let prev = lines[i - 1].trim();
607 if prev == "}"
608 && (trimmed.starts_with("pub struct")
609 || trimmed.starts_with("impl ")
610 || trimmed.starts_with("#[derive")
611 || trimmed.starts_with("pub async fn")
612 || trimmed.starts_with("pub fn"))
613 {
614 result.push("");
615 }
616 }
617
618 if i > 0 {
622 let trimmed = line.trim();
623 let prev = lines[i - 1].trim();
624 let prev_is_await_end = prev.ends_with(".await?;")
625 || prev.ends_with(".await?")
626 || (prev.ends_with(';') && prev.contains(".unwrap_or("));
627 if prev_is_await_end && (trimmed.starts_with("let ") || trimmed.starts_with("Ok(")) {
628 result.push("");
629 }
630 if trimmed.starts_with("let ")
632 && trimmed.contains("sqlx::")
633 && prev.starts_with("let ")
634 && !prev.contains("sqlx::")
635 {
636 result.push("");
637 }
638 }
639
640 result.push(line);
641 }
642
643 result.join("\n")
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use crate::introspect::{
650 ColumnInfo, CompositeTypeInfo, DomainInfo, EnumInfo, SchemaInfo, TableInfo,
651 };
652 use std::collections::HashMap;
653
654 #[test]
657 fn test_keyword_type() {
658 assert!(is_rust_keyword("type"));
659 }
660
661 #[test]
662 fn test_keyword_fn() {
663 assert!(is_rust_keyword("fn"));
664 }
665
666 #[test]
667 fn test_keyword_let() {
668 assert!(is_rust_keyword("let"));
669 }
670
671 #[test]
672 fn test_keyword_match() {
673 assert!(is_rust_keyword("match"));
674 }
675
676 #[test]
677 fn test_keyword_async() {
678 assert!(is_rust_keyword("async"));
679 }
680
681 #[test]
682 fn test_keyword_await() {
683 assert!(is_rust_keyword("await"));
684 }
685
686 #[test]
687 fn test_keyword_yield() {
688 assert!(is_rust_keyword("yield"));
689 }
690
691 #[test]
692 fn test_keyword_abstract() {
693 assert!(is_rust_keyword("abstract"));
694 }
695
696 #[test]
697 fn test_keyword_try() {
698 assert!(is_rust_keyword("try"));
699 }
700
701 #[test]
702 fn test_not_keyword_name() {
703 assert!(!is_rust_keyword("name"));
704 }
705
706 #[test]
707 fn test_not_keyword_id() {
708 assert!(!is_rust_keyword("id"));
709 }
710
711 #[test]
712 fn test_not_keyword_uppercase_type() {
713 assert!(!is_rust_keyword("Type"));
714 }
715
716 #[test]
719 fn test_normalize_no_underscores() {
720 assert_eq!(normalize_module_name("users"), "users");
721 }
722
723 #[test]
724 fn test_normalize_single_underscore() {
725 assert_eq!(normalize_module_name("user_roles"), "user_roles");
726 }
727
728 #[test]
729 fn test_normalize_double_underscore() {
730 assert_eq!(normalize_module_name("user__roles"), "user_roles");
731 }
732
733 #[test]
734 fn test_normalize_triple_underscore() {
735 assert_eq!(normalize_module_name("a___b"), "a_b");
736 }
737
738 #[test]
739 fn test_normalize_leading_underscore() {
740 assert_eq!(normalize_module_name("_private"), "_private");
741 }
742
743 #[test]
744 fn test_normalize_trailing_underscore() {
745 assert_eq!(normalize_module_name("name_"), "name_");
746 }
747
748 #[test]
749 fn test_normalize_double_leading() {
750 assert_eq!(normalize_module_name("__double_leading"), "_double_leading");
751 }
752
753 #[test]
754 fn test_normalize_multiple_groups() {
755 assert_eq!(normalize_module_name("a__b__c"), "a_b_c");
756 }
757
758 #[test]
761 fn test_build_no_collision_no_prefix() {
762 assert_eq!(build_module_name("public", "users", false), "users");
763 }
764
765 #[test]
766 fn test_build_no_collision_non_default_no_prefix() {
767 assert_eq!(build_module_name("billing", "invoices", false), "invoices");
768 }
769
770 #[test]
771 fn test_build_collision_prefixed() {
772 assert_eq!(build_module_name("billing", "users", true), "billing_users");
773 }
774
775 #[test]
776 fn test_build_collision_default_schema_no_prefix() {
777 assert_eq!(build_module_name("public", "users", true), "users");
778 }
779
780 #[test]
781 fn test_build_collision_normalizes_double_underscore() {
782 assert_eq!(
783 build_module_name("billing", "agent__connector", true),
784 "billing_agent_connector"
785 );
786 }
787
788 #[test]
791 fn test_default_schema_public() {
792 assert!(is_default_schema("public"));
793 }
794
795 #[test]
796 fn test_default_schema_main() {
797 assert!(is_default_schema("main"));
798 }
799
800 #[test]
801 fn test_non_default_schema() {
802 assert!(!is_default_schema("billing"));
803 }
804
805 #[test]
808 fn test_imports_empty() {
809 let result = imports_for_derives(&[]);
810 assert!(result.is_empty());
811 }
812
813 #[test]
814 fn test_imports_serialize_only() {
815 let derives = vec!["Serialize".to_string()];
816 let result = imports_for_derives(&derives);
817 assert_eq!(result, vec!["use serde::{Serialize};"]);
818 }
819
820 #[test]
821 fn test_imports_deserialize_only() {
822 let derives = vec!["Deserialize".to_string()];
823 let result = imports_for_derives(&derives);
824 assert_eq!(result, vec!["use serde::{Deserialize};"]);
825 }
826
827 #[test]
828 fn test_imports_both_serde() {
829 let derives = vec!["Serialize".to_string(), "Deserialize".to_string()];
830 let result = imports_for_derives(&derives);
831 assert_eq!(result, vec!["use serde::{Serialize, Deserialize};"]);
832 }
833
834 #[test]
835 fn test_imports_non_serde() {
836 let derives = vec!["Hash".to_string()];
837 let result = imports_for_derives(&derives);
838 assert!(result.is_empty());
839 }
840
841 #[test]
842 fn test_imports_non_serde_multiple() {
843 let derives = vec!["PartialEq".to_string(), "Eq".to_string()];
844 let result = imports_for_derives(&derives);
845 assert!(result.is_empty());
846 }
847
848 #[test]
849 fn test_imports_mixed_serde_and_others() {
850 let derives = vec![
851 "Serialize".to_string(),
852 "Hash".to_string(),
853 "Deserialize".to_string(),
854 ];
855 let result = imports_for_derives(&derives);
856 assert_eq!(result.len(), 1);
857 assert!(result[0].contains("Serialize"));
858 assert!(result[0].contains("Deserialize"));
859 }
860
861 #[test]
864 fn test_blank_lines_between_renamed_variants() {
865 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n #[sqlx(rename = \"b\")]\n B,\n}";
866 let result = add_blank_lines_between_items(input);
867 assert!(result.contains("A,\n\n #[sqlx(rename = \"b\")]"));
868 }
869
870 #[test]
871 fn test_no_blank_line_for_first_variant() {
872 let input = "pub enum Foo {\n #[sqlx(rename = \"a\")]\n A,\n}";
873 let result = add_blank_lines_between_items(input);
874 assert!(!result.contains("{\n\n"));
876 }
877
878 #[test]
879 fn test_no_change_without_rename() {
880 let input = "pub enum Foo {\n A,\n B,\n}";
881 let result = add_blank_lines_between_items(input);
882 assert_eq!(result, input);
883 }
884
885 #[test]
886 fn test_no_change_for_struct() {
887 let input = "pub struct Foo {\n pub a: i32,\n pub b: String,\n}";
888 let result = add_blank_lines_between_items(input);
889 assert_eq!(result, input);
890 }
891
892 fn schema_with_two_role_enums() -> SchemaInfo {
895 SchemaInfo {
896 enums: vec![
897 crate::introspect::EnumInfo {
898 schema_name: "auth".into(),
899 name: "role".into(),
900 variants: vec!["admin".into(), "user".into()],
901 default_variant: None,
902 },
903 crate::introspect::EnumInfo {
904 schema_name: "billing".into(),
905 name: "role".into(),
906 variants: vec!["payer".into(), "payee".into()],
907 default_variant: None,
908 },
909 ],
910 ..Default::default()
911 }
912 }
913
914 #[test]
915 fn rust_type_name_prefixes_schema_on_cross_schema_collision() {
916 let s = schema_with_two_role_enums();
917 assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
918 assert_eq!(rust_type_name_for(&s, "billing", "role"), "BillingRole");
919 }
920
921 #[test]
922 fn rust_type_name_keeps_bare_name_when_unique() {
923 let s = SchemaInfo {
924 enums: vec![crate::introspect::EnumInfo {
925 schema_name: "auth".into(),
926 name: "role".into(),
927 variants: vec!["admin".into()],
928 default_variant: None,
929 }],
930 ..Default::default()
931 };
932 assert_eq!(rust_type_name_for(&s, "auth", "role"), "Role");
933 }
934
935 #[test]
936 fn required_search_path_collects_non_default_schemas() {
937 let s = SchemaInfo {
938 enums: vec![
939 crate::introspect::EnumInfo {
940 schema_name: "auth".into(),
941 name: "role".into(),
942 variants: vec!["x".into()],
943 default_variant: None,
944 },
945 crate::introspect::EnumInfo {
946 schema_name: "public".into(),
947 name: "status".into(),
948 variants: vec!["y".into()],
949 default_variant: None,
950 },
951 ],
952 composite_types: vec![crate::introspect::CompositeTypeInfo {
953 schema_name: "billing".into(),
954 name: "addr".into(),
955 fields: vec![],
956 }],
957 domains: vec![crate::introspect::DomainInfo {
958 schema_name: "auth".into(),
959 name: "email".into(),
960 base_type: "text".into(),
961 }],
962 ..Default::default()
963 };
964 assert_eq!(required_pg_search_path(&s), vec!["auth", "billing"]);
966 }
967
968 #[test]
969 fn required_search_path_empty_when_only_default_schema() {
970 let s = SchemaInfo {
971 enums: vec![crate::introspect::EnumInfo {
972 schema_name: "public".into(),
973 name: "status".into(),
974 variants: vec!["y".into()],
975 default_variant: None,
976 }],
977 ..Default::default()
978 };
979 assert!(required_pg_search_path(&s).is_empty());
980 }
981
982 #[test]
983 fn rust_type_name_default_schema_keeps_bare_name_even_on_collision() {
984 let s = SchemaInfo {
985 enums: vec![
986 crate::introspect::EnumInfo {
987 schema_name: "public".into(),
988 name: "role".into(),
989 variants: vec!["a".into()],
990 default_variant: None,
991 },
992 crate::introspect::EnumInfo {
993 schema_name: "auth".into(),
994 name: "role".into(),
995 variants: vec!["b".into()],
996 default_variant: None,
997 },
998 ],
999 ..Default::default()
1000 };
1001 assert_eq!(rust_type_name_for(&s, "public", "role"), "Role");
1003 assert_eq!(rust_type_name_for(&s, "auth", "role"), "AuthRole");
1004 }
1005
1006 #[test]
1009 fn test_filter_single_file_strips_super_types() {
1010 let mut imports = BTreeSet::new();
1011 imports.insert("use super::types::Foo;".to_string());
1012 imports.insert("use chrono::NaiveDateTime;".to_string());
1013 let result = filter_imports(&imports, true);
1014 assert!(!result.contains("use super::types::Foo;"));
1015 assert!(result.contains("use chrono::NaiveDateTime;"));
1016 }
1017
1018 #[test]
1019 fn test_filter_single_file_keeps_other_imports() {
1020 let mut imports = BTreeSet::new();
1021 imports.insert("use chrono::NaiveDateTime;".to_string());
1022 let result = filter_imports(&imports, true);
1023 assert!(result.contains("use chrono::NaiveDateTime;"));
1024 }
1025
1026 #[test]
1027 fn test_filter_multi_file_keeps_all() {
1028 let mut imports = BTreeSet::new();
1029 imports.insert("use super::types::Foo;".to_string());
1030 imports.insert("use chrono::NaiveDateTime;".to_string());
1031 let result = filter_imports(&imports, false);
1032 assert_eq!(result.len(), 2);
1033 }
1034
1035 #[test]
1036 fn test_filter_empty_set() {
1037 let imports = BTreeSet::new();
1038 let result = filter_imports(&imports, true);
1039 assert!(result.is_empty());
1040 }
1041
1042 fn make_table(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1045 TableInfo {
1046 schema_name: "public".to_string(),
1047 name: name.to_string(),
1048 columns,
1049 }
1050 }
1051
1052 fn make_col(name: &str, udt_name: &str) -> ColumnInfo {
1053 ColumnInfo {
1054 name: name.to_string(),
1055 data_type: udt_name.to_string(),
1056 udt_name: udt_name.to_string(),
1057 is_nullable: false,
1058 is_primary_key: false,
1059 ordinal_position: 0,
1060 schema_name: "public".to_string(),
1061 udt_schema: None,
1062 column_default: None,
1063 }
1064 }
1065
1066 #[test]
1067 fn test_generate_empty_schema() {
1068 let schema = SchemaInfo::default();
1069 let files = generate(
1070 &schema,
1071 DatabaseKind::Postgres,
1072 &[],
1073 &HashMap::new(),
1074 false,
1075 TimeCrate::Chrono,
1076 )
1077 .unwrap();
1078 assert!(files.is_empty());
1079 }
1080
1081 #[test]
1082 fn test_generate_one_table() {
1083 let schema = SchemaInfo {
1084 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1085 ..Default::default()
1086 };
1087 let files = generate(
1088 &schema,
1089 DatabaseKind::Postgres,
1090 &[],
1091 &HashMap::new(),
1092 false,
1093 TimeCrate::Chrono,
1094 )
1095 .unwrap();
1096 assert_eq!(files.len(), 1);
1097 assert_eq!(files[0].filename, "users.rs");
1098 }
1099
1100 #[test]
1101 fn test_generate_two_tables() {
1102 let schema = SchemaInfo {
1103 tables: vec![
1104 make_table("users", vec![make_col("id", "int4")]),
1105 make_table("posts", vec![make_col("id", "int4")]),
1106 ],
1107 ..Default::default()
1108 };
1109 let files = generate(
1110 &schema,
1111 DatabaseKind::Postgres,
1112 &[],
1113 &HashMap::new(),
1114 false,
1115 TimeCrate::Chrono,
1116 )
1117 .unwrap();
1118 assert_eq!(files.len(), 2);
1119 }
1120
1121 #[test]
1122 fn test_generate_enum_creates_types_file() {
1123 let schema = SchemaInfo {
1124 enums: vec![EnumInfo {
1125 schema_name: "public".to_string(),
1126 name: "status".to_string(),
1127 variants: vec!["active".to_string(), "inactive".to_string()],
1128 default_variant: None,
1129 }],
1130 ..Default::default()
1131 };
1132 let files = generate(
1133 &schema,
1134 DatabaseKind::Postgres,
1135 &[],
1136 &HashMap::new(),
1137 false,
1138 TimeCrate::Chrono,
1139 )
1140 .unwrap();
1141 assert_eq!(files.len(), 1);
1142 assert_eq!(files[0].filename, "types.rs");
1143 }
1144
1145 #[test]
1146 fn test_generate_enums_composites_domains_single_types_file() {
1147 let schema = SchemaInfo {
1148 enums: vec![EnumInfo {
1149 schema_name: "public".to_string(),
1150 name: "status".to_string(),
1151 variants: vec!["active".to_string()],
1152 default_variant: None,
1153 }],
1154 composite_types: vec![CompositeTypeInfo {
1155 schema_name: "public".to_string(),
1156 name: "address".to_string(),
1157 fields: vec![make_col("street", "text")],
1158 }],
1159 domains: vec![DomainInfo {
1160 schema_name: "public".to_string(),
1161 name: "email".to_string(),
1162 base_type: "text".to_string(),
1163 }],
1164 ..Default::default()
1165 };
1166 let files = generate(
1167 &schema,
1168 DatabaseKind::Postgres,
1169 &[],
1170 &HashMap::new(),
1171 false,
1172 TimeCrate::Chrono,
1173 )
1174 .unwrap();
1175 let types_files: Vec<_> = files.iter().filter(|f| f.filename == "types.rs").collect();
1177 assert_eq!(types_files.len(), 1);
1178 }
1179
1180 #[test]
1181 fn test_generate_tables_and_enums() {
1182 let schema = SchemaInfo {
1183 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1184 enums: vec![EnumInfo {
1185 schema_name: "public".to_string(),
1186 name: "status".to_string(),
1187 variants: vec!["active".to_string()],
1188 default_variant: None,
1189 }],
1190 ..Default::default()
1191 };
1192 let files = generate(
1193 &schema,
1194 DatabaseKind::Postgres,
1195 &[],
1196 &HashMap::new(),
1197 false,
1198 TimeCrate::Chrono,
1199 )
1200 .unwrap();
1201 assert_eq!(files.len(), 2); }
1203
1204 #[test]
1205 fn test_generate_filename_normalized() {
1206 let schema = SchemaInfo {
1207 tables: vec![make_table("user__data", vec![make_col("id", "int4")])],
1208 ..Default::default()
1209 };
1210 let files = generate(
1211 &schema,
1212 DatabaseKind::Postgres,
1213 &[],
1214 &HashMap::new(),
1215 false,
1216 TimeCrate::Chrono,
1217 )
1218 .unwrap();
1219 assert_eq!(files[0].filename, "user_data.rs");
1220 }
1221
1222 #[test]
1223 fn test_generate_no_origin_for_tables() {
1224 let schema = SchemaInfo {
1225 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1226 ..Default::default()
1227 };
1228 let files = generate(
1229 &schema,
1230 DatabaseKind::Postgres,
1231 &[],
1232 &HashMap::new(),
1233 false,
1234 TimeCrate::Chrono,
1235 )
1236 .unwrap();
1237 assert_eq!(files[0].origin, None);
1238 }
1239
1240 #[test]
1241 fn test_generate_types_no_origin() {
1242 let schema = SchemaInfo {
1243 enums: vec![EnumInfo {
1244 schema_name: "public".to_string(),
1245 name: "status".to_string(),
1246 variants: vec!["active".to_string()],
1247 default_variant: None,
1248 }],
1249 ..Default::default()
1250 };
1251 let files = generate(
1252 &schema,
1253 DatabaseKind::Postgres,
1254 &[],
1255 &HashMap::new(),
1256 false,
1257 TimeCrate::Chrono,
1258 )
1259 .unwrap();
1260 assert_eq!(files[0].origin, None);
1261 }
1262
1263 #[test]
1264 fn test_generate_single_file_filters_super_types_imports() {
1265 let schema = SchemaInfo {
1266 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1267 enums: vec![EnumInfo {
1268 schema_name: "public".to_string(),
1269 name: "status".to_string(),
1270 variants: vec!["active".to_string()],
1271 default_variant: None,
1272 }],
1273 ..Default::default()
1274 };
1275 let files = generate(
1276 &schema,
1277 DatabaseKind::Postgres,
1278 &[],
1279 &HashMap::new(),
1280 true,
1281 TimeCrate::Chrono,
1282 )
1283 .unwrap();
1284 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1286 assert!(!struct_file.code.contains("super::types::"));
1287 }
1288
1289 #[test]
1290 fn test_generate_multi_file_keeps_super_types_imports() {
1291 let schema = SchemaInfo {
1293 tables: vec![make_table("users", vec![make_col("status", "status")])],
1294 enums: vec![EnumInfo {
1295 schema_name: "public".to_string(),
1296 name: "status".to_string(),
1297 variants: vec!["active".to_string()],
1298 default_variant: None,
1299 }],
1300 ..Default::default()
1301 };
1302 let files = generate(
1303 &schema,
1304 DatabaseKind::Postgres,
1305 &[],
1306 &HashMap::new(),
1307 false,
1308 TimeCrate::Chrono,
1309 )
1310 .unwrap();
1311 let struct_file = files.iter().find(|f| f.filename == "users.rs").unwrap();
1312 assert!(struct_file.code.contains("super::types::"));
1313 }
1314
1315 #[test]
1316 fn test_generate_extra_derives_in_struct() {
1317 let schema = SchemaInfo {
1318 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1319 ..Default::default()
1320 };
1321 let derives = vec!["Serialize".to_string()];
1322 let files = generate(
1323 &schema,
1324 DatabaseKind::Postgres,
1325 &derives,
1326 &HashMap::new(),
1327 false,
1328 TimeCrate::Chrono,
1329 )
1330 .unwrap();
1331 assert!(files[0].code.contains("Serialize"));
1332 }
1333
1334 #[test]
1335 fn test_generate_extra_derives_in_enum() {
1336 let schema = SchemaInfo {
1337 enums: vec![EnumInfo {
1338 schema_name: "public".to_string(),
1339 name: "status".to_string(),
1340 variants: vec!["active".to_string()],
1341 default_variant: None,
1342 }],
1343 ..Default::default()
1344 };
1345 let derives = vec!["Serialize".to_string()];
1346 let files = generate(
1347 &schema,
1348 DatabaseKind::Postgres,
1349 &derives,
1350 &HashMap::new(),
1351 false,
1352 TimeCrate::Chrono,
1353 )
1354 .unwrap();
1355 assert!(files[0].code.contains("Serialize"));
1356 }
1357
1358 #[test]
1359 fn test_generate_type_overrides_in_struct() {
1360 let mut overrides = HashMap::new();
1361 overrides.insert("jsonb".to_string(), "MyJson".to_string());
1362 let schema = SchemaInfo {
1363 tables: vec![make_table("users", vec![make_col("data", "jsonb")])],
1364 ..Default::default()
1365 };
1366 let files = generate(
1367 &schema,
1368 DatabaseKind::Postgres,
1369 &[],
1370 &overrides,
1371 false,
1372 TimeCrate::Chrono,
1373 )
1374 .unwrap();
1375 assert!(files[0].code.contains("MyJson"));
1376 }
1377
1378 #[test]
1379 fn test_generate_valid_rust_syntax() {
1380 let schema = SchemaInfo {
1381 tables: vec![make_table(
1382 "users",
1383 vec![make_col("id", "int4"), make_col("name", "text")],
1384 )],
1385 enums: vec![EnumInfo {
1386 schema_name: "public".to_string(),
1387 name: "status".to_string(),
1388 variants: vec!["active".to_string(), "inactive".to_string()],
1389 default_variant: None,
1390 }],
1391 ..Default::default()
1392 };
1393 let files = generate(
1394 &schema,
1395 DatabaseKind::Postgres,
1396 &[],
1397 &HashMap::new(),
1398 false,
1399 TimeCrate::Chrono,
1400 )
1401 .unwrap();
1402 for f in &files {
1403 let parse_result = syn::parse_file(&f.code);
1405 assert!(
1406 parse_result.is_ok(),
1407 "Failed to parse {}: {:?}",
1408 f.filename,
1409 parse_result.err()
1410 );
1411 }
1412 }
1413
1414 fn make_view(name: &str, columns: Vec<ColumnInfo>) -> TableInfo {
1417 TableInfo {
1418 schema_name: "public".to_string(),
1419 name: name.to_string(),
1420 columns,
1421 }
1422 }
1423
1424 #[test]
1425 fn test_generate_one_view() {
1426 let schema = SchemaInfo {
1427 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1428 ..Default::default()
1429 };
1430 let files = generate(
1431 &schema,
1432 DatabaseKind::Postgres,
1433 &[],
1434 &HashMap::new(),
1435 false,
1436 TimeCrate::Chrono,
1437 )
1438 .unwrap();
1439 assert_eq!(files.len(), 1);
1440 assert_eq!(files[0].filename, "active_users.rs");
1441 }
1442
1443 #[test]
1444 fn test_generate_no_origin_for_views() {
1445 let schema = SchemaInfo {
1446 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1447 ..Default::default()
1448 };
1449 let files = generate(
1450 &schema,
1451 DatabaseKind::Postgres,
1452 &[],
1453 &HashMap::new(),
1454 false,
1455 TimeCrate::Chrono,
1456 )
1457 .unwrap();
1458 assert_eq!(files[0].origin, None);
1459 }
1460
1461 #[test]
1462 fn test_generate_tables_and_views() {
1463 let schema = SchemaInfo {
1464 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1465 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1466 ..Default::default()
1467 };
1468 let files = generate(
1469 &schema,
1470 DatabaseKind::Postgres,
1471 &[],
1472 &HashMap::new(),
1473 false,
1474 TimeCrate::Chrono,
1475 )
1476 .unwrap();
1477 assert_eq!(files.len(), 2);
1478 }
1479
1480 #[test]
1481 fn test_generate_view_valid_rust() {
1482 let schema = SchemaInfo {
1483 views: vec![make_view(
1484 "active_users",
1485 vec![make_col("id", "int4"), make_col("name", "text")],
1486 )],
1487 ..Default::default()
1488 };
1489 let files = generate(
1490 &schema,
1491 DatabaseKind::Postgres,
1492 &[],
1493 &HashMap::new(),
1494 false,
1495 TimeCrate::Chrono,
1496 )
1497 .unwrap();
1498 let parse_result = syn::parse_file(&files[0].code);
1499 assert!(
1500 parse_result.is_ok(),
1501 "Failed to parse: {:?}",
1502 parse_result.err()
1503 );
1504 }
1505
1506 #[test]
1507 fn test_generate_view_nullable_column() {
1508 let schema = SchemaInfo {
1509 views: vec![make_view(
1510 "v",
1511 vec![ColumnInfo {
1512 name: "email".to_string(),
1513 data_type: "text".to_string(),
1514 udt_name: "text".to_string(),
1515 is_nullable: true,
1516 is_primary_key: false,
1517 ordinal_position: 0,
1518 schema_name: "public".to_string(),
1519 udt_schema: None,
1520 column_default: None,
1521 }],
1522 )],
1523 ..Default::default()
1524 };
1525 let files = generate(
1526 &schema,
1527 DatabaseKind::Postgres,
1528 &[],
1529 &HashMap::new(),
1530 false,
1531 TimeCrate::Chrono,
1532 )
1533 .unwrap();
1534 assert!(files[0].code.contains("Option<String>"));
1535 }
1536
1537 #[test]
1538 fn test_generate_collision_both_prefixed() {
1539 let schema = SchemaInfo {
1540 tables: vec![
1541 make_table("users", vec![make_col("id", "int4")]),
1542 TableInfo {
1543 schema_name: "billing".to_string(),
1544 name: "users".to_string(),
1545 columns: vec![make_col("id", "int4")],
1546 },
1547 ],
1548 ..Default::default()
1549 };
1550 let files = generate(
1551 &schema,
1552 DatabaseKind::Postgres,
1553 &[],
1554 &HashMap::new(),
1555 false,
1556 TimeCrate::Chrono,
1557 )
1558 .unwrap();
1559 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1560 assert!(filenames.contains(&"users.rs"));
1561 assert!(filenames.contains(&"billing_users.rs"));
1562 }
1563
1564 #[test]
1565 fn test_generate_no_collision_no_prefix() {
1566 let schema = SchemaInfo {
1567 tables: vec![
1568 make_table("users", vec![make_col("id", "int4")]),
1569 TableInfo {
1570 schema_name: "billing".to_string(),
1571 name: "invoices".to_string(),
1572 columns: vec![make_col("id", "int4")],
1573 },
1574 ],
1575 ..Default::default()
1576 };
1577 let files = generate(
1578 &schema,
1579 DatabaseKind::Postgres,
1580 &[],
1581 &HashMap::new(),
1582 false,
1583 TimeCrate::Chrono,
1584 )
1585 .unwrap();
1586 let filenames: Vec<_> = files.iter().map(|f| f.filename.as_str()).collect();
1587 assert!(filenames.contains(&"users.rs"));
1588 assert!(filenames.contains(&"invoices.rs"));
1589 }
1590
1591 #[test]
1592 fn test_generate_single_schema_no_prefix() {
1593 let schema = SchemaInfo {
1594 tables: vec![
1595 make_table("users", vec![make_col("id", "int4")]),
1596 make_table("posts", vec![make_col("id", "int4")]),
1597 ],
1598 ..Default::default()
1599 };
1600 let files = generate(
1601 &schema,
1602 DatabaseKind::Postgres,
1603 &[],
1604 &HashMap::new(),
1605 false,
1606 TimeCrate::Chrono,
1607 )
1608 .unwrap();
1609 assert_eq!(files[0].filename, "users.rs");
1610 assert_eq!(files[1].filename, "posts.rs");
1611 }
1612
1613 #[test]
1614 fn test_generate_view_single_file_mode() {
1615 let schema = SchemaInfo {
1616 tables: vec![make_table("users", vec![make_col("id", "int4")])],
1617 views: vec![make_view("active_users", vec![make_col("id", "int4")])],
1618 ..Default::default()
1619 };
1620 let files = generate(
1621 &schema,
1622 DatabaseKind::Postgres,
1623 &[],
1624 &HashMap::new(),
1625 true,
1626 TimeCrate::Chrono,
1627 )
1628 .unwrap();
1629 assert_eq!(files.len(), 2);
1630 }
1631
1632 #[test]
1635 fn test_parse_pg_enum_default_simple() {
1636 assert_eq!(
1637 parse_pg_enum_default("'idle'::task_status"),
1638 Some("idle".to_string())
1639 );
1640 }
1641
1642 #[test]
1643 fn test_parse_pg_enum_default_schema_qualified() {
1644 assert_eq!(
1645 parse_pg_enum_default("'active'::public.task_status"),
1646 Some("active".to_string())
1647 );
1648 }
1649
1650 #[test]
1651 fn test_parse_pg_enum_default_not_enum() {
1652 assert_eq!(parse_pg_enum_default("nextval('users_id_seq')"), None);
1654 }
1655
1656 #[test]
1657 fn test_parse_pg_enum_default_no_cast() {
1658 assert_eq!(parse_pg_enum_default("'hello'"), None);
1659 }
1660
1661 #[test]
1662 fn test_parse_pg_enum_default_empty() {
1663 assert_eq!(parse_pg_enum_default(""), None);
1664 }
1665
1666 #[test]
1669 fn test_extract_enum_defaults_from_column() {
1670 let schema = SchemaInfo {
1671 tables: vec![TableInfo {
1672 schema_name: "public".to_string(),
1673 name: "tasks".to_string(),
1674 columns: vec![ColumnInfo {
1675 name: "status".to_string(),
1676 data_type: "USER-DEFINED".to_string(),
1677 udt_name: "task_status".to_string(),
1678 is_nullable: false,
1679 is_primary_key: false,
1680 ordinal_position: 0,
1681 schema_name: "public".to_string(),
1682 udt_schema: None,
1683 column_default: Some("'idle'::task_status".to_string()),
1684 }],
1685 }],
1686 enums: vec![EnumInfo {
1687 schema_name: "public".to_string(),
1688 name: "task_status".to_string(),
1689 variants: vec!["idle".to_string(), "running".to_string()],
1690 default_variant: None,
1691 }],
1692 ..Default::default()
1693 };
1694 let defaults = extract_enum_defaults(&schema);
1695 assert_eq!(defaults.get("task_status"), Some(&"idle".to_string()));
1696 }
1697
1698 #[test]
1699 fn test_extract_enum_defaults_no_default() {
1700 let schema = SchemaInfo {
1701 tables: vec![TableInfo {
1702 schema_name: "public".to_string(),
1703 name: "tasks".to_string(),
1704 columns: vec![ColumnInfo {
1705 name: "status".to_string(),
1706 data_type: "USER-DEFINED".to_string(),
1707 udt_name: "task_status".to_string(),
1708 is_nullable: false,
1709 is_primary_key: false,
1710 ordinal_position: 0,
1711 schema_name: "public".to_string(),
1712 udt_schema: None,
1713 column_default: None,
1714 }],
1715 }],
1716 enums: vec![EnumInfo {
1717 schema_name: "public".to_string(),
1718 name: "task_status".to_string(),
1719 variants: vec!["idle".to_string()],
1720 default_variant: None,
1721 }],
1722 ..Default::default()
1723 };
1724 let defaults = extract_enum_defaults(&schema);
1725 assert!(defaults.is_empty());
1726 }
1727
1728 #[test]
1729 fn test_extract_enum_defaults_non_enum_column_ignored() {
1730 let schema = SchemaInfo {
1731 tables: vec![TableInfo {
1732 schema_name: "public".to_string(),
1733 name: "users".to_string(),
1734 columns: vec![ColumnInfo {
1735 name: "name".to_string(),
1736 data_type: "character varying".to_string(),
1737 udt_name: "varchar".to_string(),
1738 is_nullable: false,
1739 is_primary_key: false,
1740 ordinal_position: 0,
1741 schema_name: "public".to_string(),
1742 udt_schema: None,
1743 column_default: Some("'hello'::character varying".to_string()),
1744 }],
1745 }],
1746 enums: vec![],
1747 ..Default::default()
1748 };
1749 let defaults = extract_enum_defaults(&schema);
1750 assert!(defaults.is_empty());
1751 }
1752
1753 #[test]
1754 fn test_generate_enum_with_default() {
1755 let schema = SchemaInfo {
1756 tables: vec![TableInfo {
1757 schema_name: "public".to_string(),
1758 name: "tasks".to_string(),
1759 columns: vec![ColumnInfo {
1760 name: "status".to_string(),
1761 data_type: "USER-DEFINED".to_string(),
1762 udt_name: "task_status".to_string(),
1763 is_nullable: false,
1764 is_primary_key: false,
1765 ordinal_position: 0,
1766 schema_name: "public".to_string(),
1767 udt_schema: None,
1768 column_default: Some("'idle'::task_status".to_string()),
1769 }],
1770 }],
1771 enums: vec![EnumInfo {
1772 schema_name: "public".to_string(),
1773 name: "task_status".to_string(),
1774 variants: vec!["idle".to_string(), "running".to_string()],
1775 default_variant: None,
1776 }],
1777 ..Default::default()
1778 };
1779 let files = generate(
1780 &schema,
1781 DatabaseKind::Postgres,
1782 &[],
1783 &HashMap::new(),
1784 false,
1785 TimeCrate::Chrono,
1786 )
1787 .unwrap();
1788 let types_file = files.iter().find(|f| f.filename == "types.rs").unwrap();
1789 assert!(types_file.code.contains("impl Default for TaskStatus"));
1790 assert!(types_file.code.contains("Self::Idle"));
1791 }
1792}