use async_trait::async_trait;
use super::alter_table;
use crate::Executable;
use crate::schema::{DataType, DatabaseValue};
use crate::{Database, DatabaseError};
#[async_trait]
pub trait AutoReversible: Executable {
type Reversed: Executable;
fn reverse(&self) -> Self::Reversed;
}
#[async_trait]
impl<'a> AutoReversible for crate::schema::CreateTableStatement<'a> {
type Reversed = crate::schema::DropTableStatement<'a>;
fn reverse(&self) -> Self::Reversed {
crate::schema::DropTableStatement {
table_name: self.table_name, if_exists: true, #[cfg(feature = "cascade")]
behavior: crate::schema::DropBehavior::Default,
}
}
}
#[async_trait]
impl<'a> AutoReversible for crate::schema::CreateIndexStatement<'a> {
type Reversed = crate::schema::DropIndexStatement<'a>;
fn reverse(&self) -> Self::Reversed {
crate::schema::DropIndexStatement {
index_name: self.index_name,
table_name: self.table_name,
if_exists: true, }
}
}
#[cfg(feature = "auto-reverse")]
pub struct AddColumnOperation<'a> {
pub table_name: &'a str,
pub name: String,
pub data_type: DataType,
pub nullable: bool,
pub default: Option<DatabaseValue>,
}
#[cfg(feature = "auto-reverse")]
#[async_trait]
impl crate::Executable for AddColumnOperation<'_> {
async fn execute(&self, db: &dyn Database) -> Result<(), DatabaseError> {
alter_table(self.table_name)
.add_column(
self.name.clone(),
self.data_type.clone(),
self.nullable,
self.default.clone(),
)
.execute(db)
.await
}
}
#[cfg(feature = "auto-reverse")]
pub struct DropColumnOperation<'a> {
pub table_name: &'a str,
pub column_name: String,
}
#[cfg(feature = "auto-reverse")]
#[async_trait]
impl crate::Executable for DropColumnOperation<'_> {
async fn execute(&self, db: &dyn Database) -> Result<(), DatabaseError> {
alter_table(self.table_name)
.drop_column(self.column_name.clone())
.execute(db)
.await
}
}
#[cfg(feature = "auto-reverse")]
#[async_trait]
impl<'a> AutoReversible for AddColumnOperation<'a> {
type Reversed = DropColumnOperation<'a>;
fn reverse(&self) -> Self::Reversed {
DropColumnOperation {
table_name: self.table_name,
column_name: self.name.clone(),
}
}
}
#[cfg(feature = "auto-reverse")]
pub fn add_column(
table: &str,
name: impl Into<String>,
data_type: DataType,
nullable: bool,
default: Option<DatabaseValue>,
) -> AddColumnOperation<'_> {
AddColumnOperation {
table_name: table,
name: name.into(),
data_type,
nullable,
default,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DatabaseValue;
use crate::schema::{Column, DataType, create_index, create_table};
mod create_table {
use super::*;
#[test]
fn test_create_table_auto_reverse_basic() {
let create = create_table("users");
let drop = create.reverse();
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists); }
#[test]
fn test_create_table_auto_reverse_non_consuming() {
let create = create_table("users");
let drop = create.reverse();
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists);
assert_eq!(create.table_name, "users");
let drop2 = create.reverse();
assert_eq!(drop2.table_name, "users");
assert!(drop2.if_exists);
}
#[test]
fn test_create_table_auto_reverse_with_columns() {
let create = create_table("products")
.column(Column {
name: "id".to_string(),
data_type: DataType::Int,
nullable: false,
auto_increment: true,
default: None,
})
.column(Column {
name: "name".to_string(),
data_type: DataType::Text,
nullable: false,
auto_increment: false,
default: Some(DatabaseValue::String("Unknown".to_string())),
});
let drop = create.reverse();
assert_eq!(drop.table_name, "products");
assert!(drop.if_exists);
}
#[test]
fn test_create_table_auto_reverse_with_constraints() {
let create = create_table("orders")
.column(Column {
name: "id".to_string(),
data_type: DataType::Int,
nullable: false,
auto_increment: true,
default: None,
})
.primary_key("id")
.foreign_key(("user_id", "users.id"));
let drop = create.reverse();
assert_eq!(drop.table_name, "orders");
assert!(drop.if_exists);
}
#[test]
#[cfg(feature = "cascade")]
fn test_create_table_auto_reverse_cascade_behavior() {
let create = create_table("test_cascade");
let drop = create.reverse();
assert_eq!(drop.table_name, "test_cascade");
assert!(drop.if_exists);
assert_eq!(drop.behavior, crate::schema::DropBehavior::Default);
}
#[test]
fn test_create_table_auto_reverse_executable_trait() {
let create = create_table("async_test");
let drop = create.reverse();
assert_eq!(drop.table_name, "async_test");
assert!(drop.if_exists);
}
#[test]
fn test_create_table_auto_reverse_complex_async() {
let create = create_table("complex_async")
.column(Column {
name: "id".to_string(),
data_type: DataType::BigInt,
nullable: false,
auto_increment: true,
default: None,
})
.column(Column {
name: "timestamp".to_string(),
data_type: DataType::DateTime,
nullable: true,
auto_increment: false,
default: Some(DatabaseValue::String("CURRENT_TIMESTAMP".to_string())),
})
.primary_key("id");
let drop = create.reverse();
assert_eq!(drop.table_name, "complex_async");
assert!(drop.if_exists);
#[cfg(feature = "cascade")]
assert_eq!(drop.behavior, crate::schema::DropBehavior::Default);
}
}
mod create_index {
use super::*;
#[test]
fn test_create_index_auto_reverse_basic() {
let create = create_index("idx_users_email")
.table("users")
.column("email");
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_users_email");
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists); }
#[test]
fn test_create_index_auto_reverse_non_consuming() {
let create = create_index("idx_users_email")
.table("users")
.column("email");
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_users_email");
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists);
assert_eq!(create.index_name, "idx_users_email");
assert_eq!(create.table_name, "users");
let drop2 = create.reverse();
assert_eq!(drop2.index_name, "idx_users_email");
assert!(drop2.if_exists);
}
#[test]
fn test_create_index_auto_reverse_multi_column() {
let create = create_index("idx_users_name")
.table("users")
.columns(vec!["first_name", "last_name"]);
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_users_name");
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists);
}
#[test]
fn test_create_index_auto_reverse_unique() {
let create = create_index("idx_unique_email")
.table("users")
.column("email")
.unique(true);
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_unique_email");
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists);
}
#[test]
fn test_create_index_auto_reverse_if_not_exists() {
let create = create_index("idx_conditional")
.table("users")
.column("email")
.if_not_exists(true);
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_conditional");
assert_eq!(drop.table_name, "users");
assert!(drop.if_exists); }
#[test]
fn test_create_index_auto_reverse_executable_trait() {
let create = create_index("idx_async_test")
.table("test_table")
.column("test_column");
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_async_test");
assert_eq!(drop.table_name, "test_table");
assert!(drop.if_exists);
}
#[test]
fn test_create_index_auto_reverse_complex_async() {
let create = create_index("idx_complex_async")
.table("complex_table")
.columns(vec!["col1", "col2", "col3"])
.unique(true)
.if_not_exists(true);
let drop = create.reverse();
assert_eq!(drop.index_name, "idx_complex_async");
assert_eq!(drop.table_name, "complex_table");
assert!(drop.if_exists);
assert_eq!(create.index_name, "idx_complex_async");
assert_eq!(create.table_name, "complex_table");
assert_eq!(create.columns, vec!["col1", "col2", "col3"]);
assert!(create.unique);
assert!(create.if_not_exists);
}
}
mod add_column {
use super::*;
#[test]
#[cfg(feature = "auto-reverse")]
fn test_add_column_reversal() {
let add = add_column(
"users",
"age",
DataType::Int,
true, None, );
let drop = add.reverse();
assert_eq!(drop.table_name, "users");
assert_eq!(drop.column_name, "age");
}
#[test]
#[cfg(feature = "auto-reverse")]
fn test_add_column_with_default() {
let add = add_column(
"users",
"status",
DataType::Text,
false, Some(DatabaseValue::String("active".to_string())),
);
let drop = add.reverse();
assert_eq!(drop.table_name, "users");
assert_eq!(drop.column_name, "status");
}
#[test]
#[cfg(feature = "auto-reverse")]
fn test_add_column_non_consuming() {
let add = add_column(
"products",
"price",
DataType::Real,
true,
Some(DatabaseValue::Real64(0.0)),
);
let drop = add.reverse();
assert_eq!(drop.table_name, "products");
assert_eq!(drop.column_name, "price");
assert_eq!(add.table_name, "products");
assert_eq!(add.name, "price");
assert_eq!(add.data_type, DataType::Real);
assert!(add.nullable);
}
#[test]
#[cfg(feature = "auto-reverse")]
fn test_add_column_executable_trait() {
let add = add_column(
"async_table",
"new_column",
DataType::BigInt,
false,
Some(DatabaseValue::Int64(42)),
);
assert_eq!(add.table_name, "async_table");
assert_eq!(add.name, "new_column");
assert_eq!(add.data_type, DataType::BigInt);
assert!(!add.nullable);
}
#[test]
#[cfg(feature = "auto-reverse")]
fn test_add_column_complex_async() {
let add = add_column(
"complex_table",
"complex_column",
DataType::VarChar(255),
true,
Some(DatabaseValue::String("default_value".to_string())),
);
let drop = add.reverse();
assert_eq!(add.table_name, "complex_table");
assert_eq!(add.name, "complex_column");
assert_eq!(drop.table_name, "complex_table");
assert_eq!(drop.column_name, "complex_column");
}
}
}