use rusqlite::{Connection, Error as SqlError, Transaction};
use thiserror::Error;
pub fn migrate(
conn: &mut Connection,
migrations: &[fn(&Transaction) -> Result<(), SqlError>],
) -> Result<usize, MigrationError> {
let current_version = get_user_version(conn)?;
let mut last_successful_version = current_version;
for (index, migration) in migrations.iter().enumerate() {
let migration_version = index + 1;
if migration_version > last_successful_version {
let tx = conn.transaction()?;
match migration(&tx) {
Ok(()) => {
set_user_version(&tx, migration_version)?;
tx.commit()?;
last_successful_version = migration_version;
}
Err(error) => {
tx.rollback()?;
return Err(MigrationError {
version: last_successful_version,
error: error,
});
}
}
}
}
Ok(last_successful_version)
}
#[derive(Debug, Error)]
#[error("Error performing migration. Rolled back to version {version}. Error: {error}")]
pub struct MigrationError {
pub version: usize,
pub error: SqlError,
}
impl From<SqlError> for MigrationError {
fn from(error: SqlError) -> Self {
MigrationError { version: 0, error }
}
}
pub fn get_user_version(conn: &Connection) -> Result<usize, SqlError> {
let version: i32 = conn.pragma_query_value(None, "user_version", |row| row.get(0))?;
Ok(version as usize)
}
fn set_user_version(tx: &Transaction, version: usize) -> Result<(), SqlError> {
tx.pragma_update(None, "user_version", version)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_db() -> Connection {
Connection::open_in_memory().unwrap()
}
fn migration1(tx: &Transaction) -> Result<(), SqlError> {
tx.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)", [])?;
Ok(())
}
fn migration2(tx: &Transaction) -> Result<(), SqlError> {
tx.execute("ALTER TABLE test ADD COLUMN name TEXT", [])?;
Ok(())
}
fn failing_migration(_tx: &Transaction) -> Result<(), SqlError> {
Err(SqlError::SqliteFailure(
rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_CONSTRAINT),
Some("Test error".to_string()),
))
}
#[test]
fn test_empty_migrations() {
let mut conn = create_test_db();
let migrations: &[fn(&Transaction) -> Result<(), SqlError>] = &[];
let result = migrate(&mut conn, migrations).unwrap();
assert_eq!(result, 0);
assert_eq!(get_user_version(&conn).unwrap(), 0);
}
#[test]
fn test_single_migration() {
let mut conn = create_test_db();
let migrations: &[fn(&Transaction) -> Result<(), SqlError>] = &[migration1];
let result = migrate(&mut conn, migrations).unwrap();
assert_eq!(result, 1);
assert_eq!(get_user_version(&conn).unwrap(), 1);
}
#[test]
fn test_multiple_migrations() {
let mut conn = create_test_db();
let migrations: &[fn(&Transaction) -> Result<(), SqlError>] = &[migration1, migration2];
let result = migrate(&mut conn, migrations).unwrap();
assert_eq!(result, 2);
assert_eq!(get_user_version(&conn).unwrap(), 2);
}
#[test]
fn test_migration_failure_rollback() {
let mut conn = create_test_db();
let migrations: &[fn(&Transaction) -> Result<(), SqlError>] =
&[migration1, failing_migration];
let error =
migrate(&mut conn, migrations).expect_err("Migrate should have returned an error");
assert_eq!(error.version, 1);
assert_eq!(get_user_version(&conn).unwrap(), 1);
}
#[test]
fn test_idempotent_migrations() {
let mut conn = create_test_db();
let migrations: &[fn(&Transaction) -> Result<(), SqlError>] = &[migration1];
let result1 = migrate(&mut conn, migrations).unwrap();
assert_eq!(result1, 1);
let result2 = migrate(&mut conn, migrations).unwrap();
assert_eq!(result2, 1);
assert_eq!(get_user_version(&conn).unwrap(), 1);
}
}