migrio 1.1.0

A drop-in database migration library for PostgreSQL
Documentation
use std::{
    collections::{hash_map::DefaultHasher, HashMap},
    hash::{Hash, Hasher},
};

use sha2::{Digest, Sha384};
use tokio_postgres::Client;

use crate::{errors::MigrationError, step::MigrateStep};

/// Generate a lock id for the current database
fn generate_lock_id(db_name: &str) -> i64 {
    let mut hasher = DefaultHasher::new();
    db_name.hash(&mut hasher);
    hasher.finish() as i64
}

/// Get the current database name
async fn current_database(client: &Client) -> Result<String, MigrationError> {
    let row = client.query_one("SELECT current_database()", &[]).await?;
    let db_name: String = row.get(0);
    Ok(db_name)
}

/// Global lock for the database to execute migrations
async fn lock_database(client: &Client) -> Result<(), MigrationError> {
    let db_name = current_database(client).await?;
    let lock_id = generate_lock_id(&db_name);
    let _ = client
        .query("SELECT pg_advisory_lock($1)", &[&lock_id])
        .await?;
    Ok(())
}

/// Release the global lock for the database
async fn unlock_database(client: &Client) -> Result<bool, MigrationError> {
    let db_name = current_database(client).await?;
    let lock_id = generate_lock_id(&db_name);
    let row = client
        .query_one("SELECT pg_advisory_unlock($1)", &[&lock_id])
        .await?;
    Ok(row.get(0))
}

/// Create the migration table if it does not exist
async fn set_migration_table(client: &Client) -> Result<(), MigrationError> {
    let _ = client
        .query(
            r"
CREATE TABLE IF NOT EXISTS _migrio_changelog (
  version BIGINT PRIMARY KEY,
  description TEXT NOT NULL,
  date_executed TIMESTAMPTZ NOT NULL DEFAULT now(),
  checksum BYTEA NOT NULL,
  elapsed_time BIGINT NOT NULL
);
            ",
            &[],
        )
        .await?;
    Ok(())
}

struct MigrationEntry {
    version: i64,
    description: String,
    step: MigrateStep,
    sql: String,
    checksum: Vec<u8>,
}

/// Migration object to handle the migration process
pub struct Migration {
    migrations: Vec<MigrationEntry>,
}

impl Migration {
    /// Create a new migration from a directory of SQL files
    ///
    /// The input path must be a directory containing SQL files. Non-SQL files
    /// will be ignored. The SQL file names must have the following format:
    ///
    /// - `{VERSION}_{DESCRIPTION}.sql` for basic migrations
    /// - `{VERSION}_{DESCRIPTION}.{up|down}.sql` for reversible migrations
    ///
    /// `{VERSION}` must be parsable as an `i64` and represents the order of the
    /// migration. `{DESCRIPTION}` is a human-readable description of the
    /// migration.
    ///
    /// The `up` or `down` suffix is optional and represents the direction of a
    /// reversible migration. Each `up` migration must have a corresponding
    /// `down` migration. The `up` migration will be applied and the `down`
    /// migration can be used to revert the migration.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use migrio::Migration;
    ///
    /// let migration = Migration::new("migrations").unwrap();
    /// ```
    ///
    /// # Errors
    ///
    /// Returns an error if the directory or the files are invalid.
    pub fn new<P>(path: P) -> Result<Self, MigrationError>
    where
        P: AsRef<std::path::Path>,
    {
        let mut migrations = Vec::new();

        for res in std::fs::read_dir(path.as_ref())? {
            let entry = res?;
            let path = entry.path();

            if !path.metadata()?.is_file() {
                continue;
            }

            if path.extension().map_or(true, |ext| ext != "sql") {
                continue;
            }

            let file_name = entry.file_name();
            let file_name = file_name.to_str().ok_or(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("Invalid file name: {file_name:?}"),
            ))?;

            let (prefix, suffix) = file_name.split_once('_').ok_or(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("File name must contain version and description: {file_name:?}"),
            ))?;

            let version = prefix
                .parse::<i64>()
                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
            let step = MigrateStep::from_filename(suffix);
            let description = suffix.trim_end_matches(step.extension()).replace('_', " ");
            let sql = std::fs::read_to_string(&path)?;
            let checksum = Vec::from(Sha384::digest(sql.as_bytes()).as_slice());

