use std::fmt;
use crate::config::DatabaseType;
use crate::database::{require_db, Database};
use crate::error::{Error, Result};
use crate::internal::ConnectionTrait;
use crate::{tide_info, tide_debug};
pub use async_trait::async_trait;
#[async_trait]
pub trait Migration: Send + Sync {
fn version(&self) -> &str;
fn name(&self) -> &str;
async fn up(&self, schema: &mut Schema) -> Result<()>;
async fn down(&self, schema: &mut Schema) -> Result<()>;
}
pub struct Schema {
database_type: DatabaseType,
statements: Vec<String>,
}
impl Schema {
pub fn new(database_type: DatabaseType) -> Self {
Self {
database_type,
statements: Vec::new(),
}
}
pub async fn create_table<F>(&mut self, name: &str, f: F) -> Result<()>
where
F: FnOnce(&mut TableBuilder),
{
let mut builder = TableBuilder::new(name, self.database_type);
f(&mut builder);
let sql = builder.build_create();
self.execute(&sql).await?;
for index_sql in builder.build_indexes() {
self.execute(&index_sql).await?;
}
Ok(())
}
pub async fn create_table_if_not_exists<F>(&mut self, name: &str, f: F) -> Result<()>
where
F: FnOnce(&mut TableBuilder),
{
let mut builder = TableBuilder::new(name, self.database_type);
f(&mut builder);
let sql = builder.build_create_if_not_exists();
self.execute(&sql).await?;
for index_sql in builder.build_indexes_if_not_exists() {
self.execute(&index_sql).await?;
}
Ok(())
}
pub async fn alter_table<F>(&mut self, name: &str, f: F) -> Result<()>
where
F: FnOnce(&mut AlterTableBuilder),
{
let mut builder = AlterTableBuilder::new(name, self.database_type);
f(&mut builder);
for sql in builder.build() {
self.execute(&sql).await?;
}
Ok(())
}
pub async fn drop_table(&mut self, name: &str) -> Result<()> {
let sql = format!(
"DROP TABLE {}",
self.quote_identifier(name)
);
self.execute(&sql).await
}
pub async fn drop_table_if_exists(&mut self, name: &str) -> Result<()> {
let sql = format!(
"DROP TABLE IF EXISTS {}",
self.quote_identifier(name)
);
self.execute(&sql).await
}
pub async fn rename_table(&mut self, from: &str, to: &str) -> Result<()> {
let sql = match self.database_type {
DatabaseType::MySQL | DatabaseType::MariaDB => format!(
"RENAME TABLE {} TO {}",
self.quote_identifier(from),
self.quote_identifier(to)
),
_ => format!(
"ALTER TABLE {} RENAME TO {}",
self.quote_identifier(from),
self.quote_identifier(to)
),
};
self.execute(&sql).await
}
pub async fn create_index(
&mut self,
table: &str,
name: &str,
columns: &[&str],
unique: bool,
) -> Result<()> {
let index_type = if unique { "UNIQUE INDEX" } else { "INDEX" };
let cols: Vec<String> = columns.iter().map(|c| self.quote_identifier(c)).collect();
let sql = format!(
"CREATE {} {} ON {} ({})",
index_type,
self.quote_identifier(name),
self.quote_identifier(table),
cols.join(", ")
);
self.execute(&sql).await
}
pub async fn drop_index(&mut self, table: &str, name: &str) -> Result<()> {
let sql = match self.database_type {
DatabaseType::MySQL | DatabaseType::MariaDB => format!(
"DROP INDEX {} ON {}",
self.quote_identifier(name),
self.quote_identifier(table)
),
_ => format!(
"DROP INDEX {}",
self.quote_identifier(name)
),
};
self.execute(&sql).await
}
pub async fn raw(&mut self, sql: &str) -> Result<()> {
self.execute(sql).await
}
async fn execute(&mut self, sql: &str) -> Result<()> {
log_migration_sql(sql);
self.statements.push(sql.to_string());
let db = require_db()?;
db.__internal_connection()
.execute_unprepared(sql)
.await
.map_err(|e| Error::query_with_context(
e.to_string(),
crate::error::ErrorContext::new().query(sql.to_string()),
))?;
Ok(())
}
fn quote_identifier(&self, name: &str) -> String {
match self.database_type {
DatabaseType::Postgres | DatabaseType::SQLite => format!("\"{}\"", name),
DatabaseType::MySQL | DatabaseType::MariaDB => format!("`{}`", name),
}
}
pub fn database_type(&self) -> DatabaseType {
self.database_type
}
}
#[derive(Debug, Clone)]
pub struct UniqueConstraint {
pub name: Option<String>,
pub columns: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct CompositePrimaryKey {
pub columns: Vec<String>,
}
pub struct TableBuilder {
name: String,
database_type: DatabaseType,
columns: Vec<ColumnDefinition>,
indexes: Vec<IndexBuilder>,
primary_key: Option<String>,
unique_constraints: Vec<UniqueConstraint>,
composite_primary_key: Option<CompositePrimaryKey>,
}
impl TableBuilder {
pub fn new(name: &str, database_type: DatabaseType) -> Self {
Self {
name: name.to_string(),
database_type,
columns: Vec::new(),
indexes: Vec::new(),
primary_key: None,
unique_constraints: Vec::new(),
composite_primary_key: None,
}
}
pub fn id(&mut self) -> &mut Self {
self.big_increments("id")
}
pub fn big_increments(&mut self, name: &str) -> &mut Self {
let col = ColumnDefinition {
name: name.to_string(),
column_type: ColumnType::BigInteger,
nullable: false,
default: None,
primary_key: true,
auto_increment: true,
unique: false,
check: None,
extra: None,
};
self.columns.push(col);
self.primary_key = Some(name.to_string());
self
}
pub fn increments(&mut self, name: &str) -> &mut Self {
let col = ColumnDefinition {
name: name.to_string(),
column_type: ColumnType::Integer,
nullable: false,
default: None,
primary_key: true,
auto_increment: true,
unique: false,
check: None,
extra: None,
};
self.columns.push(col);
self.primary_key = Some(name.to_string());
self
}
pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::String)
}
pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Text)
}
pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Integer)
}
pub fn big_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::BigInteger)
}
pub fn small_integer(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::SmallInteger)
}
pub fn decimal(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Decimal { precision: 10, scale: 2 })
}
pub fn decimal_with(&mut self, name: &str, precision: u32, scale: u32) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Decimal { precision, scale })
}
pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Float)
}
pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Double)
}
pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Boolean)
}
pub fn date(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Date)
}
pub fn time(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Time)
}
pub fn datetime(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::DateTime)
}
pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Timestamp)
}
pub fn timestamptz(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::TimestampTz)
}
pub fn timestamps(&mut self) -> &mut Self {
self.column("created_at", ColumnType::TimestampTz)
.default_now()
.not_null();
self.column("updated_at", ColumnType::TimestampTz)
.default_now()
.not_null();
self
}
pub fn timestamps_naive(&mut self) -> &mut Self {
self.column("created_at", ColumnType::Timestamp)
.default_now()
.not_null();
self.column("updated_at", ColumnType::Timestamp)
.default_now()
.not_null();
self
}
pub fn soft_deletes(&mut self) -> &mut Self {
self.column("deleted_at", ColumnType::TimestampTz).nullable();
self
}
pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Uuid)
}
pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Json)
}
pub fn jsonb(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Jsonb)
}
pub fn binary(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::Binary)
}
pub fn integer_array(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::IntegerArray)
}
pub fn text_array(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::TextArray)
}
pub fn column(&mut self, name: &str, column_type: ColumnType) -> ColumnBuilder<'_> {
ColumnBuilder {
table: self,
definition: ColumnDefinition {
name: name.to_string(),
column_type,
nullable: true,
default: None,
primary_key: false,
auto_increment: false,
unique: false,
check: None,
extra: None,
},
}
}
pub fn foreign_id(&mut self, name: &str) -> ColumnBuilder<'_> {
self.column(name, ColumnType::BigInteger)
}
pub fn index(&mut self, columns: &[&str]) -> &mut Self {
let idx = IndexBuilder {
name: format!("idx_{}_{}", self.name, columns.join("_")),
columns: columns.iter().map(|s| s.to_string()).collect(),
unique: false,
};
self.indexes.push(idx);
self
}
pub fn unique_index(&mut self, columns: &[&str]) -> &mut Self {
let idx = IndexBuilder {
name: format!("idx_{}_{}_unique", self.name, columns.join("_")),
columns: columns.iter().map(|s| s.to_string()).collect(),
unique: true,
};
self.indexes.push(idx);
self
}
pub fn unique(&mut self, columns: &[&str]) -> &mut Self {
self.unique_constraints.push(UniqueConstraint {
name: None,
columns: columns.iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn unique_named(&mut self, name: &str, columns: &[&str]) -> &mut Self {
self.unique_constraints.push(UniqueConstraint {
name: Some(name.to_string()),
columns: columns.iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn primary_key(&mut self, columns: &[&str]) -> &mut Self {
self.composite_primary_key = Some(CompositePrimaryKey {
columns: columns.iter().map(|s| s.to_string()).collect(),
});
self
}
pub fn index_named(&mut self, name: &str, columns: &[&str]) -> &mut Self {
let idx = IndexBuilder {
name: name.to_string(),
columns: columns.iter().map(|s| s.to_string()).collect(),
unique: false,
};
self.indexes.push(idx);
self
}
fn build_create(&self) -> String {
self.build_create_internal(false)
}
fn build_create_if_not_exists(&self) -> String {
self.build_create_internal(true)
}
fn build_create_internal(&self, if_not_exists: bool) -> String {
let exists_clause = if if_not_exists { "IF NOT EXISTS " } else { "" };
let mut sql = format!(
"CREATE TABLE {}{} (\n",
exists_clause,
self.quote_identifier(&self.name)
);
let column_defs: Vec<String> = self
.columns
.iter()
.map(|col| self.build_column_def(col))
.collect();
sql.push_str(&column_defs.join(",\n"));
if let Some(ref pk) = self.primary_key {
sql.push_str(",\n");
sql.push_str(&format!(
" PRIMARY KEY ({})",
self.quote_identifier(pk)
));
}
if let Some(ref cpk) = self.composite_primary_key {
sql.push_str(",\n");
let cols: Vec<String> = cpk
.columns
.iter()
.map(|c| self.quote_identifier(c))
.collect();
sql.push_str(&format!(" PRIMARY KEY ({})", cols.join(", ")));
}
for uc in &self.unique_constraints {
sql.push_str(",\n");
let cols: Vec<String> = uc
.columns
.iter()
.map(|c| self.quote_identifier(c))
.collect();
if let Some(ref name) = uc.name {
sql.push_str(&format!(
" CONSTRAINT {} UNIQUE ({})",
self.quote_identifier(name),
cols.join(", ")
));
} else {
sql.push_str(&format!(" UNIQUE ({})", cols.join(", ")));
}
}
sql.push_str("\n)");
sql
}
fn build_column_def(&self, col: &ColumnDefinition) -> String {
let mut def = format!(
" {} {}",
self.quote_identifier(&col.name),
self.type_to_sql(&col.column_type)
);
if col.auto_increment {
match self.database_type {
DatabaseType::Postgres => {
def = match col.column_type {
ColumnType::Integer => format!(
" {} SERIAL",
self.quote_identifier(&col.name)
),
_ => format!(
" {} BIGSERIAL",
self.quote_identifier(&col.name)
),
};
}
DatabaseType::MySQL | DatabaseType::MariaDB => {
def.push_str(" AUTO_INCREMENT");
}
DatabaseType::SQLite => {
}
}
}
if !col.nullable && !col.primary_key {
def.push_str(" NOT NULL");
}
if let Some(ref default) = col.default {
def.push_str(&format!(" DEFAULT {}", default));
}
if col.unique && !col.primary_key {
def.push_str(" UNIQUE");
}
if let Some(ref check_expr) = col.check {
def.push_str(&format!(" CHECK ({})", check_expr));
}
if let Some(ref extra_sql) = col.extra {
def.push_str(&format!(" {}", extra_sql));
}
def
}
fn build_indexes(&self) -> Vec<String> {
self.build_indexes_internal(false)
}
fn build_indexes_if_not_exists(&self) -> Vec<String> {
self.build_indexes_internal(true)
}
fn build_indexes_internal(&self, if_not_exists: bool) -> Vec<String> {
let exists_clause = if if_not_exists { "IF NOT EXISTS " } else { "" };
self.indexes
.iter()
.map(|idx| {
let index_type = if idx.unique { "UNIQUE INDEX" } else { "INDEX" };
let cols: Vec<String> = idx
.columns
.iter()
.map(|c| self.quote_identifier(c))
.collect();
format!(
"CREATE {} {}{} ON {} ({})",
index_type,
exists_clause,
self.quote_identifier(&idx.name),
self.quote_identifier(&self.name),
cols.join(", ")
)
})
.collect()
}
fn type_to_sql(&self, column_type: &ColumnType) -> String {
match self.database_type {
DatabaseType::Postgres => column_type.to_postgres_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => column_type.to_mysql_sql(),
DatabaseType::SQLite => column_type.to_sqlite_sql(),
}
}
fn quote_identifier(&self, name: &str) -> String {
match self.database_type {
DatabaseType::Postgres | DatabaseType::SQLite => format!("\"{}\"", name),
DatabaseType::MySQL | DatabaseType::MariaDB => format!("`{}`", name),
}
}
}
pub struct ColumnBuilder<'a> {
table: &'a mut TableBuilder,
definition: ColumnDefinition,
}
impl<'a> ColumnBuilder<'a> {
pub fn not_null(mut self) -> Self {
self.definition.nullable = false;
self
}
pub fn nullable(mut self) -> Self {
self.definition.nullable = true;
self
}
pub fn default(mut self, value: impl Into<DefaultValue>) -> Self {
self.definition.default = Some(value.into().to_sql());
self
}
pub fn default_now(mut self) -> Self {
self.definition.default = Some("CURRENT_TIMESTAMP".to_string());
self
}
pub fn unique(mut self) -> Self {
self.definition.unique = true;
self
}
pub fn primary_key(mut self) -> Self {
self.definition.primary_key = true;
self.definition.nullable = false;
self.table.primary_key = Some(self.definition.name.clone());
self
}
pub fn check(mut self, expression: &str) -> Self {
self.definition.check = Some(expression.to_string());
self
}
pub fn extra(mut self, sql: &str) -> Self {
self.definition.extra = Some(sql.to_string());
self
}
}
impl<'a> Drop for ColumnBuilder<'a> {
fn drop(&mut self) {
let def = std::mem::replace(
&mut self.definition,
ColumnDefinition {
name: String::new(),
column_type: ColumnType::String,
nullable: true,
default: None,
primary_key: false,
auto_increment: false,
unique: false,
check: None,
extra: None,
},
);
if !def.name.is_empty() {
self.table.columns.push(def);
}
}
}
pub struct AlterTableBuilder {
name: String,
database_type: DatabaseType,
operations: Vec<AlterOperation>,
}
impl AlterTableBuilder {
pub fn new(name: &str, database_type: DatabaseType) -> Self {
Self {
name: name.to_string(),
database_type,
operations: Vec::new(),
}
}
pub fn add_column(&mut self, name: &str, column_type: ColumnType) -> AlterColumnBuilder<'_> {
AlterColumnBuilder {
builder: self,
definition: ColumnDefinition {
name: name.to_string(),
column_type,
nullable: true,
default: None,
primary_key: false,
auto_increment: false,
unique: false,
check: None,
extra: None,
},
}
}
pub fn drop_column(&mut self, name: &str) -> &mut Self {
self.operations.push(AlterOperation::DropColumn(name.to_string()));
self
}
pub fn rename_column(&mut self, from: &str, to: &str) -> &mut Self {
self.operations.push(AlterOperation::RenameColumn(
from.to_string(),
to.to_string(),
));
self
}
pub fn change_column(&mut self, name: &str, column_type: ColumnType) -> &mut Self {
self.operations.push(AlterOperation::ChangeColumnType(
name.to_string(),
column_type,
));
self
}
pub fn add_index(&mut self, name: &str, columns: &[&str], unique: bool) -> &mut Self {
self.operations.push(AlterOperation::AddIndex(IndexBuilder {
name: name.to_string(),
columns: columns.iter().map(|s| s.to_string()).collect(),
unique,
}));
self
}
pub fn drop_index(&mut self, name: &str) -> &mut Self {
self.operations.push(AlterOperation::DropIndex(name.to_string()));
self
}
fn build(&self) -> Vec<String> {
self.operations
.iter()
.map(|op| self.build_operation(op))
.collect()
}
fn build_operation(&self, op: &AlterOperation) -> String {
match op {
AlterOperation::AddColumn(col) => {
let col_def = self.build_column_def(col);
format!(
"ALTER TABLE {} ADD COLUMN {}",
self.quote_identifier(&self.name),
col_def.trim()
)
}
AlterOperation::DropColumn(name) => {
format!(
"ALTER TABLE {} DROP COLUMN {}",
self.quote_identifier(&self.name),
self.quote_identifier(name)
)
}
AlterOperation::RenameColumn(from, to) => match self.database_type {
DatabaseType::Postgres | DatabaseType::SQLite => {
format!(
"ALTER TABLE {} RENAME COLUMN {} TO {}",
self.quote_identifier(&self.name),
self.quote_identifier(from),
self.quote_identifier(to)
)
}
DatabaseType::MySQL | DatabaseType::MariaDB => {
format!(
"ALTER TABLE {} RENAME COLUMN {} TO {}",
self.quote_identifier(&self.name),
self.quote_identifier(from),
self.quote_identifier(to)
)
}
},
AlterOperation::ChangeColumnType(name, column_type) => {
let type_sql = self.type_to_sql(column_type);
match self.database_type {
DatabaseType::Postgres => {
format!(
"ALTER TABLE {} ALTER COLUMN {} TYPE {}",
self.quote_identifier(&self.name),
self.quote_identifier(name),
type_sql
)
}
DatabaseType::MySQL | DatabaseType::MariaDB => {
format!(
"ALTER TABLE {} MODIFY COLUMN {} {}",
self.quote_identifier(&self.name),
self.quote_identifier(name),
type_sql
)
}
DatabaseType::SQLite => {
format!(
"-- SQLite does not support ALTER COLUMN TYPE; table recreation needed for {}",
name
)
}
}
}
AlterOperation::AddIndex(idx) => {
let index_type = if idx.unique { "UNIQUE INDEX" } else { "INDEX" };
let cols: Vec<String> = idx
.columns
.iter()
.map(|c| self.quote_identifier(c))
.collect();
format!(
"CREATE {} {} ON {} ({})",
index_type,
self.quote_identifier(&idx.name),
self.quote_identifier(&self.name),
cols.join(", ")
)
}
AlterOperation::DropIndex(name) => match self.database_type {
DatabaseType::MySQL | DatabaseType::MariaDB => {
format!(
"DROP INDEX {} ON {}",
self.quote_identifier(name),
self.quote_identifier(&self.name)
)
}
_ => {
format!("DROP INDEX {}", self.quote_identifier(name))
}
},
}
}
fn build_column_def(&self, col: &ColumnDefinition) -> String {
let mut def = format!(
"{} {}",
self.quote_identifier(&col.name),
self.type_to_sql(&col.column_type)
);
if !col.nullable {
def.push_str(" NOT NULL");
}
if let Some(ref default) = col.default {
def.push_str(&format!(" DEFAULT {}", default));
}
if col.unique {
def.push_str(" UNIQUE");
}
def
}
fn type_to_sql(&self, column_type: &ColumnType) -> String {
match self.database_type {
DatabaseType::Postgres => column_type.to_postgres_sql(),
DatabaseType::MySQL | DatabaseType::MariaDB => column_type.to_mysql_sql(),
DatabaseType::SQLite => column_type.to_sqlite_sql(),
}
}
fn quote_identifier(&self, name: &str) -> String {
match self.database_type {
DatabaseType::Postgres | DatabaseType::SQLite => format!("\"{}\"", name),
DatabaseType::MySQL | DatabaseType::MariaDB => format!("`{}`", name),
}
}
}
pub struct AlterColumnBuilder<'a> {
builder: &'a mut AlterTableBuilder,
definition: ColumnDefinition,
}
impl<'a> AlterColumnBuilder<'a> {
pub fn not_null(mut self) -> Self {
self.definition.nullable = false;
self
}
pub fn nullable(mut self) -> Self {
self.definition.nullable = true;
self
}
pub fn default(mut self, value: impl Into<DefaultValue>) -> Self {
self.definition.default = Some(value.into().to_sql());
self
}
pub fn default_now(mut self) -> Self {
self.definition.default = Some("CURRENT_TIMESTAMP".to_string());
self
}
pub fn unique(mut self) -> Self {
self.definition.unique = true;
self
}
}
impl<'a> Drop for AlterColumnBuilder<'a> {
fn drop(&mut self) {
let def = std::mem::replace(
&mut self.definition,
ColumnDefinition {
name: String::new(),
column_type: ColumnType::String,
nullable: true,
default: None,
primary_key: false,
auto_increment: false,
unique: false,
check: None,
extra: None,
},
);
if !def.name.is_empty() {
self.builder.operations.push(AlterOperation::AddColumn(def));
}
}
}
#[derive(Debug, Clone)]
pub enum ColumnType {
SmallInteger,
Integer,
BigInteger,
Float,
Double,
Decimal {
precision: u32,
scale: u32,
},
String,
Text,
Boolean,
Date,
Time,
DateTime,
Timestamp,
TimestampTz,
Uuid,
Json,
Jsonb,
Binary,
IntegerArray,
TextArray,
Custom(String),
}
impl ColumnType {
pub fn to_postgres_sql(&self) -> String {
match self {
ColumnType::SmallInteger => "SMALLINT".to_string(),
ColumnType::Integer => "INTEGER".to_string(),
ColumnType::BigInteger => "BIGINT".to_string(),
ColumnType::Float => "REAL".to_string(),
ColumnType::Double => "DOUBLE PRECISION".to_string(),
ColumnType::Decimal { precision, scale } => {
format!("DECIMAL({}, {})", precision, scale)
}
ColumnType::String => "VARCHAR(255)".to_string(),
ColumnType::Text => "TEXT".to_string(),
ColumnType::Boolean => "BOOLEAN".to_string(),
ColumnType::Date => "DATE".to_string(),
ColumnType::Time => "TIME".to_string(),
ColumnType::DateTime => "TIMESTAMP".to_string(),
ColumnType::Timestamp => "TIMESTAMP".to_string(),
ColumnType::TimestampTz => "TIMESTAMPTZ".to_string(),
ColumnType::Uuid => "UUID".to_string(),
ColumnType::Json => "JSON".to_string(),
ColumnType::Jsonb => "JSONB".to_string(),
ColumnType::Binary => "BYTEA".to_string(),
ColumnType::IntegerArray => "INTEGER[]".to_string(),
ColumnType::TextArray => "TEXT[]".to_string(),
ColumnType::Custom(s) => s.clone(),
}
}
pub fn to_mysql_sql(&self) -> String {
match self {
ColumnType::SmallInteger => "SMALLINT".to_string(),
ColumnType::Integer => "INT".to_string(),
ColumnType::BigInteger => "BIGINT".to_string(),
ColumnType::Float => "FLOAT".to_string(),
ColumnType::Double => "DOUBLE".to_string(),
ColumnType::Decimal { precision, scale } => {
format!("DECIMAL({}, {})", precision, scale)
}
ColumnType::String => "VARCHAR(255)".to_string(),
ColumnType::Text => "TEXT".to_string(),
ColumnType::Boolean => "TINYINT(1)".to_string(),
ColumnType::Date => "DATE".to_string(),
ColumnType::Time => "TIME".to_string(),
ColumnType::DateTime => "DATETIME".to_string(),
ColumnType::Timestamp | ColumnType::TimestampTz => "TIMESTAMP".to_string(), ColumnType::Uuid => "CHAR(36)".to_string(),
ColumnType::Json | ColumnType::Jsonb => "JSON".to_string(),
ColumnType::Binary => "BLOB".to_string(),
ColumnType::IntegerArray | ColumnType::TextArray => "JSON".to_string(), ColumnType::Custom(s) => s.clone(),
}
}
pub fn to_sqlite_sql(&self) -> String {
match self {
ColumnType::SmallInteger
| ColumnType::Integer
| ColumnType::BigInteger
| ColumnType::Boolean => "INTEGER".to_string(),
ColumnType::Float | ColumnType::Double | ColumnType::Decimal { .. } => {
"REAL".to_string()
}
ColumnType::String
| ColumnType::Text
| ColumnType::Uuid
| ColumnType::Date
| ColumnType::Time
| ColumnType::DateTime
| ColumnType::Timestamp
| ColumnType::TimestampTz
| ColumnType::Json
| ColumnType::Jsonb
| ColumnType::IntegerArray
| ColumnType::TextArray => "TEXT".to_string(),
ColumnType::Binary => "BLOB".to_string(),
ColumnType::Custom(s) => s.clone(),
}
}
}
#[derive(Debug, Clone)]
pub enum DefaultValue {
String(String),
Integer(i64),
Float(f64),
Boolean(bool),
Raw(String),
Null,
}
impl DefaultValue {
pub fn to_sql(&self) -> String {
match self {
DefaultValue::String(s) => format!("'{}'", s.replace('\'', "''")),
DefaultValue::Integer(i) => i.to_string(),
DefaultValue::Float(f) => f.to_string(),
DefaultValue::Boolean(b) => {
if *b {
"TRUE".to_string()
} else {
"FALSE".to_string()
}
}
DefaultValue::Raw(s) => s.clone(),
DefaultValue::Null => "NULL".to_string(),
}
}
}
impl From<&str> for DefaultValue {
fn from(s: &str) -> Self {
DefaultValue::String(s.to_string())
}
}
impl From<String> for DefaultValue {
fn from(s: String) -> Self {
DefaultValue::String(s)
}
}
impl From<i32> for DefaultValue {
fn from(i: i32) -> Self {
DefaultValue::Integer(i as i64)
}
}
impl From<i64> for DefaultValue {
fn from(i: i64) -> Self {
DefaultValue::Integer(i)
}
}
impl From<f64> for DefaultValue {
fn from(f: f64) -> Self {
DefaultValue::Float(f)
}
}
impl From<bool> for DefaultValue {
fn from(b: bool) -> Self {
DefaultValue::Boolean(b)
}
}
#[derive(Debug, Clone)]
struct ColumnDefinition {
name: String,
column_type: ColumnType,
nullable: bool,
default: Option<String>,
primary_key: bool,
auto_increment: bool,
unique: bool,
check: Option<String>,
extra: Option<String>,
}
#[derive(Debug, Clone)]
struct IndexBuilder {
name: String,
columns: Vec<String>,
unique: bool,
}
#[derive(Debug, Clone)]
enum AlterOperation {
AddColumn(ColumnDefinition),
DropColumn(String),
RenameColumn(String, String),
ChangeColumnType(String, ColumnType),
AddIndex(IndexBuilder),
DropIndex(String),
}
pub struct Migrator {
migrations: Vec<Box<dyn Migration>>,
}
impl Migrator {
pub fn new() -> Self {
Self {
migrations: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add<M: Migration + 'static>(mut self, migration: M) -> Self {
self.migrations.push(Box::new(migration));
self
}
#[doc(hidden)]
pub fn add_boxed(mut self, migration: Box<dyn Migration>) -> Self {
self.migrations.push(migration);
self
}
pub async fn run(&self) -> Result<MigrationResult> {
self.ensure_migrations_table().await?;
let applied = self.get_applied_migrations().await?;
let mut result = MigrationResult::new();
let db = require_db()?;
let db_type = detect_database_type(db);
let mut migrations: Vec<_> = self.migrations.iter().collect();
migrations.sort_by_key(|m| m.version());
for migration in migrations {
let version = migration.version();
if applied.contains(&version.to_string()) {
result.skipped.push(MigrationInfo {
version: version.to_string(),
name: migration.name().to_string(),
});
continue;
}
log_migration_start(version, migration.name());
let mut schema = Schema::new(db_type);
migration.up(&mut schema).await?;
self.record_migration(version, migration.name()).await?;
result.applied.push(MigrationInfo {
version: version.to_string(),
name: migration.name().to_string(),
});
log_migration_complete(version, migration.name());
}
Ok(result)
}
pub async fn rollback(&self) -> Result<MigrationResult> {
self.ensure_migrations_table().await?;
let applied = self.get_applied_migrations().await?;
let mut result = MigrationResult::new();
if applied.is_empty() {
return Ok(result);
}
let last_version = match applied.last() {
Some(v) => v,
None => return Ok(result),
};
let db = require_db()?;
let db_type = detect_database_type(db);
for migration in &self.migrations {
if migration.version() == last_version {
log_migration_rollback(last_version, migration.name());
let mut schema = Schema::new(db_type);
migration.down(&mut schema).await?;
self.remove_migration_record(last_version).await?;
result.rolled_back.push(MigrationInfo {
version: migration.version().to_string(),
name: migration.name().to_string(),
});
break;
}
}
Ok(result)
}
pub async fn rollback_steps(&self, steps: usize) -> Result<MigrationResult> {
let mut result = MigrationResult::new();
for _ in 0..steps {
let step_result = self.rollback().await?;
if step_result.rolled_back.is_empty() {
break;
}
result.rolled_back.extend(step_result.rolled_back);
}
Ok(result)
}
pub async fn reset(&self) -> Result<MigrationResult> {
let applied = self.get_applied_migrations().await?;
self.rollback_steps(applied.len()).await
}
pub async fn refresh(&self) -> Result<MigrationResult> {
let reset_result = self.reset().await?;
let run_result = self.run().await?;
Ok(MigrationResult {
applied: run_result.applied,
skipped: run_result.skipped,
rolled_back: reset_result.rolled_back,
})
}
pub async fn status(&self) -> Result<Vec<MigrationStatus>> {
self.ensure_migrations_table().await?;
let applied = self.get_applied_migrations().await?;
let mut status = Vec::new();
let mut migrations: Vec<_> = self.migrations.iter().collect();
migrations.sort_by_key(|m| m.version());
for migration in migrations {
let is_applied = applied.contains(&migration.version().to_string());
status.push(MigrationStatus {
version: migration.version().to_string(),
name: migration.name().to_string(),
applied: is_applied,
});
}
Ok(status)
}
async fn ensure_migrations_table(&self) -> Result<()> {
let db = require_db()?;
let db_type = detect_database_type(db);
let sql = match db_type {
DatabaseType::Postgres => {
r#"
CREATE TABLE IF NOT EXISTS "_migrations" (
"id" SERIAL PRIMARY KEY,
"version" VARCHAR(255) NOT NULL UNIQUE,
"name" VARCHAR(255) NOT NULL,
"applied_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::MySQL | DatabaseType::MariaDB => {
r#"
CREATE TABLE IF NOT EXISTS `_migrations` (
`id` INT AUTO_INCREMENT PRIMARY KEY,
`version` VARCHAR(255) NOT NULL UNIQUE,
`name` VARCHAR(255) NOT NULL,
`applied_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
DatabaseType::SQLite => {
r#"
CREATE TABLE IF NOT EXISTS "_migrations" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"version" TEXT NOT NULL UNIQUE,
"name" TEXT NOT NULL,
"applied_at" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
"#
}
};
db.__internal_connection()
.execute_unprepared(sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn get_applied_migrations(&self) -> Result<Vec<String>> {
let db = require_db()?;
use crate::internal::Statement;
let backend = db.__internal_connection().get_database_backend();
let db_type = detect_database_type(db);
let q = |id: &str| quote_migration_identifier(id, db_type);
let sql = format!(
"SELECT {} FROM {} ORDER BY {} ASC",
q("version"), q("_migrations"), q("version")
);
let stmt = Statement::from_string(backend, sql);
let results = db
.__internal_connection()
.query_all_raw(stmt)
.await
.map_err(|e| Error::query(e.to_string()))?;
let registered_versions: std::collections::HashSet<_> =
self.migrations.iter().map(|m| m.version().to_string()).collect();
let mut versions = Vec::new();
for row in results {
let version: String = row
.try_get("", "version")
.map_err(|e| Error::query(e.to_string()))?;
if registered_versions.contains(&version) {
versions.push(version);
}
}
Ok(versions)
}
async fn record_migration(&self, version: &str, name: &str) -> Result<()> {
let db = require_db()?;
let db_type = detect_database_type(db);
let q = |id: &str| quote_migration_identifier(id, db_type);
let sql = format!(
"INSERT INTO {} ({}, {}) VALUES ('{}', '{}')",
q("_migrations"), q("version"), q("name"),
version.replace('\'', "''"),
name.replace('\'', "''")
);
db.__internal_connection()
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
async fn remove_migration_record(&self, version: &str) -> Result<()> {
let db = require_db()?;
let db_type = detect_database_type(db);
let q = |id: &str| quote_migration_identifier(id, db_type);
let sql = format!(
"DELETE FROM {} WHERE {} = '{}'",
q("_migrations"), q("version"),
version.replace('\'', "''")
);
db.__internal_connection()
.execute_unprepared(&sql)
.await
.map_err(|e| Error::query(e.to_string()))?;
Ok(())
}
}
impl Default for Migrator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MigrationResult {
pub applied: Vec<MigrationInfo>,
pub skipped: Vec<MigrationInfo>,
pub rolled_back: Vec<MigrationInfo>,
}
impl MigrationResult {
fn new() -> Self {
Self {
applied: Vec::new(),
skipped: Vec::new(),
rolled_back: Vec::new(),
}
}
pub fn has_applied(&self) -> bool {
!self.applied.is_empty()
}
pub fn has_rolled_back(&self) -> bool {
!self.rolled_back.is_empty()
}
}
impl fmt::Display for MigrationResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.applied.is_empty() {
writeln!(f, "Applied migrations:")?;
for m in &self.applied {
writeln!(f, " ✓ {} - {}", m.version, m.name)?;
}
}
if !self.skipped.is_empty() {
writeln!(f, "Skipped migrations (already applied):")?;
for m in &self.skipped {
writeln!(f, " - {} - {}", m.version, m.name)?;
}
}
if !self.rolled_back.is_empty() {
writeln!(f, "Rolled back migrations:")?;
for m in &self.rolled_back {
writeln!(f, " ↩ {} - {}", m.version, m.name)?;
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct MigrationInfo {
pub version: String,
pub name: String,
}
#[derive(Debug, Clone)]
pub struct MigrationStatus {
pub version: String,
pub name: String,
pub applied: bool,
}
impl fmt::Display for MigrationStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let status = if self.applied { "✓" } else { "○" };
write!(f, "[{}] {} - {}", status, self.version, self.name)
}
}
fn detect_database_type(db: &Database) -> DatabaseType {
db.backend()
}
fn quote_migration_identifier(name: &str, db_type: DatabaseType) -> String {
match db_type {
DatabaseType::MySQL | DatabaseType::MariaDB => format!("`{}`", name),
_ => format!(r#""{}""#, name), }
}
fn log_migration_sql(sql: &str) {
if std::env::var("TIDE_LOG_QUERIES").is_ok() {
tide_debug!("Migration SQL: {}", sql);
}
}
fn log_migration_start(version: &str, name: &str) {
tide_info!("Running migration: {} - {}", version, name);
}
fn log_migration_complete(version: &str, name: &str) {
tide_info!("Completed migration: {} - {}", version, name);
}
fn log_migration_rollback(version: &str, name: &str) {
tide_info!("Rolling back migration: {} - {}", version, name);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_column_type_postgres() {
assert_eq!(ColumnType::Integer.to_postgres_sql(), "INTEGER");
assert_eq!(ColumnType::BigInteger.to_postgres_sql(), "BIGINT");
assert_eq!(ColumnType::String.to_postgres_sql(), "VARCHAR(255)");
assert_eq!(ColumnType::Text.to_postgres_sql(), "TEXT");
assert_eq!(ColumnType::Boolean.to_postgres_sql(), "BOOLEAN");
assert_eq!(ColumnType::Jsonb.to_postgres_sql(), "JSONB");
assert_eq!(ColumnType::IntegerArray.to_postgres_sql(), "INTEGER[]");
assert_eq!(ColumnType::Timestamp.to_postgres_sql(), "TIMESTAMP");
assert_eq!(ColumnType::TimestampTz.to_postgres_sql(), "TIMESTAMPTZ");
assert_eq!(ColumnType::Date.to_postgres_sql(), "DATE");
assert_eq!(ColumnType::Time.to_postgres_sql(), "TIME");
}
#[test]
fn test_column_type_mysql() {
assert_eq!(ColumnType::Integer.to_mysql_sql(), "INT");
assert_eq!(ColumnType::BigInteger.to_mysql_sql(), "BIGINT");
assert_eq!(ColumnType::Boolean.to_mysql_sql(), "TINYINT(1)");
assert_eq!(ColumnType::Jsonb.to_mysql_sql(), "JSON");
assert_eq!(ColumnType::Timestamp.to_mysql_sql(), "TIMESTAMP");
assert_eq!(ColumnType::TimestampTz.to_mysql_sql(), "TIMESTAMP"); assert_eq!(ColumnType::Date.to_mysql_sql(), "DATE");
assert_eq!(ColumnType::Time.to_mysql_sql(), "TIME");
}
#[test]
fn test_column_type_sqlite() {
assert_eq!(ColumnType::Integer.to_sqlite_sql(), "INTEGER");
assert_eq!(ColumnType::BigInteger.to_sqlite_sql(), "INTEGER");
assert_eq!(ColumnType::String.to_sqlite_sql(), "TEXT");
assert_eq!(ColumnType::Boolean.to_sqlite_sql(), "INTEGER");
assert_eq!(ColumnType::Timestamp.to_sqlite_sql(), "TEXT");
assert_eq!(ColumnType::TimestampTz.to_sqlite_sql(), "TEXT");
assert_eq!(ColumnType::Date.to_sqlite_sql(), "TEXT");
assert_eq!(ColumnType::Time.to_sqlite_sql(), "TEXT");
}
#[test]
fn test_default_value() {
assert_eq!(DefaultValue::String("test".to_string()).to_sql(), "'test'");
assert_eq!(DefaultValue::Integer(42).to_sql(), "42");
assert_eq!(DefaultValue::Boolean(true).to_sql(), "TRUE");
assert_eq!(DefaultValue::Boolean(false).to_sql(), "FALSE");
assert_eq!(DefaultValue::Null.to_sql(), "NULL");
}
#[test]
fn test_table_builder_create() {
let mut builder = TableBuilder::new("users", DatabaseType::Postgres);
builder.id();
builder.string("email").unique().not_null();
builder.string("name").not_null();
builder.boolean("active").default(true);
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("CREATE TABLE"));
assert!(sql.contains("\"users\""));
assert!(sql.contains("\"id\" BIGSERIAL"));
assert!(sql.contains("\"email\""));
assert!(sql.contains("\"name\""));
assert!(sql.contains("\"active\""));
assert!(sql.contains("\"created_at\""));
assert!(sql.contains("\"updated_at\""));
}
#[test]
fn test_timestamps_feature() {
let mut builder = TableBuilder::new("posts", DatabaseType::Postgres);
builder.id();
builder.string("title").not_null();
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("\"created_at\" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"PostgreSQL should have created_at with TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
assert!(sql.contains("\"updated_at\" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"PostgreSQL should have updated_at with TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
let mut builder = TableBuilder::new("posts", DatabaseType::MySQL);
builder.id();
builder.string("title").not_null();
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("`created_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"MySQL should have created_at with TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
assert!(sql.contains("`updated_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"MySQL should have updated_at with TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
let mut builder = TableBuilder::new("posts", DatabaseType::MariaDB);
builder.id();
builder.string("title").not_null();
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("`created_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"MariaDB should have created_at with TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
assert!(sql.contains("`updated_at` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"MariaDB should have updated_at with TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
let mut builder = TableBuilder::new("posts", DatabaseType::SQLite);
builder.id();
builder.string("title").not_null();
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("\"created_at\" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"SQLite should have created_at with TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
assert!(sql.contains("\"updated_at\" TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"SQLite should have updated_at with TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP. Got: {}", sql);
}
#[test]
fn test_timestamps_naive_feature() {
let mut builder = TableBuilder::new("logs", DatabaseType::Postgres);
builder.id();
builder.text("message").not_null();
builder.timestamps_naive();
let sql = builder.build_create();
assert!(sql.contains("\"created_at\" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"PostgreSQL timestamps_naive should use TIMESTAMP. Got: {}", sql);
assert!(sql.contains("\"updated_at\" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP"),
"PostgreSQL timestamps_naive should use TIMESTAMP. Got: {}", sql);
}
#[test]
fn test_timestamptz_column() {
let mut builder = TableBuilder::new("sessions", DatabaseType::Postgres);
builder.id();
builder.string("token").not_null();
builder.timestamptz("expires_at").not_null();
builder.timestamptz("last_activity").nullable();
let sql = builder.build_create();
assert!(sql.contains("\"expires_at\" TIMESTAMPTZ NOT NULL"),
"Should have expires_at as TIMESTAMPTZ NOT NULL. Got: {}", sql);
assert!(sql.contains("\"last_activity\" TIMESTAMPTZ"),
"Should have last_activity as TIMESTAMPTZ. Got: {}", sql);
assert!(!sql.contains("\"last_activity\" TIMESTAMPTZ NOT NULL"),
"last_activity should be nullable. Got: {}", sql);
}
#[test]
fn test_timestamp_vs_timestamptz() {
let mut builder = TableBuilder::new("events", DatabaseType::Postgres);
builder.id();
builder.timestamp("local_time"); builder.timestamptz("utc_time");
let sql = builder.build_create();
assert!(sql.contains("\"local_time\" TIMESTAMP"),
"timestamp() should produce TIMESTAMP. Got: {}", sql);
assert!(sql.contains("\"utc_time\" TIMESTAMPTZ"),
"timestamptz() should produce TIMESTAMPTZ. Got: {}", sql);
}
#[test]
fn test_date_time_columns() {
let mut builder = TableBuilder::new("schedules", DatabaseType::Postgres);
builder.id();
builder.date("event_date");
builder.time("start_time");
builder.datetime("local_datetime");
builder.timestamp("naive_timestamp");
builder.timestamptz("utc_timestamp");
let sql = builder.build_create();
assert!(sql.contains("\"event_date\" DATE"), "Should have DATE column. Got: {}", sql);
assert!(sql.contains("\"start_time\" TIME"), "Should have TIME column. Got: {}", sql);
assert!(sql.contains("\"local_datetime\" TIMESTAMP"), "Should have TIMESTAMP for datetime. Got: {}", sql);
assert!(sql.contains("\"naive_timestamp\" TIMESTAMP"), "Should have TIMESTAMP. Got: {}", sql);
assert!(sql.contains("\"utc_timestamp\" TIMESTAMPTZ"), "Should have TIMESTAMPTZ. Got: {}", sql);
}
#[test]
fn test_soft_deletes_feature() {
let mut builder = TableBuilder::new("posts", DatabaseType::Postgres);
builder.id();
builder.soft_deletes();
let sql = builder.build_create();
assert!(sql.contains("\"deleted_at\" TIMESTAMPTZ"),
"Should have deleted_at TIMESTAMPTZ column. Got: {}", sql);
assert!(!sql.contains("\"deleted_at\" TIMESTAMPTZ NOT NULL"),
"deleted_at should be nullable (no NOT NULL). Got: {}", sql);
}
#[test]
fn test_alter_table_builder() {
let mut builder = AlterTableBuilder::new("users", DatabaseType::Postgres);
builder.add_column("phone", ColumnType::String).nullable();
builder.drop_column("legacy");
builder.rename_column("name", "full_name");
let statements = builder.build();
assert_eq!(statements.len(), 3);
assert!(statements[0].contains("ADD COLUMN"));
assert!(statements[1].contains("DROP COLUMN"));
assert!(statements[2].contains("RENAME COLUMN"));
}
#[test]
fn test_multi_column_unique_constraint() {
let mut builder = TableBuilder::new("user_roles", DatabaseType::Postgres);
builder.big_integer("user_id").not_null();
builder.big_integer("role_id").not_null();
builder.unique(&["user_id", "role_id"]);
let sql = builder.build_create();
assert!(sql.contains("UNIQUE (\"user_id\", \"role_id\")"),
"Should have multi-column unique constraint. Got: {}", sql);
let mut builder = TableBuilder::new("users", DatabaseType::Postgres);
builder.id();
builder.string("email").not_null();
builder.big_integer("tenant_id").not_null();
builder.unique_named("uq_user_email_tenant", &["email", "tenant_id"]);
let sql = builder.build_create();
assert!(sql.contains("CONSTRAINT \"uq_user_email_tenant\" UNIQUE (\"email\", \"tenant_id\")"),
"Should have named unique constraint. Got: {}", sql);
}
#[test]
fn test_composite_primary_key() {
let mut builder = TableBuilder::new("user_roles", DatabaseType::Postgres);
builder.big_integer("user_id").not_null();
builder.big_integer("role_id").not_null();
builder.timestamps();
builder.primary_key(&["user_id", "role_id"]);
let sql = builder.build_create();
assert!(sql.contains("PRIMARY KEY (\"user_id\", \"role_id\")"),
"Should have composite primary key. Got: {}", sql);
assert!(!sql.contains("BIGINT PRIMARY KEY"),
"Individual columns should not be marked as primary key. Got: {}", sql);
}
#[test]
fn test_check_constraint() {
let mut builder = TableBuilder::new("products", DatabaseType::Postgres);
builder.id();
builder.decimal("price").check("price >= 0");
builder.integer("quantity").check("quantity >= 0");
let sql = builder.build_create();
assert!(sql.contains("CHECK (price >= 0)"),
"Should have CHECK constraint on price. Got: {}", sql);
assert!(sql.contains("CHECK (quantity >= 0)"),
"Should have CHECK constraint on quantity. Got: {}", sql);
}
#[test]
fn test_extra_sql_attribute() {
let mut builder = TableBuilder::new("logs", DatabaseType::MySQL);
builder.id();
builder.text("message").extra("COLLATE utf8mb4_unicode_ci");
let sql = builder.build_create();
assert!(sql.contains("COLLATE utf8mb4_unicode_ci"),
"Should include extra SQL. Got: {}", sql);
let mut builder = TableBuilder::new("logs", DatabaseType::MariaDB);
builder.id();
builder.text("message").extra("COLLATE utf8mb4_unicode_ci");
let sql = builder.build_create();
assert!(sql.contains("COLLATE utf8mb4_unicode_ci"),
"MariaDB should include extra SQL. Got: {}", sql);
}
#[test]
fn test_column_type_mariadb() {
assert_eq!(ColumnType::Integer.to_mysql_sql(), "INT");
assert_eq!(ColumnType::BigInteger.to_mysql_sql(), "BIGINT");
assert_eq!(ColumnType::Boolean.to_mysql_sql(), "TINYINT(1)");
assert_eq!(ColumnType::Jsonb.to_mysql_sql(), "JSON");
assert_eq!(ColumnType::Timestamp.to_mysql_sql(), "TIMESTAMP");
assert_eq!(ColumnType::TimestampTz.to_mysql_sql(), "TIMESTAMP");
assert_eq!(ColumnType::Date.to_mysql_sql(), "DATE");
assert_eq!(ColumnType::Time.to_mysql_sql(), "TIME");
}
#[test]
fn test_mariadb_table_builder_create() {
let mut builder = TableBuilder::new("users", DatabaseType::MariaDB);
builder.id();
builder.string("email").unique().not_null();
builder.string("name").not_null();
builder.boolean("active").default(true);
builder.timestamps();
let sql = builder.build_create();
assert!(sql.contains("CREATE TABLE"));
assert!(sql.contains("`users`"));
assert!(sql.contains("`id` BIGINT AUTO_INCREMENT"));
assert!(sql.contains("`email`"));
assert!(sql.contains("`name`"));
assert!(sql.contains("`active`"));
assert!(sql.contains("`created_at`"));
assert!(sql.contains("`updated_at`"));
}
#[test]
fn test_mariadb_alter_table_builder() {
let mut builder = AlterTableBuilder::new("users", DatabaseType::MariaDB);
builder.add_column("phone", ColumnType::String).nullable();
builder.drop_column("legacy");
builder.rename_column("name", "full_name");
let statements = builder.build();
assert_eq!(statements.len(), 3);
assert!(statements[0].contains("ADD COLUMN"));
assert!(statements[0].contains("`phone`"));
assert!(statements[1].contains("DROP COLUMN"));
assert!(statements[1].contains("`legacy`"));
assert!(statements[2].contains("RENAME COLUMN"));
assert!(statements[2].contains("`name`"));
assert!(statements[2].contains("`full_name`"));
}
#[test]
fn test_mariadb_quoting() {
let schema = Schema::new(DatabaseType::MariaDB);
assert_eq!(schema.quote_identifier("users"), "`users`");
assert_eq!(schema.quote_identifier("email"), "`email`");
}
}