#![cfg_attr(feature = "fail-on-warnings", deny(warnings))]
#![warn(clippy::all, clippy::pedantic, clippy::nursery, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use std::{future::Future, pin::Pin, sync::Arc};
use async_trait::async_trait;
use switchy_database::{Database, DatabaseError};
use switchy_schema::{
migration::{Migration, MigrationSource},
runner::MigrationRunner,
version::{DEFAULT_MIGRATIONS_TABLE, VersionTracker},
};
use crate::TestError;
pub struct MigrationTestBuilder<'a> {
migrations: Vec<Arc<dyn Migration<'a> + 'a>>,
breakpoints: Vec<Breakpoint<'a>>,
initial_setup: Option<SetupFn<'a>>,
with_rollback: bool,
table_name: Option<String>,
}
struct Breakpoint<'a> {
migration_id: String,
timing: BreakpointTiming,
action: SetupFn<'a>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum BreakpointTiming {
Before,
After,
}
type SetupFn<'a> = Box<
dyn for<'db> FnOnce(
&'db dyn Database,
)
-> Pin<Box<dyn Future<Output = Result<(), DatabaseError>> + Send + 'db>>
+ Send
+ 'a,
>;
struct VecMigrationSource<'a> {
migrations: Vec<Arc<dyn Migration<'a> + 'a>>,
}
impl<'a> VecMigrationSource<'a> {
#[must_use]
fn new(migrations: Vec<Arc<dyn Migration<'a> + 'a>>) -> Self {
Self { migrations }
}
}
#[async_trait]
impl<'a> MigrationSource<'a> for VecMigrationSource<'a> {
async fn migrations(&self) -> switchy_schema::Result<Vec<Arc<dyn Migration<'a> + 'a>>> {
Ok(self.migrations.clone()) }
}
impl<'a> MigrationTestBuilder<'a> {
#[must_use]
pub fn new(migrations: Vec<Arc<dyn Migration<'a> + 'a>>) -> Self {
Self {
migrations,
breakpoints: Vec::new(),
initial_setup: None,
with_rollback: false,
table_name: None,
}
}
#[must_use]
pub fn with_initial_setup<F>(mut self, setup: F) -> Self
where
F: for<'db> FnOnce(
&'db dyn Database,
)
-> Pin<Box<dyn Future<Output = Result<(), DatabaseError>> + Send + 'db>>
+ Send
+ 'a,
{
self.initial_setup = Some(Box::new(setup));
self
}
#[must_use]
pub fn with_data_before<F>(mut self, migration_id: &str, setup: F) -> Self
where
F: for<'db> FnOnce(
&'db dyn Database,
)
-> Pin<Box<dyn Future<Output = Result<(), DatabaseError>> + Send + 'db>>
+ Send
+ 'a,
{
self.breakpoints.push(Breakpoint {
migration_id: migration_id.to_string(),
timing: BreakpointTiming::Before,
action: Box::new(setup),
});
self
}
#[must_use]
pub fn with_data_after<F>(mut self, migration_id: &str, setup: F) -> Self
where
F: for<'db> FnOnce(
&'db dyn Database,
)
-> Pin<Box<dyn Future<Output = Result<(), DatabaseError>> + Send + 'db>>
+ Send
+ 'a,
{
self.breakpoints.push(Breakpoint {
migration_id: migration_id.to_string(),
timing: BreakpointTiming::After,
action: Box::new(setup),
});
self
}
#[must_use]
pub const fn with_rollback(mut self) -> Self {
self.with_rollback = true;
self
}
#[must_use]
pub fn with_table_name(mut self, table_name: impl Into<String>) -> Self {
self.table_name = Some(table_name.into());
self
}
pub async fn run(self, db: &dyn Database) -> Result<(), TestError> {
use std::collections::BTreeMap;
let migrations = self.migrations;
let breakpoints = self.breakpoints;
let initial_setup = self.initial_setup;
let with_rollback = self.with_rollback;
let table_name = self.table_name;
if let Some(setup) = initial_setup {
setup(db).await?;
}
let mut breakpoints_by_migration: BTreeMap<
usize,
(Vec<Breakpoint<'_>>, Vec<Breakpoint<'_>>),
> = BTreeMap::new();
for breakpoint in breakpoints {
let migration_index = migrations
.iter()
.position(|m| m.id() == breakpoint.migration_id)
.ok_or_else(|| {
TestError::Migration(switchy_schema::MigrationError::Validation(format!(
"Migration '{}' not found in migration list",
breakpoint.migration_id
)))
})?;
let entry = breakpoints_by_migration
.entry(migration_index)
.or_insert((Vec::new(), Vec::new()));
match breakpoint.timing {
BreakpointTiming::Before => entry.0.push(breakpoint),
BreakpointTiming::After => entry.1.push(breakpoint),
}
}
let mut current_migration_index = 0;
for (breakpoint_migration_index, (before_breakpoints, after_breakpoints)) in
breakpoints_by_migration
{
if current_migration_index < breakpoint_migration_index {
let migrations_to_run =
migrations[current_migration_index..breakpoint_migration_index].to_vec();
if !migrations_to_run.is_empty() {
let source = VecMigrationSource::new(migrations_to_run);
let mut runner = MigrationRunner::new(Box::new(source));
if let Some(ref table_name) = table_name {
runner = runner.with_table_name(table_name.clone());
}
runner.run(db).await?;
}
current_migration_index = breakpoint_migration_index;
}
let target_migration = &migrations[breakpoint_migration_index];
for breakpoint in before_breakpoints {
(breakpoint.action)(db).await?;
}
target_migration
.up(db)
.await
.map_err(TestError::Migration)?;
for breakpoint in after_breakpoints {
(breakpoint.action)(db).await?;
}
if let Some(ref table_name) = table_name {
Self::record_migration(db, table_name, target_migration.id()).await?;
} else {
Self::record_migration(db, DEFAULT_MIGRATIONS_TABLE, target_migration.id()).await?;
}
current_migration_index += 1;
}
if current_migration_index < migrations.len() {
let remaining_migrations = migrations[current_migration_index..].to_vec();
let source = VecMigrationSource::new(remaining_migrations);
let mut runner = MigrationRunner::new(Box::new(source));
if let Some(ref table_name) = table_name {
runner = runner.with_table_name(table_name.clone());
}
runner.run(db).await?;
}
if with_rollback {
let source = VecMigrationSource::new(migrations);
let mut runner = MigrationRunner::new(Box::new(source));
if let Some(ref table_name) = table_name {
runner = runner.with_table_name(table_name.clone());
}
runner
.rollback(db, switchy_schema::runner::RollbackStrategy::All)
.await?;
}
Ok(())
}
async fn record_migration(
db: &dyn Database,
table_name: &str,
migration_id: &str,
) -> Result<(), TestError> {
let version_tracker = VersionTracker::with_table_name(table_name);
version_tracker.ensure_table_exists(db).await?;
version_tracker.record_migration(db, migration_id).await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use switchy_database::{query, query::FilterableQuery};
use switchy_schema::migration::Migration;
struct TestMigration {
id: String,
up_sql: String,
down_sql: Option<String>,
}
impl TestMigration {
fn new(id: &str, up_sql: &str, down_sql: Option<&str>) -> Self {
Self {
id: id.to_string(),
up_sql: up_sql.to_string(),
down_sql: down_sql.map(String::from),
}
}
}
#[async_trait]
impl Migration<'static> for TestMigration {
fn id(&self) -> &str {
&self.id
}
async fn up(&self, db: &dyn Database) -> switchy_schema::Result<()> {
db.exec_raw(&self.up_sql).await?;
Ok(())
}
async fn down(&self, db: &dyn Database) -> switchy_schema::Result<()> {
if let Some(ref down_sql) = self.down_sql {
db.exec_raw(down_sql).await?;
}
Ok(())
}
fn description(&self) -> Option<&str> {
None
}
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_migration_test_builder_basic() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>];
MigrationTestBuilder::new(migrations)
.run(&*db)
.await
.unwrap();
let result = query::select("sqlite_master")
.columns(&["name"])
.where_eq("type", "table")
.where_eq("name", "users")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_migration_test_builder_default_persistence() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![Arc::new(TestMigration::new(
"001_create_test",
"CREATE TABLE test_table (id INTEGER)",
Some("DROP TABLE test_table"),
)) as Arc<dyn Migration<'static> + 'static>];
MigrationTestBuilder::new(migrations)
.run(&*db)
.await
.unwrap();
let result = query::select("sqlite_master")
.columns(&["name"])
.where_eq("type", "table")
.where_eq("name", "test_table")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_migration_test_builder_custom_table_name() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![Arc::new(TestMigration::new(
"001_create_test",
"CREATE TABLE test_table (id INTEGER)",
Some("DROP TABLE test_table"),
)) as Arc<dyn Migration<'static> + 'static>];
MigrationTestBuilder::new(migrations)
.with_table_name("__custom_migrations")
.run(&*db)
.await
.unwrap();
let result = query::select("sqlite_master")
.columns(&["name"])
.where_eq("type", "table")
.where_eq("name", "__custom_migrations")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_with_data_before_breakpoint() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![
Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>,
Arc::new(TestMigration::new(
"002_add_email_column",
"ALTER TABLE users ADD COLUMN email TEXT",
Some("ALTER TABLE users DROP COLUMN email"),
)) as Arc<dyn Migration<'static> + 'static>,
];
MigrationTestBuilder::new(migrations)
.with_data_before("002_add_email_column", |db| {
Box::pin(async move {
db.exec_raw("INSERT INTO users (name) VALUES ('Alice')")
.await?;
Ok(())
})
})
.run(&*db)
.await
.unwrap();
let result = query::select("users")
.columns(&["name", "email"])
.where_eq("name", "Alice")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
let row = &result[0];
assert_eq!(row.get("name").unwrap().as_str().unwrap(), "Alice");
assert_eq!(
row.get("email").unwrap(),
switchy_database::DatabaseValue::Null
);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_with_data_after_breakpoint() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![
Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>,
Arc::new(TestMigration::new(
"002_add_email_column",
"ALTER TABLE users ADD COLUMN email TEXT",
Some("ALTER TABLE users DROP COLUMN email"),
)) as Arc<dyn Migration<'static> + 'static>,
];
MigrationTestBuilder::new(migrations)
.with_data_after("002_add_email_column", |db| {
Box::pin(async move {
db.exec_raw(
"INSERT INTO users (name, email) VALUES ('Bob', 'bob@example.com')",
)
.await?;
Ok(())
})
})
.run(&*db)
.await
.unwrap();
let result = query::select("users")
.columns(&["name", "email"])
.where_eq("name", "Bob")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
let row = &result[0];
assert_eq!(row.get("name").unwrap().as_str().unwrap(), "Bob");
assert_eq!(
row.get("email").unwrap().as_str().unwrap(),
"bob@example.com"
);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_multiple_breakpoints_in_sequence() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![
Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>,
Arc::new(TestMigration::new(
"002_add_email_column",
"ALTER TABLE users ADD COLUMN email TEXT",
Some("ALTER TABLE users DROP COLUMN email"),
)) as Arc<dyn Migration<'static> + 'static>,
Arc::new(TestMigration::new(
"003_add_age_column",
"ALTER TABLE users ADD COLUMN age INTEGER",
Some("ALTER TABLE users DROP COLUMN age"),
)) as Arc<dyn Migration<'static> + 'static>,
];
MigrationTestBuilder::new(migrations)
.with_data_before("002_add_email_column", |db| {
Box::pin(async move {
db.exec_raw("INSERT INTO users (name) VALUES ('Alice')")
.await?;
Ok(())
})
})
.with_data_after("002_add_email_column", |db| {
Box::pin(async move {
db.exec_raw(
"UPDATE users SET email = 'alice@example.com' WHERE name = 'Alice'",
)
.await?;
Ok(())
})
})
.with_data_after("003_add_age_column", |db| {
Box::pin(async move {
db.exec_raw("UPDATE users SET age = 30 WHERE name = 'Alice'")
.await?;
Ok(())
})
})
.run(&*db)
.await
.unwrap();
let result = query::select("users")
.columns(&["name", "email", "age"])
.where_eq("name", "Alice")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
let row = &result[0];
assert_eq!(row.get("name").unwrap().as_str().unwrap(), "Alice");
assert_eq!(
row.get("email").unwrap().as_str().unwrap(),
"alice@example.com"
);
assert_eq!(row.get("age").unwrap().as_i64().unwrap(), 30);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_initial_setup_functionality() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>];
MigrationTestBuilder::new(migrations)
.with_initial_setup(|db| {
Box::pin(async move {
db.exec_raw("CREATE TABLE temp_setup (value TEXT)").await?;
db.exec_raw("INSERT INTO temp_setup VALUES ('setup_complete')")
.await?;
Ok(())
})
})
.run(&*db)
.await
.unwrap();
let result = query::select("temp_setup")
.columns(&["value"])
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
let row = &result[0];
assert_eq!(
row.get("value").unwrap().as_str().unwrap(),
"setup_complete"
);
let result = query::select("sqlite_master")
.columns(&["name"])
.where_eq("type", "table")
.where_eq("name", "users")
.execute(db.as_ref())
.await
.unwrap();
assert_eq!(result.len(), 1);
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_breakpoint_with_nonexistent_migration_id() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>];
let result = MigrationTestBuilder::new(migrations)
.with_data_before("999_nonexistent", |_db| Box::pin(async move { Ok(()) }))
.run(&*db)
.await;
assert!(result.is_err());
let error_msg = format!("{:?}", result.unwrap_err());
assert!(error_msg.contains("Migration '999_nonexistent' not found"));
}
#[cfg(feature = "sqlite")]
#[test_log::test(switchy_async::test)]
async fn test_rollback_works_with_breakpoints() {
let db = crate::create_empty_in_memory().await.unwrap();
let migrations = vec![
Arc::new(TestMigration::new(
"001_create_users",
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
Some("DROP TABLE users"),
)) as Arc<dyn Migration<'static> + 'static>,
Arc::new(TestMigration::new(
"002_add_email_column",
"ALTER TABLE users ADD COLUMN email TEXT",
Some("ALTER TABLE users DROP COLUMN email"),
)) as Arc<dyn Migration<'static> + 'static>,
];
MigrationTestBuilder::new(migrations)
.with_data_before("002_add_email_column", |db| {
Box::pin(async move {
db.exec_raw("INSERT INTO users (name) VALUES ('Alice')")
.await?;
Ok(())
})
})
.with_rollback() .run(&*db)
.await
.unwrap();
let result = query::select("sqlite_master")
.columns(&["name"])
.where_eq("type", "table")
.where_eq("name", "users")
.execute(db.as_ref())
.await;
assert!(result.is_err() || result.unwrap().is_empty());
}
}