#![cfg_attr(feature = "_docs", feature(doc_cfg))]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_lossless,
clippy::unreadable_literal,
clippy::doc_markdown,
clippy::module_name_repetitions
)]
use db::{AppliedMigration, Migrations};
use futures_core::future::LocalBoxFuture;
use itertools::{EitherOrBoth, Itertools};
use sha2::{Digest, Sha256};
use sqlx::{ConnectOptions, Connection, Database, Executor, Pool};
use state::TypeMap;
use std::{
borrow::Cow,
str::FromStr,
sync::Arc,
time::{Duration, Instant},
};
pub mod context;
pub mod db;
pub mod error;
pub use context::MigrationContext;
pub use error::Error;
#[cfg(feature = "cli")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "cli")))]
pub mod cli;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
mod gen;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
pub use gen::generate;
type MigrationFn<DB> =
Box<dyn Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>>>;
pub const DEFAULT_MIGRATIONS_TABLE: &str = "_sqlx_migrations";
pub mod prelude {
pub use super::Migration;
pub use super::MigrationContext;
pub use super::MigrationError;
pub use super::MigrationStatus;
pub use super::MigrationSummary;
pub use super::Migrator;
pub use super::MigratorOptions;
}
pub struct Migration<DB: Database> {
name: Cow<'static, str>,
up: MigrationFn<DB>,
down: Option<MigrationFn<DB>>,
}
impl<DB: Database> Migration<DB> {
pub fn new(
name: impl Into<Cow<'static, str>>,
up: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
) -> Self {
Self {
name: name.into(),
up: Box::new(up),
down: None,
}
}
#[must_use]
pub fn reversible(
mut self,
down: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
) -> Self {
self.down = Some(Box::new(down));
self
}
#[must_use]
pub fn revertible(
self,
down: impl Fn(&mut MigrationContext<DB>) -> LocalBoxFuture<Result<(), MigrationError>> + 'static,
) -> Self {
self.reversible(down)
}
#[must_use]
pub fn name(&self) -> &str {
self.name.as_ref()
}
#[must_use]
pub fn is_reversible(&self) -> bool {
self.down.is_some()
}
#[must_use]
pub fn is_revertible(&self) -> bool {
self.down.is_some()
}
}
impl<DB: Database> Eq for Migration<DB> {}
impl<DB: Database> PartialEq for Migration<DB> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
#[must_use]
pub struct Migrator<Db>
where
Db: Database,
Db::Connection: db::Migrations,
{
options: MigratorOptions,
conn: Db::Connection,
table: Cow<'static, str>,
migrations: Vec<Migration<Db>>,
extensions: Arc<TypeMap!(Send + Sync)>,
}
impl<Db> Migrator<Db>
where
Db: Database,
Db::Connection: db::Migrations,
for<'a> &'a mut Db::Connection: Executor<'a>,
{
pub fn new(conn: Db::Connection) -> Self {
Self {
options: MigratorOptions::default(),
conn,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
}
}
pub async fn connect(url: &str) -> Result<Self, sqlx::Error> {
let mut opts: <<Db as Database>::Connection as Connection>::Options = url.parse()?;
opts = opts.disable_statement_logging();
let mut conn = Db::Connection::connect_with(&opts).await?;
conn.execute(
r#"--sql
SET client_min_messages TO WARNING;
"#,
)
.await?;
Ok(Self {
options: MigratorOptions::default(),
conn,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
})
}
pub async fn connect_with(
options: &<Db::Connection as Connection>::Options,
) -> Result<Self, sqlx::Error> {
let mut conn = Db::Connection::connect_with(options).await?;
conn.execute(
r#"--sql
SET client_min_messages TO WARNING;
"#,
)
.await?;
Ok(Self {
options: MigratorOptions::default(),
conn,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
})
}
pub async fn connect_with_pool(pool: &Pool<Db>) -> Result<Self, sqlx::Error> {
let mut conn = pool.acquire().await?;
conn.execute(
r#"--sql
SET client_min_messages TO WARNING;
"#,
)
.await?;
Ok(Self {
options: MigratorOptions::default(),
conn: conn.detach(),
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
extensions: Arc::new(<TypeMap![Send + Sync]>::new()),
})
}
pub fn set_migrations_table(&mut self, name: impl AsRef<str>) {
self.table = Cow::Owned(name.as_ref().to_string());
}
pub fn add_migrations(&mut self, migrations: impl IntoIterator<Item = Migration<Db>>) {
self.migrations.extend(migrations);
}
pub fn set_options(&mut self, options: MigratorOptions) {
self.options = options;
}
pub fn with<T: Send + Sync + 'static>(&mut self, value: T) -> &mut Self {
self.set(value);
self
}
pub fn set<T: Send + Sync + 'static>(&mut self, value: T) {
self.extensions.set(value);
}
pub fn local_migrations(&self) -> &[Migration<Db>] {
&self.migrations
}
}
impl<Db> Migrator<Db>
where
Db: Database,
Db::Connection: db::Migrations,
for<'a> &'a mut Db::Connection: Executor<'a>,
{
#[allow(clippy::missing_panics_doc)]
pub async fn migrate(mut self, target_version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(target_version)?;
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&db_migrations)?;
let to_apply = self.migrations.iter();
let db_version = db_migrations.len() as _;
let mut conn = self.conn;
conn.execute("BEGIN").await?;
for (idx, mig) in to_apply.enumerate() {
let mig_version = idx as u64 + 1;
if mig_version > target_version {
break;
}
if mig_version <= db_version {
continue;
}
let start = Instant::now();
tracing::info!(
version = mig_version,
name = %mig.name,
"applying migration"
);
let hasher = Sha256::new();
let mut ctx = MigrationContext {
hash_only: true,
ext: self.extensions.clone(),
hasher,
conn,
};
(*mig.up)(&mut ctx)
.await
.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
ctx.hash_only = false;
(*mig.up)(&mut ctx)
.await
.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let execution_time = start.elapsed();
if self.options.verify_checksums {
if let Some(db_mig) = db_migrations.get(idx) {
if db_mig.checksum != checksum {
ctx.conn.execute("ROLLBACK").await?;
return Err(Error::ChecksumMismatch {
version: mig_version,
local_checksum: checksum.clone().into(),
db_checksum: db_mig.checksum.clone(),
});
}
}
}
ctx.conn
.add_migration(
&self.table,
AppliedMigration {
version: mig_version,
name: mig.name.clone(),
checksum: checksum.into(),
execution_time,
},
)
.await?;
conn = ctx.conn;
tracing::info!(
version = mig_version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration applied"
);
}
tracing::info!("committing changes");
conn.execute("COMMIT").await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(target_version.max(db_version)),
})
}
pub async fn migrate_all(self) -> Result<MigrationSummary, Error> {
if self.migrations.is_empty() {
return Ok(MigrationSummary {
new_version: None,
old_version: None,
});
}
let migrations = self.migrations.len() as _;
self.migrate(migrations).await
}
#[allow(clippy::missing_panics_doc)]
pub async fn revert(mut self, target_version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(target_version)?;
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&db_migrations)?;
let to_revert = self
.migrations
.iter()
.enumerate()
.skip_while(|(idx, _)| idx + 1 < target_version as _)
.take_while(|(idx, _)| *idx < db_migrations.len())
.collect::<Vec<_>>()
.into_iter()
.rev();
let mut conn = self.conn;
conn.execute("BEGIN").await?;
for (idx, mig) in to_revert {
let version = idx as u64 + 1;
let start = Instant::now();
tracing::info!(
version,
name = %mig.name,
"reverting migration"
);
let hasher = Sha256::new();
let mut ctx = MigrationContext {
hash_only: false,
ext: self.extensions.clone(),
hasher,
conn,
};
match &mig.down {
Some(down) => {
down(&mut ctx).await.map_err(|error| Error::Revert {
name: mig.name.clone(),
version,
error,
})?;
}
None => {
tracing::warn!(
version,
name = %mig.name,
"no down migration found"
);
}
}
let execution_time = start.elapsed();
ctx.conn.remove_migration(&self.table, version).await?;
conn = ctx.conn;
tracing::info!(
version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration reverted"
);
}
tracing::info!("committing changes");
conn.execute("COMMIT").await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: if target_version == 1 {
None
} else {
Some(target_version - 1)
},
})
}
pub async fn revert_all(self) -> Result<MigrationSummary, Error> {
self.revert(1).await
}
#[allow(clippy::missing_panics_doc)]
pub async fn force_version(mut self, version: u64) -> Result<MigrationSummary, Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
if version == 0 {
self.conn.clear_migrations(&self.table).await?;
return Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: None,
});
}
self.local_migration(version)?;
let migrations = self
.migrations
.iter()
.enumerate()
.take_while(|(idx, _)| *idx < version as usize);
self.conn.clear_migrations(&self.table).await?;
let mut conn = self.conn;
conn.execute("BEGIN").await?;
for (idx, mig) in migrations {
let mig_version = idx as u64 + 1;
let hasher = Sha256::new();
let mut ctx = MigrationContext {
hash_only: true,
ext: self.extensions.clone(),
hasher,
conn,
};
(*mig.up)(&mut ctx)
.await
.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
ctx.conn
.add_migration(
&self.table,
AppliedMigration {
version: mig_version,
name: mig.name.clone(),
checksum: checksum.into(),
execution_time: Duration::default(),
},
)
.await?;
conn = ctx.conn;
tracing::info!(
version = idx + 1,
name = %mig.name,
"migration forcibly set as applied"
);
}
tracing::info!("committing changes");
conn.execute("COMMIT").await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(version),
})
}
#[allow(clippy::missing_panics_doc)]
pub async fn verify(mut self) -> Result<(), Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
self.check_migrations(&migrations)?;
if self.options.verify_checksums {
for res in self.verify_checksums(&migrations).await?.1 {
res?;
}
}
Ok(())
}
#[allow(clippy::missing_panics_doc)]
pub async fn status(mut self) -> Result<Vec<MigrationStatus>, Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
let mut status = Vec::with_capacity(self.migrations.len());
let (migrator, checksums) = self.verify_checksums(&migrations).await?;
self = migrator;
for (idx, pair) in self.migrations.iter().zip_longest(migrations).enumerate() {
let version = idx as u64 + 1;
match pair {
EitherOrBoth::Both(local, db) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
applied: Some(db),
missing_local: false,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
EitherOrBoth::Left(local) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
applied: None,
missing_local: false,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
EitherOrBoth::Right(r) => status.push(MigrationStatus {
version: r.version,
name: r.name.clone().into_owned(),
reversible: false,
applied: Some(r),
missing_local: true,
checksum_ok: checksums.get(idx).map_or(true, Result::is_ok),
}),
}
}
Ok(status)
}
}
impl<Db> Migrator<Db>
where
Db: Database,
Db::Connection: db::Migrations,
for<'a> &'a mut Db::Connection: Executor<'a>,
{
fn local_migration(&self, version: u64) -> Result<&Migration<Db>, Error> {
if version == 0 {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
if self.migrations.is_empty() {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
let idx = version - 1;
self.migrations
.get(idx as usize)
.ok_or(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
})
}
fn check_migrations(&mut self, migrations: &[AppliedMigration<'_>]) -> Result<(), Error> {
if self.migrations.len() < migrations.len() {
return Err(Error::MissingMigrations {
local_count: self.migrations.len(),
db_count: migrations.len(),
});
}
for (idx, (db_migration, local_migration)) in
migrations.iter().zip(self.migrations.iter()).enumerate()
{
let version = idx as u64 + 1;
if self.options.verify_names && db_migration.name != local_migration.name {
return Err(Error::NameMismatch {
version,
local_name: local_migration.name.clone(),
db_name: db_migration.name.to_string().into(),
});
}
}
Ok(())
}
async fn verify_checksums(
mut self,
migrations: &[AppliedMigration<'_>],
) -> Result<(Self, Vec<Result<(), Error>>), Error> {
let mut results = Vec::with_capacity(self.migrations.len());
let local_migrations = self.migrations.iter();
let mut conn = self.conn;
for (idx, mig) in local_migrations.enumerate() {
let mig_version = idx as u64 + 1;
let hasher = Sha256::new();
let mut ctx = MigrationContext {
hash_only: true,
ext: self.extensions.clone(),
hasher,
conn,
};
(*mig.up)(&mut ctx)
.await
.map_err(|error| Error::Migration {
name: mig.name.clone(),
version: mig_version,
error,
})?;
let checksum = std::mem::take(&mut ctx.hasher).finalize().to_vec();
conn = ctx.conn;
if let Some(db_mig) = migrations.get(idx) {
if db_mig.checksum == checksum {
results.push(Ok(()));
} else {
results.push(Err(Error::ChecksumMismatch {
version: mig_version,
local_checksum: checksum.clone().into(),
db_checksum: db_mig.checksum.clone().into_owned().into(),
}));
}
}
}
conn.execute("ROLLBACK").await?;
self.conn = conn;
Ok((self, results))
}
}
#[derive(Debug)]
pub struct MigratorOptions {
pub verify_checksums: bool,
pub verify_names: bool,
}
impl Default for MigratorOptions {
fn default() -> Self {
Self {
verify_checksums: true,
verify_names: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MigrationSummary {
pub old_version: Option<u64>,
pub new_version: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct MigrationStatus {
pub version: u64,
pub name: String,
pub reversible: bool,
pub applied: Option<db::AppliedMigration<'static>>,
pub missing_local: bool,
pub checksum_ok: bool,
}
pub type MigrationError = anyhow::Error;
#[cfg_attr(feature = "cli", derive(clap::ValueEnum))]
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum DatabaseType {
Postgres,
Sqlite,
Any,
}
impl DatabaseType {
fn sqlx_type(self) -> &'static str {
match self {
DatabaseType::Postgres => "Postgres",
DatabaseType::Sqlite => "Sqlite",
DatabaseType::Any => "Any",
}
}
}
impl FromStr for DatabaseType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"postgres" => Ok(Self::Postgres),
"sqlite" => Ok(Self::Sqlite),
"any" => Ok(Self::Any),
db => Err(anyhow::anyhow!("invalid database type `{}`", db)),
}
}
}