pub mod fields;
pub mod models;
pub mod postgres;
pub mod special;
mod to_tokens;
pub use fields::{AddField, AlterField, RemoveField, RenameField};
pub use models::{CreateModel, DeleteModel, FieldDefinition, MoveModel, RenameModel};
pub use postgres::{CreateCollation, CreateExtension, DropExtension};
pub use special::{RunCode, RunSQL, StateOperation};
use super::{FieldState, FieldType, ModelState, ProjectState};
use pg_escape::{quote_identifier, quote_literal};
use reinhardt_query::prelude::{
Alias, AlterTableStatement, ColumnDef, CreateIndexStatement, CreateTableStatement,
DropIndexStatement, DropTableStatement, Query, SimpleExpr, Value,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum IndexType {
#[default]
BTree,
Hash,
Gin,
Gist,
Brin,
Fulltext,
Spatial,
}
impl std::fmt::Display for IndexType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
IndexType::BTree => write!(f, "btree"),
IndexType::Hash => write!(f, "hash"),
IndexType::Gin => write!(f, "gin"),
IndexType::Gist => write!(f, "gist"),
IndexType::Brin => write!(f, "brin"),
IndexType::Fulltext => write!(f, "fulltext"),
IndexType::Spatial => write!(f, "spatial"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum MySqlAlgorithm {
Instant,
Inplace,
Copy,
#[default]
Default,
}
impl std::fmt::Display for MySqlAlgorithm {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MySqlAlgorithm::Instant => write!(f, "INSTANT"),
MySqlAlgorithm::Inplace => write!(f, "INPLACE"),
MySqlAlgorithm::Copy => write!(f, "COPY"),
MySqlAlgorithm::Default => write!(f, "DEFAULT"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "UPPERCASE")]
pub enum MySqlLock {
None,
Shared,
Exclusive,
#[default]
Default,
}
impl std::fmt::Display for MySqlLock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MySqlLock::None => write!(f, "NONE"),
MySqlLock::Shared => write!(f, "SHARED"),
MySqlLock::Exclusive => write!(f, "EXCLUSIVE"),
MySqlLock::Default => write!(f, "DEFAULT"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
pub struct AlterTableOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub algorithm: Option<MySqlAlgorithm>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub lock: Option<MySqlLock>,
}
impl AlterTableOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_algorithm(mut self, algorithm: MySqlAlgorithm) -> Self {
self.algorithm = Some(algorithm);
self
}
pub fn with_lock(mut self, lock: MySqlLock) -> Self {
self.lock = Some(lock);
self
}
pub fn is_empty(&self) -> bool {
self.algorithm.is_none() && self.lock.is_none()
}
pub fn to_sql_suffix(&self) -> String {
let mut parts = Vec::new();
if let Some(algo) = &self.algorithm
&& *algo != MySqlAlgorithm::Default
{
parts.push(format!("ALGORITHM={}", algo));
}
if let Some(lock) = &self.lock
&& *lock != MySqlLock::Default
{
parts.push(format!("LOCK={}", lock));
}
if parts.is_empty() {
String::new()
} else {
format!(", {}", parts.join(", "))
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum PartitionType {
Range,
List,
Hash,
Key,
}
impl std::fmt::Display for PartitionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PartitionType::Range => write!(f, "RANGE"),
PartitionType::List => write!(f, "LIST"),
PartitionType::Hash => write!(f, "HASH"),
PartitionType::Key => write!(f, "KEY"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum PartitionValues {
LessThan(String),
In(Vec<String>),
ModuloCount(u32),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PartitionDef {
pub name: String,
pub values: PartitionValues,
}
impl PartitionDef {
pub fn new(name: impl Into<String>, values: PartitionValues) -> Self {
Self {
name: name.into(),
values,
}
}
pub fn less_than(name: impl Into<String>, value: impl Into<String>) -> Self {
Self::new(name, PartitionValues::LessThan(value.into()))
}
pub fn maxvalue(name: impl Into<String>) -> Self {
Self::new(name, PartitionValues::LessThan("MAXVALUE".to_string()))
}
pub fn list_in(name: impl Into<String>, values: Vec<String>) -> Self {
Self::new(name, PartitionValues::In(values))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct InterleaveSpec {
pub parent_table: String,
pub parent_columns: Vec<String>,
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PartitionOptions {
pub partition_type: PartitionType,
pub column: String,
pub partitions: Vec<PartitionDef>,
}
impl PartitionOptions {
pub fn new(
partition_type: PartitionType,
column: impl Into<String>,
partitions: Vec<PartitionDef>,
) -> Self {
Self {
partition_type,
column: column.into(),
partitions,
}
}
pub fn range(column: impl Into<String>, partitions: Vec<PartitionDef>) -> Self {
Self::new(PartitionType::Range, column, partitions)
}
pub fn list(column: impl Into<String>, partitions: Vec<PartitionDef>) -> Self {
Self::new(PartitionType::List, column, partitions)
}
pub fn hash(column: impl Into<String>, num_partitions: u32) -> Self {
Self::new(
PartitionType::Hash,
column,
vec![PartitionDef::new(
"",
PartitionValues::ModuloCount(num_partitions),
)],
)
}
pub fn key(column: impl Into<String>, num_partitions: u32) -> Self {
Self::new(
PartitionType::Key,
column,
vec![PartitionDef::new(
"",
PartitionValues::ModuloCount(num_partitions),
)],
)
}
pub fn to_sql(&self) -> String {
let mut sql = format!("PARTITION BY {}({})", self.partition_type, self.column);
match self.partition_type {
PartitionType::Hash | PartitionType::Key => {
if let Some(p) = self.partitions.first()
&& let PartitionValues::ModuloCount(n) = &p.values
{
sql.push_str(&format!(" PARTITIONS {}", n));
}
}
PartitionType::Range | PartitionType::List => {
sql.push_str(" (");
let defs: Vec<String> = self
.partitions
.iter()
.map(|p| {
let vals = match &p.values {
PartitionValues::LessThan(v) => {
if v == "MAXVALUE" {
"VALUES LESS THAN MAXVALUE".to_string()
} else {
format!("VALUES LESS THAN ('{}')", v)
}
}
PartitionValues::In(v) => format!(
"VALUES IN ({})",
v.iter()
.map(|x| format!("'{}'", x))
.collect::<Vec<_>>()
.join(", ")
),
PartitionValues::ModuloCount(_) => String::new(),
};
format!("PARTITION {} {}", p.name, vals)
})
.collect();
sql.push_str(&defs.join(", "));
sql.push(')');
}
}
sql
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum DeferrableOption {
Immediate,
Deferred,
}
impl std::fmt::Display for DeferrableOption {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeferrableOption::Immediate => write!(f, "DEFERRABLE INITIALLY IMMEDIATE"),
DeferrableOption::Deferred => write!(f, "DEFERRABLE INITIALLY DEFERRED"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
#[serde(tag = "type")]
pub enum Constraint {
PrimaryKey {
name: String,
columns: Vec<String>,
},
ForeignKey {
name: String,
columns: Vec<String>,
referenced_table: String,
referenced_columns: Vec<String>,
on_delete: super::ForeignKeyAction,
on_update: super::ForeignKeyAction,
#[serde(default, skip_serializing_if = "Option::is_none")]
deferrable: Option<DeferrableOption>,
},
Unique {
name: String,
columns: Vec<String>,
},
Check {
name: String,
expression: String,
},
OneToOne {
name: String,
column: String,
referenced_table: String,
referenced_column: String,
on_delete: super::ForeignKeyAction,
on_update: super::ForeignKeyAction,
#[serde(default, skip_serializing_if = "Option::is_none")]
deferrable: Option<DeferrableOption>,
},
ManyToMany {
name: String,
through_table: String,
source_column: String,
target_column: String,
target_table: String,
},
Exclude {
name: String,
elements: Vec<(String, String)>,
#[serde(default, skip_serializing_if = "Option::is_none")]
using: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
where_clause: Option<String>,
},
}
impl std::fmt::Display for Constraint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Constraint::PrimaryKey { name, columns } => {
write!(
f,
"CONSTRAINT {} PRIMARY KEY ({})",
name,
columns.join(", ")
)
}
Constraint::ForeignKey {
name,
columns,
referenced_table,
referenced_columns,
on_delete,
on_update,
deferrable,
} => {
write!(
f,
"CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({}) ON DELETE {} ON UPDATE {}",
name,
columns.join(", "),
referenced_table,
referenced_columns.join(", "),
on_delete.to_sql_keyword(),
on_update.to_sql_keyword()
)?;
if let Some(defer_opt) = deferrable {
write!(f, " {}", defer_opt)?;
}
Ok(())
}
Constraint::Unique { name, columns } => {
write!(f, "CONSTRAINT {} UNIQUE ({})", name, columns.join(", "))
}
Constraint::Check { name, expression } => {
write!(f, "CONSTRAINT {} CHECK ({})", name, expression)
}
Constraint::OneToOne {
name,
column,
referenced_table,
referenced_column,
on_delete,
on_update,
deferrable,
} => {
write!(
f,
"CONSTRAINT {} FOREIGN KEY ({}) REFERENCES {}({}) ON DELETE {} ON UPDATE {}",
name,
column,
referenced_table,
referenced_column,
on_delete.to_sql_keyword(),
on_update.to_sql_keyword()
)?;
if let Some(defer_opt) = deferrable {
write!(f, " {}", defer_opt)?;
}
write!(f, ", CONSTRAINT {}_unique UNIQUE ({})", name, column)
}
Constraint::ManyToMany { through_table, .. } => {
write!(f, "-- ManyToMany via {}", through_table)
}
Constraint::Exclude {
name,
elements,
using,
where_clause,
} => {
let elements_str: Vec<String> = elements
.iter()
.map(|(col, op)| format!("{} WITH {}", col, op))
.collect();
let using_str = using.as_deref().unwrap_or("gist");
if let Some(where_cl) = where_clause {
write!(
f,
"CONSTRAINT {} EXCLUDE USING {} ({}) WHERE ({})",
name,
using_str,
elements_str.join(", "),
where_cl
)
} else {
write!(
f,
"CONSTRAINT {} EXCLUDE USING {} ({})",
name,
using_str,
elements_str.join(", ")
)
}
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", content = "value")]
pub enum BulkLoadSource {
File(String),
Stdin,
Program(String),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum BulkLoadFormat {
#[default]
Text,
Csv,
Binary,
}
impl std::fmt::Display for BulkLoadFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BulkLoadFormat::Text => write!(f, "TEXT"),
BulkLoadFormat::Csv => write!(f, "CSV"),
BulkLoadFormat::Binary => write!(f, "BINARY"),
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct BulkLoadOptions {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub delimiter: Option<char>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub null_string: Option<String>,
#[serde(default)]
pub header: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub columns: Option<Vec<String>>,
#[serde(default)]
pub local: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub quote: Option<char>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub escape: Option<char>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub line_terminator: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub encoding: Option<String>,
}
impl BulkLoadOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_delimiter(mut self, delimiter: char) -> Self {
self.delimiter = Some(delimiter);
self
}
pub fn with_null_string(mut self, null_string: impl Into<String>) -> Self {
self.null_string = Some(null_string.into());
self
}
pub fn with_header(mut self, header: bool) -> Self {
self.header = header;
self
}
pub fn with_columns(mut self, columns: Vec<String>) -> Self {
self.columns = Some(columns);
self
}
pub fn with_local(mut self, local: bool) -> Self {
self.local = local;
self
}
pub fn with_quote(mut self, quote: char) -> Self {
self.quote = Some(quote);
self
}
pub fn with_escape(mut self, escape: char) -> Self {
self.escape = Some(escape);
self
}
pub fn with_line_terminator(mut self, terminator: impl Into<String>) -> Self {
self.line_terminator = Some(terminator.into());
self
}
pub fn with_encoding(mut self, encoding: impl Into<String>) -> Self {
self.encoding = Some(encoding.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "type")]
pub enum Operation {
CreateTable {
name: String,
columns: Vec<ColumnDefinition>,
#[serde(default)]
constraints: Vec<Constraint>,
#[serde(default, skip_serializing_if = "Option::is_none")]
without_rowid: Option<bool>,
#[serde(default, skip_serializing_if = "Option::is_none")]
interleave_in_parent: Option<InterleaveSpec>,
#[serde(default, skip_serializing_if = "Option::is_none")]
partition: Option<PartitionOptions>,
},
DropTable {
name: String,
},
AddColumn {
table: String,
column: ColumnDefinition,
#[serde(default, skip_serializing_if = "Option::is_none")]
mysql_options: Option<AlterTableOptions>,
},
DropColumn {
table: String,
column: String,
},
AlterColumn {
table: String,
column: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
old_definition: Option<ColumnDefinition>,
new_definition: ColumnDefinition,
#[serde(default, skip_serializing_if = "Option::is_none")]
mysql_options: Option<AlterTableOptions>,
},
RenameTable {
old_name: String,
new_name: String,
},
RenameColumn {
table: String,
old_name: String,
new_name: String,
},
AddConstraint {
table: String,
constraint_sql: String,
},
DropConstraint {
table: String,
constraint_name: String,
},
CreateIndex {
table: String,
columns: Vec<String>,
unique: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
index_type: Option<IndexType>,
#[serde(default, skip_serializing_if = "Option::is_none")]
where_clause: Option<String>,
#[serde(default)]
concurrently: bool,
#[serde(default, skip_serializing_if = "Option::is_none")]
expressions: Option<Vec<String>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
mysql_options: Option<AlterTableOptions>,
#[serde(default, skip_serializing_if = "Option::is_none")]
operator_class: Option<String>,
},
DropIndex {
table: String,
columns: Vec<String>,
},
RunSQL {
sql: String,
reverse_sql: Option<String>,
},
RunRust {
code: String,
reverse_code: Option<String>,
},
AlterTableComment {
table: String,
comment: Option<String>,
},
AlterUniqueTogether {
table: String,
unique_together: Vec<Vec<String>>,
},
AlterModelOptions {
table: String,
options: std::collections::HashMap<String, String>,
},
CreateInheritedTable {
name: String,
columns: Vec<ColumnDefinition>,
base_table: String,
join_column: String,
},
AddDiscriminatorColumn {
table: String,
column_name: String,
default_value: String,
},
MoveModel {
model_name: String,
from_app: String,
to_app: String,
rename_table: bool,
old_table_name: Option<String>,
new_table_name: Option<String>,
},
CreateSchema {
name: String,
#[serde(default)]
if_not_exists: bool,
},
DropSchema {
name: String,
#[serde(default)]
cascade: bool,
#[serde(default = "default_true")]
if_exists: bool,
},
CreateExtension {
name: String,
#[serde(default = "default_true")]
if_not_exists: bool,
#[serde(default)]
schema: Option<String>,
},
BulkLoad {
table: String,
source: BulkLoadSource,
#[serde(default)]
format: BulkLoadFormat,
#[serde(default)]
options: BulkLoadOptions,
},
SetAutoIncrementValue {
table: String,
column: String,
value: i64,
},
CreateCompositePrimaryKey {
table: String,
columns: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
constraint_name: Option<String>,
},
}
const fn default_true() -> bool {
true
}
impl Operation {
pub fn state_forwards(&self, app_label: &str, state: &mut ProjectState) {
match self {
Operation::CreateTable { name, columns, .. } => {
let mut model = ModelState::new(app_label, name.clone());
for column in columns {
let field = FieldState::new(
column.name.to_string(),
column.type_definition.clone(),
false,
);
model.add_field(field);
}
state.add_model(model);
}
Operation::DropTable { name } => {
state.remove_model(app_label, name);
}
Operation::AddColumn { table, column, .. } => {
if let Some(model) = state.get_model_mut(app_label, table) {
let field = FieldState::new(
column.name.to_string(),
column.type_definition.clone(),
false,
);
model.add_field(field);
}
}
Operation::DropColumn { table, column } => {
if let Some(model) = state.get_model_mut(app_label, table) {
model.remove_field(column);
}
}
Operation::AlterColumn {
table,
column,
new_definition,
..
} => {
if let Some(model) = state.get_model_mut(app_label, table) {
let field = FieldState::new(
column.to_string(),
new_definition.type_definition.clone(),
false,
);
model.alter_field(column, field);
}
}
Operation::RenameTable { old_name, new_name } => {
state.rename_model(app_label, old_name, new_name.to_string());
}
Operation::RenameColumn {
table,
old_name,
new_name,
} => {
if let Some(model) = state.get_model_mut(app_label, table) {
model.rename_field(old_name, new_name.to_string());
}
}
Operation::CreateInheritedTable {
name,
columns,
base_table,
join_column,
} => {
let mut model = ModelState::new(app_label, name.clone());
model.base_model = Some(base_table.to_string());
model.inheritance_type = Some("joined_table".to_string());
let join_field = FieldState::new(
join_column.to_string(),
FieldType::Custom(format!("INTEGER REFERENCES {}(id)", base_table)),
false,
);
model.add_field(join_field);
for column in columns {
let field = FieldState::new(
column.name.to_string(),
column.type_definition.clone(),
false,
);
model.add_field(field);
}
state.add_model(model);
}
Operation::AddDiscriminatorColumn {
table,
column_name,
default_value,
} => {
if let Some(model) = state.get_model_mut(app_label, table) {
model.discriminator_column = Some(column_name.to_string());
model.inheritance_type = Some("single_table".to_string());
let field = FieldState::new(
column_name.to_string(),
FieldType::Custom(format!("VARCHAR(50) DEFAULT '{}'", default_value)),
false,
);
model.add_field(field);
}
}
Operation::AddConstraint { .. }
| Operation::DropConstraint { .. }
| Operation::CreateIndex { .. }
| Operation::DropIndex { .. }
| Operation::RunSQL { .. }
| Operation::RunRust { .. }
| Operation::AlterTableComment { .. }
| Operation::AlterUniqueTogether { .. }
| Operation::AlterModelOptions { .. }
| Operation::SetAutoIncrementValue { .. }
| Operation::CreateCompositePrimaryKey { .. } => {
}
Operation::MoveModel {
model_name,
from_app,
to_app,
rename_table,
old_table_name,
new_table_name,
} => {
if let Some(model) = state.get_model(from_app, model_name).cloned() {
state.remove_model(from_app, model_name);
let mut new_model = model;
new_model.app_label = to_app.to_string();
if *rename_table
&& let (Some(_old_name), Some(new_name)) = (old_table_name, new_table_name)
{
new_model.table_name = new_name.to_string();
}
state.add_model(new_model);
}
}
Operation::CreateSchema { .. }
| Operation::DropSchema { .. }
| Operation::CreateExtension { .. } => {
}
Operation::BulkLoad { .. } => {
}
}
}
fn column_to_sql_without_pk(col: &ColumnDefinition, dialect: &SqlDialect) -> String {
let mut parts = Vec::new();
parts.push(quote_identifier(&col.name));
if col.auto_increment {
match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
match &col.type_definition {
FieldType::BigInteger => {
parts
.push("BIGINT GENERATED BY DEFAULT AS IDENTITY".to_string().into());
}
FieldType::Integer => {
parts.push(
"INTEGER GENERATED BY DEFAULT AS IDENTITY"
.to_string()
.into(),
);
}
FieldType::SmallInteger => {
parts.push(
"SMALLINT GENERATED BY DEFAULT AS IDENTITY"
.to_string()
.into(),
);
}
_ => {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
}
}
SqlDialect::Mysql => {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
parts.push("AUTO_INCREMENT".to_string().into());
}
SqlDialect::Sqlite => {
match &col.type_definition {
FieldType::BigInteger | FieldType::Integer | FieldType::SmallInteger => {
parts.push("INTEGER".to_string().into());
}
_ => {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
}
}
}
} else {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
if col.not_null {
parts.push("NOT NULL".to_string().into());
}
if col.unique {
parts.push("UNIQUE".to_string().into());
}
if let Some(default) = &col.default {
parts.push(format!("DEFAULT {}", default).into());
}
parts.join(" ")
}
fn column_to_sql(col: &ColumnDefinition, dialect: &SqlDialect) -> String {
let mut parts = Vec::new();
parts.push(quote_identifier(&col.name));
if col.auto_increment {
match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
match &col.type_definition {
FieldType::BigInteger => {
parts
.push("BIGINT GENERATED BY DEFAULT AS IDENTITY".to_string().into());
}
FieldType::Integer => {
parts.push(
"INTEGER GENERATED BY DEFAULT AS IDENTITY"
.to_string()
.into(),
);
}
FieldType::SmallInteger => {
parts.push(
"SMALLINT GENERATED BY DEFAULT AS IDENTITY"
.to_string()
.into(),
);
}
_ => {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
}
}
SqlDialect::Mysql => {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
parts.push("AUTO_INCREMENT".to_string().into());
}
SqlDialect::Sqlite => {
let widened_to_integer = matches!(
&col.type_definition,
FieldType::BigInteger | FieldType::Integer | FieldType::SmallInteger
);
if widened_to_integer {
parts.push("INTEGER".to_string().into());
} else {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
if col.primary_key {
if widened_to_integer {
parts.push("PRIMARY KEY AUTOINCREMENT".to_string().into());
} else {
parts.push("PRIMARY KEY".to_string().into());
}
if col.unique {
parts.push("UNIQUE".to_string().into());
}
if let Some(default) = &col.default {
parts.push(format!("DEFAULT {}", default).into());
}
return parts.join(" ");
}
}
}
} else {
parts.push(col.type_definition.to_sql_for_dialect(dialect).into());
}
if col.not_null {
parts.push("NOT NULL".to_string().into());
}
if col.primary_key {
parts.push("PRIMARY KEY".to_string().into());
}
if col.unique {
parts.push("UNIQUE".to_string().into());
}
if let Some(default) = &col.default {
parts.push(format!("DEFAULT {}", default).into());
}
parts.join(" ")
}
pub fn to_sql(&self, dialect: &SqlDialect) -> String {
match self {
Operation::CreateTable {
name,
columns,
constraints,
without_rowid,
interleave_in_parent,
partition,
} => {
let pk_columns: Vec<&String> = columns
.iter()
.filter(|col| col.primary_key)
.map(|col| &col.name)
.collect();
let has_composite_pk = pk_columns.len() > 1;
let mut parts = Vec::new();
for col in columns {
if has_composite_pk {
parts.push(format!(
" {}",
Self::column_to_sql_without_pk(col, dialect)
));
} else {
parts.push(format!(" {}", Self::column_to_sql(col, dialect)));
}
}
if has_composite_pk {
let pk_constraint_name = format!("{}_pkey", name);
let quoted_pk_columns = pk_columns
.iter()
.map(|s| quote_identifier(s))
.collect::<Vec<_>>()
.join(", ");
let pk_constraint = format!(
" CONSTRAINT {} PRIMARY KEY ({})",
quote_identifier(&pk_constraint_name),
quoted_pk_columns
);
parts.push(pk_constraint);
}
for constraint in constraints {
parts.push(format!(" {}", constraint));
}
let mut sql = format!(
"CREATE TABLE {} (\n{}\n)",
quote_identifier(name),
parts.join(",\n")
);
if matches!(dialect, SqlDialect::Sqlite)
&& let Some(true) = without_rowid
{
sql.push_str(" WITHOUT ROWID");
}
if matches!(dialect, SqlDialect::Mysql)
&& let Some(partition_opts) = partition
{
sql.push(' ');
sql.push_str(&partition_opts.to_sql());
}
if matches!(dialect, SqlDialect::Cockroachdb)
&& let Some(interleave) = interleave_in_parent
{
let quoted_columns = interleave
.parent_columns
.iter()
.map(|col| quote_identifier(col))
.collect::<Vec<_>>()
.join(", ");
sql.push_str(&format!(
" INTERLEAVE IN PARENT {} ({})",
quote_identifier(&interleave.parent_table),
quoted_columns
));
}
sql.push(';');
sql
}
Operation::DropTable { name } => format!("DROP TABLE {};", quote_identifier(name)),
Operation::AddColumn {
table,
column,
mysql_options,
} => {
let base_sql = format!(
"ALTER TABLE {} ADD COLUMN {}",
quote_identifier(table),
Self::column_to_sql(column, dialect)
);
if matches!(dialect, SqlDialect::Mysql)
&& let Some(opts) = mysql_options
{
let suffix = opts.to_sql_suffix();
if !suffix.is_empty() {
return format!("{}{};", base_sql, suffix);
}
}
format!("{};", base_sql)
}
Operation::DropColumn { table, column } => {
format!(
"ALTER TABLE {} DROP COLUMN {};",
quote_identifier(table),
quote_identifier(column)
)
}
Operation::AlterColumn {
table,
column,
new_definition,
mysql_options,
..
} => {
let sql_type = new_definition.type_definition.to_sql_for_dialect(dialect);
match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
format!(
"ALTER TABLE {} ALTER COLUMN {} TYPE {};",
quote_identifier(table),
quote_identifier(column),
sql_type
)
}
SqlDialect::Mysql => {
let base_sql = format!(
"ALTER TABLE {} MODIFY COLUMN {} {}",
quote_identifier(table),
quote_identifier(column),
sql_type
);
if let Some(opts) = mysql_options {
let suffix = opts.to_sql_suffix();
if !suffix.is_empty() {
return format!("{}{};", base_sql, suffix);
}
}
format!("{};", base_sql)
}
SqlDialect::Sqlite => {
format!(
"-- SQLite does not support ALTER COLUMN, table recreation required for {}",
quote_identifier(table)
)
}
}
}
Operation::RenameColumn {
table,
old_name,
new_name,
} => {
format!(
"ALTER TABLE {} RENAME COLUMN {} TO {};",
quote_identifier(table),
quote_identifier(old_name),
quote_identifier(new_name)
)
}
Operation::RenameTable { old_name, new_name } => {
format!(
"ALTER TABLE {} RENAME TO {};",
quote_identifier(old_name),
quote_identifier(new_name)
)
}
Operation::AddConstraint {
table,
constraint_sql,
} => {
format!(
"ALTER TABLE {} ADD {};",
quote_identifier(table),
constraint_sql
)
}
Operation::DropConstraint {
table,
constraint_name,
} => {
format!(
"ALTER TABLE {} DROP CONSTRAINT {};",
quote_identifier(table),
quote_identifier(constraint_name)
)
}
Operation::CreateIndex {
table,
columns,
unique,
index_type,
where_clause,
concurrently,
expressions,
mysql_options,
operator_class,
} => {
let unique_str = if *unique { "UNIQUE " } else { "" };
let concurrent_str = if *concurrently && matches!(dialect, SqlDialect::Postgres) {
"CONCURRENTLY "
} else {
""
};
let (mysql_prefix, effective_unique) = match (index_type, dialect) {
(Some(IndexType::Fulltext), SqlDialect::Mysql) => ("FULLTEXT ", ""),
(Some(IndexType::Spatial), SqlDialect::Mysql) => ("SPATIAL ", ""),
_ => ("", unique_str),
};
let (index_content, name_suffix) =
if let Some(exprs) = expressions.as_ref().filter(|e| !e.is_empty()) {
let content = exprs.join(", ");
let suffix = "expr";
(content, suffix.to_string())
} else {
let content = if let Some(op_class) = operator_class {
if matches!(dialect, SqlDialect::Postgres) {
columns
.iter()
.map(|c| format!("{} {}", quote_identifier(c), op_class))
.collect::<Vec<_>>()
.join(", ")
} else {
columns
.iter()
.map(|c| quote_identifier(c).to_string())
.collect::<Vec<_>>()
.join(", ")
}
} else {
columns
.iter()
.map(|c| quote_identifier(c).to_string())
.collect::<Vec<_>>()
.join(", ")
};
(content, columns.join("_"))
};
let idx_name = format!("idx_{}_{}", table, name_suffix);
let using_clause = match (index_type, dialect) {
(Some(IndexType::BTree), _) => String::new(), (Some(idx_type), SqlDialect::Postgres | SqlDialect::Cockroachdb) => {
format!(" USING {}", idx_type)
}
(Some(IndexType::Fulltext | IndexType::Spatial), SqlDialect::Mysql) => {
String::new()
}
_ => String::new(),
};
let mut sql = match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
format!(
"CREATE {}INDEX {}{}",
effective_unique,
concurrent_str,
quote_identifier(&idx_name)
)
}
SqlDialect::Mysql => {
format!(
"CREATE {}{}INDEX {}",
mysql_prefix,
effective_unique,
quote_identifier(&idx_name)
)
}
SqlDialect::Sqlite => {
format!(
"CREATE {}INDEX {}",
effective_unique,
quote_identifier(&idx_name)
)
}
};
sql.push_str(&format!(
" ON {}{} ({})",
quote_identifier(table),
using_clause,
index_content
));
if let Some(where_cond) = where_clause
&& !matches!(dialect, SqlDialect::Mysql)
{
sql.push_str(&format!(" WHERE {}", where_cond));
}
if matches!(dialect, SqlDialect::Mysql)
&& let Some(opts) = mysql_options
{
let suffix = opts.to_sql_suffix();
if !suffix.is_empty() {
sql.push_str(&suffix);
}
}
sql.push(';');
sql
}
Operation::DropIndex { table, columns } => {
let idx_name = format!("idx_{}_{}", table, columns.join("_"));
match dialect {
SqlDialect::Mysql => {
format!(
"DROP INDEX {} ON {};",
quote_identifier(&idx_name),
quote_identifier(table)
)
}
SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Cockroachdb => {
format!("DROP INDEX {};", quote_identifier(&idx_name))
}
}
}
Operation::RunSQL { sql, .. } => sql.to_string(),
Operation::RunRust { code, .. } => {
format!("-- RunRust: {}", code.lines().next().unwrap_or(""))
}
Operation::AlterTableComment { table, comment } => match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
if let Some(comment_text) = comment {
format!(
"COMMENT ON TABLE {} IS '{}';",
quote_identifier(table),
comment_text
)
} else {
format!("COMMENT ON TABLE {} IS NULL;", quote_identifier(table))
}
}
SqlDialect::Mysql => {
if let Some(comment_text) = comment {
format!(
"ALTER TABLE {} COMMENT='{}';",
quote_identifier(table),
comment_text
)
} else {
format!("ALTER TABLE {} COMMENT='';", quote_identifier(table))
}
}
SqlDialect::Sqlite => String::new(),
},
Operation::AlterUniqueTogether {
table,
unique_together,
} => {
let mut sql = Vec::new();
for (idx, fields) in unique_together.iter().enumerate() {
let constraint_name = format!("{}_{}_uniq", table, idx);
let fields_str = fields
.iter()
.map(|f| quote_identifier(f))
.collect::<Vec<_>>()
.join(", ");
sql.push(format!(
"ALTER TABLE {} ADD CONSTRAINT {} UNIQUE ({});",
quote_identifier(table),
quote_identifier(&constraint_name),
fields_str
));
}
sql.join("\n")
}
Operation::AlterModelOptions { .. } => String::new(),
Operation::CreateInheritedTable {
name,
columns,
base_table,
join_column,
} => {
let mut parts = Vec::new();
parts.push(format!(
" {} INTEGER REFERENCES {}(id)",
quote_identifier(join_column),
quote_identifier(base_table)
));
for col in columns {
parts.push(format!(" {}", Self::column_to_sql(col, dialect)));
}
format!(
"CREATE TABLE {} (\n{}\n);",
quote_identifier(name),
parts.join(",\n")
)
}
Operation::AddDiscriminatorColumn {
table,
column_name,
default_value,
} => {
format!(
"ALTER TABLE {} ADD COLUMN {} VARCHAR(50) DEFAULT '{}';",
quote_identifier(table),
quote_identifier(column_name),
default_value
)
}
Operation::MoveModel {
rename_table,
old_table_name,
new_table_name,
..
} => {
if *rename_table {
if let (Some(old_name), Some(new_name)) = (old_table_name, new_table_name) {
match dialect {
SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Cockroachdb => {
format!(
"ALTER TABLE {} RENAME TO {};",
quote_identifier(old_name),
quote_identifier(new_name)
)
}
SqlDialect::Mysql => {
format!(
"RENAME TABLE {} TO {};",
quote_identifier(old_name),
quote_identifier(new_name)
)
}
}
} else {
"-- MoveModel: No table rename specified".to_string()
}
} else {
"-- MoveModel: State-only operation (no table rename)".to_string()
}
}
Operation::CreateSchema {
name,
if_not_exists,
} => {
let if_not_exists_clause = if *if_not_exists { " IF NOT EXISTS" } else { "" };
format!(
"CREATE SCHEMA{} {};",
if_not_exists_clause,
quote_identifier(name)
)
}
Operation::DropSchema {
name,
cascade,
if_exists,
} => {
let if_exists_clause = if *if_exists { " IF EXISTS" } else { "" };
let cascade_clause = if *cascade { " CASCADE" } else { "" };
format!(
"DROP SCHEMA{} {}{};",
if_exists_clause,
quote_identifier(name),
cascade_clause
)
}
Operation::CreateExtension {
name,
if_not_exists,
schema,
} => {
let if_not_exists_clause = if *if_not_exists { " IF NOT EXISTS" } else { "" };
let schema_clause = if let Some(s) = schema {
format!(" SCHEMA {}", quote_identifier(s))
} else {
String::new()
};
format!(
"CREATE EXTENSION{} {}{};",
if_not_exists_clause,
quote_identifier(name),
schema_clause
)
}
Operation::BulkLoad {
table,
source,
format,
options,
} => Self::bulk_load_to_sql(table, source, format, options, dialect),
Operation::SetAutoIncrementValue {
table,
column,
value,
} => Self::set_auto_increment_to_sql(table, column, *value, dialect),
Operation::CreateCompositePrimaryKey {
table,
columns,
constraint_name,
} => Self::create_composite_pk_to_sql(table, columns, constraint_name.as_deref()),
}
}
fn set_auto_increment_to_sql(
table: &str,
column: &str,
value: i64,
dialect: &SqlDialect,
) -> String {
match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
format!(
"SELECT setval(pg_get_serial_sequence({}, {}), {}, false);",
quote_literal(table),
quote_literal(column),
value
)
}
SqlDialect::Mysql => {
format!(
"ALTER TABLE {} AUTO_INCREMENT = {};",
quote_identifier(table),
value
)
}
SqlDialect::Sqlite => {
format!(
"INSERT OR REPLACE INTO sqlite_sequence(name, seq) VALUES ({}, {});",
quote_literal(table),
value
)
}
}
}
fn create_composite_pk_to_sql(
table: &str,
columns: &[String],
constraint_name: Option<&str>,
) -> String {
if columns.is_empty() {
return format!(
"SYNTAX_ERROR_create_composite_pk_on_{}_requires_at_least_one_column;",
table.replace(|c: char| !c.is_ascii_alphanumeric(), "_")
);
}
let default_name;
let name: &str = match constraint_name {
Some(n) => n,
None => {
default_name = format!("{}_pkey", table);
&default_name
}
};
let quoted_columns = columns
.iter()
.map(|c| quote_identifier(c).to_string())
.collect::<Vec<_>>()
.join(", ");
format!(
"ALTER TABLE {} ADD CONSTRAINT {} PRIMARY KEY ({});",
quote_identifier(table),
quote_identifier(name),
quoted_columns
)
}
fn bulk_load_to_sql(
table: &str,
source: &BulkLoadSource,
format: &BulkLoadFormat,
options: &BulkLoadOptions,
dialect: &SqlDialect,
) -> String {
match dialect {
SqlDialect::Postgres | SqlDialect::Cockroachdb => {
Self::postgres_copy_from_sql(table, source, format, options)
}
SqlDialect::Mysql => Self::mysql_load_data_sql(table, source, format, options),
SqlDialect::Sqlite => {
format!(
"-- SQLite does not support bulk loading. Use INSERT statements instead for table {}",
quote_identifier(table)
)
}
}
}
fn postgres_copy_from_sql(
table: &str,
source: &BulkLoadSource,
format: &BulkLoadFormat,
options: &BulkLoadOptions,
) -> String {
let source_clause = match source {
BulkLoadSource::File(path) => format!("'{}'", path),
BulkLoadSource::Stdin => "STDIN".to_string(),
BulkLoadSource::Program(cmd) => format!("PROGRAM '{}'", cmd),
};
let columns_clause = if let Some(cols) = &options.columns {
let quoted_cols = cols
.iter()
.map(|c| quote_identifier(c))
.collect::<Vec<_>>()
.join(", ");
format!(" ({})", quoted_cols)
} else {
String::new()
};
let mut with_options = Vec::new();
with_options.push(format!("FORMAT {}", format));
if let Some(delim) = options.delimiter {
with_options.push(format!("DELIMITER '{}'", delim));
}
if let Some(null_str) = &options.null_string {
with_options.push(format!("NULL '{}'", null_str));
}
if options.header {
with_options.push("HEADER true".to_string());
}
if let Some(quote) = options.quote {
with_options.push(format!("QUOTE '{}'", quote));
}
if let Some(escape) = options.escape {
with_options.push(format!("ESCAPE '{}'", escape));
}
format!(
"COPY {}{} FROM {} WITH ({});",
quote_identifier(table),
columns_clause,
source_clause,
with_options.join(", ")
)
}
fn mysql_load_data_sql(
table: &str,
source: &BulkLoadSource,
format: &BulkLoadFormat,
options: &BulkLoadOptions,
) -> String {
let local_clause = if options.local { " LOCAL" } else { "" };
let file_path = match source {
BulkLoadSource::File(path) => path.clone(),
BulkLoadSource::Stdin => {
return format!(
"-- MySQL does not support LOAD DATA from STDIN directly for table {}",
quote_identifier(table)
);
}
BulkLoadSource::Program(_) => {
return format!(
"-- MySQL does not support LOAD DATA from PROGRAM directly for table {}",
quote_identifier(table)
);
}
};
let columns_clause = if let Some(cols) = &options.columns {
let quoted_cols = cols
.iter()
.map(|c| quote_identifier(c))
.collect::<Vec<_>>()
.join(", ");
format!(" ({})", quoted_cols)
} else {
String::new()
};
let delimiter = options.delimiter.unwrap_or(match format {
BulkLoadFormat::Csv => ',',
BulkLoadFormat::Text | BulkLoadFormat::Binary => '\t',
});
let mut field_options = Vec::new();
field_options.push(format!("TERMINATED BY '{}'", delimiter));
if *format == BulkLoadFormat::Csv {
let quote = options.quote.unwrap_or('"');
field_options.push(format!("ENCLOSED BY '{}'", quote));
}
if let Some(escape) = options.escape {
field_options.push(format!("ESCAPED BY '{}'", escape));
}
let line_terminator = options
.line_terminator
.clone()
.unwrap_or_else(|| "\\n".to_string());
let encoding_clause = if let Some(enc) = &options.encoding {
format!(" CHARACTER SET {}", enc)
} else {
String::new()
};
let ignore_clause = if options.header {
" IGNORE 1 LINES"
} else {
""
};
format!(
"LOAD DATA{} INFILE '{}'{} INTO TABLE {} FIELDS {} LINES TERMINATED BY '{}'{}{};",
local_clause,
file_path,
encoding_clause,
quote_identifier(table),
field_options.join(" "),
line_terminator,
ignore_clause,
columns_clause
)
}
pub fn to_reverse_sql(
&self,
dialect: &SqlDialect,
project_state: &ProjectState,
) -> super::Result<Option<String>> {
match self {
Operation::CreateTable { name, .. } => {
Ok(Some(format!("DROP TABLE {};", quote_identifier(name))))
}
Operation::AddColumn { table, column, .. } => Ok(Some(format!(
"ALTER TABLE {} DROP COLUMN {};",
quote_identifier(table),
quote_identifier(&column.name)
))),
Operation::RunSQL { reverse_sql, .. } => {
Ok(reverse_sql.as_ref().map(|s| s.to_string()))
}
Operation::RunRust { reverse_code, .. } => Ok(reverse_code.as_ref().map(|code| {
format!(
"-- RunRust (reverse): {}",
code.lines().next().unwrap_or("")
)
})),
Operation::RenameTable { old_name, new_name } => Ok(Some(format!(
"ALTER TABLE {} RENAME TO {};",
quote_identifier(new_name),
quote_identifier(old_name)
))),
Operation::RenameColumn {
table,
old_name,
new_name,
} => Ok(Some(format!(
"ALTER TABLE {} RENAME COLUMN {} TO {};",
quote_identifier(table),
quote_identifier(new_name),
quote_identifier(old_name)
))),
Operation::CreateIndex { table, columns, .. } => {
let columns_joined = columns.join("_");
let index_name = format!("idx_{}_{}", table, columns_joined);
Ok(Some(match dialect {
SqlDialect::Mysql => format!(
"DROP INDEX {} ON {};",
quote_identifier(&index_name),
quote_identifier(table)
),
SqlDialect::Postgres | SqlDialect::Sqlite | SqlDialect::Cockroachdb => {
format!("DROP INDEX {};", quote_identifier(&index_name))
}
}))
}
Operation::AddConstraint {
table,
constraint_sql,
} => {
let constraint_name =
Self::extract_constraint_name(constraint_sql).ok_or_else(|| {
super::MigrationError::InvalidMigration(format!(
"Cannot extract constraint name from: {}",
constraint_sql
))
})?;
Ok(Some(format!(
"ALTER TABLE {} DROP CONSTRAINT {};",
quote_identifier(table),
quote_identifier(&constraint_name)
)))
}
Operation::DropColumn { table, column } => {
if let Some(model) = project_state.find_model_by_table(table)
&& let Some(field) = model.get_field(column)
{
let col_def = ColumnDefinition::from_field_state(column.clone(), field);
let col_sql = Self::column_to_sql(&col_def, dialect);
return Ok(Some(format!(
"ALTER TABLE {} ADD COLUMN {};",
quote_identifier(table),
col_sql
)));
}
Ok(None)
}
Operation::AlterColumn {
table,
column,
old_definition,
new_definition: _,
..
} => {
let resolved_old_def = old_definition.clone().or_else(|| {
project_state
.find_model_by_table(table)
.and_then(|model| model.get_field(column))
.map(|field| ColumnDefinition::from_field_state(column.clone(), field))
});
let Some(old_def) = resolved_old_def else {
return Ok(None);
};
let type_sql = old_def.type_definition.to_sql_for_dialect(dialect);
let null_clause = if old_def.not_null { " NOT NULL" } else { "" };
let sql = match dialect {
SqlDialect::Postgres => {
let nullability_clause = if old_def.not_null {
"SET NOT NULL"
} else {
"DROP NOT NULL"
};
format!(
"ALTER TABLE {table} \
ALTER COLUMN {column} TYPE {type_sql}, \
ALTER COLUMN {column} {nullability_clause};",
table = quote_identifier(table),
column = quote_identifier(column),
type_sql = type_sql,
nullability_clause = nullability_clause,
)
}
SqlDialect::Cockroachdb => {
format!(
"ALTER TABLE {} ALTER COLUMN {} TYPE {};",
quote_identifier(table),
quote_identifier(column),
type_sql
)
}
SqlDialect::Mysql => format!(
"ALTER TABLE {} MODIFY COLUMN {} {}{};",
quote_identifier(table),
quote_identifier(column),
type_sql,
null_clause
),
SqlDialect::Sqlite => format!(
"-- SQLite does not support ALTER COLUMN, table recreation required for {}",
quote_identifier(table)
),
};
Ok(Some(sql))
}
Operation::DropIndex { table, columns } => {
let columns_joined = columns.join("_");
let index_name = format!("idx_{}_{}", table, columns_joined);
let columns_list = columns
.iter()
.map(|c| quote_identifier(c).to_string())
.collect::<Vec<_>>()
.join(", ");
Ok(Some(format!(
"CREATE INDEX {} ON {} ({});",
quote_identifier(&index_name),
quote_identifier(table),
columns_list
)))
}
Operation::DropConstraint {
table,
constraint_name,
} => {
if let Some(model) = project_state.find_model_by_table(table)
&& let Some(constraint_def) = model
.constraints
.iter()
.find(|c| c.name == *constraint_name)
{
let constraint = constraint_def.to_constraint();
return Ok(Some(format!(
"ALTER TABLE {} ADD {};",
quote_identifier(table),
constraint
)));
}
Ok(None)
}
Operation::DropTable { name } => {
if let Some(model) = project_state.find_model_by_table(name) {
let mut parts = Vec::new();
for (field_name, field) in &model.fields {
let col_def = ColumnDefinition::from_field_state(field_name.clone(), field);
parts.push(format!(" {}", Self::column_to_sql(&col_def, dialect)));
}
for constraint_def in &model.constraints {
let constraint = constraint_def.to_constraint();
parts.push(format!(" {}", constraint));
}
return Ok(Some(format!(
"CREATE TABLE {} (\n{}\n);",
quote_identifier(name),
parts.join(",\n")
)));
}
Ok(None)
}
Operation::BulkLoad { table, .. } => {
Ok(Some(format!("TRUNCATE TABLE {};", quote_identifier(table))))
}
_ => Ok(None),
}
}
pub fn state_backwards(&self, app_label: &str, state: &mut ProjectState) {
match self {
Operation::CreateTable { name, .. } => {
state
.models
.remove(&(app_label.to_string(), name.to_string()));
}
Operation::DropTable { name: _ } => {
}
Operation::RenameTable { old_name, new_name } => {
if let Some(mut model) = state
.models
.remove(&(app_label.to_string(), new_name.to_string()))
{
model.table_name = old_name.to_string();
state
.models
.insert((app_label.to_string(), old_name.to_string()), model);
}
}
Operation::AddColumn { table, column, .. } => {
if let Some(model) = state.find_model_by_table_mut(table) {
model.remove_field(&column.name);
}
}
Operation::DropColumn {
table: _,
column: _,
} => {
}
Operation::AlterColumn {
table: _,
column: _,
..
} => {
}
Operation::RenameColumn {
table,
old_name,
new_name,
} => {
if let Some(model) = state.find_model_by_table_mut(table) {
model.rename_field(new_name, old_name.to_string());
}
}
Operation::AddConstraint { table, .. } => {
if let Some(model) = state.find_model_by_table_mut(table) {
let _ = model;
}
}
Operation::DropConstraint {
table: _,
constraint_name: _,
} => {
}
_ => {
}
}
}
fn extract_constraint_name(constraint_sql: &str) -> Option<String> {
let sql = constraint_sql.trim();
if sql.starts_with("CONSTRAINT ") || sql.contains(" CONSTRAINT ") {
let parts: Vec<&str> = sql.split_whitespace().collect();
if let Some(pos) = parts.iter().position(|&s| s == "CONSTRAINT")
&& pos + 1 < parts.len()
{
return Some(parts[pos + 1].to_string());
}
}
None
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ColumnDefinition {
pub name: String,
pub type_definition: FieldType,
#[serde(default)]
pub not_null: bool,
#[serde(default)]
pub unique: bool,
#[serde(default)]
pub primary_key: bool,
#[serde(default)]
pub auto_increment: bool,
#[serde(default)]
pub default: Option<String>,
}
impl ColumnDefinition {
pub fn new(name: impl Into<String>, type_def: FieldType) -> Self {
Self {
name: name.into(),
type_definition: type_def,
not_null: false,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
}
}
pub fn from_field_state(name: impl Into<String>, field_state: &FieldState) -> Self {
let name_str = name.into();
let params = &field_state.params;
let primary_key = params
.get("primary_key")
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(false);
let not_null = !field_state.nullable || primary_key;
let unique = params
.get("unique")
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(false);
let auto_increment = params
.get("auto_increment")
.and_then(|v| v.parse::<bool>().ok())
.unwrap_or(false);
let default = params.get("default").cloned();
let type_definition = resolve_foreign_key_column_type(field_state)
.unwrap_or_else(|| field_state.field_type.clone());
Self {
name: name_str,
type_definition,
not_null,
unique,
primary_key,
auto_increment,
default,
}
}
}
fn resolve_foreign_key_column_type(field_state: &FieldState) -> Option<FieldType> {
resolve_foreign_key_column_type_with(field_state, super::model_registry::global_registry())
}
fn resolve_foreign_key_column_type_with(
field_state: &FieldState,
registry: &super::model_registry::ModelRegistry,
) -> Option<FieldType> {
let target_model = field_state.params.get("fk_target")?;
let target = match field_state.params.get("fk_target_app") {
Some(app) => registry
.find_model_qualified(app, target_model)
.or_else(|| registry.find_model_by_name(target_model)),
None => registry.find_model_by_name(target_model),
};
let target = match target {
Some(t) => t,
None => {
if registry.count_models_by_name(target_model) > 1 {
tracing::warn!(
model_name = %target_model,
fk_target_app = ?field_state.params.get("fk_target_app"),
"FK target name is ambiguous across apps and the qualified \
lookup did not resolve a unique target. Refusing to resolve \
to avoid silent wrong-target resolution. Ensure the FK \
target type is registered and that its `Model::app_label()` \
matches one of the registered apps.",
);
}
return None;
}
};
let pk_field = target
.fields
.values()
.find(|f| f.params.get("primary_key").map(String::as_str) == Some("true"))?;
Some(pk_field.field_type.clone())
}
pub fn field_type_string_to_field_type(
field_type: &str,
attributes: &std::collections::HashMap<String, String>,
) -> Result<FieldType, String> {
let type_name = field_type.split('.').next_back().unwrap_or(field_type);
match type_name {
"IntegerField"
| "PositiveIntegerField"
| "SmallIntegerField"
| "PositiveSmallIntegerField" => Ok(FieldType::Integer),
"BigIntegerField" | "PositiveBigIntegerField" => Ok(FieldType::BigInteger),
"AutoField" => Ok(FieldType::Integer),
"BigAutoField" => Ok(FieldType::BigInteger),
"SmallAutoField" => Ok(FieldType::SmallInteger),
"CharField" => {
let max_length = attributes
.get("max_length")
.and_then(|v| v.parse::<u32>().ok())
.ok_or_else(|| "CharField requires max_length attribute".to_string())?;
Ok(FieldType::VarChar(max_length))
}
"TextField" => Ok(FieldType::Text),
"SlugField" => {
let max_length = attributes
.get("max_length")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(50);
Ok(FieldType::VarChar(max_length))
}
"EmailField" => {
let max_length = attributes
.get("max_length")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(254);
Ok(FieldType::VarChar(max_length))
}
"URLField" => {
let max_length = attributes
.get("max_length")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(200);
Ok(FieldType::VarChar(max_length))
}
"BooleanField" => Ok(FieldType::Boolean),
"NullBooleanField" => Ok(FieldType::Boolean),
"DateField" => Ok(FieldType::Date),
"TimeField" => Ok(FieldType::Time),
"DateTimeField" => Ok(FieldType::DateTime),
"DurationField" => Ok(FieldType::BigInteger),
"FloatField" => Ok(FieldType::Float),
"DecimalField" => {
let precision = attributes
.get("max_digits")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(10);
let scale = attributes
.get("decimal_places")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(2);
Ok(FieldType::Decimal { precision, scale })
}
"BinaryField" => Ok(FieldType::Binary),
"UUIDField" => Ok(FieldType::Uuid),
"JSONField" => Ok(FieldType::Json),
"FileField" | "ImageField" => {
let max_length = attributes
.get("max_length")
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(100);
Ok(FieldType::VarChar(max_length))
}
"GenericIPAddressField" | "IPAddressField" => {
Ok(FieldType::VarChar(39)) }
"ForeignKey" => {
Ok(FieldType::BigInteger)
}
"OneToOneField" => Ok(FieldType::BigInteger),
other => Err(format!("Unsupported field type: {}", other)),
}
}
#[derive(Debug, Clone, Copy)]
pub enum SqlDialect {
Sqlite,
Postgres,
Mysql,
Cockroachdb,
}
#[derive(Debug, Clone)]
pub struct SqliteTableRecreation {
pub table_name: String,
pub new_columns: Vec<ColumnDefinition>,
pub columns_to_copy: Vec<String>,
pub constraints: Vec<Constraint>,
pub raw_constraint_sqls: Vec<String>,
pub without_rowid: bool,
}
impl SqliteTableRecreation {
pub fn for_drop_column(
table_name: impl Into<String>,
current_columns: Vec<ColumnDefinition>,
column_to_drop: &str,
current_constraints: Vec<Constraint>,
) -> Self {
let table_name = table_name.into();
let new_columns: Vec<_> = current_columns
.into_iter()
.filter(|c| c.name != column_to_drop)
.collect();
let columns_to_copy: Vec<_> = new_columns.iter().map(|c| c.name.to_string()).collect();
let constraints: Vec<_> = current_constraints
.into_iter()
.filter(|c| !Self::constraint_references_column(c, column_to_drop))
.collect();
Self {
table_name,
new_columns,
columns_to_copy,
constraints,
raw_constraint_sqls: Vec::new(),
without_rowid: false,
}
}
pub fn for_alter_column(
table_name: impl Into<String>,
current_columns: Vec<ColumnDefinition>,
column_name: &str,
new_definition: ColumnDefinition,
current_constraints: Vec<Constraint>,
) -> Self {
let table_name = table_name.into();
let new_columns: Vec<_> = current_columns
.into_iter()
.map(|c| {
if c.name == column_name {
new_definition.clone()
} else {
c
}
})
.collect();
let columns_to_copy: Vec<_> = new_columns.iter().map(|c| c.name.to_string()).collect();
Self {
table_name,
new_columns,
columns_to_copy,
constraints: current_constraints,
raw_constraint_sqls: Vec::new(),
without_rowid: false,
}
}
pub fn for_add_constraint(
table_name: impl Into<String>,
current_columns: Vec<ColumnDefinition>,
current_constraints: Vec<Constraint>,
constraint_sql: String,
) -> Self {
let table_name = table_name.into();
let columns_to_copy: Vec<_> = current_columns.iter().map(|c| c.name.to_string()).collect();
Self {
table_name,
new_columns: current_columns,
columns_to_copy,
constraints: current_constraints,
raw_constraint_sqls: vec![constraint_sql],
without_rowid: false,
}
}
pub fn for_drop_constraint(
table_name: impl Into<String>,
current_columns: Vec<ColumnDefinition>,
current_constraints: Vec<Constraint>,
constraint_name: &str,
) -> Self {
let table_name = table_name.into();
let columns_to_copy: Vec<_> = current_columns.iter().map(|c| c.name.to_string()).collect();
let constraints: Vec<_> = current_constraints
.into_iter()
.filter(|c| !Self::constraint_has_name(c, constraint_name))
.collect();
Self {
table_name,
new_columns: current_columns,
columns_to_copy,
constraints,
raw_constraint_sqls: Vec::new(),
without_rowid: false,
}
}
pub fn to_sql_statements(&self) -> Vec<String> {
let temp_table = format!("{}_new", self.table_name);
let column_defs: Vec<String> = self
.new_columns
.iter()
.map(|c| Operation::column_to_sql(c, &SqlDialect::Sqlite))
.collect();
let constraint_defs: Vec<String> = self.constraints.iter().map(|c| c.to_string()).collect();
let mut create_parts = column_defs;
create_parts.extend(constraint_defs);
create_parts.extend(self.raw_constraint_sqls.clone());
let mut create_sql = format!(
"CREATE TABLE \"{}\" (\n {}\n)",
temp_table,
create_parts.join(",\n ")
);
if self.without_rowid {
create_sql.push_str(" WITHOUT ROWID");
}
create_sql.push(';');
let columns_list = self
.columns_to_copy
.iter()
.map(|c| format!("\"{}\"", c))
.collect::<Vec<_>>()
.join(", ");
let insert_sql = format!(
"INSERT INTO \"{}\" SELECT {} FROM \"{}\";",
temp_table, columns_list, self.table_name
);
let drop_sql = format!("DROP TABLE \"{}\";", self.table_name);
let rename_sql = format!(
"ALTER TABLE \"{}\" RENAME TO \"{}\";",
temp_table, self.table_name
);
vec![create_sql, insert_sql, drop_sql, rename_sql]
}
fn constraint_references_column(constraint: &Constraint, column_name: &str) -> bool {
match constraint {
Constraint::PrimaryKey { columns, .. } => columns.iter().any(|c| c == column_name),
Constraint::ForeignKey { columns, .. } => columns.iter().any(|c| c == column_name),
Constraint::Unique { columns, .. } => columns.iter().any(|c| c == column_name),
Constraint::Check { expression, .. } => expression.contains(column_name),
Constraint::OneToOne { column, .. } => column == column_name,
Constraint::ManyToMany { source_column, .. } => source_column == column_name,
Constraint::Exclude { elements, .. } => {
elements.iter().any(|(col, _)| col == column_name)
}
}
}
fn constraint_has_name(constraint: &Constraint, constraint_name: &str) -> bool {
match constraint {
Constraint::PrimaryKey { name, .. } => name == constraint_name,
Constraint::ForeignKey { name, .. } => name == constraint_name,
Constraint::Unique { name, .. } => name == constraint_name,
Constraint::Check { name, .. } => name == constraint_name,
Constraint::OneToOne { name, .. } => name == constraint_name,
Constraint::ManyToMany { name, .. } => name == constraint_name,
Constraint::Exclude { name, .. } => name == constraint_name,
}
}
}
impl Operation {
pub fn requires_sqlite_recreation(&self) -> bool {
matches!(
self,
Operation::DropColumn { .. }
| Operation::AlterColumn { .. }
| Operation::AddConstraint { .. }
| Operation::DropConstraint { .. }
)
}
pub fn reverse_requires_sqlite_recreation(&self) -> bool {
matches!(
self,
Operation::AddColumn { .. }
| Operation::AlterColumn { .. }
| Operation::AddConstraint { .. }
| Operation::DropConstraint { .. }
)
}
pub fn to_reverse_operation(
&self,
project_state: &ProjectState,
) -> super::Result<Option<Operation>> {
match self {
Operation::CreateTable { name, .. } => {
Ok(Some(Operation::DropTable { name: name.clone() }))
}
Operation::DropTable { name } => {
if let Some(model) = project_state.find_model_by_table(name) {
let columns: Vec<ColumnDefinition> = model
.fields
.iter()
.map(|(field_name, field)| {
ColumnDefinition::from_field_state(field_name.clone(), field)
})
.collect();
let constraints: Vec<Constraint> = model
.constraints
.iter()
.map(|c| c.to_constraint())
.collect();
return Ok(Some(Operation::CreateTable {
name: name.clone(),
columns,
constraints,
without_rowid: None,
interleave_in_parent: None,
partition: None,
}));
}
Ok(None)
}
Operation::AddColumn { table, column, .. } => Ok(Some(Operation::DropColumn {
table: table.clone(),
column: column.name.clone(),
})),
Operation::DropColumn { table, column } => {
if let Some(model) = project_state.find_model_by_table(table)
&& let Some(field) = model.get_field(column)
{
let col_def = ColumnDefinition::from_field_state(column.clone(), field);
return Ok(Some(Operation::AddColumn {
table: table.clone(),
column: col_def,
mysql_options: None,
}));
}
Ok(None)
}
Operation::AlterColumn {
table,
column,
old_definition,
new_definition: _,
..
} => {
let resolved_old_def = old_definition.clone().or_else(|| {
project_state
.find_model_by_table(table)
.and_then(|model| model.get_field(column))
.map(|field| ColumnDefinition::from_field_state(column.clone(), field))
});
if let Some(col_def) = resolved_old_def {
return Ok(Some(Operation::AlterColumn {
table: table.clone(),
column: column.clone(),
old_definition: None,
new_definition: col_def,
mysql_options: None,
}));
}
Ok(None)
}
Operation::AddConstraint {
table,
constraint_sql,
} => {
if let Some(constraint_name) = Self::extract_constraint_name(constraint_sql) {
return Ok(Some(Operation::DropConstraint {
table: table.clone(),
constraint_name,
}));
}
Err(super::MigrationError::InvalidMigration(format!(
"Cannot extract constraint name from: {}",
constraint_sql
)))
}
Operation::DropConstraint {
table,
constraint_name,
} => {
if let Some(model) = project_state.find_model_by_table(table)
&& let Some(constraint_def) = model
.constraints
.iter()
.find(|c| c.name == *constraint_name)
{
let constraint = constraint_def.to_constraint();
return Ok(Some(Operation::AddConstraint {
table: table.clone(),
constraint_sql: format!("{}", constraint),
}));
}
Ok(None)
}
Operation::RenameTable { old_name, new_name } => Ok(Some(Operation::RenameTable {
old_name: new_name.clone(),
new_name: old_name.clone(),
})),
Operation::RenameColumn {
table,
old_name,
new_name,
} => Ok(Some(Operation::RenameColumn {
table: table.clone(),
old_name: new_name.clone(),
new_name: old_name.clone(),
})),
Operation::CreateIndex { table, columns, .. } => Ok(Some(Operation::DropIndex {
table: table.clone(),
columns: columns.clone(),
})),
Operation::DropIndex { table, columns } => {
Ok(Some(Operation::CreateIndex {
table: table.clone(),
columns: columns.clone(),
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
}))
}
Operation::RunSQL { .. } | Operation::RunRust { .. } | Operation::BulkLoad { .. } => {
Ok(None)
}
_ => Ok(None),
}
}
}
pub use Operation::{AddColumn, AlterColumn, CreateTable, DropColumn};
pub enum OperationStatement {
TableCreate(CreateTableStatement),
TableDrop(DropTableStatement),
TableAlter(AlterTableStatement),
TableRename(AlterTableStatement),
IndexCreate(CreateIndexStatement),
IndexDrop(DropIndexStatement),
RawSql(String),
}
impl OperationStatement {
pub async fn execute<'c, E>(&self, executor: E) -> Result<(), sqlx::Error>
where
E: sqlx::Executor<'c, Database = sqlx::Postgres>,
{
use crate::backends::sql_build_helpers;
use crate::backends::types::DatabaseType;
let db_type = DatabaseType::Postgres;
match self {
OperationStatement::TableCreate(stmt) => {
let sql = sql_build_helpers::build_create_table_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::TableDrop(stmt) => {
let sql = sql_build_helpers::build_drop_table_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::TableAlter(stmt) => {
let sql = sql_build_helpers::build_alter_table_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::TableRename(stmt) => {
let sql = sql_build_helpers::build_alter_table_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::IndexCreate(stmt) => {
let sql = sql_build_helpers::build_create_index_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::IndexDrop(stmt) => {
let sql = sql_build_helpers::build_drop_index_sql(db_type, stmt);
sqlx::query(&sql).execute(executor).await?;
}
OperationStatement::RawSql(sql) => {
sqlx::query(sql).execute(executor).await?;
}
}
Ok(())
}
pub fn to_sql_string(&self, db_type: crate::backends::types::DatabaseType) -> String {
use crate::backends::sql_build_helpers;
match self {
OperationStatement::TableCreate(stmt) => {
sql_build_helpers::build_create_table_sql(db_type, stmt)
}
OperationStatement::TableDrop(stmt) => {
sql_build_helpers::build_drop_table_sql(db_type, stmt)
}
OperationStatement::TableAlter(stmt) => {
sql_build_helpers::build_alter_table_sql(db_type, stmt)
}
OperationStatement::TableRename(stmt) => {
sql_build_helpers::build_alter_table_sql(db_type, stmt)
}
OperationStatement::IndexCreate(stmt) => {
sql_build_helpers::build_create_index_sql(db_type, stmt)
}
OperationStatement::IndexDrop(stmt) => {
sql_build_helpers::build_drop_index_sql(db_type, stmt)
}
OperationStatement::RawSql(sql) => sql.clone(),
}
}
}
impl Operation {
pub fn to_statement(&self) -> OperationStatement {
match self {
Operation::CreateTable {
name,
columns,
constraints,
..
} => {
OperationStatement::TableCreate(self.build_create_table(name, columns, constraints))
}
Operation::DropTable { name } => {
OperationStatement::TableDrop(self.build_drop_table(name))
}
Operation::AddColumn { table, column, .. } => {
OperationStatement::TableAlter(self.build_add_column(table, column))
}
Operation::DropColumn { table, column } => {
OperationStatement::TableAlter(self.build_drop_column(table, column))
}
Operation::AlterColumn {
table,
column,
new_definition,
..
} => OperationStatement::TableAlter(self.build_alter_column(
table,
column,
new_definition,
)),
Operation::RenameTable { old_name, new_name } => {
OperationStatement::TableRename(self.build_rename_table(old_name, new_name))
}
Operation::RenameColumn {
table,
old_name,
new_name,
} => OperationStatement::RawSql(format!(
"ALTER TABLE {} RENAME COLUMN {} TO {}",
quote_identifier(table),
quote_identifier(old_name),
quote_identifier(new_name)
)),
Operation::AddConstraint {
table,
constraint_sql,
} => {
OperationStatement::RawSql(format!(
"ALTER TABLE {} ADD {}",
quote_identifier(table),
constraint_sql
))
}
Operation::DropConstraint {
table,
constraint_name,
} => OperationStatement::RawSql(format!(
"ALTER TABLE {} DROP CONSTRAINT {}",
quote_identifier(table),
quote_identifier(constraint_name)
)),
Operation::CreateIndex {
table,
columns,
unique,
..
} => {
let idx_name = format!("idx_{}_{}", table, columns.join("_"));
OperationStatement::IndexCreate(
self.build_create_index(&idx_name, table, columns, *unique),
)
}
Operation::DropIndex { table, columns } => {
let idx_name = format!("idx_{}_{}", table, columns.join("_"));
OperationStatement::IndexDrop(self.build_drop_index(&idx_name))
}
Operation::RunSQL { sql, .. } => OperationStatement::RawSql(sql.to_string()),
Operation::RunRust { code, .. } => {
OperationStatement::RawSql(format!(
"-- RunRust: {}",
code.lines().next().unwrap_or("")
))
}
Operation::AlterTableComment { table, comment } => {
OperationStatement::RawSql(if let Some(comment_text) = comment {
format!(
"COMMENT ON TABLE {} IS '{}'",
quote_identifier(table),
comment_text.replace('\'', "''") )
} else {
format!("COMMENT ON TABLE {} IS NULL", quote_identifier(table))
})
}
Operation::AlterUniqueTogether {
table,
unique_together,
} => {
let mut sqls = Vec::new();
for (idx, fields) in unique_together.iter().enumerate() {
let constraint_name = format!("{}_{}_uniq", table, idx);
let fields_str: Vec<String> = fields
.iter()
.map(|f| quote_identifier(f).to_string())
.collect();
sqls.push(format!(
"ALTER TABLE {} ADD CONSTRAINT {} UNIQUE ({})",
quote_identifier(table),
quote_identifier(&constraint_name),
fields_str.join(", ")
));
}
OperationStatement::RawSql(sqls.join(";\n"))
}
Operation::AlterModelOptions { .. } => OperationStatement::RawSql(String::new()),
Operation::CreateInheritedTable {
name,
columns,
base_table,
join_column,
} => {
let mut stmt = Query::create_table();
stmt.table(Alias::new(name.as_str())).if_not_exists();
let join_col = ColumnDef::new(Alias::new(join_column.as_str()));
let join_col = join_col.integer();
stmt.col(join_col);
for col in columns {
let mut column = ColumnDef::new(Alias::new(col.name.as_str()));
column = self.apply_column_type(column, &col.type_definition);
stmt.col(column);
}
let mut fk = reinhardt_query::prelude::ForeignKey::create();
fk.from_tbl(Alias::new(name.as_str()))
.from_col(Alias::new(join_column.as_str()))
.to_tbl(Alias::new(base_table.as_str()))
.to_col(Alias::new("id"));
stmt.foreign_key_from_builder(&mut fk);
OperationStatement::TableCreate(stmt.to_owned())
}
Operation::AddDiscriminatorColumn {
table,
column_name,
default_value,
} => {
let mut stmt = Query::alter_table();
stmt.table(Alias::new(table.as_str()));
let mut col = ColumnDef::new(Alias::new(column_name.as_str()));
col = col
.string_len(50)
.default(SimpleExpr::from(default_value.to_string()));
stmt.add_column(col);
OperationStatement::TableAlter(stmt.to_owned())
}
Operation::MoveModel {
rename_table,
old_table_name,
new_table_name,
..
} => {
if *rename_table {
if let (Some(old_name), Some(new_name)) = (old_table_name, new_table_name) {
OperationStatement::TableRename(self.build_rename_table(old_name, new_name))
} else {
OperationStatement::RawSql("-- MoveModel: State-only operation".to_string())
}
} else {
OperationStatement::RawSql("-- MoveModel: State-only operation".to_string())
}
}
Operation::CreateSchema {
name,
if_not_exists,
} => {
let sql = if *if_not_exists {
format!("CREATE SCHEMA IF NOT EXISTS {}", quote_identifier(name))
} else {
format!("CREATE SCHEMA {}", quote_identifier(name))
};
OperationStatement::RawSql(sql)
}
Operation::DropSchema {
name,
cascade,
if_exists,
} => {
let if_exists_clause = if *if_exists { " IF EXISTS" } else { "" };
let cascade_clause = if *cascade { " CASCADE" } else { "" };
let sql = format!(
"DROP SCHEMA{} {}{}",
if_exists_clause,
quote_identifier(name),
cascade_clause
);
OperationStatement::RawSql(sql)
}
Operation::CreateExtension {
name,
if_not_exists,
schema,
} => {
let if_not_exists_clause = if *if_not_exists { " IF NOT EXISTS" } else { "" };
let schema_clause = if let Some(s) = schema {
format!(" SCHEMA {}", quote_identifier(s))
} else {
String::new()
};
let sql = format!(
"CREATE EXTENSION{} {}{}",
if_not_exists_clause,
quote_identifier(name),
schema_clause
);
OperationStatement::RawSql(sql)
}
Operation::BulkLoad {
table,
source,
format,
options,
} => {
OperationStatement::RawSql(Self::postgres_copy_from_sql(
table, source, format, options,
))
}
Operation::SetAutoIncrementValue { table, .. } => {
OperationStatement::RawSql(format!(
"SELECT 1/0 AS \"SetAutoIncrementValue on {} requires dialect-aware rendering; call Operation::to_sql(&dialect) instead of to_statement()\";",
table.replace('"', "\"\"")
))
}
Operation::CreateCompositePrimaryKey {
table,
columns,
constraint_name,
} => OperationStatement::RawSql(Self::create_composite_pk_to_sql(
table,
columns,
constraint_name.as_deref(),
)),
}
}
fn build_create_table(
&self,
name: &str,
columns: &[ColumnDefinition],
constraints: &[Constraint],
) -> CreateTableStatement {
let mut stmt = Query::create_table();
stmt.table(Alias::new(name)).if_not_exists();
for col in columns {
let mut column = ColumnDef::new(Alias::new(col.name.as_str()));
column = self.apply_column_type(column, &col.type_definition);
if col.not_null {
column = column.not_null(true);
}
if col.unique {
column = column.unique(true);
}
if col.primary_key {
column = column.primary_key(true);
}
if col.auto_increment {
column = column.auto_increment(true);
}
if let Some(default) = &col.default {
column = column.default(SimpleExpr::from(self.convert_default_value(default)));
}
stmt.col(column);
}
for constraint in constraints {
match constraint {
Constraint::PrimaryKey { columns, .. } => {
let col_idens: Vec<Alias> =
columns.iter().map(|c| Alias::new(c.as_str())).collect();
stmt.primary_key(col_idens);
}
Constraint::ForeignKey {
name,
columns,
referenced_table,
referenced_columns,
on_delete,
on_update,
..
} => {
let mut fk = reinhardt_query::prelude::ForeignKey::create();
fk.name(Alias::new(name.as_str()))
.from_tbl(Alias::new(name.as_str()))
.to_tbl(Alias::new(referenced_table.as_str()));
for col in columns {
fk.from_col(Alias::new(col.as_str()));
}
for col in referenced_columns {
fk.to_col(Alias::new(col.as_str()));
}
fk.on_delete((*on_delete).into());
fk.on_update((*on_update).into());
stmt.foreign_key_from_builder(&mut fk);
}
Constraint::Unique { columns, .. } => {
let col_idens: Vec<Alias> =
columns.iter().map(|c| Alias::new(c.as_str())).collect();
stmt.unique(col_idens);
}
Constraint::Check { name, expression } => {
let _ = (name, expression); }
Constraint::OneToOne {
name,
column,
referenced_table,
referenced_column,
on_delete,
on_update,
..
} => {
let mut fk = reinhardt_query::prelude::ForeignKey::create();
fk.name(Alias::new(name.as_str()))
.from_tbl(Alias::new(name.as_str()))
.to_tbl(Alias::new(referenced_table.as_str()))
.from_col(Alias::new(column.as_str()))
.to_col(Alias::new(referenced_column.as_str()))
.on_delete((*on_delete).into())
.on_update((*on_update).into());
stmt.foreign_key_from_builder(&mut fk);
}
Constraint::ManyToMany { .. } => {
}
Constraint::Exclude { .. } => {
}
}
}
stmt.to_owned()
}
fn build_drop_table(&self, name: &str) -> DropTableStatement {
Query::drop_table()
.table(Alias::new(name))
.if_exists()
.cascade()
.to_owned()
}
fn build_add_column(&self, table: &str, column: &ColumnDefinition) -> AlterTableStatement {
let mut stmt = Query::alter_table();
stmt.table(Alias::new(table));
let mut col_def = ColumnDef::new(Alias::new(column.name.as_str()));
col_def = self.apply_column_type(col_def, &column.type_definition);
if column.not_null {
col_def = col_def.not_null(true);
}
if let Some(default) = &column.default {
col_def = col_def.default(SimpleExpr::from(self.convert_default_value(default)));
}
stmt.add_column(col_def);
stmt.to_owned()
}
fn build_drop_column(&self, table: &str, column: &str) -> AlterTableStatement {
Query::alter_table()
.table(Alias::new(table))
.drop_column(Alias::new(column))
.to_owned()
}
fn build_alter_column(
&self,
table: &str,
column: &str,
new_definition: &ColumnDefinition,
) -> AlterTableStatement {
let mut stmt = Query::alter_table();
stmt.table(Alias::new(table));
let mut col_def = ColumnDef::new(Alias::new(column));
col_def = self.apply_column_type(col_def, &new_definition.type_definition);
if new_definition.not_null {
col_def = col_def.not_null(true);
}
stmt.modify_column(col_def);
stmt.to_owned()
}
fn build_rename_table(&self, old_name: &str, new_name: &str) -> AlterTableStatement {
Query::alter_table()
.table(Alias::new(old_name))
.rename_table(Alias::new(new_name))
.to_owned()
}
fn build_create_index(
&self,
name: &str,
table: &str,
columns: &[String],
unique: bool,
) -> CreateIndexStatement {
let mut stmt = Query::create_index();
stmt.name(Alias::new(name)).table(Alias::new(table));
for col in columns {
stmt.col(Alias::new(col));
}
if unique {
stmt.unique();
}
stmt.to_owned()
}
fn build_drop_index(&self, name: &str) -> DropIndexStatement {
Query::drop_index().name(Alias::new(name)).to_owned()
}
fn apply_column_type(&self, col_def: ColumnDef, field_type: &FieldType) -> ColumnDef {
use FieldType;
match field_type {
FieldType::Integer => col_def.integer(),
FieldType::BigInteger => col_def.big_integer(),
FieldType::SmallInteger => col_def.small_integer(),
FieldType::TinyInt => col_def.tiny_integer(),
FieldType::VarChar(max_length) => col_def.string_len(*max_length),
FieldType::Char(max_length) => col_def.char_len(*max_length),
FieldType::Text | FieldType::TinyText | FieldType::MediumText | FieldType::LongText => {
col_def.text()
}
FieldType::Boolean => col_def.custom(Alias::new("BOOLEAN")),
FieldType::DateTime => col_def.timestamp(),
FieldType::TimestampTz => col_def.timestamp_with_time_zone(),
FieldType::Date => col_def.date(),
FieldType::Time => col_def.time(),
FieldType::Decimal { precision, scale } => col_def.decimal(*precision, *scale),
FieldType::Float => col_def.float(),
FieldType::Double | FieldType::Real => col_def.double(),
FieldType::Json => col_def.json(),
FieldType::JsonBinary => col_def.json_binary(),
FieldType::Uuid => col_def.uuid(),
FieldType::Binary | FieldType::Bytea => col_def.binary(0),
FieldType::Blob | FieldType::TinyBlob | FieldType::MediumBlob | FieldType::LongBlob => {
col_def.binary(0)
}
FieldType::MediumInt => col_def.integer(),
FieldType::Year => col_def.small_integer(),
FieldType::Enum { values } => {
col_def.custom(Alias::new(format!("ENUM({})", values.join(","))))
}
FieldType::Set { values } => {
col_def.custom(Alias::new(format!("SET({})", values.join(","))))
}
FieldType::ForeignKey { .. } => {
col_def.integer()
}
FieldType::OneToOne { .. } => {
col_def.big_integer()
}
FieldType::ManyToMany { .. } => {
col_def.big_integer()
}
FieldType::Array(inner) => {
let inner_sql = inner.to_sql_string();
col_def.custom(Alias::new(format!("{}[]", inner_sql)))
}
FieldType::HStore => col_def.custom(Alias::new("HSTORE")),
FieldType::CIText => col_def.custom(Alias::new("CITEXT")),
FieldType::Int4Range => col_def.custom(Alias::new("INT4RANGE")),
FieldType::Int8Range => col_def.custom(Alias::new("INT8RANGE")),
FieldType::NumRange => col_def.custom(Alias::new("NUMRANGE")),
FieldType::DateRange => col_def.custom(Alias::new("DATERANGE")),
FieldType::TsRange => col_def.custom(Alias::new("TSRANGE")),
FieldType::TsTzRange => col_def.custom(Alias::new("TSTZRANGE")),
FieldType::TsVector => col_def.custom(Alias::new("TSVECTOR")),
FieldType::TsQuery => col_def.custom(Alias::new("TSQUERY")),
FieldType::Custom(custom_type) => col_def.custom(Alias::new(custom_type)),
}
}
fn convert_default_value(&self, default: &str) -> Value {
let trimmed = default.trim();
if trimmed.eq_ignore_ascii_case("null") {
return Value::String(None);
}
if trimmed.eq_ignore_ascii_case("true") {
return Value::Bool(Some(true));
}
if trimmed.eq_ignore_ascii_case("false") {
return Value::Bool(Some(false));
}
if let Ok(i) = trimmed.parse::<i64>() {
return Value::BigInt(Some(i));
}
if let Ok(f) = trimmed.parse::<f64>() {
return Value::Double(Some(f));
}
if (trimmed.starts_with('"') && trimmed.ends_with('"'))
|| (trimmed.starts_with('\'') && trimmed.ends_with('\''))
{
let unquoted = &trimmed[1..trimmed.len() - 1];
return Value::String(Some(Box::new(unquoted.to_string())));
}
if ((trimmed.starts_with('[') && trimmed.ends_with(']'))
|| (trimmed.starts_with('{') && trimmed.ends_with('}')))
&& let Ok(json) = serde_json::from_str::<serde_json::Value>(trimmed)
{
return json_to_sea_value(&json);
}
const SQL_CONSTANTS: &[&str] = &[
"CURRENT_TIMESTAMP",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_USER",
"SESSION_USER",
"LOCALTIME",
"LOCALTIMESTAMP",
];
if trimmed.ends_with("()") || trimmed.contains('(') {
return Value::String(Some(Box::new(trimmed.to_string())));
}
if SQL_CONSTANTS
.iter()
.any(|c| trimmed.eq_ignore_ascii_case(c))
{
return Value::String(Some(Box::new(trimmed.to_string())));
}
Value::String(Some(Box::new(format!("'{}'", trimmed.replace('\'', "''")))))
}
}
fn json_to_sea_value(json: &serde_json::Value) -> Value {
match json {
serde_json::Value::Null => Value::String(None),
serde_json::Value::Bool(b) => Value::Bool(Some(*b)),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::BigInt(Some(i))
} else if let Some(f) = n.as_f64() {
Value::Double(Some(f))
} else {
Value::String(Some(Box::new(n.to_string())))
}
}
serde_json::Value::String(s) => Value::String(Some(Box::new(s.clone()))),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
Value::String(Some(Box::new(json.to_string())))
}
}
}
use super::operation_trait::MigrationOperation;
impl MigrationOperation for Operation {
fn migration_name_fragment(&self) -> Option<String> {
match self {
Operation::CreateTable { name, .. } => Some(name.to_lowercase()),
Operation::DropTable { name } => Some(format!("delete_{}", name.to_lowercase())),
Operation::AddColumn { table, column, .. } => Some(format!(
"{}_{}",
table.to_lowercase(),
column.name.to_lowercase()
)),
Operation::DropColumn { table, column } => Some(format!(
"remove_{}_{}",
table.to_lowercase(),
column.to_lowercase()
)),
Operation::AlterColumn { table, column, .. } => Some(format!(
"alter_{}_{}",
table.to_lowercase(),
column.to_lowercase()
)),
Operation::RenameTable { old_name, new_name } => Some(format!(
"rename_{}_to_{}",
old_name.to_lowercase(),
new_name.to_lowercase()
)),
Operation::RenameColumn {
table, new_name, ..
} => Some(format!(
"rename_{}_{}",
table.to_lowercase(),
new_name.to_lowercase()
)),
Operation::AddConstraint { table, .. } => {
Some(format!("add_constraint_{}", table.to_lowercase()))
}
Operation::DropConstraint {
table: _,
constraint_name,
} => Some(format!(
"drop_constraint_{}",
constraint_name.to_lowercase()
)),
Operation::CreateIndex { table, unique, .. } => {
if *unique {
Some(format!("create_unique_index_{}", table.to_lowercase()))
} else {
Some(format!("create_index_{}", table.to_lowercase()))
}
}
Operation::DropIndex { table, .. } => {
Some(format!("drop_index_{}", table.to_lowercase()))
}
Operation::RunSQL { .. } => None, Operation::RunRust { .. } => None, Operation::AlterTableComment { table, .. } => {
Some(format!("alter_comment_{}", table.to_lowercase()))
}
Operation::AlterUniqueTogether { table, .. } => {
Some(format!("alter_unique_{}", table.to_lowercase()))
}
Operation::AlterModelOptions { table, .. } => {
Some(format!("alter_options_{}", table.to_lowercase()))
}
Operation::CreateInheritedTable { name, .. } => {
Some(format!("create_inherited_{}", name.to_lowercase()))
}
Operation::AddDiscriminatorColumn { table, .. } => {
Some(format!("add_discriminator_{}", table.to_lowercase()))
}
Operation::MoveModel {
model_name,
from_app,
to_app,
..
} => Some(format!(
"move_{}_{}_{}_{}",
from_app.to_lowercase(),
model_name.to_lowercase(),
to_app.to_lowercase(),
model_name.to_lowercase()
)),
Operation::CreateSchema { name, .. } => {
Some(format!("create_schema_{}", name.to_lowercase()))
}
Operation::DropSchema { name, .. } => {
Some(format!("drop_schema_{}", name.to_lowercase()))
}
Operation::CreateExtension { name, .. } => {
Some(format!("create_extension_{}", name.to_lowercase()))
}
Operation::BulkLoad { table, .. } => {
Some(format!("bulk_load_{}", table.to_lowercase()))
}
Operation::SetAutoIncrementValue { table, column, .. } => Some(format!(
"set_auto_increment_{}_{}",
table.to_lowercase(),
column.to_lowercase()
)),
Operation::CreateCompositePrimaryKey { table, .. } => {
Some(format!("composite_pk_{}", table.to_lowercase()))
}
}
}
fn describe(&self) -> String {
match self {
Operation::CreateTable { name, .. } => format!("Create table {}", name),
Operation::DropTable { name } => format!("Drop table {}", name),
Operation::AddColumn { table, column, .. } => {
format!("Add column {} to {}", column.name, table)
}
Operation::DropColumn { table, column } => {
format!("Drop column {} from {}", column, table)
}
Operation::AlterColumn { table, column, .. } => {
format!("Alter column {} on {}", column, table)
}
Operation::RenameTable { old_name, new_name } => {
format!("Rename table {} to {}", old_name, new_name)
}
Operation::RenameColumn {
table,
old_name,
new_name,
} => format!("Rename column {} to {} on {}", old_name, new_name, table),
Operation::AddConstraint { table, .. } => format!("Add constraint on {}", table),
Operation::DropConstraint {
table,
constraint_name,
} => format!("Drop constraint {} from {}", constraint_name, table),
Operation::CreateIndex { table, unique, .. } => {
if *unique {
format!("Create unique index on {}", table)
} else {
format!("Create index on {}", table)
}
}
Operation::DropIndex { table, .. } => format!("Drop index on {}", table),
Operation::RunSQL { sql, .. } => {
let preview = if sql.len() > 50 {
format!("{}...", &sql[..50])
} else {
(*sql).to_string()
};
format!("RunSQL: {}", preview)
}
Operation::RunRust { code, .. } => {
let preview = if code.len() > 50 {
format!("{}...", &code[..50])
} else {
(*code).to_string()
};
format!("RunRust: {}", preview)
}
Operation::AlterTableComment { table, comment } => match comment {
Some(c) => format!("Set comment on {} to '{}'", table, c),
None => format!("Remove comment from {}", table),
},
Operation::AlterUniqueTogether { table, .. } => {
format!("Alter unique_together on {}", table)
}
Operation::AlterModelOptions { table, .. } => {
format!("Alter model options on {}", table)
}
Operation::CreateInheritedTable {
name, base_table, ..
} => {
format!("Create inherited table {} from {}", name, base_table)
}
Operation::AddDiscriminatorColumn {
table, column_name, ..
} => format!("Add discriminator column {} to {}", column_name, table),
Operation::MoveModel {
model_name,
from_app,
to_app,
..
} => format!("Move model {} from {} to {}", model_name, from_app, to_app),
Operation::CreateSchema { name, .. } => format!("Create schema {}", name),
Operation::DropSchema { name, .. } => format!("Drop schema {}", name),
Operation::CreateExtension { name, .. } => format!("Create extension {}", name),
Operation::BulkLoad { table, source, .. } => {
let source_desc = match source {
BulkLoadSource::File(path) => format!("file '{}'", path),
BulkLoadSource::Stdin => "STDIN".to_string(),
BulkLoadSource::Program(cmd) => format!("program '{}'", cmd),
};
format!("Bulk load data into {} from {}", table, source_desc)
}
Operation::SetAutoIncrementValue {
table,
column,
value,
} => format!("Set auto-increment of {}.{} to {}", table, column, value),
Operation::CreateCompositePrimaryKey { table, columns, .. } => format!(
"Create composite primary key on {} ({})",
table,
columns.join(", ")
),
}
}
fn normalize(&self) -> Self
where
Self: Sized + Clone,
{
match self {
Operation::CreateTable {
name,
columns,
constraints,
without_rowid,
interleave_in_parent,
partition,
} => {
let mut sorted_columns = columns.clone();
sorted_columns.sort_by(|a, b| a.name.cmp(&b.name));
let mut sorted_constraints = constraints.clone();
sorted_constraints.sort();
Operation::CreateTable {
name: name.clone(),
columns: sorted_columns,
constraints: sorted_constraints,
without_rowid: *without_rowid,
interleave_in_parent: interleave_in_parent.clone(),
partition: partition.clone(),
}
}
Operation::CreateIndex {
table,
columns,
unique,
index_type,
where_clause,
concurrently,
expressions,
mysql_options,
operator_class,
} => {
let mut sorted_columns = columns.clone();
sorted_columns.sort();
Operation::CreateIndex {
table: table.clone(),
columns: sorted_columns,
unique: *unique,
index_type: *index_type,
where_clause: where_clause.clone(),
concurrently: *concurrently,
expressions: expressions.clone(),
mysql_options: *mysql_options,
operator_class: operator_class.clone(),
}
}
Operation::DropIndex { table, columns } => {
let mut sorted_columns = columns.clone();
sorted_columns.sort();
Operation::DropIndex {
table: table.clone(),
columns: sorted_columns,
}
}
Operation::AlterUniqueTogether {
table,
unique_together,
} => {
let mut sorted_unique_together: Vec<Vec<String>> = unique_together
.iter()
.map(|field_list| {
let mut sorted = field_list.clone();
sorted.sort();
sorted
})
.collect();
sorted_unique_together.sort();
Operation::AlterUniqueTogether {
table: table.clone(),
unique_together: sorted_unique_together,
}
}
Operation::AlterModelOptions { table, options } => Operation::AlterModelOptions {
table: table.clone(),
options: options.clone(),
},
_ => self.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use FieldType;
use rstest::rstest;
#[test]
fn test_create_table_to_statement() {
let op = Operation::CreateTable {
name: "users".to_string(),
columns: vec![
ColumnDefinition {
name: "id".to_string(),
type_definition: FieldType::Integer,
not_null: false,
unique: false,
primary_key: true,
auto_increment: true,
default: None,
},
ColumnDefinition {
name: "name".to_string(),
type_definition: FieldType::VarChar(100),
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("CREATE TABLE"),
"SQL should contain CREATE TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("id") && sql.contains("name"),
"SQL should contain both 'id' and 'name' columns, got: {}",
sql
);
}
#[test]
fn test_drop_table_to_statement() {
let op = Operation::DropTable {
name: "users".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("DROP TABLE"),
"SQL should contain DROP TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("CASCADE"),
"SQL should include CASCADE option, got: {}",
sql
);
}
#[test]
fn test_add_column_to_statement() {
let op = Operation::AddColumn {
table: "users".to_string(),
column: ColumnDefinition {
name: "email".to_string(),
type_definition: FieldType::VarChar(255),
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: Some("''".to_string()),
},
mysql_options: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("ADD COLUMN"),
"SQL should contain ADD COLUMN clause, got: {}",
sql
);
assert!(
sql.contains("email"),
"SQL should reference 'email' column, got: {}",
sql
);
}
#[test]
fn test_drop_column_to_statement() {
let op = Operation::DropColumn {
table: "users".to_string(),
column: "email".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("DROP COLUMN"),
"SQL should contain DROP COLUMN clause, got: {}",
sql
);
assert!(
sql.contains("email"),
"SQL should reference 'email' column, got: {}",
sql
);
}
#[test]
fn test_alter_column_to_statement() {
let op = Operation::AlterColumn {
table: "users".to_string(),
column: "age".to_string(),
old_definition: None,
new_definition: ColumnDefinition {
name: "age".to_string(),
type_definition: FieldType::BigInteger,
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("age"),
"SQL should reference 'age' column, got: {}",
sql
);
}
#[test]
fn test_rename_table_to_statement() {
let op = Operation::RenameTable {
old_name: "users".to_string(),
new_name: "accounts".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("users"),
"SQL should reference old table name 'users', got: {}",
sql
);
assert!(
sql.contains("accounts"),
"SQL should reference new table name 'accounts', got: {}",
sql
);
}
#[test]
fn test_rename_column_to_statement() {
let op = Operation::RenameColumn {
table: "users".to_string(),
old_name: "name".to_string(),
new_name: "full_name".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("RENAME COLUMN"),
"SQL should contain RENAME COLUMN clause, got: {}",
sql
);
assert!(
sql.contains("name"),
"SQL should reference old column name 'name', got: {}",
sql
);
assert!(
sql.contains("full_name"),
"SQL should reference new column name 'full_name', got: {}",
sql
);
}
#[test]
fn test_add_constraint_to_statement() {
let op = Operation::AddConstraint {
table: "users".to_string(),
constraint_sql: "CONSTRAINT age_check CHECK (age >= 0)".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("ADD"),
"SQL should contain ADD keyword, got: {}",
sql
);
assert!(
sql.contains("age_check"),
"SQL should contain constraint name 'age_check', got: {}",
sql
);
}
#[test]
fn test_drop_constraint_to_statement() {
let op = Operation::DropConstraint {
table: "users".to_string(),
constraint_name: "age_check".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("DROP CONSTRAINT"),
"SQL should contain DROP CONSTRAINT clause, got: {}",
sql
);
assert!(
sql.contains("age_check"),
"SQL should reference constraint 'age_check', got: {}",
sql
);
}
#[test]
fn test_create_index_to_statement() {
let op = Operation::CreateIndex {
table: "users".to_string(),
columns: vec!["email".to_string()],
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("CREATE INDEX"),
"SQL should contain CREATE INDEX keywords, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("email"),
"SQL should reference 'email' column, got: {}",
sql
);
}
#[test]
fn test_create_unique_index_to_statement() {
let op = Operation::CreateIndex {
table: "users".to_string(),
columns: vec!["email".to_string()],
unique: true,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("CREATE UNIQUE INDEX"),
"SQL should contain CREATE UNIQUE INDEX keywords, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("email"),
"SQL should reference 'email' column, got: {}",
sql
);
}
#[test]
fn test_drop_index_to_statement() {
let op = Operation::DropIndex {
table: "users".to_string(),
columns: vec!["email".to_string()],
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("DROP INDEX"),
"SQL should contain DROP INDEX keywords, got: {}",
sql
);
assert!(
sql.contains("idx_users_email"),
"SQL should contain generated index name 'idx_users_email', got: {}",
sql
);
}
#[test]
fn test_run_sql_to_statement() {
let op = Operation::RunSQL {
sql: "CREATE EXTENSION IF NOT EXISTS \"uuid-ossp\"".to_string(),
reverse_sql: Some("DROP EXTENSION \"uuid-ossp\"".to_string()),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("CREATE EXTENSION"),
"SQL should contain CREATE EXTENSION keywords, got: {}",
sql
);
assert!(
sql.contains("uuid-ossp"),
"SQL should reference 'uuid-ossp' extension, got: {}",
sql
);
}
#[test]
fn test_alter_table_comment_to_statement() {
let op = Operation::AlterTableComment {
table: "users".to_string(),
comment: Some("User accounts table".to_string()),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("COMMENT ON TABLE"),
"SQL should contain COMMENT ON TABLE keywords, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("User accounts table"),
"SQL should include comment text 'User accounts table', got: {}",
sql
);
}
#[test]
fn test_alter_table_comment_null_to_statement() {
let op = Operation::AlterTableComment {
table: "users".to_string(),
comment: None,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("COMMENT ON TABLE"),
"SQL should contain COMMENT ON TABLE keywords, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("NULL"),
"SQL should include NULL for null comment, got: {}",
sql
);
}
#[test]
fn test_alter_unique_together_to_statement() {
let op = Operation::AlterUniqueTogether {
table: "users".to_string(),
unique_together: vec![vec!["email".to_string(), "username".to_string()]],
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("ADD CONSTRAINT"),
"SQL should contain ADD CONSTRAINT clause, got: {}",
sql
);
assert!(
sql.contains("UNIQUE"),
"SQL should contain UNIQUE keyword, got: {}",
sql
);
assert!(
sql.contains("email") && sql.contains("username"),
"SQL should reference both 'email' and 'username' columns, got: {}",
sql
);
}
#[test]
fn test_alter_unique_together_empty() {
let op = Operation::AlterUniqueTogether {
table: "users".to_string(),
unique_together: vec![],
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert_eq!(
sql, "",
"SQL should be empty for empty unique_together constraint"
);
}
#[test]
fn test_alter_model_options_to_statement() {
let mut options = std::collections::HashMap::new();
options.insert("db_table".to_string(), "custom_users".to_string());
let op = Operation::AlterModelOptions {
table: "users".to_string(),
options,
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert_eq!(sql, "", "SQL should be empty for model options operation");
}
#[test]
fn test_create_inherited_table_to_statement() {
let op = Operation::CreateInheritedTable {
name: "admin_users".to_string(),
columns: vec![ColumnDefinition {
name: "admin_level".to_string(),
type_definition: FieldType::Integer,
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: Some("1".to_string()),
}],
base_table: "users".to_string(),
join_column: "user_id".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("CREATE TABLE"),
"SQL should contain CREATE TABLE keywords, got: {}",
sql
);
assert!(
sql.contains("admin_users"),
"SQL should reference 'admin_users' table, got: {}",
sql
);
assert!(
sql.contains("user_id"),
"SQL should include join column 'user_id', got: {}",
sql
);
}
#[test]
fn test_add_discriminator_column_to_statement() {
let op = Operation::AddDiscriminatorColumn {
table: "users".to_string(),
column_name: "user_type".to_string(),
default_value: "regular".to_string(),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("ALTER TABLE"),
"SQL should contain ALTER TABLE keyword, got: {}",
sql
);
assert!(
sql.contains("users"),
"SQL should reference 'users' table, got: {}",
sql
);
assert!(
sql.contains("ADD COLUMN"),
"SQL should contain ADD COLUMN clause, got: {}",
sql
);
assert!(
sql.contains("user_type"),
"SQL should reference 'user_type' column, got: {}",
sql
);
}
#[test]
fn test_state_forwards_create_table() {
let mut state = ProjectState::new();
let op = Operation::CreateTable {
name: "users".to_string(),
columns: vec![
ColumnDefinition {
name: "id".to_string(),
type_definition: FieldType::Integer,
not_null: false,
unique: false,
primary_key: true,
auto_increment: true,
default: None,
},
ColumnDefinition {
name: "name".to_string(),
type_definition: FieldType::VarChar(100),
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users");
assert!(model.is_some(), "Model 'users' should exist in state");
let model = model.unwrap();
assert_eq!(
model.fields.len(),
2,
"Model should have exactly 2 fields, got: {}",
model.fields.len()
);
assert!(
model.fields.contains_key("id"),
"Model should contain 'id' field"
);
assert!(
model.fields.contains_key("name"),
"Model should contain 'name' field"
);
}
#[test]
fn test_state_forwards_drop_table() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new("id".to_string(), FieldType::Integer, false));
state.add_model(model);
let op = Operation::DropTable {
name: "users".to_string(),
};
op.state_forwards("myapp", &mut state);
assert!(
state.get_model("myapp", "users").is_none(),
"Model 'users' should be removed from state after drop"
);
}
#[test]
fn test_state_forwards_add_column() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new("id".to_string(), FieldType::Integer, false));
state.add_model(model);
let op = Operation::AddColumn {
table: "users".to_string(),
column: ColumnDefinition {
name: "email".to_string(),
type_definition: FieldType::VarChar(255),
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users").unwrap();
assert_eq!(
model.fields.len(),
2,
"Model should have 2 fields after adding 'email', got: {}",
model.fields.len()
);
assert!(
model.fields.contains_key("email"),
"Model should contain newly added 'email' field"
);
}
#[test]
fn test_state_forwards_drop_column() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new("id".to_string(), FieldType::Integer, false));
model.add_field(FieldState::new(
"email".to_string(),
FieldType::VarChar(255),
false,
));
state.add_model(model);
let op = Operation::DropColumn {
table: "users".to_string(),
column: "email".to_string(),
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users").unwrap();
assert_eq!(
model.fields.len(),
1,
"Model should have 1 field after dropping 'email', got: {}",
model.fields.len()
);
assert!(
!model.fields.contains_key("email"),
"Model should not contain dropped 'email' field"
);
}
#[test]
fn test_state_forwards_rename_table() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new("id".to_string(), FieldType::Integer, false));
state.add_model(model);
let op = Operation::RenameTable {
old_name: "users".to_string(),
new_name: "accounts".to_string(),
};
op.state_forwards("myapp", &mut state);
assert!(
state.get_model("myapp", "users").is_none(),
"Old model name 'users' should not exist after rename"
);
assert!(
state.get_model("myapp", "accounts").is_some(),
"New model name 'accounts' should exist after rename"
);
}
#[test]
fn test_state_forwards_rename_column() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new(
"name".to_string(),
FieldType::VarChar(255),
false,
));
state.add_model(model);
let op = Operation::RenameColumn {
table: "users".to_string(),
old_name: "name".to_string(),
new_name: "full_name".to_string(),
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users").unwrap();
assert!(
!model.fields.contains_key("name"),
"Old field name 'name' should not exist after rename"
);
assert!(
model.fields.contains_key("full_name"),
"New field name 'full_name' should exist after rename"
);
}
#[test]
fn test_to_reverse_sql_create_table() {
let op = Operation::CreateTable {
name: "users".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let state = ProjectState::default();
let reverse = op.to_reverse_sql(&SqlDialect::Postgres, &state);
assert!(
reverse.is_ok() && reverse.as_ref().ok().unwrap().is_some(),
"CreateTable should have reverse SQL operation"
);
let sql = reverse.unwrap().unwrap();
assert!(
sql.contains("DROP TABLE"),
"Reverse SQL should contain DROP TABLE, got: {}",
sql
);
assert!(
sql.contains("users"),
"Reverse SQL should reference 'users' table, got: {}",
sql
);
}
#[test]
fn test_to_reverse_sql_drop_table() {
let op = Operation::DropTable {
name: "users".to_string(),
};
let state = ProjectState::default();
let reverse = op.to_reverse_sql(&SqlDialect::Postgres, &state);
assert!(
reverse.is_ok() && reverse.as_ref().ok().unwrap().is_none(),
"DropTable should not have reverse SQL (cannot recreate table structure)"
);
}
#[test]
fn test_to_reverse_sql_add_column() {
let op = Operation::AddColumn {
table: "users".to_string(),
column: ColumnDefinition {
name: "email".to_string(),
type_definition: FieldType::VarChar(255),
not_null: false,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
};
let state = ProjectState::default();
let reverse = op.to_reverse_sql(&SqlDialect::Postgres, &state);
assert!(
reverse.is_ok() && reverse.as_ref().ok().unwrap().is_some(),
"AddColumn should have reverse SQL operation"
);
let sql = reverse.unwrap().unwrap();
assert!(
sql.contains("DROP COLUMN"),
"Reverse SQL should contain DROP COLUMN, got: {}",
sql
);
assert!(
sql.contains("email"),
"Reverse SQL should reference 'email' column, got: {}",
sql
);
}
fn alter_column_with_old_def() -> Operation {
Operation::AlterColumn {
table: "products".to_string(),
column: "name".to_string(),
old_definition: Some(ColumnDefinition {
name: "name".to_string(),
type_definition: FieldType::VarChar(50),
not_null: false,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
}),
new_definition: ColumnDefinition {
name: "name".to_string(),
type_definition: FieldType::Text,
not_null: false,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
}
}
#[test]
fn test_to_reverse_sql_alter_column_postgres() {
let op = alter_column_with_old_def();
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.expect("reverse SQL should succeed")
.expect("reverse SQL should be present");
assert!(
sql.contains("ALTER COLUMN") && sql.contains("TYPE"),
"Postgres reverse SQL should use ALTER COLUMN ... TYPE syntax, got: {}",
sql
);
assert!(
sql.contains("VARCHAR(50)"),
"Postgres reverse SQL should restore VARCHAR(50), got: {}",
sql
);
}
#[test]
fn test_to_reverse_sql_alter_column_mysql() {
let op = alter_column_with_old_def();
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Mysql, &state)
.expect("reverse SQL should succeed")
.expect("reverse SQL should be present");
assert!(
sql.contains("MODIFY COLUMN"),
"MySQL reverse SQL should use MODIFY COLUMN syntax, got: {}",
sql
);
assert!(
!sql.contains("ALTER COLUMN"),
"MySQL reverse SQL must not emit Postgres ALTER COLUMN syntax, got: {}",
sql
);
assert!(
!sql.contains(" TYPE "),
"MySQL reverse SQL must not contain Postgres ' TYPE ' token, got: {}",
sql
);
assert!(
sql.contains("VARCHAR(50)"),
"MySQL reverse SQL should restore VARCHAR(50), got: {}",
sql
);
}
#[test]
fn test_to_reverse_sql_alter_column_cockroachdb() {
let op = alter_column_with_old_def();
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Cockroachdb, &state)
.expect("reverse SQL should succeed")
.expect("reverse SQL should be present");
let expected = "ALTER TABLE products ALTER COLUMN name TYPE VARCHAR(50);";
assert_eq!(
sql, expected,
"CockroachDB reverse SQL must match the pinned single-statement \
form exactly (no extra clauses), got: {}",
sql
);
let trimmed = sql.trim().trim_end_matches(';').trim();
assert!(
!trimmed.contains(';'),
"CockroachDB reverse SQL must be exactly one statement (no internal \
`;`), got: {}",
sql
);
assert!(
!sql.contains(",\n") && !sql.contains(", ALTER COLUMN"),
"CockroachDB reverse SQL must not emit the Postgres comma-combined \
form (CockroachDB rejects it), got: {}",
sql
);
assert!(
!sql.contains("SET NOT NULL") && !sql.contains("DROP NOT NULL"),
"CockroachDB reverse SQL must not emit nullability clause under \
the current stop-gap (tracked in #4640), got: {}",
sql
);
}
#[test]
fn test_to_reverse_sql_alter_column_sqlite() {
let op = alter_column_with_old_def();
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Sqlite, &state)
.expect("reverse SQL should succeed")
.expect("reverse SQL should be present");
assert!(
sql.trim_start().starts_with("--"),
"SQLite reverse SQL should be a SQL comment (recreation handled by executor), got: {}",
sql
);
let body = sql.trim_start_matches("--").trim_start();
assert!(
!body.to_uppercase().contains("ALTER TABLE"),
"SQLite reverse SQL body must not emit executable ALTER TABLE statement, got: {}",
sql
);
}
#[test]
fn test_to_reverse_operation_alter_column_uses_old_definition() {
let op = alter_column_with_old_def();
let state = ProjectState::default();
let reverse = op
.to_reverse_operation(&state)
.expect("reverse operation should succeed")
.expect("reverse operation should be present (old_definition is supplied)");
match reverse {
Operation::AlterColumn { new_definition, .. } => {
assert!(
matches!(new_definition.type_definition, FieldType::VarChar(50)),
"reverse AlterColumn should restore VARCHAR(50), got: {:?}",
new_definition.type_definition
);
}
other => panic!("reverse operation should be AlterColumn, got: {:?}", other),
}
}
#[test]
fn test_to_reverse_sql_run_sql_with_reverse() {
let op = Operation::RunSQL {
sql: "CREATE INDEX idx_name ON users(name)".to_string(),
reverse_sql: Some("DROP INDEX idx_name".to_string()),
};
let state = ProjectState::default();
let reverse = op.to_reverse_sql(&SqlDialect::Postgres, &state);
assert!(
reverse.is_ok() && reverse.as_ref().ok().unwrap().is_some(),
"RunSQL with reverse_sql should have reverse SQL"
);
let sql = reverse.unwrap().unwrap();
assert!(
sql.contains("DROP INDEX"),
"Reverse SQL should contain provided reverse_sql, got: {}",
sql
);
}
#[test]
fn test_to_reverse_sql_run_sql_without_reverse() {
let op = Operation::RunSQL {
sql: "CREATE INDEX idx_name ON users(name)".to_string(),
reverse_sql: None,
};
let state = ProjectState::default();
let reverse = op.to_reverse_sql(&SqlDialect::Postgres, &state);
assert!(
reverse.is_ok() && reverse.as_ref().ok().unwrap().is_none(),
"RunSQL without reverse_sql should not have reverse SQL"
);
}
#[test]
fn test_column_definition_new() {
let col = ColumnDefinition::new("id", FieldType::Integer);
assert_eq!(col.name, "id", "Column name should be 'id'");
assert_eq!(
col.type_definition,
FieldType::Integer,
"Column type should be Integer"
);
assert!(!col.not_null, "not_null should default to false");
assert!(!col.unique, "unique should default to false");
assert!(!col.primary_key, "primary_key should default to false");
assert!(
!col.auto_increment,
"auto_increment should default to false"
);
assert!(col.default.is_none(), "default should be None");
}
#[rstest]
fn from_field_state_non_optional_bool_with_true_default() {
let mut field_state = FieldState::new("is_active", FieldType::Boolean, false);
field_state
.params
.insert("default".to_string(), "true".to_string());
let col = ColumnDefinition::from_field_state("is_active", &field_state);
assert_eq!(col.name, "is_active", "Column name should round-trip");
assert_eq!(
col.type_definition,
FieldType::Boolean,
"Boolean field type should round-trip"
);
assert!(
col.not_null,
"Non-Optional bool must emit NOT NULL (regression #4573)"
);
assert_eq!(
col.default,
Some("true".to_string()),
"`#[field(default = true)]` must propagate as Some(\"true\")"
);
assert!(!col.primary_key, "Non-PK field must not be primary_key");
}
#[rstest]
fn from_field_state_non_optional_bool_with_false_default() {
let mut field_state = FieldState::new("is_superuser", FieldType::Boolean, false);
field_state
.params
.insert("default".to_string(), "false".to_string());
let col = ColumnDefinition::from_field_state("is_superuser", &field_state);
assert!(
col.not_null,
"Non-Optional bool with default=false must emit NOT NULL"
);
assert_eq!(
col.default,
Some("false".to_string()),
"default=false must propagate as Some(\"false\")"
);
}
#[rstest]
fn from_field_state_optional_bool_with_default() {
let mut field_state = FieldState::new("maybe_flag", FieldType::Boolean, true);
field_state
.params
.insert("default".to_string(), "true".to_string());
let col = ColumnDefinition::from_field_state("maybe_flag", &field_state);
assert!(
!col.not_null,
"Optional bool must remain NULLABLE — no regression on Option<T>"
);
assert_eq!(
col.default,
Some("true".to_string()),
"Default propagation must work for Optional fields too"
);
}
#[rstest]
fn from_field_state_non_optional_non_bool() {
let field_state = FieldState::new("username", FieldType::VarChar(150), false);
let col = ColumnDefinition::from_field_state("username", &field_state);
assert!(
col.not_null,
"Non-Optional String must emit NOT NULL (regression #4573 — bug \
affected all field types, not just bool)"
);
assert!(
col.default.is_none(),
"No default annotation → default = None"
);
}
#[rstest]
fn from_field_state_primary_key_is_always_not_null() {
let mut field_state = FieldState::new("id", FieldType::Uuid, true);
field_state
.params
.insert("primary_key".to_string(), "true".to_string());
let col = ColumnDefinition::from_field_state("id", &field_state);
assert!(
col.primary_key,
"primary_key param must propagate to ColumnDefinition"
);
assert!(
col.not_null,
"Primary key must be NOT NULL regardless of nullable flag"
);
}
#[rstest]
fn from_field_state_optional_field_remains_nullable() {
let field_state = FieldState::new("last_login", FieldType::TimestampTz, true);
let col = ColumnDefinition::from_field_state("last_login", &field_state);
assert!(
!col.not_null,
"Optional field with no default must remain NULLABLE"
);
assert!(col.default.is_none(), "No default → default = None");
assert!(!col.primary_key, "Non-PK field must not be primary_key");
}
#[test]
fn test_convert_default_value_null() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value("null");
assert!(
matches!(value, Value::String(None)),
"NULL value should be converted to Value::String(None)"
);
}
#[test]
fn test_convert_default_value_bool() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value("true");
assert!(
matches!(value, Value::Bool(Some(true))),
"'true' should be converted to Value::Bool(Some(true))"
);
let value = op.convert_default_value("false");
assert!(
matches!(value, Value::Bool(Some(false))),
"'false' should be converted to Value::Bool(Some(false))"
);
}
#[test]
fn test_convert_default_value_integer() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value("42");
assert!(
matches!(value, Value::BigInt(Some(42))),
"Integer '42' should be converted to Value::BigInt(Some(42))"
);
}
#[test]
fn test_convert_default_value_float() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value("3.15");
assert!(
matches!(value, Value::Double(_)),
"Float '3.15' should be converted to Value::Double"
);
}
#[test]
fn test_convert_default_value_string() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value("'hello'");
match value {
Value::String(Some(s)) => assert_eq!(
*s, "hello",
"Quoted string should be unquoted and stored as 'hello'"
),
_ => {
panic!("Expected Value::String(Some(\"hello\")), got different variant")
}
}
}
#[rstest]
#[case("pending", "'pending'")]
#[case("active", "'active'")]
#[case("hello world", "'hello world'")]
#[case("it's", "'it''s'")]
fn test_convert_default_value_plain_string(#[case] input: &str, #[case] expected: &str) {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value(input);
match value {
Value::String(Some(s)) => assert_eq!(
*s, expected,
"Plain string '{input}' should be auto-quoted as SQL string literal"
),
_ => {
panic!("Expected Value::String(Some(\"{expected}\")), got {value:?}")
}
}
}
#[rstest]
#[case("CURRENT_TIMESTAMP")]
#[case("current_timestamp")]
#[case("CURRENT_DATE")]
#[case("CURRENT_TIME")]
#[case("CURRENT_USER")]
#[case("SESSION_USER")]
#[case("LOCALTIME")]
#[case("LOCALTIMESTAMP")]
fn test_convert_default_value_sql_constant(#[case] input: &str) {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value(input);
match value {
Value::String(Some(s)) => {
assert_eq!(*s, input, "SQL constant '{input}' should remain unquoted")
}
_ => {
panic!("Expected Value::String(Some(\"{input}\")), got {value:?}")
}
}
}
#[rstest]
#[case("NOW()")]
#[case("uuid_generate_v4()")]
#[case("gen_random_uuid()")]
fn test_convert_default_value_sql_function(#[case] input: &str) {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let value = op.convert_default_value(input);
match value {
Value::String(Some(s)) => {
assert_eq!(*s, input, "SQL function '{input}' should remain unquoted")
}
_ => {
panic!("Expected Value::String(Some(\"{input}\")), got {value:?}")
}
}
}
#[test]
fn test_apply_column_type_integer() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let col = ColumnDef::new(Alias::new("id"));
let _col = op.apply_column_type(col, &FieldType::Integer);
}
#[test]
fn test_apply_column_type_varchar_with_length() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let col = ColumnDef::new(Alias::new("name"));
let _col = op.apply_column_type(col, &FieldType::VarChar(100));
}
#[test]
fn test_apply_column_type_custom() {
let op = Operation::CreateTable {
name: "test".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let col = ColumnDef::new(Alias::new("data"));
let _col = op.apply_column_type(col, &FieldType::Custom("CUSTOM_TYPE".to_string()));
}
#[test]
fn test_create_index_composite() {
let op = Operation::CreateIndex {
table: "users".to_string(),
columns: vec!["first_name".to_string(), "last_name".to_string()],
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let sql = op.to_sql(&SqlDialect::Postgres);
assert!(
sql.contains("first_name"),
"SQL should include 'first_name' column, got: {}",
sql
);
assert!(
sql.contains("last_name"),
"SQL should include 'last_name' column, got: {}",
sql
);
assert!(
sql.contains("idx_users_first_name_last_name"),
"SQL should include composite index name, got: {}",
sql
);
}
#[test]
fn test_alter_table_comment_with_quotes() {
let op = Operation::AlterTableComment {
table: "users".to_string(),
comment: Some("User's account table".to_string()),
};
let stmt = op.to_statement();
let sql = stmt.to_sql_string(crate::backends::types::DatabaseType::Postgres);
assert!(
sql.contains("COMMENT ON TABLE"),
"SQL should contain COMMENT ON TABLE keywords, got: {}",
sql
);
assert!(
sql.contains("User''s account table"),
"SQL should properly escape single quotes in comment, got: {}",
sql
);
}
#[test]
fn test_state_forwards_alter_column() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new(
"age".to_string(),
FieldType::Integer,
false,
));
state.add_model(model);
let op = Operation::AlterColumn {
table: "users".to_string(),
column: "age".to_string(),
old_definition: None,
new_definition: ColumnDefinition {
name: "age".to_string(),
type_definition: FieldType::BigInteger,
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users").unwrap();
let field = model.fields.get("age").unwrap();
assert_eq!(
field.field_type,
FieldType::BigInteger,
"Field type should be updated to BigInteger, got: {}",
field.field_type
);
}
#[test]
fn test_state_forwards_create_inherited_table() {
let mut state = ProjectState::new();
let op = Operation::CreateInheritedTable {
name: "admin_users".to_string(),
columns: vec![ColumnDefinition {
name: "admin_level".to_string(),
type_definition: FieldType::Integer,
not_null: true,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
}],
base_table: "users".to_string(),
join_column: "user_id".to_string(),
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "admin_users");
assert!(
model.is_some(),
"Inherited table 'admin_users' should exist in state"
);
let model = model.unwrap();
assert_eq!(
model.base_model,
Some("users".to_string()),
"base_model should be set to 'users'"
);
assert_eq!(
model.inheritance_type,
Some("joined_table".to_string()),
"inheritance_type should be 'joined_table'"
);
}
#[test]
fn test_state_forwards_add_discriminator_column() {
let mut state = ProjectState::new();
let mut model = ModelState::new("myapp", "users");
model.add_field(FieldState::new("id".to_string(), FieldType::Integer, false));
state.add_model(model);
let op = Operation::AddDiscriminatorColumn {
table: "users".to_string(),
column_name: "user_type".to_string(),
default_value: "regular".to_string(),
};
op.state_forwards("myapp", &mut state);
let model = state.get_model("myapp", "users").unwrap();
assert_eq!(
model.discriminator_column,
Some("user_type".to_string()),
"discriminator_column should be set to 'user_type'"
);
assert_eq!(
model.inheritance_type,
Some("single_table".to_string()),
"inheritance_type should be 'single_table'"
);
}
#[rstest]
fn test_to_reverse_sql_create_table_quotes_identifiers() {
let op = Operation::CreateTable {
name: "user-data".to_string(),
columns: vec![],
constraints: vec![],
without_rowid: None,
partition: None,
interleave_in_parent: None,
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "DROP TABLE \"user-data\";",
"Identifiers with special characters must be quoted"
);
}
#[rstest]
fn test_to_reverse_sql_add_column_quotes_identifiers() {
let op = Operation::AddColumn {
table: "my table".to_string(),
column: ColumnDefinition {
name: "my column".to_string(),
type_definition: FieldType::VarChar(255),
not_null: false,
unique: false,
primary_key: false,
auto_increment: false,
default: None,
},
mysql_options: None,
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "ALTER TABLE \"my table\" DROP COLUMN \"my column\";",
"Table and column names with spaces must be quoted"
);
}
#[rstest]
fn test_to_reverse_sql_rename_table_quotes_identifiers() {
let op = Operation::RenameTable {
old_name: "old; DROP TABLE users;--".to_string(),
new_name: "new-name".to_string(),
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "ALTER TABLE \"new-name\" RENAME TO \"old; DROP TABLE users;--\";",
"SQL injection attempt must be quoted as identifier"
);
}
#[rstest]
fn test_to_reverse_sql_rename_column_quotes_identifiers() {
let op = Operation::RenameColumn {
table: "my table".to_string(),
old_name: "old col".to_string(),
new_name: "new col".to_string(),
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "ALTER TABLE \"my table\" RENAME COLUMN \"new col\" TO \"old col\";",
"Identifiers with spaces must be quoted"
);
}
#[rstest]
fn test_to_reverse_sql_create_index_quotes_identifiers() {
let op = Operation::CreateIndex {
table: "my-table".to_string(),
columns: vec!["col a".to_string()],
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert!(
sql.contains("DROP INDEX \"idx_my-table_col a\""),
"Index name must be quoted, got: {}",
sql
);
}
#[rstest]
fn test_to_reverse_sql_create_index_emits_on_table_clause_for_mysql() {
let op = Operation::CreateIndex {
table: "users".to_string(),
columns: vec!["email".to_string()],
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Mysql, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "DROP INDEX idx_users_email ON users;",
"MySQL reverse SQL must include `ON <table>` clause"
);
}
#[rstest]
#[case(SqlDialect::Postgres, "DROP INDEX idx_users_email;")]
#[case(SqlDialect::Sqlite, "DROP INDEX idx_users_email;")]
#[case(SqlDialect::Cockroachdb, "DROP INDEX idx_users_email;")]
fn test_to_reverse_sql_create_index_omits_on_table_for_non_mysql(
#[case] dialect: SqlDialect,
#[case] expected: &str,
) {
let op = Operation::CreateIndex {
table: "users".to_string(),
columns: vec!["email".to_string()],
unique: false,
index_type: None,
where_clause: None,
concurrently: false,
expressions: None,
mysql_options: None,
operator_class: None,
};
let state = ProjectState::default();
let sql = op.to_reverse_sql(&dialect, &state).unwrap().unwrap();
assert_eq!(
sql, expected,
"Non-MySQL reverse SQL must remain unchanged for dialect {:?}",
dialect
);
}
#[rstest]
fn test_to_reverse_sql_add_constraint_quotes_identifiers() {
let op = Operation::AddConstraint {
table: "my-table".to_string(),
constraint_sql: "CONSTRAINT chk_positive CHECK (x > 0)".to_string(),
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert!(
sql.contains("ALTER TABLE \"my-table\""),
"Table name with special characters must be quoted, got: {}",
sql
);
assert!(
sql.contains("DROP CONSTRAINT"),
"Should contain DROP CONSTRAINT, got: {}",
sql
);
}
#[rstest]
fn test_to_reverse_sql_bulk_load_quotes_identifiers() {
let op = Operation::BulkLoad {
table: "user-data".to_string(),
source: BulkLoadSource::Stdin,
format: BulkLoadFormat::default(),
options: BulkLoadOptions::default(),
};
let state = ProjectState::default();
let sql = op
.to_reverse_sql(&SqlDialect::Postgres, &state)
.unwrap()
.unwrap();
assert_eq!(
sql, "TRUNCATE TABLE \"user-data\";",
"Table name must be quoted"
);
}
#[rstest]
#[case::postgres(SqlDialect::Postgres)]
#[case::cockroachdb(SqlDialect::Cockroachdb)]
fn test_set_auto_increment_postgres_uses_setval(#[case] dialect: SqlDialect) {
let op = Operation::SetAutoIncrementValue {
table: "users".to_string(),
column: "id".to_string(),
value: 1000,
};
let sql = op.to_sql(&dialect);
assert_eq!(
sql,
"SELECT setval(pg_get_serial_sequence('users', 'id'), 1000, false);"
);
}
#[test]
fn test_set_auto_increment_mysql_alters_table() {
let op = Operation::SetAutoIncrementValue {
table: "users".to_string(),
column: "id".to_string(),
value: 1000,
};
let sql = op.to_sql(&SqlDialect::Mysql);
assert_eq!(sql, "ALTER TABLE users AUTO_INCREMENT = 1000;");
}
#[test]
fn test_set_auto_increment_sqlite_upserts_sqlite_sequence() {
let op = Operation::SetAutoIncrementValue {
table: "users".to_string(),
column: "id".to_string(),
value: 1000,
};
let sql = op.to_sql(&SqlDialect::Sqlite);
assert_eq!(
sql,
"INSERT OR REPLACE INTO sqlite_sequence(name, seq) VALUES ('users', 1000);"
);
}
#[test]
fn test_set_auto_increment_postgres_escapes_literals() {
let op = Operation::SetAutoIncrementValue {
table: "user's".to_string(),
column: "id".to_string(),
value: 42,
};
let sql = op.to_sql(&SqlDialect::Postgres);
assert!(
sql.contains("'user''s'"),
"single quote in table name must be escaped: {}",
sql
);
}
#[rstest]
#[case::postgres(SqlDialect::Postgres)]
#[case::mysql(SqlDialect::Mysql)]
#[case::sqlite(SqlDialect::Sqlite)]
#[case::cockroachdb(SqlDialect::Cockroachdb)]
fn test_composite_pk_default_name(#[case] dialect: SqlDialect) {
let op = Operation::CreateCompositePrimaryKey {
table: "order_items".to_string(),
columns: vec!["order_id".to_string(), "line_number".to_string()],
constraint_name: None,
};
let sql = op.to_sql(&dialect);
assert!(
sql.contains("ALTER TABLE"),
"SQL should use ALTER TABLE: {}",
sql
);
assert!(
sql.contains("ADD CONSTRAINT"),
"SQL should add a named constraint: {}",
sql
);
assert!(
sql.contains("PRIMARY KEY"),
"SQL should add PRIMARY KEY: {}",
sql
);
assert!(
sql.contains("order_items_pkey"),
"Default constraint name should be table_pkey: {}",
sql
);
assert!(
sql.contains("order_id") && sql.contains("line_number"),
"Both PK columns must appear: {}",
sql
);
}
#[test]
fn test_composite_pk_custom_name_and_quoting() {
let op = Operation::CreateCompositePrimaryKey {
table: "tbl".to_string(),
columns: vec!["a".to_string(), "b".to_string()],
constraint_name: Some("my_pk".to_string()),
};
let sql = op.to_sql(&SqlDialect::Postgres);
assert_eq!(
sql,
"ALTER TABLE tbl ADD CONSTRAINT my_pk PRIMARY KEY (a, b);"
);
}
#[test]
fn test_composite_pk_empty_columns_produces_failing_sql() {
let op = Operation::CreateCompositePrimaryKey {
table: "tbl".to_string(),
columns: vec![],
constraint_name: None,
};
for dialect in [SqlDialect::Postgres, SqlDialect::Mysql, SqlDialect::Sqlite] {
let sql = op.to_sql(&dialect);
assert!(
sql.starts_with("SYNTAX_ERROR_create_composite_pk_on_")
&& sql.contains("requires_at_least_one_column"),
"Empty column list must emit a syntax-error statement with diagnostic ({:?}): {}",
dialect,
sql
);
assert!(
!sql.contains("SELECT 1/0"),
"Must not fall back to SELECT 1/0 (silently passes on SQLite / lax MySQL): {}",
sql
);
}
}
#[rstest]
#[case::big_integer(FieldType::BigInteger)]
#[case::integer(FieldType::Integer)]
#[case::small_integer(FieldType::SmallInteger)]
fn test_column_to_sql_sqlite_auto_increment_pk_emits_integer(#[case] field_type: FieldType) {
let mut col = ColumnDefinition::new("id", field_type);
col.primary_key = true;
col.auto_increment = true;
col.not_null = true;
let sql = Operation::column_to_sql(&col, &SqlDialect::Sqlite);
assert!(
sql.contains("INTEGER PRIMARY KEY AUTOINCREMENT"),
"SQLite auto_increment PK must emit `INTEGER PRIMARY KEY AUTOINCREMENT`: {}",
sql
);
assert!(
!sql.contains("BIGINT"),
"SQLite auto_increment must not emit BIGINT (rejected by SQLite): {}",
sql
);
assert!(
!sql.contains("SMALLINT"),
"SQLite auto_increment must not emit SMALLINT (rejected by SQLite): {}",
sql
);
}
#[test]
fn test_column_to_sql_sqlite_big_integer_without_auto_increment_no_autoincrement() {
let mut col = ColumnDefinition::new("count", FieldType::BigInteger);
col.not_null = true;
let sql = Operation::column_to_sql(&col, &SqlDialect::Sqlite);
assert!(
!sql.contains("AUTOINCREMENT"),
"Non-auto_increment column must not emit AUTOINCREMENT: {}",
sql
);
assert!(
!sql.contains("BIGINT"),
"emitter is expected to normalize BigInteger to INTEGER for SQLite: {}",
sql
);
}
#[test]
fn test_column_to_sql_postgres_big_integer_auto_increment_unchanged() {
let mut col = ColumnDefinition::new("id", FieldType::BigInteger);
col.primary_key = true;
col.auto_increment = true;
col.not_null = true;
let sql = Operation::column_to_sql(&col, &SqlDialect::Postgres);
assert!(
sql.contains("BIGINT GENERATED BY DEFAULT AS IDENTITY"),
"Postgres auto_increment BigInteger must emit identity syntax: {}",
sql
);
}
#[test]
fn test_column_to_sql_sqlite_auto_increment_uuid_pk_omits_autoincrement() {
let mut col = ColumnDefinition::new("id", FieldType::Uuid);
col.primary_key = true;
col.auto_increment = true;
col.not_null = true;
let sql = Operation::column_to_sql(&col, &SqlDialect::Sqlite);
assert!(
sql.contains("PRIMARY KEY"),
"UUID PK must still emit PRIMARY KEY: {}",
sql
);
assert!(
!sql.contains("AUTOINCREMENT"),
"non-integer auto_increment PK must not emit AUTOINCREMENT (SQLite rejects it): {}",
sql
);
assert!(
!sql.contains("INTEGER"),
"UUID column type must not be widened to INTEGER: {}",
sql
);
}
#[test]
fn test_column_to_sql_without_pk_sqlite_auto_increment_emits_integer() {
let mut col = ColumnDefinition::new("id", FieldType::BigInteger);
col.auto_increment = true;
col.not_null = true;
let sql = Operation::column_to_sql_without_pk(&col, &SqlDialect::Sqlite);
assert!(
sql.contains("INTEGER"),
"SQLite auto_increment column (composite PK path) must emit INTEGER: {}",
sql
);
assert!(
!sql.contains("BIGINT"),
"SQLite auto_increment must not emit BIGINT in composite PK path: {}",
sql
);
}
mod resolve_foreign_key_column_type_tests {
use super::super::resolve_foreign_key_column_type_with;
use super::FieldType;
use crate::migrations::autodetector::FieldState;
use crate::migrations::model_registry::{FieldMetadata, ModelMetadata, ModelRegistry};
fn target_model(app: &str, name: &str, table: &str, pk_type: FieldType) -> ModelMetadata {
let mut meta = ModelMetadata::new(app, name, table);
meta.add_field(
"id".to_string(),
FieldMetadata::new(pk_type).with_param("primary_key", "true"),
);
meta
}
fn fk_field_state(target_model: &str, target_app: Option<&str>) -> FieldState {
let mut fs = FieldState::new("owner_id", FieldType::Uuid, false);
fs.params
.insert("fk_target".to_string(), target_model.to_string());
if let Some(app) = target_app {
fs.params
.insert("fk_target_app".to_string(), app.to_string());
}
fs
}
#[test]
fn qualified_hit_resolves_to_target_pk_type() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"auth",
"User",
"auth_user",
FieldType::BigInteger,
));
let fs = fk_field_state("User", Some("auth"));
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, Some(FieldType::BigInteger));
}
#[test]
fn qualified_miss_falls_back_to_by_name_when_unambiguous() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"reinhardt_auth",
"User",
"auth_user",
FieldType::Uuid,
));
let fs = fk_field_state("User", Some("blog"));
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, Some(FieldType::Uuid));
}
#[test]
fn ambiguous_by_name_returns_none() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"auth",
"User",
"auth_user",
FieldType::BigInteger,
));
registry.register_model(target_model(
"billing",
"User",
"billing_user",
FieldType::Uuid,
));
let fs = fk_field_state("User", None);
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, None);
}
#[test]
fn path_typed_disambiguates_ambiguous_name() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"blog",
"User",
"blog_user",
FieldType::BigInteger,
));
registry.register_model(target_model(
"reinhardt_auth",
"User",
"reinhardt_auth_user",
FieldType::Uuid,
));
let fs = fk_field_state("User", Some("reinhardt_auth"));
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, Some(FieldType::Uuid));
}
#[test]
fn qualified_miss_with_ambiguous_by_name_returns_none() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"auth",
"User",
"auth_user",
FieldType::BigInteger,
));
registry.register_model(target_model(
"billing",
"User",
"billing_user",
FieldType::Uuid,
));
let fs = fk_field_state("User", Some("blog"));
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, None);
}
#[test]
fn no_fk_target_param_returns_none() {
let registry = ModelRegistry::new();
registry.register_model(target_model(
"auth",
"User",
"auth_user",
FieldType::BigInteger,
));
let fs = FieldState::new("name", FieldType::VarChar(64), false);
let resolved = resolve_foreign_key_column_type_with(&fs, ®istry);
assert_eq!(resolved, None);
}
}
}