use crate::{
BlueprintExecutor, MigrationError,
schema::render::{
default_index_name, infer_referenced_table, render_column, render_constraint,
render_foreign_key,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SchemaDialect {
Sqlite,
Postgres,
MariaDb,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ColumnType {
Integer,
BigInt,
Bool,
Char(u32),
Varchar(u32),
Text,
Date,
Time,
DateTime,
Timestamp,
Decimal(u32, u32),
Float,
Double,
Json,
Uuid,
Custom(String),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum TableAlterOperation {
DropColumn(String),
AddColumn(ColumnDef),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum ConstraintDef {
Primary {
columns: Vec<String>,
},
Unique {
name: Option<String>,
columns: Vec<String>,
},
Check {
name: Option<String>,
expression: String,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ColumnDef {
pub(crate) name: String,
pub(crate) ty: ColumnType,
pub(crate) nullable: bool,
pub(crate) primary_key: bool,
pub(crate) auto_increment: bool,
pub(crate) unique: bool,
pub(crate) default_raw: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DefaultValue {
pub(crate) sql: String,
}
impl DefaultValue {
pub fn raw(sql: impl Into<String>) -> Self {
Self { sql: sql.into() }
}
}
pub fn current_timestamp() -> DefaultValue {
DefaultValue::raw("current_timestamp")
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct ForeignKeyDef {
pub(crate) column: String,
pub(crate) references_table: String,
pub(crate) references_column: String,
pub(crate) on_delete: Option<ForeignKeyAction>,
pub(crate) on_update: Option<ForeignKeyAction>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ForeignKeyAction {
Cascade,
Restrict,
SetNull,
NoAction,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct IndexBlueprint {
name: String,
table: String,
columns: Vec<String>,
unique: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TableBlueprint {
name: String,
columns: Vec<ColumnDef>,
foreign_keys: Vec<ForeignKeyDef>,
constraints: Vec<ConstraintDef>,
indexes: Vec<IndexBlueprint>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AlterTableBlueprint {
name: String,
operations: Vec<TableAlterOperation>,
}
pub trait IntoSchemaColumns {
fn into_schema_columns(self) -> Vec<String>;
}
impl IntoSchemaColumns for &str {
fn into_schema_columns(self) -> Vec<String> {
vec![self.to_owned()]
}
}
impl IntoSchemaColumns for String {
fn into_schema_columns(self) -> Vec<String> {
vec![self]
}
}
impl<const N: usize> IntoSchemaColumns for [&str; N] {
fn into_schema_columns(self) -> Vec<String> {
self.into_iter().map(str::to_owned).collect()
}
}
impl<const N: usize> IntoSchemaColumns for [String; N] {
fn into_schema_columns(self) -> Vec<String> {
self.into_iter().collect()
}
}
impl IntoSchemaColumns for Vec<&str> {
fn into_schema_columns(self) -> Vec<String> {
self.into_iter().map(str::to_owned).collect()
}
}
impl IntoSchemaColumns for Vec<String> {
fn into_schema_columns(self) -> Vec<String> {
self
}
}
macro_rules! impl_into_schema_columns_tuple {
($(($type_name:ident, $value_name:ident)),+ $(,)?) => {
impl<$($type_name),+> IntoSchemaColumns for ($($type_name,)+)
where
$($type_name: AsRef<str>,)+
{
fn into_schema_columns(self) -> Vec<String> {
let ($($value_name,)+) = self;
vec![$($value_name.as_ref().to_owned(),)+]
}
}
};
}
impl_into_schema_columns_tuple!((A, a), (B, b));
impl_into_schema_columns_tuple!((A, a), (B, b), (C, c));
impl_into_schema_columns_tuple!((A, a), (B, b), (C, c), (D, d));
impl TableBlueprint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
columns: Vec::new(),
foreign_keys: Vec::new(),
constraints: Vec::new(),
indexes: Vec::new(),
}
}
pub fn id(&mut self) {
self.big_increments("id");
}
pub fn increments(&mut self, name: &str) {
self.columns.push(ColumnDef {
name: name.to_owned(),
ty: ColumnType::Integer,
nullable: false,
primary_key: true,
auto_increment: true,
unique: false,
default_raw: None,
});
}
pub fn big_increments(&mut self, name: &str) {
self.columns.push(ColumnDef {
name: name.to_owned(),
ty: ColumnType::BigInt,
nullable: false,
primary_key: true,
auto_increment: true,
unique: false,
default_raw: None,
});
}
pub fn string(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Varchar(255))
}
pub fn char(&mut self, name: &str, len: u32) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Char(len))
}
pub fn varchar(&mut self, name: &str, len: u32) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Varchar(len))
}
pub fn text(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Text)
}
pub fn integer(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Integer)
}
pub fn bigint(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::BigInt)
}
pub fn boolean(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Bool)
}
pub fn date(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Date)
}
pub fn time(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Time)
}
pub fn datetime(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::DateTime)
}
pub fn timestamp(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Timestamp)
}
pub fn decimal(&mut self, name: &str, precision: u32, scale: u32) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Decimal(precision, scale))
}
pub fn float(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Float)
}
pub fn double(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Double)
}
pub fn json(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Json)
}
pub fn uuid(&mut self, name: &str) -> ColumnBuilder<'_> {
self.push_column(name, ColumnType::Uuid)
}
pub(crate) fn custom(&mut self, name: &str, ty: ColumnType) -> ColumnBuilder<'_> {
self.push_column(name, ty)
}
pub fn timestamps(&mut self) {
self.timestamp("created_at").default(current_timestamp());
self.timestamp("updated_at").default(current_timestamp());
}
pub fn remember_token(&mut self) -> ColumnBuilder<'_> {
self.push_column("remember_token", ColumnType::Varchar(100))
.nullable()
}
pub fn foreign_id(&mut self, column: &str) -> ForeignKeyBuilder<'_> {
self.foreign(column, ColumnType::BigInt)
}
pub fn foreign(&mut self, column: &str, ty: ColumnType) -> ForeignKeyBuilder<'_> {
self.columns.push(ColumnDef {
name: column.to_owned(),
ty,
nullable: false,
primary_key: false,
auto_increment: false,
unique: false,
default_raw: None,
});
let index = self.columns.len() - 1;
ForeignKeyBuilder {
table: self,
index,
foreign_key: None,
}
}
pub fn unique<I>(&mut self, columns: I)
where
I: IntoSchemaColumns,
{
self.constraints.push(ConstraintDef::Unique {
name: None,
columns: columns.into_schema_columns(),
});
}
pub fn primary<I>(&mut self, columns: I)
where
I: IntoSchemaColumns,
{
self.constraints.push(ConstraintDef::Primary {
columns: columns.into_schema_columns(),
});
}
pub fn unique_named<I>(&mut self, name: &str, columns: I)
where
I: IntoSchemaColumns,
{
self.constraints.push(ConstraintDef::Unique {
name: Some(name.to_owned()),
columns: columns.into_schema_columns(),
});
}
pub fn check(&mut self, expression: &str) {
self.constraints.push(ConstraintDef::Check {
name: None,
expression: expression.to_owned(),
});
}
pub fn constraint(&mut self, name: &str) -> ConstraintBuilder<'_> {
ConstraintBuilder {
table: self,
name: name.to_owned(),
}
}
pub fn check_named(&mut self, name: &str, expression: &str) {
self.constraints.push(ConstraintDef::Check {
name: Some(name.to_owned()),
expression: expression.to_owned(),
});
}
pub fn index<I>(&mut self, name: &str, columns: I)
where
I: IntoSchemaColumns,
{
self.indexes
.push(IndexBlueprint::new(name, self.name.as_str(), columns));
}
pub fn unique_index<I>(&mut self, name: &str, columns: I)
where
I: IntoSchemaColumns,
{
self.indexes.push(IndexBlueprint::new_unique(
name,
self.name.as_str(),
columns,
));
}
fn push_column(&mut self, name: &str, ty: ColumnType) -> ColumnBuilder<'_> {
self.columns.push(ColumnDef {
name: name.to_owned(),
ty,
nullable: false,
primary_key: false,
auto_increment: false,
unique: false,
default_raw: None,
});
let index = self.columns.len() - 1;
ColumnBuilder { table: self, index }
}
pub fn create_sql(&self, dialect: SchemaDialect) -> String {
self.create_statements(dialect)
.into_iter()
.next()
.expect("create table statements are never empty")
}
pub fn create_statements(&self, dialect: SchemaDialect) -> Vec<String> {
let name = dialect.quote_ident(&self.name);
let mut defs = self
.columns
.iter()
.map(|column| render_column(dialect, column))
.collect::<Vec<_>>();
defs.extend(
self.foreign_keys
.iter()
.map(|foreign| render_foreign_key(dialect, foreign)),
);
defs.extend(
self.constraints
.iter()
.map(|constraint| render_constraint(dialect, constraint)),
);
let mut statements = vec![format!("create table {name} ({});", defs.join(", "))];
statements.extend(self.indexes.iter().map(|index| index.create_sql(dialect)));
statements
}
pub fn drop_sql(&self, dialect: SchemaDialect) -> String {
format!("drop table if exists {};", dialect.quote_ident(&self.name))
}
pub async fn create<C>(self, ctx: &mut C) -> Result<(), MigrationError>
where
C: BlueprintExecutor,
{
for sql in self.create_statements(ctx.dialect()) {
ctx.execute_raw_blueprint(&sql).await?;
}
Ok(())
}
}
impl IndexBlueprint {
pub fn named(name: &str) -> Self {
Self {
name: name.to_owned(),
table: String::new(),
columns: Vec::new(),
unique: false,
}
}
pub fn new<I>(name: &str, table: &str, columns: I) -> Self
where
I: IntoSchemaColumns,
{
Self {
name: name.to_owned(),
table: table.to_owned(),
columns: columns.into_schema_columns(),
unique: false,
}
}
pub fn new_unique<I>(name: &str, table: &str, columns: I) -> Self
where
I: IntoSchemaColumns,
{
Self {
name: name.to_owned(),
table: table.to_owned(),
columns: columns.into_schema_columns(),
unique: true,
}
}
pub fn create_sql(&self, dialect: SchemaDialect) -> String {
let unique = if self.unique { "unique " } else { "" };
let name = dialect.quote_ident(&self.name);
let table = dialect.quote_ident(&self.table);
let columns = self
.columns
.iter()
.map(|column| dialect.quote_ident(column))
.collect::<Vec<_>>()
.join(", ");
format!("create {unique}index {name} on {table} ({columns});")
}
pub fn drop_sql(&self, dialect: SchemaDialect) -> String {
format!("drop index if exists {};", dialect.quote_ident(&self.name))
}
}
impl AlterTableBlueprint {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
operations: Vec::new(),
}
}
pub fn drop_column(&mut self, name: &str) {
self.operations
.push(TableAlterOperation::DropColumn(name.to_owned()));
}
pub fn string(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Varchar(255))
}
pub fn text(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Text)
}
pub fn varchar(&mut self, name: &str, len: u32) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Varchar(len))
}
pub fn integer(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Integer)
}
pub fn bigint(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::BigInt)
}
pub fn boolean(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Bool)
}
pub fn date(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Date)
}
pub fn time(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Time)
}
pub fn datetime(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::DateTime)
}
pub fn timestamp(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Timestamp)
}
pub fn uuid(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Uuid)
}
pub(crate) fn custom(&mut self, name: &str, ty: ColumnType) -> AlterColumnBuilder<'_> {
self.push_column(name, ty)
}
pub fn decimal(&mut self, name: &str, precision: u32, scale: u32) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Decimal(precision, scale))
}
pub fn float(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Float)
}
pub fn double(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Double)
}
pub fn json(&mut self, name: &str) -> AlterColumnBuilder<'_> {
self.push_column(name, ColumnType::Json)
}
pub fn drop_columns<const N: usize>(&mut self, names: [&str; N]) {
for name in names {
self.drop_column(name);
}
}
pub fn drop_timestamps(&mut self) {
self.drop_columns(["created_at", "updated_at"]);
}
fn push_column(&mut self, name: &str, ty: ColumnType) -> AlterColumnBuilder<'_> {
self.operations
.push(TableAlterOperation::AddColumn(ColumnDef {
name: name.to_owned(),
ty,
nullable: false,
primary_key: false,
auto_increment: false,
unique: false,
default_raw: None,
}));
let index = self.operations.len() - 1;
AlterColumnBuilder { table: self, index }
}
pub(crate) fn sql_statements(&self, dialect: SchemaDialect) -> Vec<String> {
if self.operations.is_empty() {
return Vec::new();
}
let table_name = dialect.quote_ident(&self.name);
match dialect {
SchemaDialect::Sqlite => self
.operations
.iter()
.map(|operation| match operation {
TableAlterOperation::DropColumn(name) => format!(
"alter table {table_name} drop column {};",
dialect.quote_ident(name)
),
TableAlterOperation::AddColumn(column) => format!(
"alter table {table_name} add column {};",
render_column(dialect, column)
),
})
.collect(),
SchemaDialect::Postgres | SchemaDialect::MariaDb => {
let actions = self
.operations
.iter()
.map(|operation| match operation {
TableAlterOperation::DropColumn(name) => {
format!("drop column {}", dialect.quote_ident(name))
}
TableAlterOperation::AddColumn(column) => {
format!("add column {}", render_column(dialect, column))
}
})
.collect::<Vec<_>>();
vec![format!("alter table {table_name} {};", actions.join(", "))]
}
}
}
}
pub struct ColumnBuilder<'a> {
table: &'a mut TableBlueprint,
index: usize,
}
pub struct AlterColumnBuilder<'a> {
table: &'a mut AlterTableBlueprint,
index: usize,
}
pub struct ConstraintBuilder<'a> {
table: &'a mut TableBlueprint,
name: String,
}
impl<'a> ColumnBuilder<'a> {
pub fn index(self) -> Self {
let table_name = self.table.name.clone();
let column_name = self.table.columns[self.index].name.clone();
let name = default_index_name(&table_name, &column_name);
self.table.indexes.push(IndexBlueprint::new(
&name,
&table_name,
[column_name.as_str()],
));
self
}
pub fn default(self, value: DefaultValue) -> Self {
self.table.columns[self.index].default_raw = Some(value.sql);
self
}
pub fn nullable(self) -> Self {
self.table.columns[self.index].nullable = true;
self
}
pub fn unique(self) -> Self {
self.table.columns[self.index].unique = true;
self
}
pub fn default_raw(self, value: &str) -> Self {
self.table.columns[self.index].default_raw = Some(value.to_owned());
self
}
}
impl<'a> AlterColumnBuilder<'a> {
pub fn default(self, value: DefaultValue) -> Self {
if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
column.default_raw = Some(value.sql);
}
self
}
pub fn nullable(self) -> Self {
if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
column.nullable = true;
}
self
}
pub fn unique(self) -> Self {
if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
column.unique = true;
}
self
}
pub fn default_raw(self, value: &str) -> Self {
if let TableAlterOperation::AddColumn(column) = &mut self.table.operations[self.index] {
column.default_raw = Some(value.to_owned());
}
self
}
}
impl<'a> ConstraintBuilder<'a> {
pub fn unique<I>(self, columns: I)
where
I: IntoSchemaColumns,
{
self.table.constraints.push(ConstraintDef::Unique {
name: Some(self.name),
columns: columns.into_schema_columns(),
});
}
pub fn check(self, expression: &str) {
self.table.constraints.push(ConstraintDef::Check {
name: Some(self.name),
expression: expression.to_owned(),
});
}
}
pub struct ForeignKeyBuilder<'a> {
table: &'a mut TableBlueprint,
index: usize,
foreign_key: Option<usize>,
}
impl<'a> ForeignKeyBuilder<'a> {
pub fn index(self) -> Self {
let table_name = self.table.name.clone();
let column_name = self.table.columns[self.index].name.clone();
let name = default_index_name(&table_name, &column_name);
self.table.indexes.push(IndexBlueprint::new(
&name,
&table_name,
[column_name.as_str()],
));
self
}
pub fn default(self, value: DefaultValue) -> Self {
self.table.columns[self.index].default_raw = Some(value.sql);
self
}
pub fn nullable(self) -> Self {
self.table.columns[self.index].nullable = true;
self
}
pub fn unique(self) -> Self {
self.table.columns[self.index].unique = true;
self
}
pub fn default_raw(self, value: &str) -> Self {
self.table.columns[self.index].default_raw = Some(value.to_owned());
self
}
pub fn constrained(self) -> Self {
let column = self.table.columns[self.index].name.clone();
let referenced_table = infer_referenced_table(&column);
self.references(&referenced_table)
}
pub fn references(self, table: &str) -> Self {
self.references_column(table, "id")
}
pub fn references_column(mut self, table: &str, column: &str) -> Self {
let foreign_key = ForeignKeyDef {
column: self.table.columns[self.index].name.clone(),
references_table: table.to_owned(),
references_column: column.to_owned(),
on_delete: None,
on_update: None,
};
match self.foreign_key {
Some(index) => self.table.foreign_keys[index] = foreign_key,
None => {
self.table.foreign_keys.push(foreign_key);
self.foreign_key = Some(self.table.foreign_keys.len() - 1);
}
}
self
}
pub fn cascade_on_delete(self) -> Self {
self.with_on_delete(ForeignKeyAction::Cascade)
}
pub fn restrict_on_delete(self) -> Self {
self.with_on_delete(ForeignKeyAction::Restrict)
}
pub fn null_on_delete(self) -> Self {
self.with_on_delete(ForeignKeyAction::SetNull)
}
pub fn no_action_on_delete(self) -> Self {
self.with_on_delete(ForeignKeyAction::NoAction)
}
pub fn cascade_on_update(self) -> Self {
self.with_on_update(ForeignKeyAction::Cascade)
}
pub fn restrict_on_update(self) -> Self {
self.with_on_update(ForeignKeyAction::Restrict)
}
pub fn null_on_update(self) -> Self {
self.with_on_update(ForeignKeyAction::SetNull)
}
pub fn no_action_on_update(self) -> Self {
self.with_on_update(ForeignKeyAction::NoAction)
}
fn with_on_delete(mut self, action: ForeignKeyAction) -> Self {
let index = self.ensure_foreign_key();
self.table.foreign_keys[index].on_delete = Some(action);
self
}
fn with_on_update(mut self, action: ForeignKeyAction) -> Self {
let index = self.ensure_foreign_key();
self.table.foreign_keys[index].on_update = Some(action);
self
}
fn ensure_foreign_key(&mut self) -> usize {
if let Some(index) = self.foreign_key {
index
} else {
self.table.foreign_keys.push(ForeignKeyDef {
column: self.table.columns[self.index].name.clone(),
references_table: String::new(),
references_column: "id".to_owned(),
on_delete: None,
on_update: None,
});
let index = self.table.foreign_keys.len() - 1;
self.foreign_key = Some(index);
index
}
}
}