#![cfg_attr(feature = "_docs", feature(doc_cfg))]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_lossless,
clippy::unreadable_literal,
clippy::doc_markdown,
clippy::module_name_repetitions
)]
use context::Extensions;
use db::{AppliedMigration, Migrations};
use futures_core::future::LocalBoxFuture;
use itertools::{EitherOrBoth, Itertools};
use sha2::{Digest, Sha256};
use sqlx::{ConnectOptions, Connection, Database, Pool};
use std::{
any::TypeId,
borrow::Cow,
cell::UnsafeCell,
str::FromStr,
time::{Duration, Instant},
};
pub mod context;
pub mod db;
pub mod error;
pub use context::MigrationContext;
pub use error::Error;
#[cfg(feature = "cli")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "cli")))]
pub mod cli;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
mod gen;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
pub use gen::generate;
type MigrationFn<DB> = Box<
dyn for<'future> Fn(
MigrationContext<'future, DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>,
>;
pub const DEFAULT_MIGRATIONS_TABLE: &str = "_sqlx_migrations";
pub mod prelude {
pub use super::Migration;
pub use super::MigrationContext;
pub use super::MigrationError;
pub use super::MigrationStatus;
pub use super::MigrationSummary;
pub use super::Migrator;
pub use super::MigratorOptions;
}
pub struct Migration<DB: Database> {
name: Cow<'static, str>,
up: MigrationFn<DB>,
down: Option<MigrationFn<DB>>,
}
impl<DB: Database> Migration<DB> {
pub fn new(
name: impl Into<Cow<'static, str>>,
up: impl for<'future> Fn(
MigrationContext<'future, DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
Self {
name: name.into(),
up: Box::new(up),
down: None,
}
}
#[must_use]
pub fn reversible(
mut self,
down: impl for<'future> Fn(
MigrationContext<'future, DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
self.down = Some(Box::new(down));
self
}
#[must_use]
pub fn revertible(
self,
down: impl for<'future> Fn(
MigrationContext<'future, DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
self.reversible(down)
}
#[must_use]
pub fn name(&self) -> &str {
self.name.as_ref()
}
#[must_use]
pub fn is_reversible(&self) -> bool {
self.down.is_some()
}
#[must_use]
pub fn is_revertible(&self) -> bool {
self.down.is_some()
}
}
impl<DB: Database> Eq for Migration<DB> {}
impl<DB: Database> PartialEq for Migration<DB> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
#[must_use]
pub struct Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
options: MigratorOptions,
conn: DB::Connection,
table: Cow<'static, str>,
migrations: Vec<Migration<DB>>,
extensions: UnsafeCell<Extensions>,
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
pub fn new(conn: DB::Connection) -> Self {
Self {
options: MigratorOptions::default(),
conn,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: UnsafeCell::new(Extensions::default()),
}
}
pub async fn connect(url: &str) -> Result<Self, sqlx::Error> {
let mut opts: <<DB as Database>::Connection as Connection>::Options = url.parse()?;
opts.disable_statement_logging();
Ok(Self {
options: MigratorOptions::default(),
conn: DB::Connection::connect_with(&opts).await?,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: UnsafeCell::new(Extensions::default()),
})
}
pub async fn connect_with(
options: &<DB::Connection as Connection>::Options,
) -> Result<Self, sqlx::Error> {
Ok(Self {
options: MigratorOptions::default(),
conn: DB::Connection::connect_with(options).await?,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: UnsafeCell::new(Extensions::default()),
})
}
pub async fn connect_with_pool(pool: &Pool<DB>) -> Result<Self, sqlx::Error> {
let conn = pool.acquire().await?;
Ok(Self {
options: MigratorOptions::default(),
conn: conn.detach(),
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: UnsafeCell::new(Extensions::default()),
})
}
pub fn set_migrations_table(&mut self, name: impl AsRef<str>) {
self.table = Cow::Owned(name.as_ref().to_string());
}
pub fn add_migrations(&mut self, migrations: impl IntoIterator<Item = Migration<DB>>) {
self.migrations.extend(migrations.into_iter());
}
pub fn set_options(&mut self, options: MigratorOptions) {
self.options = options;
}
pub fn with<T: Send + Sync + 'static>(mut self, value: T) -> Self {
self.set(value);
self
}
pub fn set<T: Send + Sync + 'static>(&mut self, value: T) {
unsafe {
(*self.extensions.get())
.map
.insert(TypeId::of::<T>(), Box::new(value));
}
}
pub fn local_migrations(&self) -> &[Migration<DB>] {
&self.migrations
}
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
pub async fn migrate(&mut self, target_version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(target_version)?;
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&db_migrations)?;
let to_apply = self.migrations.iter();
let tx = UnsafeCell::new(self.conn.begin().await?);
let db_version = db_migrations.len() as _;
for (idx, mig) in to_apply.enumerate() {
let mig_version = idx as u64 + 1;
if mig_version > target_version {
break;
}
if mig_version <= db_version {
continue;
}
let start = Instant::now();
tracing::info!(
version = mig_version,
name = %mig.name,
"applying migration"
);
let hasher = UnsafeCell::new(Sha256::new());
let ctx = MigrationContext {
hash_only: true,
ext: self.extensions.get(),
hasher: hasher.get(),
tx: tx.get(),
};
(*mig.up)(ctx).await.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = hasher.into_inner().finalize().to_vec();
let hasher = UnsafeCell::new(Sha256::new());
let ctx = MigrationContext {
hash_only: false,
ext: self.extensions.get(),
hasher: hasher.get(),
tx: tx.get(),
};
(*mig.up)(ctx).await.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let execution_time = start.elapsed();
if self.options.verify_checksums {
if let Some(db_mig) = db_migrations.get(idx) {
if db_mig.checksum != checksum {
let tx = tx.into_inner();
tx.rollback().await?;
return Err(Error::ChecksumMismatch {
version: mig_version,
local_checksum: checksum.clone().into(),
db_checksum: db_mig.checksum.clone(),
});
}
}
}
DB::Connection::add_migration(
&self.table,
AppliedMigration {
version: mig_version,
name: mig.name.clone(),
checksum: checksum.into(),
execution_time,
},
unsafe { &mut *tx.get() },
)
.await?;
tracing::info!(
version = mig_version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration applied"
);
}
tracing::info!("committing changes");
tx.into_inner().commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(target_version.max(db_version)),
})
}
pub async fn migrate_all(&mut self) -> Result<MigrationSummary, Error> {
if self.migrations.is_empty() {
return Ok(MigrationSummary {
new_version: None,
old_version: None,
});
}
self.migrate(self.migrations.len() as _).await
}
pub async fn revert(&mut self, target_version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(target_version)?;
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&db_migrations)?;
let to_revert = self
.migrations
.iter()
.enumerate()
.skip_while(|(idx, _)| idx + 1 < target_version as _)
.take_while(|(idx, _)| *idx < db_migrations.len())
.collect::<Vec<_>>()
.into_iter()
.rev();
let tx = UnsafeCell::new(self.conn.begin().await?);
for (idx, mig) in to_revert {
let version = idx as u64 + 1;
let start = Instant::now();
tracing::info!(
version,
name = %mig.name,
"reverting migration"
);
let hasher = UnsafeCell::new(Sha256::new());
let ctx = MigrationContext {
hash_only: false,
ext: self.extensions.get(),
hasher: hasher.get(),
tx: tx.get(),
};
match &mig.down {
Some(down) => {
down(ctx).await.map_err(|error| Error::Revert {
name: mig.name.clone(),
version,
error,
})?;
}
None => {
tracing::warn!(
version,
name = %mig.name,
"no down migration found"
);
}
}
let execution_time = start.elapsed();
DB::Connection::remove_migration(
&self.table,
version,
unsafe { &mut *tx.get() },
)
.await?;
tracing::info!(
version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration reverted"
);
}
tracing::info!("committing changes");
tx.into_inner().commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: if target_version == 1 {
None
} else {
Some(target_version - 1)
},
})
}
pub async fn revert_all(&mut self) -> Result<MigrationSummary, Error> {
self.revert(1).await
}
pub async fn force_version(&mut self, version: u64) -> Result<MigrationSummary, Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
if version == 0 {
self.conn.clear_migrations(&self.table).await?;
return Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: None,
});
}
self.local_migration(version)?;
let migrations = self
.migrations
.iter()
.enumerate()
.take_while(|(idx, _)| *idx < version as usize);
self.conn.clear_migrations(&self.table).await?;
let tx = UnsafeCell::new(self.conn.begin().await?);
for (idx, mig) in migrations {
let mig_version = idx as u64 + 1;
let hasher = UnsafeCell::new(Sha256::new());
let ctx = MigrationContext {
hash_only: true,
ext: self.extensions.get(),
hasher: hasher.get(),
tx: tx.get(),
};
(*mig.up)(ctx).await.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = hasher.into_inner().finalize().to_vec();
DB::Connection::add_migration(
&self.table,
AppliedMigration {
version: mig_version,
name: mig.name.clone(),
checksum: checksum.into(),
execution_time: Duration::default(),
},
unsafe { &mut *tx.get() },
)
.await?;
tracing::info!(
version = idx + 1,
name = %mig.name,
"migration forcibly set as applied"
);
}
tracing::info!("committing changes");
tx.into_inner().commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(version),
})
}
pub async fn verify(&mut self) -> Result<(), Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&migrations)?;
if self.options.verify_checksums {
for res in self.verify_checksums(&migrations).await? {
res?;
}
}
Ok(())
}
pub async fn status(&mut self) -> Result<Vec<MigrationStatus>, Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
let mut status = Vec::with_capacity(self.migrations.len());
let checksums = self.verify_checksums(&migrations).await?;
for (idx, pair) in self
.migrations
.iter()
.zip_longest(migrations.into_iter())
.enumerate()
{
let version = idx as u64 + 1;
match pair {
EitherOrBoth::Both(local, db) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
applied: Some(db),
missing_local: false,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
EitherOrBoth::Left(local) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
applied: None,
missing_local: false,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
EitherOrBoth::Right(r) => status.push(MigrationStatus {
version: r.version,
name: r.name.clone().into_owned(),
reversible: false,
applied: Some(r),
missing_local: true,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
}
}
Ok(status)
}
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
fn local_migration(&self, version: u64) -> Result<&Migration<DB>, Error> {
if version == 0 {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
if self.migrations.is_empty() {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
let idx = version - 1;
self.migrations
.get(idx as usize)
.ok_or(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
})
}
fn check_migrations(&mut self, migrations: &[AppliedMigration<'_>]) -> Result<(), Error> {
if self.migrations.len() < migrations.len() {
return Err(Error::MissingMigrations {
local_count: self.migrations.len(),
db_count: migrations.len(),
});
}
for (idx, (db_migration, local_migration)) in
migrations.iter().zip(self.migrations.iter()).enumerate()
{
let version = idx as u64 + 1;
if self.options.verify_names && db_migration.name != local_migration.name {
return Err(Error::NameMismatch {
version,
local_name: local_migration.name.clone(),
db_name: db_migration.name.to_string().into(),
});
}
}
Ok(())
}
async fn verify_checksums(
&mut self,
migrations: &[AppliedMigration<'_>],
) -> Result<Vec<Result<(), Error>>, Error> {
let mut results = Vec::with_capacity(self.migrations.len());
let local_migrations = self.migrations.iter();
let tx = UnsafeCell::new(self.conn.begin().await?);
for (idx, mig) in local_migrations.enumerate() {
let mig_version = idx as u64 + 1;
let hasher = UnsafeCell::new(Sha256::new());
let ctx = MigrationContext {
hash_only: true,
ext: self.extensions.get(),
hasher: hasher.get(),
tx: tx.get(),
};
(*mig.up)(ctx).await.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = hasher.into_inner().finalize().to_vec();
if let Some(db_mig) = migrations.get(idx) {
if db_mig.checksum == checksum {
results.push(Ok(()));
} else {
results.push(Err(Error::ChecksumMismatch {
version: mig_version,
local_checksum: checksum.clone().into(),
db_checksum: db_mig.checksum.clone().into_owned().into(),
}));
}
}
}
tx.into_inner().rollback().await?;
Ok(results)
}
}
#[derive(Debug)]
pub struct MigratorOptions {
pub verify_checksums: bool,
pub verify_names: bool,
}
impl Default for MigratorOptions {
fn default() -> Self {
Self {
verify_checksums: true,
verify_names: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MigrationSummary {
pub old_version: Option<u64>,
pub new_version: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct MigrationStatus {
pub version: u64,
pub name: String,
pub reversible: bool,
pub applied: Option<db::AppliedMigration<'static>>,
pub missing_local: bool,
pub checksum_ok: bool,
}
pub type MigrationError = anyhow::Error;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum DatabaseType {
Postgres,
Sqlite,
Any,
}
impl DatabaseType {
fn sqlx_type(self) -> &'static str {
match self {
DatabaseType::Postgres => "Postgres",
DatabaseType::Sqlite => "Sqlite",
DatabaseType::Any => "Any",
}
}
}
impl FromStr for DatabaseType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"postgres" => Ok(Self::Postgres),
"sqlite" => Ok(Self::Sqlite),
"any" => Ok(Self::Any),
db => Err(anyhow::anyhow!("invalid database type `{}`", db)),
}
}
}