#![forbid(unsafe_code)]
#![warn(missing_docs)]
use log::{debug, info, trace, warn};
use rusqlite::{Connection, ToSql};
mod errors;
#[cfg(test)]
mod tests;
pub use errors::{Error, MigrationDefinitionError, Result, SchemaVersionError};
use std::{
cmp::{self, Ordering},
fmt,
num::NonZeroUsize,
};
#[derive(Debug, PartialEq, Clone)]
pub struct M<'u> {
up: &'u str,
down: Option<&'u str>,
}
impl<'u> M<'u> {
pub fn up(sql: &'u str) -> Self {
Self {
up: sql,
down: None,
}
}
pub fn down(mut self, sql: &'u str) -> Self {
self.down = Some(sql);
self
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum SchemaVersion {
NoneSet,
Inside(NonZeroUsize),
Outside(NonZeroUsize),
}
impl From<&SchemaVersion> for usize {
fn from(schema_version: &SchemaVersion) -> usize {
match schema_version {
SchemaVersion::NoneSet => 0,
SchemaVersion::Inside(v) | SchemaVersion::Outside(v) => From::from(*v),
}
}
}
impl From<SchemaVersion> for usize {
fn from(schema_version: SchemaVersion) -> Self {
From::from(&schema_version)
}
}
impl fmt::Display for SchemaVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SchemaVersion::NoneSet => write!(f, "0 (no version set)"),
SchemaVersion::Inside(v) => write!(f, "{} (inside)", v),
SchemaVersion::Outside(v) => write!(f, "{} (outside)", v),
}
}
}
impl cmp::PartialOrd for SchemaVersion {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
let self_usize: usize = self.into();
let other_usize: usize = other.into();
self_usize.partial_cmp(&other_usize)
}
}
#[derive(Debug, PartialEq, Clone)]
pub struct Migrations<'m> {
ms: Vec<M<'m>>,
}
impl<'m> Migrations<'m> {
pub fn new(ms: Vec<M<'m>>) -> Self {
Self { ms }
}
pub fn new_iter<I: IntoIterator<Item = M<'m>>>(ms: I) -> Self {
use std::iter::FromIterator;
Self::new(Vec::from_iter(ms))
}
fn db_version_to_schema(&self, db_version: usize) -> SchemaVersion {
match db_version {
0 => SchemaVersion::NoneSet,
v if v > 0 && v <= self.ms.len() => SchemaVersion::Inside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
v => SchemaVersion::Outside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
}
}
pub fn current_version(&self, conn: &Connection) -> Result<SchemaVersion> {
user_version(conn)
.map(|v| self.db_version_to_schema(v))
.map_err(|e| e.into())
}
fn goto_up(
&self,
conn: &mut Connection,
current_version: usize,
target_version: usize,
) -> Result<()> {
debug_assert!(current_version <= target_version);
debug_assert!(target_version <= self.ms.len());
trace!("start migration transaction");
let tx = conn.transaction()?;
for v in current_version..target_version {
let m = &self.ms[v];
debug!("Running: {}", m.up);
tx.execute_batch(m.up)
.map_err(|e| Error::with_sql(e, m.up))?;
}
set_user_version(&tx, target_version)?;
tx.commit()?;
trace!("commited migration transaction");
Ok(())
}
fn goto_down(
&self,
conn: &mut Connection,
current_version: usize,
target_version: usize,
) -> Result<()> {
debug_assert!(current_version >= target_version);
debug_assert!(target_version <= self.ms.len());
if let Some((i, bad_m)) = self
.ms
.iter()
.enumerate()
.skip(target_version)
.take(current_version - target_version)
.find(|(_, m)| m.down.is_none())
{
warn!("Cannot revert: {:?}", bad_m);
return Err(Error::MigrationDefinition(
MigrationDefinitionError::DownNotDefined { migration_index: i },
));
}
trace!("start migration transaction");
let tx = conn.transaction()?;
for v in (target_version..current_version).rev() {
let m = &self.ms[v];
if let Some(ref down) = m.down {
debug!("Running: {}", down);
tx.execute_batch(down)
.map_err(|e| Error::with_sql(e, down))?;
} else {
unreachable!();
}
}
set_user_version(&tx, target_version)?;
tx.commit()?;
trace!("committed migration transaction");
Ok(())
}
fn goto(&self, conn: &mut Connection, target_db_version: usize) -> Result<()> {
let current_version = user_version(conn)?;
let res = match target_db_version.cmp(¤t_version) {
Ordering::Less => {
if current_version > self.ms.len() {
return Err(Error::MigrationDefinition(
MigrationDefinitionError::DatabaseTooFarAhead,
));
}
debug!(
"rollback to older version requested, target_db_version: {}, current_version: {}",
target_db_version, current_version
);
self.goto_down(conn, current_version, target_db_version)
}
Ordering::Equal => {
debug!("no migration to run, db already up to date");
return Ok(()); }
Ordering::Greater => {
debug!(
"some migrations to run, target_db_version: {}, current_version: {}",
target_db_version, current_version
);
self.goto_up(conn, current_version, target_db_version)
}
};
if res.is_ok() {
info!("Database migrated to version {}", target_db_version);
}
res
}
fn max_schema_version(&self) -> SchemaVersion {
match self.ms.len() {
0 => SchemaVersion::NoneSet,
v => SchemaVersion::Inside(
NonZeroUsize::new(v).expect("schema version should not be equal to 0"),
),
}
}
pub fn to_latest(&self, conn: &mut Connection) -> Result<()> {
let v_max = self.max_schema_version();
match v_max {
SchemaVersion::NoneSet => {
warn!("no migration defined");
Err(Error::MigrationDefinition(
MigrationDefinitionError::NoMigrationsDefined,
))
}
SchemaVersion::Inside(_) => {
debug!("some migrations defined, try to migrate");
self.goto(conn, v_max.into())
}
SchemaVersion::Outside(_) => unreachable!(),
}
}
pub fn to_version(&self, conn: &mut Connection, version: usize) -> Result<()> {
let target_version: SchemaVersion = self.db_version_to_schema(version);
let v_max = self.max_schema_version();
match v_max {
SchemaVersion::NoneSet => {
warn!("no migrations defined");
Err(Error::MigrationDefinition(
MigrationDefinitionError::NoMigrationsDefined,
))
}
SchemaVersion::Inside(_) => {
if target_version > v_max {
warn!("specified version is higher than the max supported version");
return Err(Error::SpecifiedSchemaVersion(
SchemaVersionError::TargetVersionOutOfRange {
specified: target_version,
highest: v_max,
},
));
}
self.goto(conn, target_version.into())
}
SchemaVersion::Outside(_) => unreachable!(),
}
}
pub fn validate(&self) -> Result<()> {
let mut conn = Connection::open_in_memory()?;
self.to_latest(&mut conn)
}
}
fn user_version(conn: &Connection) -> Result<usize, rusqlite::Error> {
#[allow(deprecated)] conn.query_row::<_, &[&dyn ToSql], _>("PRAGMA user_version", &[], |row| row.get(0))
.map(|v: i64| v as usize)
}
fn set_user_version(conn: &Connection, v: usize) -> Result<()> {
trace!("set user version to: {}", v);
let v = v as u32;
conn.pragma_update(None, "user_version", &v)
.map_err(|e| Error::RusqliteError {
query: format!("PRAGMA user_version = {}; -- Approximate query", v),
err: e,
})
}