use sea_query::{Alias, Query, Table};
use vespertide_core::{ColumnDef, TableDef};
use super::helpers::{
build_sea_column_def_with_table, build_sqlite_temp_table_create, convert_default_for_backend,
normalize_fill_with, quote_ident, recreate_indexes_after_rebuild,
};
use super::rename_table::build_rename_table;
use super::types::{BuiltQuery, DatabaseBackend, RawSql};
use crate::error::QueryError;
#[expect(
clippy::too_many_arguments,
reason = "nullability builder needs action fields, fill strategy, backend, and SQLite rebuild context; NullabilityContext is deferred"
)]
pub fn build_modify_column_nullable(
backend: DatabaseBackend,
table: &str,
column: &str,
nullable: bool,
fill_with: Option<&str>,
delete_null_rows: bool,
current_schema: &[TableDef],
pending_constraints: &[vespertide_core::TableConstraint],
) -> Result<Vec<BuiltQuery>, QueryError> {
let mut queries = Vec::new();
if !nullable && delete_null_rows {
let quoted_table = quote_ident(table, backend);
let quoted_column = quote_ident(column, backend);
let delete_sql = format!("DELETE FROM {quoted_table} WHERE {quoted_column} IS NULL");
queries.push(BuiltQuery::Raw(RawSql::uniform(delete_sql)));
}
else if !nullable && let Some(fill_value) = normalize_fill_with(fill_with) {
let fill_value = convert_default_for_backend(&fill_value, backend);
let quoted_table = quote_ident(table, backend);
let quoted_column = quote_ident(column, backend);
let update_sql = format!(
"UPDATE {quoted_table} SET {quoted_column} = {fill_value} WHERE {quoted_column} IS NULL"
);
queries.push(BuiltQuery::Raw(RawSql::uniform(update_sql)));
}
match backend {
DatabaseBackend::Postgres => {
let quoted_table = quote_ident(table, backend);
let quoted_column = quote_ident(column, backend);
let alter_sql = if nullable {
format!("ALTER TABLE {quoted_table} ALTER COLUMN {quoted_column} DROP NOT NULL")
} else {
format!("ALTER TABLE {quoted_table} ALTER COLUMN {quoted_column} SET NOT NULL")
};
queries.push(BuiltQuery::Raw(RawSql::uniform(alter_sql)));
}
DatabaseBackend::MySql => {
let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::SchemaError(format!("Table '{table}' not found in current schema. MySQL requires current schema information to modify column nullability.")))?;
let column_def = table_def.columns.iter().find(|c| c.name == column).ok_or_else(|| QueryError::SchemaError(format!("Column '{column}' not found in table '{table}'. MySQL requires column information to modify nullability.")))?;
let modified_col_def = ColumnDef {
nullable,
..column_def.clone()
};
let sea_col = build_sea_column_def_with_table(backend, table, &modified_col_def);
let stmt = Table::alter()
.table(Alias::new(table))
.modify_column(sea_col)
.to_owned();
queries.push(BuiltQuery::AlterTable(Box::new(stmt)));
}
DatabaseBackend::Sqlite => {
let table_def = current_schema.iter().find(|t| t.name == table).ok_or_else(|| QueryError::SchemaError(format!("Table '{table}' not found in current schema. SQLite requires current schema information to modify column nullability.")))?;
let mut new_columns = table_def.columns.clone();
if let Some(col) = new_columns.iter_mut().find(|c| c.name == column) {
col.nullable = nullable;
}
let temp_table = format!("{table}_temp");
let create_query = build_sqlite_temp_table_create(
backend,
&temp_table,
table,
&new_columns,
&table_def.constraints,
);
queries.push(create_query);
let column_aliases: Vec<Alias> = table_def
.columns
.iter()
.map(|c| Alias::new(&c.name))
.collect();
let mut select_query = Query::select();
for col_alias in &column_aliases {
select_query.column(col_alias.clone());
}
select_query.from(Alias::new(table));
let insert_stmt = Query::insert()
.into_table(Alias::new(&temp_table))
.columns(column_aliases.clone())
.select_from(select_query)
.unwrap()
.to_owned();
queries.push(BuiltQuery::Insert(Box::new(insert_stmt)));
let drop_table = Table::drop().table(Alias::new(table)).to_owned();
queries.push(BuiltQuery::DropTable(Box::new(drop_table)));
queries.push(build_rename_table(&temp_table, table));
queries.extend(recreate_indexes_after_rebuild(
table,
&table_def.constraints,
pending_constraints,
));
}
}
Ok(queries)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::col_n as col;
use insta::{assert_snapshot, with_settings};
use rstest::rstest;
use vespertide_core::{ColumnDef, ColumnType, SimpleColumnType, TableConstraint};
fn table_def(
name: &str,
columns: Vec<ColumnDef>,
constraints: Vec<TableConstraint>,
) -> TableDef {
TableDef {
name: name.into(),
description: None,
columns,
constraints,
}
}
#[rstest]
#[case::postgres_set_not_null(DatabaseBackend::Postgres, false, None)]
#[case::postgres_drop_not_null(DatabaseBackend::Postgres, true, None)]
#[case::postgres_set_not_null_with_fill(DatabaseBackend::Postgres, false, Some("'unknown'"))]
#[case::mysql_set_not_null(DatabaseBackend::MySql, false, None)]
#[case::mysql_drop_not_null(DatabaseBackend::MySql, true, None)]
#[case::mysql_set_not_null_with_fill(DatabaseBackend::MySql, false, Some("'unknown'"))]
#[case::sqlite_set_not_null(DatabaseBackend::Sqlite, false, None)]
#[case::sqlite_drop_not_null(DatabaseBackend::Sqlite, true, None)]
#[case::sqlite_set_not_null_with_fill(DatabaseBackend::Sqlite, false, Some("'unknown'"))]
fn test_build_modify_column_nullable(
#[case] backend: DatabaseBackend,
#[case] nullable: bool,
#[case] fill_with: Option<&str>,
) {
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"email",
ColumnType::Simple(SimpleColumnType::Text),
!nullable,
),
],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"users",
"email",
nullable,
fill_with,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
let suffix = format!(
"{}_{}_users{}",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
},
if nullable { "nullable" } else { "not_null" },
if fill_with.is_some() {
"_with_fill"
} else {
""
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_table_not_found(DatabaseBackend::Postgres)]
#[case::mysql_table_not_found(DatabaseBackend::MySql)]
#[case::sqlite_table_not_found(DatabaseBackend::Sqlite)]
fn test_table_not_found(#[case] backend: DatabaseBackend) {
if backend == DatabaseBackend::Postgres {
return;
}
let result =
build_modify_column_nullable(backend, "users", "email", false, None, false, &[], &[]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Table 'users' not found"));
}
#[rstest]
#[case::postgres_column_not_found(DatabaseBackend::Postgres)]
#[case::mysql_column_not_found(DatabaseBackend::MySql)]
#[case::sqlite_column_not_found(DatabaseBackend::Sqlite)]
fn test_column_not_found(#[case] backend: DatabaseBackend) {
if backend == DatabaseBackend::Postgres || backend == DatabaseBackend::Sqlite {
return;
}
let schema = vec![table_def(
"users",
vec![col(
"id",
ColumnType::Simple(SimpleColumnType::Integer),
false,
)],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Column 'email' not found"));
}
#[rstest]
#[case::postgres_with_index(DatabaseBackend::Postgres)]
#[case::mysql_with_index(DatabaseBackend::MySql)]
#[case::sqlite_with_index(DatabaseBackend::Sqlite)]
fn test_modify_nullable_with_index(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col("email", ColumnType::Simple(SimpleColumnType::Text), true),
],
vec![TableConstraint::Index {
name: Some("idx_email".into()),
columns: vec!["email".into()],
}],
)];
let result = build_modify_column_nullable(
backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
if backend == DatabaseBackend::Sqlite {
assert!(sql.contains("CREATE INDEX"));
assert!(sql.contains("idx_email"));
}
let suffix = format!(
"{}_with_index",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_fill_now(DatabaseBackend::Postgres)]
#[case::mysql_fill_now(DatabaseBackend::MySql)]
#[case::sqlite_fill_now(DatabaseBackend::Sqlite)]
fn test_fill_with_now_converted_to_current_timestamp(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"paid_at",
ColumnType::Simple(SimpleColumnType::Timestamptz),
true,
),
],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"orders",
"paid_at",
false,
Some("NOW()"),
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
!sql.contains("NOW()"),
"SQL should not contain NOW(), got: {sql}"
);
assert!(
sql.contains("CURRENT_TIMESTAMP"),
"SQL should contain CURRENT_TIMESTAMP, got: {sql}"
);
let suffix = format!(
"{}_fill_now",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_with_default(DatabaseBackend::Postgres)]
#[case::mysql_with_default(DatabaseBackend::MySql)]
#[case::sqlite_with_default(DatabaseBackend::Sqlite)]
fn test_with_default_value(#[case] backend: DatabaseBackend) {
let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text), true);
email_col.default = Some("'default@example.com'".into());
let schema = vec![table_def(
"users",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
email_col,
],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"users",
"email",
false,
None,
false,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
if backend == DatabaseBackend::MySql || backend == DatabaseBackend::Sqlite {
assert!(sql.contains("DEFAULT"));
}
let suffix = format!(
"{}_with_default",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_delete_null_rows(DatabaseBackend::Postgres)]
#[case::mysql_delete_null_rows(DatabaseBackend::MySql)]
#[case::sqlite_delete_null_rows(DatabaseBackend::Sqlite)]
fn test_delete_null_rows(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"user_id",
ColumnType::Simple(SimpleColumnType::Integer),
true,
),
],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"orders",
"user_id",
false,
None,
true,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
sql.contains("DELETE FROM"),
"Expected DELETE FROM in SQL, got: {sql}"
);
assert!(
sql.contains("IS NULL"),
"Expected IS NULL in SQL, got: {sql}"
);
assert!(
!sql.contains("UPDATE"),
"Should NOT contain UPDATE, got: {sql}"
);
let suffix = format!(
"{}_delete_null_rows",
match backend {
DatabaseBackend::Postgres => "postgres",
DatabaseBackend::MySql => "mysql",
DatabaseBackend::Sqlite => "sqlite",
}
);
with_settings!({ snapshot_suffix => suffix }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres_delete_null_rows_nullable(DatabaseBackend::Postgres)]
fn test_delete_null_rows_with_nullable_true(#[case] backend: DatabaseBackend) {
let schema = vec![table_def(
"orders",
vec![
col("id", ColumnType::Simple(SimpleColumnType::Integer), false),
col(
"user_id",
ColumnType::Simple(SimpleColumnType::Integer),
false,
),
],
vec![],
)];
let result = build_modify_column_nullable(
backend,
"orders",
"user_id",
true,
None,
true,
&schema,
&[],
);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<String>>()
.join("\n");
assert!(
!sql.contains("DELETE FROM"),
"Should NOT contain DELETE when nullable=true, got: {sql}"
);
}
}