use crate::catalog::Catalog;
use crate::config::{ColumnOrderMode, Config, ObjectFilter};
use crate::diff::operations::{MigrationStep, SqlRenderer};
use crate::diff::plan;
use crate::schema_ops::apply_current_schema_to_shadow;
use anyhow::Result;
use std::collections::HashSet;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub show_differences: bool,
pub verbose: bool,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
show_differences: true,
verbose: true,
}
}
}
#[derive(Debug)]
pub struct ValidationResult {
pub passed: bool,
pub differences: Vec<MigrationStep>,
pub message: String,
}
pub async fn validate_database_against_schema_files(
dev_catalog: &Catalog,
config: &Config,
root_dir: &Path,
validation_config: &ValidationConfig,
shadow: &crate::config::ShadowDatabase,
) -> Result<ValidationResult> {
if validation_config.verbose {
println!("🔍 Validating database against schema files...");
}
let expected_catalog = apply_current_schema_to_shadow(config, root_dir, shadow).await?;
validate_catalogs(dev_catalog, &expected_catalog, config, validation_config)
}
pub fn validate_catalogs(
actual_catalog: &Catalog,
expected_catalog: &Catalog,
config: &Config,
validation_config: &ValidationConfig,
) -> Result<ValidationResult> {
let filter = ObjectFilter::from_config(config);
let actual = filter.filter_catalog(actual_catalog.clone());
let expected = filter.filter_catalog(expected_catalog.clone());
if validation_config.verbose {
println!("🔍 Comparing schemas...");
}
let ordered_steps = plan(&actual, &expected)?;
let column_order_mismatches = if config.migration.column_order != ColumnOrderMode::Relaxed {
find_column_order_mismatches(&actual, &expected)
} else {
Vec::new()
};
if config.migration.column_order == ColumnOrderMode::Warn && !column_order_mismatches.is_empty()
{
eprintln!("Warning: Column order mismatches detected:\n");
for mismatch in &column_order_mismatches {
eprintln!(" {}", mismatch);
}
eprintln!("\nSchema files have columns in a different order than the database.");
}
let has_column_order_errors = config.migration.column_order == ColumnOrderMode::Strict
&& !column_order_mismatches.is_empty();
if ordered_steps.is_empty() && !has_column_order_errors {
Ok(ValidationResult {
passed: true,
differences: ordered_steps,
message: "Schema validation passed! Database matches expected schema.".to_string(),
})
} else {
let message = if validation_config.show_differences {
format_validation_failure(
&ordered_steps,
&column_order_mismatches,
has_column_order_errors,
)
} else {
let total = ordered_steps.len()
+ if has_column_order_errors {
column_order_mismatches.len()
} else {
0
};
format!("Schema validation failed! Found {} differences.", total)
};
Ok(ValidationResult {
passed: false,
differences: ordered_steps,
message,
})
}
}
fn format_validation_failure(
differences: &[MigrationStep],
column_order_mismatches: &[ColumnOrderMismatch],
include_column_order_errors: bool,
) -> String {
let total_issues = differences.len()
+ if include_column_order_errors {
column_order_mismatches.len()
} else {
0
};
let mut message = format!(
"Schema validation failed! Found {} issue(s):\n",
total_issues
);
if !differences.is_empty() {
message.push_str("\nRequired changes to bring database in sync:\n");
message.push_str(&"=".repeat(50));
message.push('\n');
for (i, step) in differences.iter().enumerate() {
message.push_str(&format!("{}. {:?}\n", i + 1, step.id()));
for rendered in step.to_sql() {
message.push_str(&format!(" {}\n", rendered.sql));
}
message.push('\n');
}
}
if include_column_order_errors && !column_order_mismatches.is_empty() {
message.push_str("\nColumn order mismatches:\n");
message.push_str(&"=".repeat(50));
message.push('\n');
for mismatch in column_order_mismatches {
message.push_str(&format!("{}\n\n", mismatch));
}
}
message.push_str("💡 To fix these issues:\n");
if !differences.is_empty() {
message.push_str(" 1. Update your schema files to match the database, OR\n");
message.push_str(" 2. Run 'pgmt apply' to apply schema files to the database\n");
}
if include_column_order_errors && !column_order_mismatches.is_empty() {
message.push_str(" - For column order: Update schema files to match the physical column order in the database\n");
message.push_str(" - To disable column order checks: Set `migration.column_order: relaxed` in pgmt.yaml\n");
}
message
}
#[derive(Debug)]
pub struct BaselineValidationError {
pub differences: Vec<MigrationStep>,
}
impl std::error::Error for BaselineValidationError {}
impl std::fmt::Display for BaselineValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(
f,
"Baseline validation found {} unexpected difference(s):\n",
self.differences.len()
)?;
for (i, step) in self.differences.iter().enumerate() {
writeln!(f, " {}. {:?}", i + 1, step.id())?;
for rendered in step.to_sql() {
let sql = if rendered.sql.len() > 100 {
format!("{}...", &rendered.sql[..100])
} else {
rendered.sql.clone()
};
writeln!(f, " {}", sql)?;
}
}
Ok(())
}
}
pub fn validate_baseline_consistency(
baseline_catalog: &Catalog,
expected_catalog: &Catalog,
config: &Config,
) -> Result<(), BaselineValidationError> {
let validation_config = ValidationConfig {
show_differences: false,
verbose: false,
};
let result = validate_catalogs(
baseline_catalog,
expected_catalog,
config,
&validation_config,
)
.map_err(|_| BaselineValidationError {
differences: vec![],
})?;
if result.passed {
Ok(())
} else {
Err(BaselineValidationError {
differences: result.differences,
})
}
}
#[derive(Debug, Clone)]
pub struct ColumnOrderViolation {
pub schema: String,
pub table: String,
pub new_column: String,
pub old_column_after: String,
}
impl std::fmt::Display for ColumnOrderViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Table {}.{}: new column '{}' must come after existing column '{}'",
self.schema, self.table, self.new_column, self.old_column_after
)
}
}
#[derive(Debug, Clone)]
pub struct ColumnOrderMismatch {
pub schema: String,
pub table: String,
pub expected_order: Vec<String>,
pub actual_order: Vec<String>,
}
impl std::fmt::Display for ColumnOrderMismatch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Table {}.{}: column order mismatch\n Expected: [{}]\n Actual: [{}]",
self.schema,
self.table,
self.expected_order.join(", "),
self.actual_order.join(", ")
)
}
}
pub fn validate_column_order(
old_catalog: &Catalog,
new_catalog: &Catalog,
) -> Vec<ColumnOrderViolation> {
let mut violations = Vec::new();
let old_tables: std::collections::HashMap<(&str, &str), &crate::catalog::table::Table> =
old_catalog
.tables
.iter()
.map(|t| ((t.schema.as_str(), t.name.as_str()), t))
.collect();
for new_table in &new_catalog.tables {
let Some(old_table) = old_tables.get(&(new_table.schema.as_str(), new_table.name.as_str()))
else {
continue;
};
let old_columns: HashSet<&str> =
old_table.columns.iter().map(|c| c.name.as_str()).collect();
let mut seen_new_column: Option<&str> = None;
for column in &new_table.columns {
let is_old_column = old_columns.contains(column.name.as_str());
if !is_old_column {
if seen_new_column.is_none() {
seen_new_column = Some(&column.name);
}
} else {
if let Some(new_col_name) = seen_new_column {
violations.push(ColumnOrderViolation {
schema: new_table.schema.clone(),
table: new_table.name.clone(),
new_column: new_col_name.to_string(),
old_column_after: column.name.clone(),
});
break;
}
}
}
}
violations
}
pub fn find_column_order_mismatches(
actual_catalog: &Catalog,
expected_catalog: &Catalog,
) -> Vec<ColumnOrderMismatch> {
let mut mismatches = Vec::new();
let actual_tables: std::collections::HashMap<(&str, &str), &crate::catalog::table::Table> =
actual_catalog
.tables
.iter()
.map(|t| ((t.schema.as_str(), t.name.as_str()), t))
.collect();
for expected_table in &expected_catalog.tables {
let Some(actual_table) =
actual_tables.get(&(expected_table.schema.as_str(), expected_table.name.as_str()))
else {
continue;
};
let expected_columns: Vec<&str> = expected_table
.columns
.iter()
.map(|c| c.name.as_str())
.collect();
let actual_columns: Vec<&str> = actual_table
.columns
.iter()
.map(|c| c.name.as_str())
.collect();
let expected_set: HashSet<&str> = expected_columns.iter().copied().collect();
let actual_set: HashSet<&str> = actual_columns.iter().copied().collect();
if expected_set != actual_set {
continue;
}
if expected_columns != actual_columns {
mismatches.push(ColumnOrderMismatch {
schema: expected_table.schema.clone(),
table: expected_table.name.clone(),
expected_order: expected_columns.iter().map(|s| s.to_string()).collect(),
actual_order: actual_columns.iter().map(|s| s.to_string()).collect(),
});
}
}
mismatches
}
pub fn apply_column_order_validation(
old_catalog: &Catalog,
new_catalog: &Catalog,
mode: ColumnOrderMode,
) -> Result<()> {
if mode == ColumnOrderMode::Relaxed {
return Ok(());
}
let violations = validate_column_order(old_catalog, new_catalog);
if violations.is_empty() {
return Ok(());
}
match mode {
ColumnOrderMode::Strict => {
let mut message = String::from("Column order validation failed.\n\n");
for violation in &violations {
message.push_str(&format!("{}\n", violation));
}
message.push_str("\nTo fix: Move new columns to the end of your table definition.\n");
message.push_str(
"To disable this check: Set `migration.column_order: relaxed` in pgmt.yaml",
);
Err(anyhow::anyhow!("{}", message))
}
ColumnOrderMode::Warn => {
eprintln!("Warning: Column order validation issues detected:\n");
for violation in &violations {
eprintln!(" {}", violation);
}
eprintln!(
"\nNew columns should be placed at the end of table definitions to match physical column order."
);
Ok(())
}
ColumnOrderMode::Relaxed => Ok(()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::catalog::id::DbObjectId;
use crate::catalog::table::{Column, Table};
use crate::config::Config;
fn make_test_table(schema: &str, name: &str, columns: Vec<&str>) -> Table {
let columns = columns
.into_iter()
.map(|col_name| Column {
name: col_name.to_string(),
data_type: "text".to_string(),
default: None,
generated: None,
identity: None,
comment: None,
depends_on: vec![],
not_null: false,
})
.collect();
Table::new(
schema.to_string(),
name.to_string(),
columns,
None,
None,
vec![DbObjectId::Schema {
name: schema.to_string(),
}],
)
}
fn make_catalog_with_table(table: Table) -> Catalog {
let mut catalog = Catalog::empty();
catalog.tables.push(table);
catalog
}
#[test]
fn test_validation_config_default() {
let config = ValidationConfig::default();
assert!(config.show_differences);
assert!(config.verbose);
}
#[test]
fn test_validate_catalogs_same() {
let catalog = Catalog::empty();
let config = Config::default();
let validation_config = ValidationConfig::default();
let result = validate_catalogs(&catalog, &catalog, &config, &validation_config).unwrap();
assert!(result.passed);
assert!(result.differences.is_empty());
}
#[test]
fn test_format_validation_failure() {
let differences = vec![]; let column_order_mismatches = vec![];
let message = format_validation_failure(&differences, &column_order_mismatches, false);
assert!(message.contains("Schema validation failed"));
assert!(message.contains("To fix these issues"));
}
#[test]
fn test_new_column_at_end_passes() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "name", "email"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert!(violations.is_empty());
}
#[test]
fn test_new_column_in_middle_fails() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "email", "name"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].schema, "public");
assert_eq!(violations[0].table, "users");
assert_eq!(violations[0].new_column, "email");
assert_eq!(violations[0].old_column_after, "name");
}
#[test]
fn test_new_column_at_start_fails() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["email", "id", "name"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert_eq!(violations.len(), 1);
assert_eq!(violations[0].new_column, "email");
assert_eq!(violations[0].old_column_after, "id");
}
#[test]
fn test_multiple_new_columns_at_end_passes() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "name", "email", "phone"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert!(violations.is_empty());
}
#[test]
fn test_new_table_no_validation() {
let old_catalog = Catalog::empty();
let new_table = make_test_table("public", "users", vec!["id", "name", "email"]);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert!(violations.is_empty());
}
#[test]
fn test_dropped_columns_ignored() {
let old_table = make_test_table("public", "users", vec!["id", "legacy", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "name", "email"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let violations = validate_column_order(&old_catalog, &new_catalog);
assert!(violations.is_empty());
}
#[test]
fn test_relaxed_mode_allows_violations() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "email", "name"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let result =
apply_column_order_validation(&old_catalog, &new_catalog, ColumnOrderMode::Relaxed);
assert!(result.is_ok());
}
#[test]
fn test_strict_mode_rejects_violations() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "email", "name"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let result =
apply_column_order_validation(&old_catalog, &new_catalog, ColumnOrderMode::Strict);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Column order validation failed"));
assert!(err.contains("email"));
}
#[test]
fn test_warn_mode_allows_violations() {
let old_table = make_test_table("public", "users", vec!["id", "name"]);
let new_table = make_test_table("public", "users", vec!["id", "email", "name"]);
let old_catalog = make_catalog_with_table(old_table);
let new_catalog = make_catalog_with_table(new_table);
let result =
apply_column_order_validation(&old_catalog, &new_catalog, ColumnOrderMode::Warn);
assert!(result.is_ok());
}
#[test]
fn test_column_order_violation_display() {
let violation = ColumnOrderViolation {
schema: "public".to_string(),
table: "users".to_string(),
new_column: "email".to_string(),
old_column_after: "name".to_string(),
};
let display = format!("{}", violation);
assert!(display.contains("public.users"));
assert!(display.contains("email"));
assert!(display.contains("name"));
}
#[test]
fn test_find_column_order_mismatches_same_order() {
let table1 = make_test_table("public", "users", vec!["id", "name", "email"]);
let table2 = make_test_table("public", "users", vec!["id", "name", "email"]);
let catalog1 = make_catalog_with_table(table1);
let catalog2 = make_catalog_with_table(table2);
let mismatches = find_column_order_mismatches(&catalog1, &catalog2);
assert!(mismatches.is_empty());
}
#[test]
fn test_find_column_order_mismatches_different_order() {
let actual_table = make_test_table("public", "users", vec!["id", "name", "email"]);
let expected_table = make_test_table("public", "users", vec!["id", "email", "name"]);
let actual_catalog = make_catalog_with_table(actual_table);
let expected_catalog = make_catalog_with_table(expected_table);
let mismatches = find_column_order_mismatches(&actual_catalog, &expected_catalog);
assert_eq!(mismatches.len(), 1);
assert_eq!(mismatches[0].schema, "public");
assert_eq!(mismatches[0].table, "users");
assert_eq!(mismatches[0].actual_order, vec!["id", "name", "email"]);
assert_eq!(mismatches[0].expected_order, vec!["id", "email", "name"]);
}
#[test]
fn test_find_column_order_mismatches_different_columns_ignored() {
let actual_table = make_test_table("public", "users", vec!["id", "name"]);
let expected_table = make_test_table("public", "users", vec!["id", "name", "email"]);
let actual_catalog = make_catalog_with_table(actual_table);
let expected_catalog = make_catalog_with_table(expected_table);
let mismatches = find_column_order_mismatches(&actual_catalog, &expected_catalog);
assert!(mismatches.is_empty()); }
#[test]
fn test_find_column_order_mismatches_new_table_ignored() {
let expected_table = make_test_table("public", "users", vec!["id", "name"]);
let actual_catalog = Catalog::empty();
let expected_catalog = make_catalog_with_table(expected_table);
let mismatches = find_column_order_mismatches(&actual_catalog, &expected_catalog);
assert!(mismatches.is_empty());
}
#[test]
fn test_column_order_mismatch_display() {
let mismatch = ColumnOrderMismatch {
schema: "public".to_string(),
table: "users".to_string(),
expected_order: vec!["id".to_string(), "email".to_string(), "name".to_string()],
actual_order: vec!["id".to_string(), "name".to_string(), "email".to_string()],
};
let display = format!("{}", mismatch);
assert!(display.contains("public.users"));
assert!(display.contains("column order mismatch"));
assert!(display.contains("id, email, name"));
assert!(display.contains("id, name, email"));
}
}