            migrations.push(MigrationEntry {
                version,
                description,
                step,
                sql,
                checksum,
            });
        }

        migrations.sort_by_key(|m| m.version);

        Ok(Self { migrations })
    }

    /// Get the list of migrations that have been applied as a `HashMap`
    async fn migrated(&self, client: &Client) -> Result<HashMap<i64, Vec<u8>>, MigrationError> {
        let rows = client
            .query(
                "SELECT version, checksum FROM _migrio_changelog ORDER BY version",
                &[],
            )
            .await?
            .into_iter()
            .map(|row| {
                let version: i64 = row.get(0);
                let checksum: Vec<u8> = row.get(1);
                (version, checksum)
            })
            .collect();
        Ok(rows)
    }

    /// Apply a specific migration and record it in the changelog
    async fn apply(
        &self,
        client: &mut Client,
        migration: &MigrationEntry,
    ) -> Result<(), MigrationError> {
        let start = std::time::Instant::now();

        let transaction = client.transaction().await?;
        let () = transaction.batch_execute(&migration.sql).await?;
        let _ = transaction
            .query(
                r"
INSERT INTO _migrio_changelog (version, description, checksum, elapsed_time)
VALUES ($1, $2, $3, -1)
            ",
                &[
                    &migration.version,
                    &migration.description,
                    &migration.checksum,
                ],
            )
            .await?;

        let elapsed = start.elapsed().as_nanos() as i64;

        let _ = transaction
            .query(
                r"
UPDATE _migrio_changelog
SET elapsed_time = $1
WHERE version = $2
            ",
                &[&elapsed, &migration.version],
            )
            .await?;

        transaction.commit().await?;

        Ok(())
    }

    /// Revert a specific migration and remove it from the changelog
    async fn revert(
        &self,
        client: &mut Client,
        migration: &MigrationEntry,
    ) -> Result<(), MigrationError> {
        let transaction = client.transaction().await?;
        let () = transaction.batch_execute(&migration.sql).await?;
        let _ = transaction
            .query(
                r"
DELETE FROM _migrio_changelog WHERE version = $1
            ",
                &[&migration.version],
            )
            .await?;

        transaction.commit().await?;

        Ok(())
    }

    /// Run pending migrations against the database
    ///
    /// This method runs all pending migrations that have not been applied to
    /// the database. Migrations are applied in order based on the version
    /// number and `down` migrations are skipped. The migration process records
    /// each migration in the changelog table to prevent reapplying migrations.
    /// The database is locked during the migration process to prevent issues.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use migrio::Migration;
    /// # use migrio::MigrationError;
    /// use tokio_postgres::NoTls;
    ///
    /// # async fn run_migrations() -> Result<(), MigrationError> {
    /// let (mut client, connection) = tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
    /// // Handle `connection` by potentially spawning it on a runtime
    /// let migration = Migration::new("migrations")?;
    /// migration.run(&mut client).await?;
    /// # Ok(())
    /// # }
    /// ```
    ///
    /// # Errors
    ///
    /// Returns an error if a migration checksum does not match or if there is
    /// an issue applying the migration.
    pub async fn run(&self, client: &mut Client) -> Result<(), MigrationError> {
        lock_database(client).await?;
        set_migration_table(client).await?;
        let migrated = self.migrated(client).await?;
        for migration in &self.migrations {
            if migration.step.is_down_migration() {
                continue;
            }
            match migrated.get(&migration.version) {
                Some(checksum) => {
                    if checksum != &migration.checksum {
                        let version = migration.version;
                        return Err(std::io::Error::new(
                            std::io::ErrorKind::InvalidData,
                            format!("Migration checksum mismatch for version {version}"),
                        )
                        .into());
                    }
                }
                None => {
                    self.apply(client, migration).await?;
                }
            }
        }
        unlock_database(client).await?;
        Ok(())
    }

    /// Undo applied migrations up to a specific version
    ///
    /// This method reverts all applied migrations up to a specific version.
    /// Reverts are executed in reverse order, starting from the latest applied
    /// migration and running all `down` migrations up to the specified version.
    /// The database is locked during the migration process to prevent issues.
    ///
    /// # Examples
    ///
    /// ```no_run
    /// use migrio::Migration;
    /// # use migrio::MigrationError;
    /// use tokio_postgres::NoTls;
    ///
    /// # async fn undo_migrations() -> Result<(), MigrationError> {
    /// let (mut client, connection) = tokio_postgres::connect("host=localhost user=postgres", NoTls).await?;
    /// // Handle `connection` by potentially spawning it on a runtime
    /// let migration = Migration::new("migrations")?;
    /// migration.undo(&mut client, 1).await?;
    /// # Ok(())
    /// # }
    /// ```
    ///
    /// # Errors
    ///
    /// Returns an error if there is an issue reverting a migration.
    pub async fn undo(&self, client: &mut Client, version: i64) -> Result<(), MigrationError> {
        lock_database(client).await?;
        set_migration_table(client).await?;
        let migrated = self.migrated(client).await?;
        for migration in self.migrations.iter().rev().filter(|m| {
            m.step.is_down_migration() && migrated.contains_key(&m.version) && m.version > version
        }) {
            self.revert(client, migration).await?;
        }
        unlock_database(client).await?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use tokio_postgres::NoTls;

    use super::*;

    #[test]
    fn test_generate_lock_id() {
        let lock_id = generate_lock_id("test");
        assert_ne!(lock_id, 0);
    }

    async fn setup_connect() -> Client {
        let conn = "postgresql://postgres:postgres@127.0.0.1/postgres";
        let (client, connection) = tokio_postgres::connect(conn, NoTls).await.unwrap();
        tokio::spawn(async move { connection.await.unwrap() });
        client
    }

    async fn corrupt_changelog(client: &Client) {
        let _ = client
            .query(
                "UPDATE _migrio_changelog SET checksum = '\\000' WHERE version = 2",
                &[],
            )
            .await
            .unwrap();
    }

    #[ignore = "requires a database connection"]
    #[tokio::test]
    async fn test_current_database() {
        let client = setup_connect().await;
        let db_name = current_database(&client).await.unwrap();
        assert_eq!(db_name, "postgres");
    }

    #[ignore = "requires a database connection"]
    #[tokio::test]
    async fn test_lock_unlock_database() {
        let client = setup_connect().await;
        lock_database(&client).await.unwrap();
        let ok = unlock_database(&client).await.unwrap();
        assert!(ok);
    }

    #[test]
    fn test_migration_new() {
        let path = "tests/migrations";
        let migration = Migration::new(path).unwrap();
        assert_eq!(migration.migrations.len(), 3);
    }

    #[ignore = "requires a database connection"]
    #[tokio::test]
    async fn test_migration_run_undo() {
        let mut client = setup_connect().await;
        let path = "tests/migrations";
        let migration = Migration::new(path).unwrap();
        migration.run(&mut client).await.unwrap();
        // purposely corrupt the changelog
        corrupt_changelog(&client).await;
        assert!(migration.run(&mut client).await.is_err());
        // undo migrations
        migration.undo(&mut client, 1).await.unwrap();
    }
}