use std::collections::{HashMap, HashSet};
use sqlx::{Connection, Pool};
use crate::error::Error;
use crate::migration::{AppliedMigrationSqlRow, Migration};
#[cfg(all(
any(feature = "postgres", feature = "mysql", feature = "sqlite"),
feature = "any"
))]
mod any;
#[cfg(feature = "mysql")]
mod mysql;
#[cfg(feature = "sqlite")]
mod sqlite;
#[cfg(feature = "postgres")]
mod postgres;
type MigrationVecResult<'a, DB> = Result<Vec<&'a Box<dyn Migration<DB>>>, Error>;
#[derive(Debug)]
pub enum PlanType {
All,
Apply,
Revert,
}
#[derive(Debug)]
pub struct Plan {
plan_type: PlanType,
app: Option<String>,
migration: Option<String>,
}
impl Plan {
pub fn new(
plan_type: PlanType,
app: Option<String>,
migration: Option<String>,
) -> Result<Self, Error> {
if migration.is_some() && app.is_none() {
return Err(Error::AppNameRequired);
}
Ok(Self {
plan_type,
app,
migration,
})
}
}
pub trait Info<DB>
where
DB: sqlx::Database,
{
fn migrations(&self) -> &HashSet<Box<dyn Migration<DB>>>;
fn migrations_mut(&mut self) -> &mut HashSet<Box<dyn Migration<DB>>>;
fn add_migrations(&mut self, migrations: Vec<Box<dyn Migration<DB>>>) {
for migration in migrations {
self.add_migration(migration);
}
}
fn add_migration(&mut self, migration: Box<dyn Migration<DB>>) {
let migration_parents = migration.parents();
let migration_replaces = migration.replaces();
let is_new_value = self.migrations_mut().insert(migration);
if is_new_value {
for parent in migration_parents {
self.add_migration(parent);
}
for replace in migration_replaces {
self.add_migration(replace);
}
}
}
}
#[async_trait::async_trait]
pub trait DatabaseOperation<DB>
where
DB: sqlx::Database,
{
async fn ensure_migration_table_exists(
&self,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error>;
async fn drop_migration_table_if_exists(
&self,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error>;
#[allow(clippy::borrowed_box)]
async fn add_migration_to_db_table(
&self,
migration: &Box<dyn Migration<DB>>,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error>;
#[allow(clippy::borrowed_box)]
async fn delete_migration_from_db_table(
&self,
migration: &Box<dyn Migration<DB>>,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error>;
async fn fetch_applied_migration_from_db(
&self,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<Vec<AppliedMigrationSqlRow>, Error>;
async fn lock(&self, connection: &mut <DB as sqlx::Database>::Connection) -> Result<(), Error>;
async fn unlock(
&self,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error>;
}
#[async_trait::async_trait]
pub trait Migrate<DB>: Info<DB> + DatabaseOperation<DB> + Send + Sync
where
DB: sqlx::Database,
{
async fn list_applied_migrations(
&self,
connection: &mut <DB as sqlx::Database>::Connection,
) -> MigrationVecResult<DB> {
if cfg!(feature = "tracing") {
tracing::info!("Fetching applied migrations");
}
self.ensure_migration_table_exists(connection).await?;
let applied_migration_list = self.fetch_applied_migration_from_db(connection).await?;
let mut applied_migrations = Vec::new();
for migration in self.migrations() {
if applied_migration_list
.iter()
.any(|sqlx_migration| sqlx_migration == migration)
{
applied_migrations.push(migration);
}
}
Ok(applied_migrations)
}
async fn generate_migration_plan(
&self,
plan: Plan,
connection: &mut <DB as sqlx::Database>::Connection,
) -> MigrationVecResult<DB> {
let applied_migrations = self.list_applied_migrations(connection).await?;
if cfg!(feature = "tracing") {
tracing::info!("Generating {:?} migration plan", plan);
}
let mut migration_plan = Vec::new();
let mut parents_due_to_run_before = HashMap::<_, Vec<_>>::new();
for migration in self.migrations() {
for run_before_migration in migration.run_before() {
parents_due_to_run_before
.entry(run_before_migration)
.or_default()
.push(migration);
}
}
while migration_plan.len() != self.migrations().len() {
let old_migration_plan_length = migration_plan.len();
for migration in self.migrations() {
let all_parents_applied = migration
.parents()
.iter()
.all(|migration| migration_plan.contains(&migration));
let all_run_before_parents_added = parents_due_to_run_before
.get(migration)
.unwrap_or(&vec![])
.iter()
.all(|migration| migration_plan.contains(migration));
if all_parents_applied
&& all_run_before_parents_added
&& !migration_plan.contains(&migration)
{
migration_plan.push(migration);
}
}
if old_migration_plan_length == migration_plan.len() {
return Err(Error::FailedToCreateMigrationPlan);
}
}
for migration in migration_plan.clone() {
if !migration.replaces().is_empty() {
let replaces_applied = migration
.replaces()
.iter()
.any(|replace_migration| applied_migrations.contains(&replace_migration));
if replaces_applied {
if applied_migrations.contains(&migration) {
return Err(Error::BothMigrationTypeApplied);
}
migration_plan.retain(|&plan_migration| migration != plan_migration);
} else {
for replaced_migration in migration.replaces() {
migration_plan
.retain(|&plan_migration| &replaced_migration != plan_migration);
}
}
}
}
match plan.plan_type {
PlanType::Apply => {
migration_plan.retain(|migration| !applied_migrations.contains(migration));
}
PlanType::Revert => {
migration_plan.retain(|migration| applied_migrations.contains(migration));
migration_plan.reverse();
}
PlanType::All => {}
};
if let Some(app) = plan.app {
let position = if let Some(name) = plan.migration {
let Some(pos) = migration_plan
.iter()
.rposition(|migration| migration.app() == app && migration.name() == name)
else {
if migration_plan
.iter()
.any(|migration| migration.app() == app)
{
return Err(Error::MigrationNameNotExists {
app,
migration: name,
});
}
return Err(Error::AppNameNotExists { app });
};
pos
} else {
let Some(pos) = migration_plan
.iter()
.rposition(|migration| migration.app() == app)
else {
return Err(Error::AppNameNotExists { app });
};
pos
};
migration_plan.truncate(position + 1);
}
Ok(migration_plan)
}
async fn apply_all(&self, pool: &Pool<DB>) -> Result<(), Error> {
let mut connection = pool.acquire().await?;
if cfg!(feature = "tracing") {
tracing::info!("Applying all migration");
}
self.lock(&mut connection).await?;
let plan = Plan::new(PlanType::Apply, None, None)?;
for migration in self.generate_migration_plan(plan, &mut connection).await? {
self.apply_migration(migration, &mut connection).await?;
}
self.unlock(&mut connection).await?;
connection.close().await?;
Ok(())
}
#[allow(clippy::borrowed_box)]
async fn apply_migration(
&self,
migration: &Box<dyn Migration<DB>>,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error> {
if cfg!(feature = "tracing") {
tracing::info!(
"Applying {} migration {}",
migration.app(),
migration.name()
);
}
if migration.is_atomic() {
let mut transaction = connection.begin().await?;
for operation in migration.operations() {
operation.up(&mut transaction).await?;
}
self.add_migration_to_db_table(migration, &mut transaction)
.await?;
transaction.commit().await?;
} else {
for operation in migration.operations() {
operation.up(connection).await?;
}
self.add_migration_to_db_table(migration, connection)
.await?;
}
Ok(())
}
async fn revert_all(&self, pool: &Pool<DB>) -> Result<(), Error> {
let mut connection = pool.acquire().await?;
if cfg!(feature = "tracing") {
tracing::info!("Reverting all migration");
}
self.lock(&mut connection).await?;
let plan = Plan::new(PlanType::Revert, None, None)?;
for migration in self.generate_migration_plan(plan, &mut connection).await? {
self.revert_migration(migration, &mut connection).await?;
}
self.unlock(&mut connection).await?;
connection.close().await?;
Ok(())
}
#[allow(clippy::borrowed_box)]
async fn revert_migration(
&self,
migration: &Box<dyn Migration<DB>>,
connection: &mut <DB as sqlx::Database>::Connection,
) -> Result<(), Error> {
if cfg!(feature = "tracing") {
tracing::info!(
"Reverting {} migration {}",
migration.app(),
migration.name()
);
}
let mut operations = migration.operations();
operations.reverse();
if migration.is_atomic() {
let mut transaction = connection.begin().await?;
for operation in operations {
operation.down(&mut transaction).await?;
}
self.delete_migration_from_db_table(migration, &mut transaction)
.await?;
transaction.commit().await?;
} else {
for operation in operations {
operation.down(connection).await?;
}
self.delete_migration_from_db_table(migration, connection)
.await?;
}
Ok(())
}
}
pub struct Migrator<DB>
where
DB: sqlx::Database,
{
migrations: HashSet<Box<dyn Migration<DB>>>,
}
impl<DB> Default for Migrator<DB>
where
DB: sqlx::Database,
{
fn default() -> Self {
Self {
migrations: HashSet::default(),
}
}
}
impl<DB> Info<DB> for Migrator<DB>
where
DB: sqlx::Database,
{
fn migrations(&self) -> &HashSet<Box<dyn Migration<DB>>> {
&self.migrations
}
fn migrations_mut(&mut self) -> &mut HashSet<Box<dyn Migration<DB>>> {
&mut self.migrations
}
}