use crate::app::context::AppContext;
use crate::db::migration::{DownArgs, MigrationInfo, MigrationStatus, Migrator, UpArgs};
use axum_core::extract::FromRef;
use diesel::Connection;
use diesel::backend::Backend;
use diesel::migration::{Migration, MigrationSource};
use diesel_migrations::{MigrationError, MigrationHarness};
use itertools::Itertools;
use serde_derive::Serialize;
use std::cmp::min;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Mutex;
use strum_macros::{EnumString, IntoStaticStr};
use tracing::info;
#[derive(Debug, Default, Serialize, EnumString, IntoStaticStr)]
pub enum MigrationSortOrder {
#[default]
None,
Name,
}
pub struct DieselMigrator<C>
where
C: Send + Connection + MigrationHarness<C::Backend>,
{
migrators: Vec<Box<dyn Send + Sync + MigrationSource<C::Backend>>>,
order: MigrationSortOrder,
_conn: PhantomData<Mutex<C>>,
}
impl<C> DieselMigrator<C>
where
C: Connection + Send + MigrationHarness<C::Backend>,
{
pub fn new(migrator: impl 'static + Send + Sync + MigrationSource<C::Backend>) -> Self {
Self {
migrators: vec![Box::new(migrator)],
order: Default::default(),
_conn: Default::default(),
}
}
pub fn add_migrator(
mut self,
migrator: impl 'static + Send + Sync + MigrationSource<C::Backend>,
) -> Self {
self.migrators.push(Box::new(migrator));
self
}
pub fn order(mut self, order: MigrationSortOrder) -> Self {
self.order = order;
self
}
}
#[async_trait::async_trait]
impl<S, C> Migrator<S> for DieselMigrator<C>
where
S: 'static + Send + Sync + Clone,
AppContext: FromRef<S>,
C: Connection + Send + MigrationHarness<C::Backend>,
{
type Error = crate::error::Error;
#[tracing::instrument(skip_all)]
async fn up(&self, state: &S, args: &UpArgs) -> Result<usize, Self::Error> {
info!("Started applying migrations");
let context = AppContext::from_ref(state);
let mut conn: C = Connection::establish(context.config().database.uri.as_ref())?;
let pending = conn.pending_migrations(DieselMigrationSourceWrapper::try_from(self)?)?;
let pending = if let Some(steps) = args.steps {
let steps = min(steps, pending.len());
pending.into_iter().take(steps).collect()
} else {
pending
};
let completed = conn.run_migrations(&pending)?;
let completed = completed.len();
info!(count = completed, "Completed applying migrations");
Ok(completed)
}
#[tracing::instrument(skip_all)]
async fn down(&self, state: &S, args: &DownArgs) -> Result<usize, Self::Error> {
info!("Started rolling back migrations");
let context = AppContext::from_ref(state);
let mut conn: C = Connection::establish(context.config().database.uri.as_ref())?;
let to_roll_back = conn.applied_migrations()?.len();
let to_roll_back = if let Some(steps) = args.steps {
min(steps, to_roll_back)
} else {
to_roll_back
};
let applied_versions = conn
.applied_migrations()?
.into_iter()
.take(to_roll_back)
.collect_vec();
let mut migrations = Vec::new();
for migrator in &self.migrators {
migrations.extend(migrator.migrations()?);
}
let mut migrations: HashMap<_, _> = migrations
.into_iter()
.map(|m: Box<dyn Migration<C::Backend>>| (m.name().version().as_owned(), m))
.collect();
for version in applied_versions {
let migration_to_revert = migrations
.remove(&version)
.ok_or(MigrationError::UnknownMigrationVersion(version))?;
info!(migration.name = %migration_to_revert.name(), "Rolling back migration");
conn.revert_migration(&migration_to_revert)?;
}
info!(count = to_roll_back, "Completed rolling back migrations");
Ok(to_roll_back)
}
#[tracing::instrument(skip_all)]
async fn status(&self, state: &S) -> Result<Vec<MigrationInfo>, Self::Error> {
let context = AppContext::from_ref(state);
let mut conn: C = Connection::establish(context.config().database.uri.as_ref())?;
let pending = conn
.pending_migrations(DieselMigrationSourceWrapper::try_from(self)?)?
.into_iter()
.map(|migration| {
MigrationInfo::builder()
.name(migration.name().to_string())
.status(MigrationStatus::Pending)
.build()
});
let mut migrations = Vec::new();
for migrator in &self.migrators {
migrations.extend(migrator.migrations()?);
}
let migrations: HashMap<_, _> = migrations
.into_iter()
.map(|m: Box<dyn Migration<C::Backend>>| (m.name().version().as_owned(), m))
.collect();
let applied = conn
.applied_migrations()?
.into_iter()
.map(|version| {
let name = migrations
.get(&version)
.map(|migration| migration.name().to_string())
.unwrap_or(version.to_string());
MigrationInfo::builder()
.name(name)
.status(MigrationStatus::Applied)
.build()
})
.rev();
let migrations = applied.into_iter().chain(pending.into_iter()).collect();
Ok(migrations)
}
}
struct DieselMigrationSourceWrapper<DB: Backend> {
migrations: Mutex<Option<Vec<Box<dyn Migration<DB>>>>>,
}
impl<C> TryFrom<&DieselMigrator<C>> for DieselMigrationSourceWrapper<C::Backend>
where
C: Connection + Send + MigrationHarness<C::Backend>,
{
type Error = Box<dyn Send + Sync + std::error::Error>;
fn try_from(value: &DieselMigrator<C>) -> Result<Self, Self::Error> {
let mut migrations = vec![];
for migrator in &value.migrators {
migrations.extend(migrator.migrations()?);
}
match value.order {
MigrationSortOrder::None => {}
MigrationSortOrder::Name => {
migrations.sort_by_key(|a| a.name().to_string());
}
}
Ok(Self {
migrations: Mutex::new(Some(migrations)),
})
}
}
impl<DB: Backend> MigrationSource<DB> for DieselMigrationSourceWrapper<DB> {
fn migrations(&self) -> diesel::migration::Result<Vec<Box<dyn Migration<DB>>>> {
let mut migrations = self.migrations.lock().map_err(crate::error::Error::from)?;
match migrations.take() {
Some(migrations) => Ok(migrations),
None => Err(crate::error::db::DbError::Message("DieselMigrationSourceWrapper#migrations was called twice! This is not supported as the migrations were removed by the first call.".to_owned()).into()),
}
}
}