use sea_query::{Alias, Query, Table};
use vespertide_core::{ColumnDef, TableDef};
use super::helpers::{
build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
normalize_fill_with, recreate_indexes_after_rebuild,
};
use super::rename_table::build_rename_table;
use super::types::{BuiltQuery, DatabaseBackend, RawSql};
use crate::error::QueryError;
#[allow(clippy::too_many_arguments)]
pub fn build_modify_column_nullable(
backend: &DatabaseBackend,
table: &str,
column: &str,
nullable: bool,
fill_with: Option<&str>,
delete_null_rows: bool,
current_schema: &[TableDef],
pending_constraints: &[vespertide_core::TableConstraint],
) -> Result<Vec<BuiltQuery>, QueryError> {
let mut queries = Vec::new();
if !nullable && delete_null_rows {
let delete_sql = match backend {
DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
format!("DELETE FROM \"{}\" WHERE \"{}\" IS NULL", table, column)
}
DatabaseBackend::MySql => {
format!("DELETE FROM `{}` WHERE `{}` IS NULL", table, column)
}
};
queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
}
else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
let fill_value = convert_default_for_backend(&fill_value, backend);
let update_sql = match backend {
DatabaseBackend::Postgres | DatabaseBackend::Sqlite => format!(
"UPDATE \"{}\" SET \"{}\" = {} WHERE \"{}\" IS NULL",
table, column, fill_value, column
),
DatabaseBackend::MySql => format!(
"UPDATE `{}` SET `{}` = {} WHERE `{}` IS NULL",
table, column, fill_value, column
),
};
queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
}
match backend {
DatabaseBackend::Postgres => {
let alter_sql = if nullable {
format!(
"ALTER TABLE \"{}\" ALTER COLUMN \"{}\" DROP NOT NULL",
table, column
)
} else {
format!(
"ALTER TABLE \"{}\" ALTER COLUMN \"{}\" SET NOT NULL",
table, column
)
};
queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
}
DatabaseBackend::MySql => {
let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. MySQL requires current schema information to modify column nullability.", table)))?;
let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::Other(format!("Column '{}' not found in table '{}'. MySQL requires column information to modify nullability.", column, table)))?;
let modified_col_def = ColumnDef {
nullable,
..column_def.clone()
};
let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
let stmt = Table::alter()
.table(Alias::new(table))
.modify_column(sea_col)
.to_owned();
queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
}
DatabaseBackend::Sqlite => {
let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::Other(format!("Table '{}' not found in current schema. SQLite requires current schema information to modify column nullability.", table)))?;
let mut new_columns = table_def.columns.clone();
if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
col.nullable = nullable;
}
let temp_table = format!("{}_temp", table);
let create_query = build_sqlite_temp_table_create(
backend,
&temp_table,
table,
&new_columns,
&table_def.constraints,
);
queries.push(create_query);
let column_aliases: Vec<Alias> = table_def
.columns
.iter()
.map(|c| Alias::new(&c.name))
.collect();
let mut select_query = Query::select();
for col_alias in &column_aliases {
select_query = select_query.column(col_alias.clone()).to_owned();
}
select_query = select_query.from(Alias::new(table)).to_owned();
let insert_stmt = Query::insert()
.into_table(Alias::new(&temp_table))
.columns(column_aliases.clone())
.select_from(select_query)
.unwrap()
.to_owned();
queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
let drop_table = Table::drop().table(Alias::new(table)).to_owned();
queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
queries.push(build_rename_table(&temp_table, table));
queries.extend(recreate_indexes_after_rebuild(
table,
&table_def.constraints,
pending_constraints,
));
}
}
Ok(queries)
}
#[cfg(test)]
mod tests {
use super::*;
use insta::{assert_snapshot, with_settings};
use rstest::rstest;
use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
fn col(name: &str, ty: ColumnType, nullable: bool) -> ColumnDef {
ColumnDef {
name: name.to_string(),
r#type: ty,
nullable,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
}
}
fn table_def(
name: &str,
columns: Vec<ColumnDef>,
constraints: Vec<TableConstraint>,
) -> TableDef {
TableDef {
name: name.to_string(),
description: None,
columns,
constraints,
}
}
#[rstest]
#[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
#[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
#[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
#[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
#[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
#[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
#[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
#[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
#[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
fn test_build_modify_column_nullable(
#[case] backend: DatabaseBackend,
#[case] nullable: bool,
#[case] fill_with: Option<&str>,
) {
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"email",
ColumnType::Simple(SimpleColumnType::Text),
!nullable,
),
],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"users",
"email",
nullable,
fill_with,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
let suffix = format!(
"{}_{}_users{}",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
},
if nullable { "nullable" } else { "not_null" },
if fill_with.is_some() {
"_with_fill"
} else {
""
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_table_not_found(DatabaseBackend::Postgres)]
#[case::mysql_table_not_found(DatabaseBackend::MySql)]
#[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
fn test_table_not_found(#[case] backend: DatabaseBackend) {
if backend == DatabaseBackend::Postgres {
return;
}
let result =
build_modify_column_nullable(&backend, "users", "email", false, None, false, &[], &[]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Table 'users' not found"));
}
#[rstest]
#[case::postgres_column_not_found(DatabaseBackend::Postgres)]
#[case::mysql_column_not_found(DatabaseBackend::MySql)]
#[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
fn test_column_not_found(#[case] backend: DatabaseBackend) {
if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
return;
}
let schema = vec![table_def(
"users",
vec![col(
"id",
ColumnType::Simple(SimpleColumnType::Integer),
false,
)],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Column 'email' not found"));
}
#[rstest]
#[case::postgres_with_index(DatabaseBackend::Postgres)]
#[case::mysql_with_index(DatabaseBackend::MySql)]
#[case::sqlite_with_index(DatabaseBackend::Sqlite)]
fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col("email", ColumnType::Simple(SimpleColumnType::Text), true),
],
vec![TableConstraint::Index {
name: Some("idx_email".into()),
columns: vec!["email".into()],
}],
)];
let result = build_modify_column_nullable(
&backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
if backend == DatabaseBackend::Sqlite {
assert!(sql.contains("CREATE INDEX"));
assert!(sql.contains("idx_email"));
}
let suffix = format!(
"{}_with_index",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_fill_now(DatabaseBackend::Postgres)]
#[case::mysql_fill_now(DatabaseBackend::MySql)]
#[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"paid_at",
ColumnType::Simple(SimpleColumnType::Timestamptz),
true,
),
],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"orders",
"paid_at",
false,
Some("NOW()"),
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
!sql.contains("NOW()"),
"SQL should not contain NOW(), got: {}",
sql
);
assert!(
sql.contains("CURRENT_TIMESTAMP"),
"SQL should contain CURRENT_TIMESTAMP, got: {}",
sql
);
let suffix = format!(
"{}_fill_now",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_with_default(DatabaseBackend::Postgres)]
#[case::mysql_with_default(DatabaseBackend::MySql)]
#[case::sqlite_with_default(DatabaseBackend::Sqlite)]
fn test_with_default_value(#[case] backend: DatabaseBackend) {
let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
email_col.default = Some("'default@example.com'".into());
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
email_col,
],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
assert!(sql.contains("DEFAULT"));
}
let suffix = format!(
"{}_with_default",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
#[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
#[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"user_id",
ColumnType::Simple(SimpleColumnType::Integer),
true,
),
],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"orders",
"user_id",
false,
None,
true,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
sql.contains("DELETE FROM"),
"Expected DELETE FROM in SQL, got: {}",
sql
);
assert!(
sql.contains("IS NULL"),
"Expected IS NULL in SQL, got: {}",
sql
);
assert!(
!sql.contains("UPDATE"),
"Should NOT contain UPDATE, got: {}",
sql
);
let suffix = format!(
"{}_delete_null_rows",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"user_id",
ColumnType::Simple(SimpleColumnType::Integer),
false,
),
],
vec![],
)];
let result = build_modify_column_nullable(
&backend,
"orders",
"user_id",
true,
None,
true,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
!sql.contains("DELETE FROM"),
"Should NOT contain DELETE when nullable=true, got: {}",
sql
);
}
}