#![cfg_attr(feature = "_docs", feature(doc_cfg))]
#![deny(unsafe_code)]
#![warn(clippy::pedantic)]
#![allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_lossless,
clippy::unreadable_literal,
clippy::doc_markdown
)]
use db::{AppliedMigration, Migrations};
use futures_core::future::LocalBoxFuture;
use itertools::{EitherOrBoth, Itertools};
use sqlx::{ConnectOptions, Connection, Database, Pool, Transaction};
use std::{
borrow::Cow,
str::FromStr,
time::{Duration, Instant},
};
use thiserror::Error;
pub mod db;
#[cfg(feature = "cli")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "cli")))]
pub mod cli;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
mod gen;
#[cfg(feature = "generate")]
#[cfg_attr(feature = "_docs", doc(cfg(feature = "generate")))]
pub use gen::generate;
type MigrationFn<DB> = Box<
dyn for<'future> Fn(
&'future mut Transaction<DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>,
>;
pub const DEFAULT_MIGRATIONS_TABLE: &str = "_sqlx_migrations";
pub mod prelude {
pub use super::Error;
pub use super::Migration;
pub use super::MigrationError;
pub use super::MigrationStatus;
pub use super::MigrationSummary;
pub use super::Migrator;
pub use super::MigratorOptions;
}
pub struct Migration<DB: Database> {
name: Cow<'static, str>,
checksum: Cow<'static, [u8]>,
up: MigrationFn<DB>,
down: Option<MigrationFn<DB>>,
}
impl<DB: Database> Migration<DB> {
pub fn new(
name: impl Into<Cow<'static, str>>,
up: impl for<'future> Fn(
&'future mut Transaction<DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
Self {
name: name.into(),
checksum: Cow::default(),
up: Box::new(up),
down: None,
}
}
#[must_use]
pub fn reversible(
mut self,
down: impl for<'future> Fn(
&'future mut Transaction<DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
self.down = Some(Box::new(down));
self
}
#[must_use]
pub fn revertible(
self,
down: impl for<'future> Fn(
&'future mut Transaction<DB>,
) -> LocalBoxFuture<'future, Result<(), MigrationError>>
+ 'static,
) -> Self {
self.reversible(down)
}
#[must_use]
pub fn with_checksum(mut self, checksum: impl Into<Cow<'static, [u8]>>) -> Self {
self.checksum = checksum.into();
self
}
#[must_use]
pub fn name(&self) -> &str {
self.name.as_ref()
}
#[must_use]
pub fn checksum(&self) -> &[u8] {
self.checksum.as_ref()
}
#[must_use]
pub fn is_reversible(&self) -> bool {
self.down.is_some()
}
#[must_use]
pub fn is_revertible(&self) -> bool {
self.down.is_some()
}
}
impl<DB: Database> Eq for Migration<DB> {}
impl<DB: Database> PartialEq for Migration<DB> {
fn eq(&self, other: &Self) -> bool {
self.name == other.name && self.checksum == other.checksum
}
}
pub struct Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
options: MigratorOptions,
conn: DB::Connection,
table: Cow<'static, str>,
migrations: Vec<Migration<DB>>,
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
pub fn new(conn: DB::Connection) -> Self {
Self {
options: MigratorOptions::default(),
conn,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
}
}
pub async fn connect(url: &str) -> Result<Self, sqlx::Error> {
let mut opts: <<DB as Database>::Connection as Connection>::Options = url.parse()?;
opts.disable_statement_logging();
Ok(Self {
options: MigratorOptions::default(),
conn: DB::Connection::connect_with(&opts).await?,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
})
}
pub async fn connect_with(
options: &<DB::Connection as Connection>::Options,
) -> Result<Self, sqlx::Error> {
Ok(Self {
options: MigratorOptions::default(),
conn: DB::Connection::connect_with(options).await?,
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
})
}
pub async fn connect_with_pool(pool: &Pool<DB>) -> Result<Self, sqlx::Error> {
let conn = pool.acquire().await?;
Ok(Self {
options: MigratorOptions::default(),
conn: conn.detach(),
table: Cow::Borrowed(DEFAULT_MIGRATIONS_TABLE),
migrations: Vec::default(),
})
}
pub fn set_migrations_table(&mut self, name: impl AsRef<str>) {
self.table = Cow::Owned(name.as_ref().to_string());
}
pub fn add_migrations(&mut self, migrations: impl IntoIterator<Item = Migration<DB>>) {
self.migrations.extend(migrations.into_iter());
}
pub fn set_options(&mut self, options: MigratorOptions) {
self.options = options;
}
pub fn local_migrations(&self) -> &[Migration<DB>] {
&self.migrations
}
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
pub async fn migrate(&mut self, version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(version)?;
self.check_migrations().await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
let to_apply = self
.migrations
.iter()
.enumerate()
.skip_while(|(idx, _)| *idx < db_migrations.len())
.take_while(|(idx, _)| *idx < version as _);
let mut tx = self.conn.begin().await?;
let version = version.max(db_migrations.len() as _);
for (idx, mig) in to_apply {
let version = idx as u64 + 1;
let start = Instant::now();
tracing::info!(
version,
name = %mig.name,
"applying migration"
);
(*mig.up)(&mut tx).await.map_err(|error| Error::Migration {
name: mig.name.clone(),
version,
error,
})?;
let execution_time = Instant::now() - start;
DB::Connection::add_migration(
&self.table,
AppliedMigration {
version,
name: mig.name.clone(),
checksum: mig.checksum.clone(),
execution_time,
},
&mut tx,
)
.await?;
tracing::info!(
version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration applied"
);
}
tracing::info!("committing changes");
tx.commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(version),
})
}
pub async fn migrate_all(&mut self) -> Result<MigrationSummary, Error> {
self.check_migrations().await?;
if self.migrations.is_empty() {
return Ok(MigrationSummary {
new_version: None,
old_version: None,
});
}
self.migrate(self.migrations.len() as _).await
}
pub async fn revert(&mut self, version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(version)?;
self.check_migrations().await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
let to_revert = self
.migrations
.iter()
.enumerate()
.skip_while(|(idx, _)| idx + 1 < version as _)
.take_while(|(idx, _)| *idx < db_migrations.len())
.collect::<Vec<_>>()
.into_iter()
.rev();
let mut tx = self.conn.begin().await?;
for (idx, mig) in to_revert {
let version = idx as u64 + 1;
let start = Instant::now();
tracing::info!(
version,
name = %mig.name,
"reverting migration"
);
match &mig.down {
Some(down) => {
down(&mut tx).await.map_err(|error| Error::Revert {
name: mig.name.clone(),
version,
error,
})?;
}
None => {
tracing::warn!(
version,
name = %mig.name,
"no down migration found"
);
}
}
let execution_time = Instant::now() - start;
DB::Connection::remove_migration(&self.table, version, &mut tx).await?;
tracing::info!(
version,
name = %mig.name,
execution_time = %humantime::Duration::from(execution_time),
"migration reverted"
);
}
tracing::info!("committing changes");
tx.commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: if version == 1 {
None
} else {
Some(version - 1)
},
})
}
pub async fn revert_all(&mut self) -> Result<MigrationSummary, Error> {
self.check_migrations().await?;
if self.migrations.is_empty() {
return Ok(MigrationSummary {
new_version: None,
old_version: None,
});
}
self.revert(1).await
}
pub async fn force_version(&mut self, version: u64) -> Result<MigrationSummary, Error> {
self.local_migration(version)?;
self.conn.ensure_migrations_table(&self.table).await?;
let db_migrations = self.conn.list_migrations(&self.table).await?;
let migrations = self
.migrations
.iter()
.enumerate()
.take_while(|(idx, _)| *idx < version as usize);
self.conn.clear_migrations(&self.table).await?;
let mut tx = self.conn.begin().await?;
for (idx, mig) in migrations {
DB::Connection::add_migration(
&self.table,
AppliedMigration {
version: idx as u64 + 1,
name: mig.name.clone(),
checksum: mig.checksum.clone(),
execution_time: Duration::default(),
},
&mut tx,
)
.await?;
tracing::info!(
version = idx + 1,
name = %mig.name,
"migration forcibly set as applied"
);
}
tracing::info!("committing changes");
tx.commit().await?;
Ok(MigrationSummary {
old_version: if db_migrations.is_empty() {
None
} else {
Some(db_migrations.len() as _)
},
new_version: Some(version),
})
}
pub async fn verify(&mut self) -> Result<(), Error> {
self.check_migrations().await
}
pub async fn status(&mut self) -> Result<Vec<MigrationStatus>, Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
let mut status = Vec::with_capacity(self.migrations.len());
for (idx, pair) in self
.migrations
.iter()
.zip_longest(migrations.into_iter())
.enumerate()
{
let version = idx as u64 + 1;
match pair {
EitherOrBoth::Both(local, db) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
checksum: local.checksum.clone().into_owned(),
applied: Some(db),
missing_local: false,
}),
EitherOrBoth::Left(local) => status.push(MigrationStatus {
version,
name: local.name.clone().into_owned(),
reversible: local.is_reversible(),
checksum: local.checksum.clone().into_owned(),
applied: None,
missing_local: false,
}),
EitherOrBoth::Right(r) => status.push(MigrationStatus {
version: r.version,
name: r.name.clone().into_owned(),
checksum: Vec::default(),
reversible: false,
applied: Some(r),
missing_local: true,
}),
}
}
Ok(status)
}
}
impl<DB> Migrator<DB>
where
DB: Database,
DB::Connection: db::Migrations,
{
fn local_migration(&self, version: u64) -> Result<&Migration<DB>, Error> {
if version == 0 {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
if self.migrations.is_empty() {
return Err(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
});
}
let idx = version - 1;
self.migrations
.get(idx as usize)
.ok_or(Error::InvalidVersion {
version,
min_version: 1,
max_version: self.migrations.len() as _,
})
}
async fn check_migrations(&mut self) -> Result<(), Error> {
self.conn.ensure_migrations_table(&self.table).await?;
let migrations = self.conn.list_migrations(&self.table).await?;
if self.migrations.len() < migrations.len() {
return Err(Error::MissingMigrations {
local_count: self.migrations.len(),
db_count: migrations.len(),
});
}
for (idx, (db_migration, local_migration)) in migrations
.into_iter()
.zip(self.migrations.iter())
.enumerate()
{
let version = idx as u64 + 1;
if self.options.verify_names && db_migration.name != local_migration.name {
return Err(Error::NameMismatch {
version,
local_name: local_migration.name.clone(),
db_name: db_migration.name.clone(),
});
}
if self.options.verify_checksums && db_migration.checksum != local_migration.checksum {
return Err(Error::ChecksumMismatch {
version,
local_checksum: local_migration.checksum.clone(),
db_checksum: db_migration.checksum.clone(),
});
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct MigratorOptions {
pub verify_checksums: bool,
pub verify_names: bool,
}
impl Default for MigratorOptions {
fn default() -> Self {
Self {
verify_checksums: true,
verify_names: true,
}
}
}
#[derive(Debug, Clone)]
pub struct MigrationSummary {
pub old_version: Option<u64>,
pub new_version: Option<u64>,
}
#[derive(Debug, Clone)]
pub struct MigrationStatus {
pub version: u64,
pub name: String,
pub reversible: bool,
pub checksum: Vec<u8>,
pub applied: Option<db::AppliedMigration<'static>>,
pub missing_local: bool,
}
#[derive(Debug, Error)]
pub enum Error {
#[error("{0}")]
Database(sqlx::Error),
#[error(
"invalid version specified: {version} (available versions: {min_version}-{max_version})"
)]
InvalidVersion {
version: u64,
min_version: u64,
max_version: u64,
},
#[error("there were no local migrations found")]
NoMigrations,
#[error("missing migrations ({local_count} local, but {db_count} already applied)")]
MissingMigrations { local_count: usize, db_count: usize },
#[error("error applying migration: {error}")]
Migration {
name: Cow<'static, str>,
version: u64,
error: MigrationError,
},
#[error("error reverting migration: {error}")]
Revert {
name: Cow<'static, str>,
version: u64,
error: MigrationError,
},
#[error("expected migration {version} to be {local_name} but it was applied as {db_name}")]
NameMismatch {
version: u64,
local_name: Cow<'static, str>,
db_name: Cow<'static, str>,
},
#[error("invalid checksum for migration {version}")]
ChecksumMismatch {
version: u64,
local_checksum: Cow<'static, [u8]>,
db_checksum: Cow<'static, [u8]>,
},
}
impl From<sqlx::Error> for Error {
fn from(err: sqlx::Error) -> Self {
Self::Database(err)
}
}
pub type MigrationError = anyhow::Error;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum DatabaseType {
Postgres,
Sqlite,
Any,
}
impl DatabaseType {
fn sqlx_type(self) -> &'static str {
match self {
DatabaseType::Postgres => "Postgres",
DatabaseType::Sqlite => "Sqlite",
DatabaseType::Any => "Any",
}
}
}
impl FromStr for DatabaseType {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"postgres" => Ok(Self::Postgres),
"sqlite" => Ok(Self::Sqlite),
"any" => Ok(Self::Any),
db => Err(anyhow::anyhow!("invalid database type `{}`", db)),
}
}
}