midas 0.7.6

Do painless migration 🦀
Documentation
use anyhow::Context;
use indoc::indoc;
use postgres::tls::{
  MakeTlsConnect,
  TlsConnect,
};
use postgres::{
  Client,
  NoTls,
  Socket,
};
use url::Url;

use super::{
  AnyhowResult,
  Driver as SequelDriver,
  VecSerial,
};

/// The Postgres struct definition
pub struct Postgres {
  /// The Postgres client
  client: Client,
  /// The database name
  database_name: String,
}

/// Implement the Postgres struct
impl Postgres {
  /// Create a new instance of Postgres
  pub fn new(database_url: &str) -> AnyhowResult<Self> {
    Self::new_tls(database_url, NoTls)
  }

  /// Create a new instance of Postgres with TLS
  pub fn new_tls<T>(database_url: &str, tls_mode: T) -> AnyhowResult<Self>
  where
    T: MakeTlsConnect<Socket> + 'static + Send,
    T::TlsConnect: Send,
    T::Stream: Send,
    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
  {
    // Get the database name from the URL
    let url = Url::parse(database_url)?;
    let database_name = url
      .path_segments()
      .and_then(|s| s.last())
      .context("Database name not found")?;

    // Open the connection
    let client = Client::connect(url.as_str(), tls_mode)?;

    // Create a new instance of Postgres
    let mut db = Postgres {
      client,
      database_name: database_name.into(),
    };

    // Ensure the midas schema
    db.ensure_midas_schema()?;
    Ok(db)
  }
}

impl SequelDriver for Postgres {
  fn ensure_midas_schema(&mut self) -> AnyhowResult<()> {
    self.client.execute("create schema if not exists midas", &[])?;
    self.client.execute("grant all on schema midas to public", &[])?;
    let payload = indoc! {"
      create table if not exists midas.__schema_migrations (
        id bigint generated by default as identity primary key,
        migration bigint
      )
    "};
    self.client.execute(payload, &[])?;
    Ok(())
  }

  fn drop_migration_table(&mut self) -> AnyhowResult<()> {
    let payload = "drop table midas.__schema_migrations";
    self.client.execute(payload, &[])?;
    Ok(())
  }

  fn drop_database(&mut self, db_name: &str) -> AnyhowResult<()> {
    let payload = format! {"drop database if exists {db_name}"};
    self.client.execute(&payload, &[])?;

    let payload = format! {"create database {db_name}"};
    self.client.execute(&payload, &[])?;
    Ok(())
  }

  fn count_migrations(&mut self) -> AnyhowResult<i64> {
    log::trace!("Retrieving migrations count");
    let payload = "select count(*) as count from midas.__schema_migrations";
    let row = self.client.query_one(payload, &[])?;
    let result = row.get::<_, i64>(0);
    Ok(result)
  }

  fn get_completed_migrations(&mut self) -> AnyhowResult<VecSerial> {
    log::trace!("Retrieving all completed migrations");
    let payload = "select migration from midas.__schema_migrations order by id asc";
    let it = self.client.query(payload, &[])?;
    let result = it.iter().map(|r| r.get("migration")).collect::<_>();
    Ok(result)
  }

  fn get_last_completed_migration(&mut self) -> AnyhowResult<i64> {
    log::trace!("Checking and retrieving the last migration stored on migrations table");
    let payload = "select migration from midas.__schema_migrations order by id desc limit 1";
    let result = self
      .client
      .query(payload, &[])
      .with_context(|| "Failed to get last completed migration".to_string())?;

    if result.is_empty() {
      Ok(-1)
    } else {
      Ok(result[0].get(0))
    }
  }

  fn add_completed_migration(&mut self, migration_number: i64) -> AnyhowResult<()> {
    log::trace!("Adding migration to migrations table");
    let payload = "insert into midas.__schema_migrations (migration) values ($1)";
    self
      .client
      .execute(payload, &[&migration_number])
      .with_context(|| "Failed to add completed migration".to_string())?;
    Ok(())
  }

  fn delete_completed_migration(&mut self, migration_number: i64) -> AnyhowResult<()> {
    log::trace!("Removing a migration in the migrations table");
    let payload = "delete from midas.__schema_migrations where migration = $1";
    self
      .client
      .execute(payload, &[&migration_number])
      .with_context(|| "Failed to delete completed migration".to_string())?;
    Ok(())
  }

  fn delete_last_completed_migration(&mut self) -> AnyhowResult<()> {
    let payload =
      "delete from midas.__schema_migrations where id=(select max(id) from midas.__schema_migrations);";
    self
      .client
      .execute(payload, &[])
      .with_context(|| "Failed to delete last completed migration".to_string())?;
    Ok(())
  }

  fn migrate(&mut self, query: &str, migration_number: i64) -> AnyhowResult<()> {
    self
      .client
      .simple_query(query)
      .with_context(|| format!("Failed to execute migration - {migration_number}"))?;
    Ok(())
  }

  fn db_name(&self) -> &str {
    &self.database_name
  }
}