use sea_query::{Alias, ForeignKey, Index, Table, TableCreateStatement};
use vespertide_core::{ColumnDef, ColumnType, ComplexColumnType, TableConstraint};
use super::helpers::{
build_create_enum_type_sql, build_schema_statement, build_sea_column_def_with_table,
collect_sqlite_enum_check_clauses, to_sea_fk_action,
};
use super::types::{BuiltQuery, DatabaseBackend, RawSql};
use crate::error::QueryError;
pub(crate) fn build_create_table_for_backend(
backend: DatabaseBackend,
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) -> TableCreateStatement {
let mut stmt = Table::create().table(Alias::new(table)).to_owned();
add_create_table_columns(&mut stmt, backend, table, columns, constraints);
add_create_table_constraints(&mut stmt, backend, table, columns, constraints);
stmt
}
fn add_create_table_columns(
stmt: &mut TableCreateStatement,
backend: DatabaseBackend,
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) {
let has_table_primary_key = constraints
.iter()
.any(|c| matches!(c, TableConstraint::PrimaryKey { .. }));
let auto_increment_columns = collect_auto_increment_columns(constraints);
for column in columns {
let mut col = build_sea_column_def_with_table(backend, table, column);
if column.primary_key.is_some() && !has_table_primary_key {
col.primary_key();
}
if auto_increment_columns.contains(column.name.as_str())
&& column.r#type.supports_auto_increment()
{
if matches!(backend, DatabaseBackend::Sqlite) {
col.primary_key();
}
col.auto_increment();
}
*stmt = stmt.col(col).to_owned();
}
}
fn collect_auto_increment_columns(
constraints: &[TableConstraint],
) -> std::collections::HashSet<&str> {
constraints
.iter()
.filter_map(|c| {
if let TableConstraint::PrimaryKey {
columns: pk_cols,
auto_increment: true,
..
} = c
{
Some(pk_cols.iter().map(AsRef::as_ref).collect::<Vec<_>>())
} else {
None
}
})
.flatten()
.collect()
}
fn add_create_table_constraints(
stmt: &mut TableCreateStatement,
backend: DatabaseBackend,
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) {
for constraint in constraints {
match constraint {
TableConstraint::PrimaryKey {
columns: pk_cols,
auto_increment,
..
} => add_primary_key_constraint(stmt, backend, columns, pk_cols, *auto_increment),
TableConstraint::Unique {
name,
columns: unique_cols,
..
} => add_mysql_unique_constraint(stmt, backend, table, unique_cols, name.as_deref()),
TableConstraint::ForeignKey { .. } => {
add_foreign_key_constraint(stmt, table, constraint);
}
TableConstraint::Check { name, expr, .. } => {
let _ = (name, expr);
}
TableConstraint::Index { .. } => {}
_ => {
unreachable!("TableConstraint is #[non_exhaustive]; all variants are matched above")
}
}
}
}
fn add_primary_key_constraint(
stmt: &mut TableCreateStatement,
backend: DatabaseBackend,
columns: &[ColumnDef],
pk_cols: &[impl AsRef<str>],
auto_increment: bool,
) {
if should_skip_sqlite_auto_increment_pk(backend, columns, pk_cols, auto_increment) {
return;
}
let mut pk_idx = Index::create();
for c in pk_cols {
pk_idx.col(Alias::new(c.as_ref()));
}
*stmt = stmt.primary_key(&mut pk_idx).to_owned();
}
fn should_skip_sqlite_auto_increment_pk(
backend: DatabaseBackend,
columns: &[ColumnDef],
pk_cols: &[impl AsRef<str>],
auto_increment: bool,
) -> bool {
matches!(backend, DatabaseBackend::Sqlite)
&& auto_increment
&& pk_cols.iter().all(|col_name| {
columns
.iter()
.find(|c| c.name == col_name.as_ref())
.is_some_and(|c| c.r#type.supports_auto_increment())
})
}
fn add_mysql_unique_constraint(
stmt: &mut TableCreateStatement,
backend: DatabaseBackend,
table: &str,
unique_cols: &[impl AsRef<str>],
name: Option<&str>,
) {
if !matches!(backend, DatabaseBackend::MySql) {
return;
}
let index_name = super::helpers::build_unique_constraint_name(table, unique_cols, name);
let mut idx = Index::create().name(&index_name).unique().to_owned();
for col in unique_cols {
idx.col(Alias::new(col.as_ref()));
}
*stmt = stmt.index(&mut idx).to_owned();
}
fn add_foreign_key_constraint(
stmt: &mut TableCreateStatement,
table: &str,
constraint: &TableConstraint,
) {
if let TableConstraint::ForeignKey {
name,
columns: fk_cols,
ref_table,
ref_columns,
on_delete,
on_update,
..
} = constraint
{
let fk_name = super::helpers::build_foreign_key_name(table, fk_cols, name.as_deref());
let mut fk = ForeignKey::create().name(&fk_name).to_owned();
fk.from_tbl(Alias::new(table));
for col in fk_cols {
fk.from_col(Alias::new(col));
}
fk.to_tbl(Alias::new(ref_table));
for col in ref_columns {
fk.to_col(Alias::new(col));
}
if let Some(action) = on_delete {
fk.on_delete(to_sea_fk_action(action));
}
if let Some(action) = on_update {
fk.on_update(to_sea_fk_action(action));
}
*stmt = stmt.foreign_key(&mut fk).to_owned();
}
}
pub fn build_create_table(
backend: DatabaseBackend,
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) -> Result<Vec<BuiltQuery>, QueryError> {
let table_def = vespertide_core::TableDef {
description: None,
name: table.into(),
columns: columns.to_vec(),
constraints: constraints.to_vec(),
};
let normalized = table_def.normalize().map_err(|e| {
QueryError::SchemaError(format!("Failed to normalize table '{table}': {e}"))
})?;
let columns = &normalized.columns;
let constraints = &normalized.constraints;
let mut queries = Vec::new();
let mut created_enums = std::collections::HashSet::new();
for column in columns {
if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = &column.r#type
&& created_enums.insert(name.clone())
&& let Some(create_type_sql) = build_create_enum_type_sql(table, &column.r#type)
{
queries.push(BuiltQuery::Raw(create_type_sql));
}
}
let (table_constraints, unique_constraints): (Vec<&TableConstraint>, Vec<&TableConstraint>) =
constraints
.iter()
.partition(|c| !matches!(c, TableConstraint::Unique { .. }));
let create_table_stmt = if matches!(backend, DatabaseBackend::MySql) {
build_create_table_for_backend(backend, table, columns, constraints)
} else {
let table_constraints_owned: Vec<TableConstraint> =
table_constraints.iter().copied().cloned().collect();
build_create_table_for_backend(backend, table, columns, &table_constraints_owned)
};
if matches!(backend, DatabaseBackend::Sqlite) {
let enum_check_clauses = collect_sqlite_enum_check_clauses(table, columns);
if enum_check_clauses.is_empty() {
queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
} else {
let base_sql = build_schema_statement(&create_table_stmt, backend);
let mut modified_sql = base_sql;
if let Some(pos) = modified_sql.rfind(')') {
let check_sql = enum_check_clauses.join(", ");
modified_sql.insert_str(pos, &format!(", {check_sql}"));
}
queries.push(BuiltQuery::Raw(RawSql::per_backend(
modified_sql.clone(),
modified_sql.clone(),
modified_sql,
)));
}
} else {
queries.push(BuiltQuery::CreateTable(Box::new(create_table_stmt)));
}
if matches!(backend, DatabaseBackend::Postgres | DatabaseBackend::Sqlite) {
for constraint in unique_constraints {
if let TableConstraint::Unique {
name,
columns: unique_cols,
..
} = constraint
{
let index_name = super::helpers::build_unique_constraint_name(
table,
unique_cols,
name.as_deref(),
);
let mut idx = Index::create()
.table(Alias::new(table))
.name(&index_name)
.unique()
.to_owned();
for col in unique_cols {
idx.col(Alias::new(col));
}
queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
}
}
}
for constraint in constraints {
if let TableConstraint::Index {
name,
columns: index_cols,
} = constraint
{
let index_name = super::helpers::build_index_name(table, index_cols, name.as_deref());
let mut idx = Index::create()
.table(Alias::new(table))
.name(&index_name)
.to_owned();
for col in index_cols {
idx.col(Alias::new(col));
}
queries.push(BuiltQuery::CreateIndex(Box::new(idx)));
}
}
Ok(queries)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_support::col;
use insta::{assert_snapshot, with_settings};
use rstest::rstest;
use vespertide_core::{ColumnType, EnumValues, SimpleColumnType};
fn join_queries(queries: &[BuiltQuery], backend: DatabaseBackend, separator: &str) -> String {
let mut sql = String::new();
for (index, query) in queries.iter().enumerate() {
if index > 0 {
sql.push_str(separator);
}
sql.push_str(&query.build(backend));
}
sql
}
#[rstest]
#[case::create_table_postgres(
"create_table_postgres",
DatabaseBackend::Postgres,
&["CREATE TABLE \"users\" ( \"id\" integer )"]
)]
#[case::create_table_mysql(
"create_table_mysql",
DatabaseBackend::MySql,
&["CREATE TABLE `users` ( `id` int )"]
)]
#[case::create_table_sqlite(
"create_table_sqlite",
DatabaseBackend::Sqlite,
&["CREATE TABLE \"users\" ( \"id\" integer )"]
)]
fn test_create_table(
#[case] title: &str,
#[case] backend: DatabaseBackend,
#[case] expected: &[&str],
) {
let result = build_create_table(
backend,
"users",
&[col("id", ColumnType::Simple(SimpleColumnType::Integer))],
&[],
)
.unwrap();
let sql = join_queries(&result, backend, "\n");
for exp in expected {
assert!(
sql.contains(exp),
"Expected SQL to contain '{exp}', got: {sql}"
);
}
with_settings!({ snapshot_suffix => format!("create_table_{}", title) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::inline_unique_postgres(DatabaseBackend::Postgres)]
#[case::inline_unique_mysql(DatabaseBackend::MySql)]
#[case::inline_unique_sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_inline_unique(#[case] backend: DatabaseBackend) {
use vespertide_core::schema::str_or_bool::StrOrBoolOrArray;
let mut email_col = col("email", ColumnType::Simple(SimpleColumnType::Text));
email_col.unique = Some(StrOrBoolOrArray::Bool(true));
let result = build_create_table(
backend,
"users",
&[
col("id", ColumnType::Simple(SimpleColumnType::Integer)),
email_col,
],
&[],
)
.unwrap();
let sql = join_queries(&result, backend, "\n");
assert!(
sql.contains("UNIQUE") || sql.to_uppercase().contains("UNIQUE"),
"Normalized unique constraint should be in SQL, but not found: {sql}"
);
with_settings!({ snapshot_suffix => format!("create_table_with_inline_unique_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::table_level_unique_postgres(DatabaseBackend::Postgres)]
#[case::table_level_unique_mysql(DatabaseBackend::MySql)]
#[case::table_level_unique_sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_table_level_unique(#[case] backend: DatabaseBackend) {
let result = build_create_table(
backend,
"users",
&[
col("id", ColumnType::Simple(SimpleColumnType::Integer)),
col("email", ColumnType::Simple(SimpleColumnType::Text)),
],
&[TableConstraint::Unique {
name: Some("uq_email".into()),
columns: vec!["email".into()],
strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
keep: vespertide_core::KeepPolicy::First,
},
}],
)
.unwrap();
let sql = join_queries(&result, backend, "\n");
assert!(sql.contains("CREATE TABLE"));
match backend {
DatabaseBackend::MySql => {
assert!(
sql.contains("UNIQUE"),
"MySQL should have UNIQUE in CREATE TABLE: {sql}"
);
}
_ => {
assert!(
sql.contains("CREATE UNIQUE INDEX"),
"Postgres/SQLite should have CREATE UNIQUE INDEX: {sql}"
);
}
}
with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::table_level_unique_no_name_postgres(DatabaseBackend::Postgres)]
#[case::table_level_unique_no_name_mysql(DatabaseBackend::MySql)]
#[case::table_level_unique_no_name_sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_table_level_unique_no_name(#[case] backend: DatabaseBackend) {
let result = build_create_table(
backend,
"users",
&[
col("id", ColumnType::Simple(SimpleColumnType::Integer)),
col("email", ColumnType::Simple(SimpleColumnType::Text)),
],
&[TableConstraint::Unique {
name: None,
columns: vec!["email".into()],
strategy: vespertide_core::UniqueConstraintStrategy::DeleteDuplicates {
keep: vespertide_core::KeepPolicy::First,
},
}],
)
.unwrap();
let sql = join_queries(&result, backend, "\n");
assert!(sql.contains("CREATE TABLE"));
match backend {
DatabaseBackend::MySql => {
assert!(
sql.contains("UNIQUE"),
"MySQL should have UNIQUE in CREATE TABLE: {sql}"
);
}
_ => {
assert!(
sql.contains("CREATE UNIQUE INDEX"),
"Postgres/SQLite should have CREATE UNIQUE INDEX: {sql}"
);
}
}
with_settings!({ snapshot_suffix => format!("create_table_with_table_level_unique_no_name_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::postgres(DatabaseBackend::Postgres)]
#[case::mysql(DatabaseBackend::MySql)]
#[case::sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_enum_column(#[case] backend: DatabaseBackend) {
let columns = vec![
ColumnDef {
name: "id".into(),
r#type: ColumnType::Simple(SimpleColumnType::Integer),
nullable: false,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
ColumnDef {
name: "status".into(),
r#type: ColumnType::Complex(ComplexColumnType::Enum {
name: "user_status".into(),
values: EnumValues::String(vec![
"active".into(),
"inactive".into(),
"pending".into(),
]),
}),
nullable: false,
default: Some("'active'".into()),
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
];
let constraints = vec![TableConstraint::PrimaryKey {
auto_increment: false,
columns: vec!["id".into()],
strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
}];
let result = build_create_table(backend, "users", &columns, &constraints);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = join_queries(&queries, backend, ";\n");
with_settings!({ snapshot_suffix => format!("create_table_with_enum_column_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::auto_increment_postgres(DatabaseBackend::Postgres)]
#[case::auto_increment_mysql(DatabaseBackend::MySql)]
#[case::auto_increment_sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
let columns = vec![ColumnDef {
name: "id".into(),
r#type: ColumnType::Simple(SimpleColumnType::Integer),
nullable: false,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
}];
let constraints = vec![TableConstraint::PrimaryKey {
auto_increment: true,
columns: vec!["id".into()],
strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
}];
let result = build_create_table(backend, "users", &columns, &constraints);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = join_queries(&queries, backend, ";\n");
match backend {
DatabaseBackend::Postgres => {
assert!(
sql.contains("SERIAL")
|| sql.contains("serial")
|| sql.contains("IDENTITY")
|| sql.contains("identity"),
"PostgreSQL should use SERIAL or IDENTITY for auto_increment, got: {sql}"
);
}
DatabaseBackend::MySql => {
assert!(
sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
"MySQL should use AUTO_INCREMENT for auto_increment, got: {sql}"
);
}
DatabaseBackend::Sqlite => {
assert!(
sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
"SQLite should use AUTOINCREMENT for auto_increment, got: {sql}"
);
}
}
with_settings!({ snapshot_suffix => format!("create_table_with_auto_increment_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::inline_auto_increment_postgres(DatabaseBackend::Postgres)]
#[case::inline_auto_increment_mysql(DatabaseBackend::MySql)]
#[case::inline_auto_increment_sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_inline_auto_increment_primary_key(#[case] backend: DatabaseBackend) {
use vespertide_core::schema::primary_key::{PrimaryKeyDef, PrimaryKeySyntax};
let columns = vec![ColumnDef {
name: "id".into(),
r#type: ColumnType::Simple(SimpleColumnType::Integer),
nullable: false,
default: None,
comment: None,
primary_key: Some(PrimaryKeySyntax::Object(PrimaryKeyDef {
auto_increment: true,
})),
unique: None,
index: None,
foreign_key: None,
}];
let result = build_create_table(backend, "users", &columns, &[]);
assert!(result.is_ok());
let queries = result.unwrap();
let sql = join_queries(&queries, backend, ";\n");
match backend {
DatabaseBackend::Postgres => {
assert!(
sql.contains("SERIAL")
|| sql.contains("serial")
|| sql.contains("IDENTITY")
|| sql.contains("identity"),
"PostgreSQL should use SERIAL or IDENTITY for auto_increment, got: {sql}"
);
}
DatabaseBackend::MySql => {
assert!(
sql.contains("AUTO_INCREMENT") || sql.contains("auto_increment"),
"MySQL should use AUTO_INCREMENT for auto_increment, got: {sql}"
);
}
DatabaseBackend::Sqlite => {
assert!(
sql.contains("AUTOINCREMENT") || sql.contains("autoincrement"),
"SQLite should use AUTOINCREMENT for auto_increment, got: {sql}"
);
}
}
with_settings!({ snapshot_suffix => format!("create_table_with_inline_auto_increment_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
#[rstest]
#[case::timestamp_now_default_postgres(DatabaseBackend::Postgres)]
#[case::timestamp_now_default_mysql(DatabaseBackend::MySql)]
#[case::timestamp_now_default_sqlite(DatabaseBackend::Sqlite)]
#[rstest]
#[case::postgres(DatabaseBackend::Postgres)]
#[case::mysql(DatabaseBackend::MySql)]
#[case::sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_check_and_index_constraints_are_skipped(#[case] backend: DatabaseBackend) {
let constraints = vec![
TableConstraint::PrimaryKey {
columns: vec!["id".into()],
auto_increment: false,
strategy: vespertide_core::PrimaryKeyAdditionStrategy::default(),
},
TableConstraint::Check {
name: "chk_age".into(),
expr: "age > 0".into(),
strategy: vespertide_core::CheckViolationStrategy::default(),
},
TableConstraint::Index {
name: Some("idx_age".into()),
columns: vec!["age".into()],
},
];
let columns = vec![
ColumnDef {
name: "id".into(),
r#type: ColumnType::Simple(SimpleColumnType::Integer),
nullable: false,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
ColumnDef {
name: "age".into(),
r#type: ColumnType::Simple(SimpleColumnType::Integer),
nullable: false,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
];
let queries = build_create_table(backend, "users", &columns, &constraints).unwrap();
let sql = queries
.iter()
.map(|q| q.build(backend))
.collect::<Vec<_>>()
.join("\n");
assert!(sql.contains("PRIMARY KEY"));
assert!(!sql.contains("CHECK (age > 0)"));
assert!(sql.contains("CREATE INDEX") || sql.contains("idx_age"));
}
#[rstest]
#[case::postgres(DatabaseBackend::Postgres)]
#[case::mysql(DatabaseBackend::MySql)]
#[case::sqlite(DatabaseBackend::Sqlite)]
fn test_create_table_with_timestamp_now_default(#[case] backend: DatabaseBackend) {
let columns = vec![
ColumnDef {
name: "id".into(),
r#type: ColumnType::Simple(SimpleColumnType::BigInt),
nullable: false,
default: None,
comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
ColumnDef {
name: "created_at".into(),
r#type: ColumnType::Simple(SimpleColumnType::Timestamptz),
nullable: false,
default: Some("NOW()".into()), comment: None,
primary_key: None,
unique: None,
index: None,
foreign_key: None,
},
];
let result = build_create_table(backend, "events", &columns, &[]);
assert!(result.is_ok(), "build_create_table failed: {result:?}");
let queries = result.unwrap();
let sql = join_queries(&queries, backend, "\n");
if matches!(backend, DatabaseBackend::Sqlite) {
assert!(
!sql.contains("NOW()"),
"SQLite should not contain NOW(), got: {sql}"
);
assert!(
sql.contains("CURRENT_TIMESTAMP"),
"SQLite should use CURRENT_TIMESTAMP, got: {sql}"
);
}
with_settings!({ snapshot_suffix => format!("create_table_with_timestamp_now_default_{:?}", backend) }, {
assert_snapshot!(sql);
});
}
}