1use std::collections::HashMap;
7
8use prax_schema::Schema;
9use prax_schema::ast::{
10 Attribute, AttributeArg, AttributeValue, Enum, EnumVariant, Field, FieldType, Ident, Model,
11 ScalarType, Span, TypeModifier,
12};
13
14use crate::error::{MigrateResult, MigrationError};
15
16#[derive(Debug, Clone)]
18pub struct IntrospectionResult {
19 pub schema: Schema,
21 pub skipped_tables: Vec<SkippedTable>,
23 pub warnings: Vec<String>,
25}
26
27#[derive(Debug, Clone)]
29pub struct SkippedTable {
30 pub name: String,
32 pub reason: String,
34}
35
36#[derive(Debug, Clone)]
38pub struct IntrospectionConfig {
39 pub database_schema: String,
41 pub include_tables: Vec<String>,
43 pub exclude_tables: Vec<String>,
45 pub include_views: bool,
47 pub include_enums: bool,
49}
50
51impl Default for IntrospectionConfig {
52 fn default() -> Self {
53 Self {
54 database_schema: "public".to_string(),
55 include_tables: Vec::new(),
56 exclude_tables: vec![
57 "_prax_migrations".to_string(),
58 "_prisma_migrations".to_string(),
59 "schema_migrations".to_string(),
60 ],
61 include_views: true,
62 include_enums: true,
63 }
64 }
65}
66
67impl IntrospectionConfig {
68 pub fn new() -> Self {
70 Self::default()
71 }
72
73 pub fn database_schema(mut self, schema: impl Into<String>) -> Self {
75 self.database_schema = schema.into();
76 self
77 }
78
79 pub fn include_tables(mut self, tables: Vec<String>) -> Self {
81 self.include_tables = tables;
82 self
83 }
84
85 pub fn exclude_tables(mut self, tables: Vec<String>) -> Self {
87 self.exclude_tables = tables;
88 self
89 }
90
91 pub fn include_views(mut self, include: bool) -> Self {
93 self.include_views = include;
94 self
95 }
96
97 pub fn include_enums(mut self, include: bool) -> Self {
99 self.include_enums = include;
100 self
101 }
102
103 pub fn should_include_table(&self, name: &str) -> bool {
105 if self.exclude_tables.contains(&name.to_string()) {
106 return false;
107 }
108 if self.include_tables.is_empty() {
109 return true;
110 }
111 self.include_tables.contains(&name.to_string())
112 }
113}
114
115#[derive(Debug, Clone)]
117pub struct TableInfo {
118 pub name: String,
120 pub schema: String,
122 pub table_type: String,
124 pub comment: Option<String>,
126}
127
128#[derive(Debug, Clone)]
130pub struct ColumnInfo {
131 pub name: String,
133 pub data_type: String,
135 pub udt_name: String,
137 pub character_maximum_length: Option<i32>,
139 pub numeric_precision: Option<i32>,
141 pub is_nullable: bool,
143 pub column_default: Option<String>,
145 pub ordinal_position: i32,
147 pub comment: Option<String>,
149}
150
151#[derive(Debug, Clone)]
153pub struct ConstraintInfo {
154 pub name: String,
156 pub constraint_type: String,
158 pub table_name: String,
160 pub columns: Vec<String>,
162 pub referenced_table: Option<String>,
164 pub referenced_columns: Option<Vec<String>>,
166 pub on_delete: Option<String>,
168 pub on_update: Option<String>,
170}
171
172#[derive(Debug, Clone)]
174pub struct EnumInfo {
175 pub name: String,
177 pub values: Vec<String>,
179 pub schema: String,
181}
182
183#[derive(Debug, Clone)]
185pub struct IndexInfo {
186 pub name: String,
188 pub table_name: String,
190 pub columns: Vec<String>,
192 pub is_unique: bool,
194 pub is_primary: bool,
196 pub index_method: String,
198}
199
200#[async_trait::async_trait]
202pub trait Introspector: Send + Sync {
203 async fn get_tables(&self, config: &IntrospectionConfig) -> MigrateResult<Vec<TableInfo>>;
205
206 async fn get_columns(&self, table: &str, schema: &str) -> MigrateResult<Vec<ColumnInfo>>;
208
209 async fn get_constraints(
211 &self,
212 table: &str,
213 schema: &str,
214 ) -> MigrateResult<Vec<ConstraintInfo>>;
215
216 async fn get_indexes(&self, table: &str, schema: &str) -> MigrateResult<Vec<IndexInfo>>;
218
219 async fn get_enums(&self, schema: &str) -> MigrateResult<Vec<EnumInfo>>;
221}
222
223pub struct SchemaBuilder {
225 config: IntrospectionConfig,
226 tables: Vec<TableInfo>,
227 columns: HashMap<String, Vec<ColumnInfo>>,
228 constraints: HashMap<String, Vec<ConstraintInfo>>,
229 indexes: HashMap<String, Vec<IndexInfo>>,
230 enums: Vec<EnumInfo>,
231}
232
233impl SchemaBuilder {
234 pub fn new(config: IntrospectionConfig) -> Self {
236 Self {
237 config,
238 tables: Vec::new(),
239 columns: HashMap::new(),
240 constraints: HashMap::new(),
241 indexes: HashMap::new(),
242 enums: Vec::new(),
243 }
244 }
245
246 pub fn with_tables(mut self, tables: Vec<TableInfo>) -> Self {
248 self.tables = tables;
249 self
250 }
251
252 pub fn with_columns(mut self, table: &str, columns: Vec<ColumnInfo>) -> Self {
254 self.columns.insert(table.to_string(), columns);
255 self
256 }
257
258 pub fn with_constraints(mut self, table: &str, constraints: Vec<ConstraintInfo>) -> Self {
260 self.constraints.insert(table.to_string(), constraints);
261 self
262 }
263
264 pub fn with_indexes(mut self, table: &str, indexes: Vec<IndexInfo>) -> Self {
266 self.indexes.insert(table.to_string(), indexes);
267 self
268 }
269
270 pub fn with_enums(mut self, enums: Vec<EnumInfo>) -> Self {
272 self.enums = enums;
273 self
274 }
275
276 pub fn build(self) -> MigrateResult<IntrospectionResult> {
278 let mut schema = Schema::new();
279 let mut skipped_tables = Vec::new();
280 let mut warnings = Vec::new();
281
282 if self.config.include_enums {
284 for enum_info in &self.enums {
285 let prax_enum = self.build_enum(enum_info);
286 schema.add_enum(prax_enum);
287 }
288 }
289
290 for table in &self.tables {
292 if !self.config.should_include_table(&table.name) {
293 skipped_tables.push(SkippedTable {
294 name: table.name.clone(),
295 reason: "Excluded by configuration".to_string(),
296 });
297 continue;
298 }
299
300 if table.table_type == "VIEW" && !self.config.include_views {
302 continue;
303 }
304
305 match self.build_model(table) {
306 Ok(model) => {
307 schema.add_model(model);
308 }
309 Err(e) => {
310 warnings.push(format!("Failed to build model for '{}': {}", table.name, e));
311 skipped_tables.push(SkippedTable {
312 name: table.name.clone(),
313 reason: e.to_string(),
314 });
315 }
316 }
317 }
318
319 Ok(IntrospectionResult {
320 schema,
321 skipped_tables,
322 warnings,
323 })
324 }
325
326 fn build_enum(&self, info: &EnumInfo) -> Enum {
328 let span = Span::new(0, 0);
329 let name = Ident::new(to_pascal_case(&info.name), span);
330 let mut prax_enum = Enum::new(name, span);
331
332 for value in &info.values {
333 prax_enum.add_variant(EnumVariant::new(Ident::new(value.clone(), span), span));
334 }
335
336 prax_enum
337 }
338
339 fn build_model(&self, table: &TableInfo) -> MigrateResult<Model> {
341 let span = Span::new(0, 0);
342 let name = Ident::new(to_pascal_case(&table.name), span);
343 let mut model = Model::new(name, span);
344
345 let model_name = to_pascal_case(&table.name);
347 if table.name != model_name && table.name != to_snake_case(&model_name) {
348 model.attributes.push(Attribute::new(
349 Ident::new("map", span),
350 vec![AttributeArg::positional(
351 AttributeValue::String(table.name.clone()),
352 span,
353 )],
354 span,
355 ));
356 }
357
358 let columns = self.columns.get(&table.name).cloned().unwrap_or_default();
360
361 let constraints = self
363 .constraints
364 .get(&table.name)
365 .cloned()
366 .unwrap_or_default();
367
368 let pk_columns: Vec<&str> = constraints
370 .iter()
371 .filter(|c| c.constraint_type == "PRIMARY KEY")
372 .flat_map(|c| c.columns.iter().map(|s| s.as_str()))
373 .collect();
374
375 let unique_columns: Vec<&str> = constraints
377 .iter()
378 .filter(|c| c.constraint_type == "UNIQUE")
379 .filter(|c| c.columns.len() == 1)
380 .flat_map(|c| c.columns.iter().map(|s| s.as_str()))
381 .collect();
382
383 for column in columns {
385 let field = self.build_field(&column, &pk_columns, &unique_columns)?;
386 model.add_field(field);
387 }
388
389 Ok(model)
390 }
391
392 fn build_field(
394 &self,
395 column: &ColumnInfo,
396 pk_columns: &[&str],
397 unique_columns: &[&str],
398 ) -> MigrateResult<Field> {
399 let span = Span::new(0, 0);
400 let name = Ident::new(&column.name, span);
401
402 let (field_type, needs_map) = self.sql_type_to_prax(&column.udt_name, &column.data_type)?;
404
405 let modifier = if column.is_nullable {
407 TypeModifier::Optional
408 } else {
409 TypeModifier::Required
410 };
411
412 let mut attributes = Vec::new();
413
414 if pk_columns.contains(&column.name.as_str()) {
416 attributes.push(Attribute::simple(Ident::new("id", span), span));
417
418 if let Some(default) = &column.column_default
420 && (default.contains("nextval") || default.contains("SERIAL"))
421 {
422 attributes.push(Attribute::simple(Ident::new("auto", span), span));
423 }
424 }
425
426 if unique_columns.contains(&column.name.as_str()) {
428 attributes.push(Attribute::simple(Ident::new("unique", span), span));
429 }
430
431 if let Some(default) = &column.column_default
433 && !default.contains("nextval")
434 && let Some(value) = parse_default_value(default)
435 {
436 attributes.push(Attribute::new(
437 Ident::new("default", span),
438 vec![AttributeArg::positional(value, span)],
439 span,
440 ));
441 }
442
443 if needs_map {
445 attributes.push(Attribute::new(
446 Ident::new("map", span),
447 vec![AttributeArg::positional(
448 AttributeValue::String(column.name.clone()),
449 span,
450 )],
451 span,
452 ));
453 }
454
455 Ok(Field::new(name, field_type, modifier, attributes, span))
456 }
457
458 fn sql_type_to_prax(
460 &self,
461 udt_name: &str,
462 data_type: &str,
463 ) -> MigrateResult<(FieldType, bool)> {
464 let enum_names: Vec<&str> = self.enums.iter().map(|e| e.name.as_str()).collect();
466 if enum_names.contains(&udt_name) {
467 return Ok((FieldType::Enum(to_pascal_case(udt_name).into()), false));
468 }
469
470 let scalar = match udt_name {
471 "int2" | "int4" | "integer" | "smallint" => ScalarType::Int,
472 "int8" | "bigint" => ScalarType::BigInt,
473 "float4" | "float8" | "real" | "double precision" => ScalarType::Float,
474 "numeric" | "decimal" | "money" => ScalarType::Decimal,
475 "text" | "varchar" | "char" | "character varying" | "character" | "bpchar" => {
476 ScalarType::String
477 }
478 "bool" | "boolean" => ScalarType::Boolean,
479 "timestamp"
480 | "timestamptz"
481 | "timestamp with time zone"
482 | "timestamp without time zone" => ScalarType::DateTime,
483 "date" => ScalarType::Date,
484 "time" | "timetz" | "time with time zone" | "time without time zone" => {
485 ScalarType::Time
486 }
487 "json" | "jsonb" => ScalarType::Json,
488 "bytea" => ScalarType::Bytes,
489 "uuid" => ScalarType::Uuid,
490 _ => {
491 match data_type {
493 "integer" | "smallint" => ScalarType::Int,
494 "bigint" => ScalarType::BigInt,
495 "real" | "double precision" => ScalarType::Float,
496 "numeric" => ScalarType::Decimal,
497 "character varying" | "character" | "text" => ScalarType::String,
498 "boolean" => ScalarType::Boolean,
499 "timestamp with time zone" | "timestamp without time zone" => {
500 ScalarType::DateTime
501 }
502 "date" => ScalarType::Date,
503 "time with time zone" | "time without time zone" => ScalarType::Time,
504 "json" | "jsonb" => ScalarType::Json,
505 "bytea" => ScalarType::Bytes,
506 "uuid" => ScalarType::Uuid,
507 "ARRAY" => {
508 ScalarType::Json
510 }
511 "USER-DEFINED" => {
512 return Err(MigrationError::InvalidMigration(format!(
514 "Unknown user-defined type: {}",
515 udt_name
516 )));
517 }
518 _ => {
519 return Err(MigrationError::InvalidMigration(format!(
520 "Unknown SQL type: {} ({})",
521 udt_name, data_type
522 )));
523 }
524 }
525 }
526 };
527
528 Ok((FieldType::Scalar(scalar), false))
529 }
530}
531
532fn parse_default_value(default: &str) -> Option<AttributeValue> {
534 let trimmed = default.trim();
535
536 if trimmed == "true" || trimmed == "TRUE" {
538 return Some(AttributeValue::Boolean(true));
539 }
540 if trimmed == "false" || trimmed == "FALSE" {
541 return Some(AttributeValue::Boolean(false));
542 }
543
544 if trimmed.to_uppercase() == "NULL" {
546 return None;
547 }
548
549 if let Ok(int) = trimmed.parse::<i64>() {
551 return Some(AttributeValue::Int(int));
552 }
553
554 if let Ok(float) = trimmed.parse::<f64>() {
556 return Some(AttributeValue::Float(float));
557 }
558
559 if (trimmed.starts_with('\'') && trimmed.ends_with('\''))
561 || (trimmed.starts_with('"') && trimmed.ends_with('"'))
562 {
563 let inner = &trimmed[1..trimmed.len() - 1];
564 return Some(AttributeValue::String(inner.to_string()));
565 }
566
567 if let Some(pos) = trimmed.find("::") {
569 return parse_default_value(&trimmed[..pos]);
570 }
571
572 if trimmed.ends_with("()") || trimmed.contains('(') {
574 let func_name = if let Some(paren_pos) = trimmed.find('(') {
575 &trimmed[..paren_pos]
576 } else {
577 &trimmed[..trimmed.len() - 2]
578 };
579 return Some(AttributeValue::Function(
580 func_name.to_string().into(),
581 vec![],
582 ));
583 }
584
585 Some(AttributeValue::String(trimmed.to_string()))
587}
588
589fn to_pascal_case(s: &str) -> String {
591 s.split('_')
592 .filter(|part| !part.is_empty())
593 .map(|part| {
594 let mut chars = part.chars();
595 match chars.next() {
596 None => String::new(),
597 Some(first) => first.to_uppercase().chain(chars).collect(),
598 }
599 })
600 .collect()
601}
602
603fn to_snake_case(s: &str) -> String {
605 let mut result = String::new();
606 for (i, ch) in s.chars().enumerate() {
607 if ch.is_uppercase() {
608 if i > 0 {
609 result.push('_');
610 }
611 result.push(ch.to_lowercase().next().unwrap_or(ch));
612 } else {
613 result.push(ch);
614 }
615 }
616 result
617}
618
619pub mod postgres_queries {
621 pub const TABLES: &str = r#"
623 SELECT
624 table_name,
625 table_schema,
626 table_type
627 FROM information_schema.tables
628 WHERE table_schema = $1
629 ORDER BY table_name
630 "#;
631
632 pub const COLUMNS: &str = r#"
634 SELECT
635 column_name,
636 data_type,
637 udt_name,
638 character_maximum_length,
639 numeric_precision,
640 is_nullable = 'YES' as is_nullable,
641 column_default,
642 ordinal_position
643 FROM information_schema.columns
644 WHERE table_schema = $1 AND table_name = $2
645 ORDER BY ordinal_position
646 "#;
647
648 pub const CONSTRAINTS: &str = r#"
650 SELECT
651 tc.constraint_name,
652 tc.constraint_type,
653 tc.table_name,
654 kcu.column_name,
655 ccu.table_name AS referenced_table,
656 ccu.column_name AS referenced_column,
657 rc.delete_rule,
658 rc.update_rule
659 FROM information_schema.table_constraints tc
660 LEFT JOIN information_schema.key_column_usage kcu
661 ON tc.constraint_name = kcu.constraint_name
662 AND tc.table_schema = kcu.table_schema
663 LEFT JOIN information_schema.constraint_column_usage ccu
664 ON tc.constraint_name = ccu.constraint_name
665 AND tc.table_schema = ccu.table_schema
666 AND tc.constraint_type = 'FOREIGN KEY'
667 LEFT JOIN information_schema.referential_constraints rc
668 ON tc.constraint_name = rc.constraint_name
669 AND tc.table_schema = rc.constraint_schema
670 WHERE tc.table_schema = $1 AND tc.table_name = $2
671 ORDER BY tc.constraint_name, kcu.ordinal_position
672 "#;
673
674 pub const INDEXES: &str = r#"
676 SELECT
677 i.relname AS index_name,
678 t.relname AS table_name,
679 array_agg(a.attname ORDER BY array_position(ix.indkey, a.attnum)) AS columns,
680 ix.indisunique AS is_unique,
681 ix.indisprimary AS is_primary,
682 am.amname AS index_method
683 FROM pg_index ix
684 JOIN pg_class i ON ix.indexrelid = i.oid
685 JOIN pg_class t ON ix.indrelid = t.oid
686 JOIN pg_namespace n ON t.relnamespace = n.oid
687 JOIN pg_am am ON i.relam = am.oid
688 JOIN pg_attribute a ON a.attrelid = t.oid AND a.attnum = ANY(ix.indkey)
689 WHERE n.nspname = $1 AND t.relname = $2
690 GROUP BY i.relname, t.relname, ix.indisunique, ix.indisprimary, am.amname
691 "#;
692
693 pub const ENUMS: &str = r#"
695 SELECT
696 t.typname AS enum_name,
697 n.nspname AS schema_name,
698 array_agg(e.enumlabel ORDER BY e.enumsortorder) AS enum_values
699 FROM pg_type t
700 JOIN pg_namespace n ON t.typnamespace = n.oid
701 JOIN pg_enum e ON t.oid = e.enumtypid
702 WHERE n.nspname = $1
703 GROUP BY t.typname, n.nspname
704 "#;
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710
711 #[test]
712 fn test_to_pascal_case() {
713 assert_eq!(to_pascal_case("user"), "User");
714 assert_eq!(to_pascal_case("user_profile"), "UserProfile");
715 assert_eq!(
716 to_pascal_case("user_profile_settings"),
717 "UserProfileSettings"
718 );
719 assert_eq!(to_pascal_case("_user_"), "User");
720 }
721
722 #[test]
723 fn test_to_snake_case() {
724 assert_eq!(to_snake_case("User"), "user");
725 assert_eq!(to_snake_case("UserProfile"), "user_profile");
726 assert_eq!(to_snake_case("HTTPResponse"), "h_t_t_p_response");
727 }
728
729 #[test]
730 fn test_parse_default_value_boolean() {
731 assert!(matches!(
732 parse_default_value("true"),
733 Some(AttributeValue::Boolean(true))
734 ));
735 assert!(matches!(
736 parse_default_value("false"),
737 Some(AttributeValue::Boolean(false))
738 ));
739 }
740
741 #[test]
742 fn test_parse_default_value_int() {
743 assert!(matches!(
744 parse_default_value("42"),
745 Some(AttributeValue::Int(42))
746 ));
747 assert!(matches!(
748 parse_default_value("-5"),
749 Some(AttributeValue::Int(-5))
750 ));
751 }
752
753 #[test]
754 #[allow(clippy::approx_constant)]
755 fn test_parse_default_value_float() {
756 if let Some(AttributeValue::Float(f)) = parse_default_value("3.14") {
757 assert!((f - 3.14).abs() < 0.001);
758 } else {
759 panic!("Expected Float");
760 }
761 }
762
763 #[test]
764 fn test_parse_default_value_string() {
765 if let Some(AttributeValue::String(s)) = parse_default_value("'hello'") {
766 assert_eq!(s.as_str(), "hello");
767 } else {
768 panic!("Expected String");
769 }
770 }
771
772 #[test]
773 fn test_parse_default_value_function() {
774 if let Some(AttributeValue::Function(name, args)) = parse_default_value("now()") {
775 assert_eq!(name.as_str(), "now");
776 assert!(args.is_empty());
777 } else {
778 panic!("Expected Function");
779 }
780 }
781
782 #[test]
783 fn test_parse_default_value_with_cast() {
784 if let Some(AttributeValue::String(s)) = parse_default_value("'active'::status_type") {
785 assert_eq!(s.as_str(), "active");
786 } else {
787 panic!("Expected String");
788 }
789 }
790
791 #[test]
792 fn test_config_should_include_table() {
793 let config = IntrospectionConfig::default();
794 assert!(config.should_include_table("users"));
795 assert!(!config.should_include_table("_prax_migrations"));
796 }
797
798 #[test]
799 fn test_config_include_specific_tables() {
800 let config = IntrospectionConfig::new().include_tables(vec!["users".to_string()]);
801 assert!(config.should_include_table("users"));
802 assert!(!config.should_include_table("posts"));
803 }
804
805 #[test]
806 fn test_sql_type_mapping() {
807 let builder = SchemaBuilder::new(IntrospectionConfig::default());
808
809 let (ft, _) = builder.sql_type_to_prax("int4", "integer").unwrap();
810 assert!(matches!(ft, FieldType::Scalar(ScalarType::Int)));
811
812 let (ft, _) = builder.sql_type_to_prax("text", "text").unwrap();
813 assert!(matches!(ft, FieldType::Scalar(ScalarType::String)));
814
815 let (ft, _) = builder.sql_type_to_prax("bool", "boolean").unwrap();
816 assert!(matches!(ft, FieldType::Scalar(ScalarType::Boolean)));
817
818 let (ft, _) = builder
819 .sql_type_to_prax("timestamptz", "timestamp with time zone")
820 .unwrap();
821 assert!(matches!(ft, FieldType::Scalar(ScalarType::DateTime)));
822
823 let (ft, _) = builder.sql_type_to_prax("uuid", "uuid").unwrap();
824 assert!(matches!(ft, FieldType::Scalar(ScalarType::Uuid)));
825 }
826}