use std::borrow::Cow;
use sea_query::{
Alias, ColumnDef as SeaColumnDef, ForeignKeyAction, MysqlQueryBuilder, PostgresQueryBuilder,
QueryStatementWriter, SchemaStatementBuilder, SimpleExpr, SqliteQueryBuilder,
};
use vespertide_core::{
ColumnDef, ColumnType, ComplexColumnType, ReferenceAction, SimpleColumnType, TableConstraint,
};
use super::create_table::build_create_table_for_backend;
use super::types::{BuiltQuery, DatabaseBackend, RawSql};
#[must_use]
pub fn normalize_fill_with(fill_with: Option<&str>) -> Option<Cow<'_, str>> {
fill_with.map(|s| {
if s.is_empty() {
Cow::Borrowed("''")
} else {
Cow::Borrowed(s)
}
})
}
pub fn build_schema_statement<T: SchemaStatementBuilder>(
stmt: &T,
backend: DatabaseBackend,
) -> String {
match backend {
DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
}
}
pub fn build_query_statement<T: QueryStatementWriter>(
stmt: &T,
backend: DatabaseBackend,
) -> String {
match backend {
DatabaseBackend::Postgres => stmt.to_string(PostgresQueryBuilder),
DatabaseBackend::MySql => stmt.to_string(MysqlQueryBuilder),
DatabaseBackend::Sqlite => stmt.to_string(SqliteQueryBuilder),
}
}
pub fn apply_column_type_with_table(
col: &mut SeaColumnDef,
ty: &ColumnType,
table: &str,
backend: DatabaseBackend,
) {
match ty {
ColumnType::Simple(simple) => apply_simple_column_type(col, *simple, backend),
ColumnType::Complex(complex) => apply_complex_column_type(col, complex, table, backend),
}
}
fn apply_simple_column_type(
col: &mut SeaColumnDef,
simple: SimpleColumnType,
backend: DatabaseBackend,
) {
match simple {
SimpleColumnType::SmallInt => {
col.small_integer();
}
SimpleColumnType::Integer => {
col.integer();
}
SimpleColumnType::BigInt => {
col.big_integer();
}
SimpleColumnType::Real => {
col.float();
}
SimpleColumnType::DoublePrecision => {
col.double();
}
SimpleColumnType::Text => {
col.text();
}
SimpleColumnType::Boolean => {
col.boolean();
}
SimpleColumnType::Date => {
col.date();
}
SimpleColumnType::Time => {
col.time();
}
SimpleColumnType::Timestamp => {
col.timestamp();
}
SimpleColumnType::Timestamptz => apply_timestamptz_type(col, backend),
SimpleColumnType::Interval => apply_interval_type(col, backend),
SimpleColumnType::Bytea => {
col.binary();
}
SimpleColumnType::Uuid => {
col.uuid();
}
SimpleColumnType::Json => {
col.json();
}
SimpleColumnType::Inet => apply_postgres_text_fallback_type(col, backend, "INET"),
SimpleColumnType::Cidr => apply_postgres_text_fallback_type(col, backend, "CIDR"),
SimpleColumnType::Macaddr => apply_postgres_text_fallback_type(col, backend, "MACADDR"),
SimpleColumnType::Xml => apply_postgres_text_fallback_type(col, backend, "XML"),
_ => unreachable!("SimpleColumnType is #[non_exhaustive]; all variants are matched above"),
}
}
fn apply_timestamptz_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
match backend {
DatabaseBackend::Postgres => {
col.timestamp_with_time_zone();
}
DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
col.timestamp();
}
}
}
fn apply_interval_type(col: &mut SeaColumnDef, backend: DatabaseBackend) {
match backend {
DatabaseBackend::Postgres => {
col.interval(None, None);
}
DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
col.text();
}
}
}
fn apply_postgres_text_fallback_type(
col: &mut SeaColumnDef,
backend: DatabaseBackend,
postgres_type: &str,
) {
match backend {
DatabaseBackend::Postgres => {
col.custom(Alias::new(postgres_type));
}
DatabaseBackend::MySql | DatabaseBackend::Sqlite => {
col.text();
}
}
}
fn apply_complex_column_type(
col: &mut SeaColumnDef,
complex: &ComplexColumnType,
table: &str,
backend: DatabaseBackend,
) {
match complex {
ComplexColumnType::Varchar { length } => {
col.string_len(*length);
}
ComplexColumnType::Numeric { precision, scale } => {
apply_numeric_type(col, *precision, *scale, backend);
}
ComplexColumnType::Char { length } => {
col.char_len(*length);
}
ComplexColumnType::Custom { custom_type } => {
col.custom(Alias::new(custom_type));
}
ComplexColumnType::Enum { name, values } => {
if values.is_integer() {
col.integer();
} else {
let type_name = build_enum_type_name(table, name);
let variants = values
.variant_names()
.into_iter()
.map(Alias::new)
.collect::<Vec<Alias>>();
col.enumeration(Alias::new(&type_name), variants);
}
}
_ => unreachable!("ComplexColumnType is #[non_exhaustive]; all variants are matched above"),
}
}
fn apply_numeric_type(
col: &mut SeaColumnDef,
precision: u32,
scale: u32,
backend: DatabaseBackend,
) {
debug_assert!(
scale <= precision,
"numeric scale ({scale}) must be <= precision ({precision}); schema validation should reject this before SQL generation"
);
let safe_precision = precision.min(28);
let safe_scale = scale.min(safe_precision);
match backend {
DatabaseBackend::Postgres | DatabaseBackend::MySql => {
col.decimal_len(safe_precision, safe_scale);
}
DatabaseBackend::Sqlite => {
col.double();
}
}
}
pub fn to_sea_fk_action(action: &ReferenceAction) -> ForeignKeyAction {
match action {
ReferenceAction::Cascade => ForeignKeyAction::Cascade,
ReferenceAction::Restrict => ForeignKeyAction::Restrict,
ReferenceAction::SetNull => ForeignKeyAction::SetNull,
ReferenceAction::SetDefault => ForeignKeyAction::SetDefault,
ReferenceAction::NoAction => ForeignKeyAction::NoAction,
_ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
}
}
pub fn reference_action_sql(action: &ReferenceAction) -> &'static str {
match action {
ReferenceAction::Cascade => "CASCADE",
ReferenceAction::Restrict => "RESTRICT",
ReferenceAction::SetNull => "SET NULL",
ReferenceAction::SetDefault => "SET DEFAULT",
ReferenceAction::NoAction => "NO ACTION",
_ => unreachable!("ReferenceAction is #[non_exhaustive]; all variants are matched above"),
}
}
pub fn convert_default_for_backend(default: &str, backend: DatabaseBackend) -> String {
let lower = default.to_lowercase();
if lower == "gen_random_uuid()" || lower == "uuid()" || lower == "lower(hex(randomblob(16)))" {
return match backend {
DatabaseBackend::Postgres => "gen_random_uuid()".to_string(),
DatabaseBackend::MySql => "(UUID())".to_string(),
DatabaseBackend::Sqlite => "lower(hex(randomblob(16)))".to_string(),
};
}
if lower == "current_timestamp()"
|| lower == "now()"
|| lower == "current_timestamp"
|| lower == "getdate()"
{
return "CURRENT_TIMESTAMP".to_string();
}
if let Some((value, cast_type)) = parse_pg_type_cast(default) {
return convert_type_cast(&value, &cast_type, backend);
}
default.to_string()
}
pub(super) fn parse_pg_type_cast(expr: &str) -> Option<(String, String)> {
let trimmed = expr.trim();
if let Some(after_open) = trimmed.strip_prefix('\'') {
let mut chars = after_open.char_indices().peekable();
while let Some((i, ch)) = chars.next() {
if ch == '\'' {
if chars.next_if(|(_, next)| *next == '\'').is_some() {
continue;
}
let value_end = i + ch.len_utf8(); let rest = after_open.get(value_end..)?;
if let Some(stripped) = rest.strip_prefix("::") {
let cast_type = stripped.trim().to_lowercase();
if !cast_type.is_empty() {
let value = format!("'{}'", after_open.get(..i)?);
return Some((value, cast_type));
}
}
return None;
}
}
return None;
}
if let Some((value, cast_type)) = trimmed.split_once("::") {
let value = value.trim().to_string();
let cast_type = cast_type.trim().to_lowercase();
if !value.is_empty() && !cast_type.is_empty() {
return Some((value, cast_type));
}
}
None
}
fn pg_type_to_mysql_cast(pg_type: &str) -> &'static str {
match pg_type {
"json" | "jsonb" => "JSON",
"integer" | "int" | "int4" | "smallint" | "int2" | "bigint" | "int8" => "SIGNED",
"real" | "float4" | "double precision" | "float8" | "numeric" | "decimal" => "DECIMAL",
"boolean" | "bool" => "UNSIGNED",
"date" => "DATE",
"time" => "TIME",
"timestamp"
| "timestamptz"
| "timestamp with time zone"
| "timestamp without time zone" => "DATETIME",
"bytea" => "BINARY",
_ => "CHAR",
}
}
fn convert_type_cast(value: &str, cast_type: &str, backend: DatabaseBackend) -> String {
match backend {
DatabaseBackend::Postgres => format!("{value}::{cast_type}"),
DatabaseBackend::MySql => {
let mysql_type = pg_type_to_mysql_cast(cast_type);
format!("CAST({value} AS {mysql_type})")
}
DatabaseBackend::Sqlite => value.to_string(),
}
}
pub(super) fn is_enum_type(column_type: &ColumnType) -> bool {
matches!(
column_type,
ColumnType::Complex(ComplexColumnType::Enum { .. })
)
}
pub fn normalize_enum_default(column_type: &ColumnType, value: &str) -> String {
if is_enum_type(column_type) && needs_quoting(value) {
format!("'{value}'")
} else {
value.to_string()
}
}
pub(super) fn needs_quoting(default_str: &str) -> bool {
let trimmed = default_str.trim();
if trimmed.is_empty() {
return true;
}
if trimmed.starts_with('\'') || trimmed.starts_with('"') {
return false;
}
if trimmed.contains('(') || trimmed.contains(')') {
return false;
}
if trimmed.eq_ignore_ascii_case("null") {
return false;
}
if trimmed.eq_ignore_ascii_case("current_timestamp")
|| trimmed.eq_ignore_ascii_case("current_date")
|| trimmed.eq_ignore_ascii_case("current_time")
{
return false;
}
true
}
pub fn build_sea_column_def_with_table(
backend: DatabaseBackend,
table: &str,
column: &ColumnDef,
) -> SeaColumnDef {
let mut col = SeaColumnDef::new(Alias::new(&column.name));
apply_column_type_with_table(&mut col, &column.r#type, table, backend);
if !column.nullable {
col.not_null();
}
if let Some(default) = &column.default {
let default_str = default.to_sql();
let converted = convert_default_for_backend(&default_str, backend);
let final_default =
if is_enum_type(&column.r#type) && default.is_string() && needs_quoting(&converted) {
format!("'{converted}'")
} else {
converted
};
let final_default = if backend == DatabaseBackend::Sqlite
&& final_default.contains('(')
&& !final_default.starts_with('(')
{
format!("({final_default})")
} else {
final_default
};
col.default(Into::<SimpleExpr>::into(sea_query::Expr::cust(
final_default,
)));
}
col
}
pub fn build_create_enum_type_sql(
table: &str,
column_type: &ColumnType,
) -> Option<super::types::RawSql> {
if let ColumnType::Complex(ComplexColumnType::Enum { name, values }) = column_type {
if values.is_integer() {
return None;
}
let values_sql = values.to_sql_values().join(", ");
let type_name = build_enum_type_name(table, name);
let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
let pg_sql = format!("CREATE TYPE {type_name} AS ENUM ({values_sql})");
Some(super::types::RawSql::per_backend(
pg_sql,
String::new(),
String::new(),
))
} else {
None
}
}
pub fn build_drop_enum_type_sql(
table: &str,
column_type: &ColumnType,
) -> Option<super::types::RawSql> {
if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
let type_name = build_enum_type_name(table, name);
let type_name = quote_ident(&type_name, DatabaseBackend::Postgres);
let pg_sql = format!("DROP TYPE {type_name}");
Some(super::types::RawSql::per_backend(
pg_sql,
String::new(),
String::new(),
))
} else {
None
}
}
pub use vespertide_naming::{
build_check_constraint_name, build_enum_type_name, build_foreign_key_name, build_index_name,
build_unique_constraint_name,
};
pub fn build_sqlite_enum_check_clause(
table: &str,
column: &str,
column_type: &ColumnType,
) -> Option<String> {
if let ColumnType::Complex(ComplexColumnType::Enum { values, .. }) = column_type {
let name = build_check_constraint_name(table, column);
let values_sql = values.to_sql_values().join(", ");
let name = quote_ident(&name, DatabaseBackend::Sqlite);
let column = quote_ident(column, DatabaseBackend::Sqlite);
Some(format!(
"CONSTRAINT {name} CHECK ({column} IN ({values_sql}))"
))
} else {
None
}
}
pub fn collect_sqlite_enum_check_clauses(table: &str, columns: &[ColumnDef]) -> Vec<String> {
columns
.iter()
.filter_map(|col| build_sqlite_enum_check_clause(table, &col.name, &col.r#type))
.collect()
}
pub fn extract_check_clauses(constraints: &[TableConstraint]) -> Vec<String> {
constraints
.iter()
.filter_map(|c| {
if let TableConstraint::Check { name, expr, .. } = c {
let name = quote_ident(name, DatabaseBackend::Sqlite);
Some(format!("CONSTRAINT {name} CHECK ({expr})"))
} else {
None
}
})
.collect()
}
pub fn collect_all_check_clauses(
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) -> Vec<String> {
let mut clauses = collect_sqlite_enum_check_clauses(table, columns);
let explicit = extract_check_clauses(constraints);
for clause in explicit {
if !clauses.contains(&clause) {
clauses.push(clause);
}
}
clauses
}
pub fn build_create_with_checks(
backend: DatabaseBackend,
create_stmt: &sea_query::TableCreateStatement,
check_clauses: &[String],
) -> BuiltQuery {
if check_clauses.is_empty() {
BuiltQuery::CreateTable(Box::new(create_stmt.clone()))
} else {
let base_sql = build_schema_statement(create_stmt, backend);
let mut modified_sql = base_sql;
if let Some(pos) = modified_sql.rfind(')') {
let check_sql = check_clauses.join(", ");
modified_sql.insert_str(pos, &format!(", {check_sql}"));
}
BuiltQuery::Raw(RawSql::per_backend(
modified_sql.clone(),
modified_sql.clone(),
modified_sql,
))
}
}
pub fn build_sqlite_temp_table_create(
backend: DatabaseBackend,
temp_table: &str,
table: &str,
columns: &[ColumnDef],
constraints: &[TableConstraint],
) -> BuiltQuery {
let create_stmt = build_create_table_for_backend(backend, temp_table, columns, constraints);
let check_clauses = collect_all_check_clauses(table, columns, constraints);
build_create_with_checks(backend, &create_stmt, &check_clauses)
}
pub fn recreate_indexes_after_rebuild(
table: &str,
constraints: &[TableConstraint],
pending_constraints: &[TableConstraint],
) -> Vec<BuiltQuery> {
let mut queries = Vec::with_capacity(constraints.len());
let pending_constraints: std::collections::BTreeSet<_> = pending_constraints.iter().collect();
for constraint in constraints {
if pending_constraints.contains(constraint) {
continue;
}
match constraint {
TableConstraint::Index { name, columns } => {
let index_name = build_index_name(table, columns, name.as_deref());
let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
let table = quote_ident(table, DatabaseBackend::Sqlite);
let sql = format!("CREATE INDEX {index_name} ON {table} ({cols_sql})");
queries.push(BuiltQuery::Raw(RawSql::per_backend(
sql.clone(),
sql.clone(),
sql,
)));
}
TableConstraint::Unique { name, columns, .. } => {
let index_name = build_unique_constraint_name(table, columns, name.as_deref());
let cols_sql = quote_idents(columns, DatabaseBackend::Sqlite);
let index_name = quote_ident(&index_name, DatabaseBackend::Sqlite);
let table = quote_ident(table, DatabaseBackend::Sqlite);
let sql = format!("CREATE UNIQUE INDEX {index_name} ON {table} ({cols_sql})");
queries.push(BuiltQuery::Raw(RawSql::per_backend(
sql.clone(),
sql.clone(),
sql,
)));
}
_ => {}
}
}
queries
}
pub fn get_enum_name(column_type: &ColumnType) -> Option<&str> {
if let ColumnType::Complex(ComplexColumnType::Enum { name, .. }) = column_type {
Some(name.as_str())
} else {
None
}
}
#[must_use]
pub fn quote_ident(name: &str, backend: DatabaseBackend) -> String {
match backend {
DatabaseBackend::Postgres | DatabaseBackend::Sqlite => {
let escaped = name.replace('"', "\"\"");
format!("\"{escaped}\"")
}
DatabaseBackend::MySql => {
let escaped = name.replace('`', "``");
format!("`{escaped}`")
}
}
}
#[must_use]
pub fn quote_idents<T: AsRef<str>>(names: &[T], backend: DatabaseBackend) -> String {
names
.iter()
.map(|n| quote_ident(n.as_ref(), backend))
.collect::<Vec<_>>()
.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
use sea_query::{Alias, ColumnDef as SeaColDef, Table};
#[test]
fn build_create_with_checks_empty_clauses_returns_plain_create_table() {
let mut stmt = Table::create();
stmt.table(Alias::new("users"))
.col(SeaColDef::new(Alias::new("id")).integer().not_null());
let query = build_create_with_checks(DatabaseBackend::Postgres, &stmt, &[]);
let sql = query.build(DatabaseBackend::Postgres);
assert!(
sql.contains("CREATE TABLE"),
"expected CREATE TABLE in: {sql}"
);
assert!(
!sql.contains("CHECK ("),
"no CHECK should be injected: {sql}"
);
assert!(
matches!(query, BuiltQuery::CreateTable(_)),
"empty-checks branch must return BuiltQuery::CreateTable"
);
}
}