teststack 0.1.0

Test utilities to run testcontainers
Documentation
use crate::TestContainer;

/// Configuration for the database name.
pub enum DbName {
    /// A random database name generated by the library.
    Random,
    /// The database name to use specified by the user.
    Static(String),
    /// Default database name defined by the database image.
    Default,
}

/// Configuration for the database connection.
pub struct DbConf {
    /// Database URL to connect to the database.
    pub url: String,
    /// Database name to use.
    pub db_name: String,
}

/// A test container for a database with the given configuration.
pub type DbContainer = TestContainer<DbConf>;

#[cfg(feature = "postgres")]
pub mod postgres {
    use crate::{Init, TestContainer};

    use super::*;
    use sqlx::PgPool;
    use testcontainers::core::ContainerPort;
    use testcontainers_modules::postgres::Postgres;

    /// Run a PostgreSQL container with the given database name.
    pub async fn run(db_name: DbName) -> DbContainer {
        let container = crate::container(Postgres::default()).await;
        let port = container
            .get_host_port_ipv4(ContainerPort::Tcp(5432))
            .await
            .expect("failed to get host port");
        let conf = super::setup_database::<sqlx::postgres::Postgres, _>(
            db_url, port, &db_name, "postgres",
        )
        .await;
        DbContainer { container, conf }
    }

    fn db_url(port: u16, db_name: &str) -> String {
        format!("postgres://postgres:postgres@127.0.0.1:{port}/{db_name}")
    }

    impl Init<PgPool> for TestContainer<DbConf> {
        async fn init(self) -> PgPool {
            PgPool::connect_lazy(&self.conf.url).unwrap()
        }
    }
}

#[cfg(feature = "mysql")]
pub mod mysql {
    use super::*;
    use crate::{Init, TestContainer};
    use sqlx::MySqlPool;
    use testcontainers::core::ContainerPort;
    use testcontainers_modules::mysql::Mysql;

    /// Run a MySql container with the given database name.
    pub async fn run(db_name: DbName) -> DbContainer {
        let container = crate::container(Mysql::default()).await;
        let port = container
            .get_host_port_ipv4(ContainerPort::Tcp(3306))
            .await
            .expect("failed to get host port");
        let conf =
            super::setup_database::<sqlx::mysql::MySql, _>(db_url, port, &db_name, "test").await;
        DbContainer { container, conf }
    }

    fn db_url(port: u16, db_name: &str) -> String {
        format!("mysql://root@127.0.0.1:{port}/{db_name}")
    }

    impl Init<MySqlPool> for TestContainer<DbConf> {
        async fn init(self) -> MySqlPool {
            MySqlPool::connect_lazy(&self.conf.url).unwrap()
        }
    }
}

/// Setup the database with the given URL and database name.
#[cfg(any(feature = "mysql", feature = "postgres"))]
async fn setup_database<DB, F>(db_url: F, port: u16, db_name: &DbName, default: &str) -> DbConf
where
    DB: sqlx::Database,
    F: Fn(u16, &str) -> String,
    for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
{
    let url = db_url(port, default);
    unsafe {
        // this is unsafe to be used in a multi-threaded context
        // The DATABASE_URL is set to the same value for all threads
        // and it does not change based on the `db_name`
        // It is set to make [`sqlx::test`] integration possible
        ::std::env::set_var("DATABASE_URL", &url);
    }
    let name = match db_name {
        DbName::Random => {
            let db_name = format!("_{}", uuid::Uuid::new_v4().simple());
            init_database::<DB>(&url, &db_name).await;
            db_name
        }
        DbName::Static(name) => {
            init_database::<DB>(&url, name).await;
            name.to_string()
        }
        DbName::Default => default.to_string(),
    };
    DbConf {
        url: db_url(port, &name),
        db_name: name,
    }
}

/// Initialize the database with the given name.
#[cfg(any(feature = "mysql", feature = "postgres"))]
async fn init_database<DB>(db_url: &str, db_name: &str)
where
    DB: sqlx::Database,
    for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
{
    use sqlx::{ConnectOptions, Executor};
    let mut conn =
        <DB::Connection as sqlx::Connection>::Options::from_url(&db_url.parse().unwrap())
            .unwrap()
            .connect()
            .await
            .expect("failed to connect to test database");
    (&mut conn)
        .execute(format!(r#"CREATE DATABASE {db_name}"#).as_str())
        .await
        .expect("Failed to create database");
}