use crate::TestContainer;
pub enum DbName {
Random,
Static(String),
Default,
}
pub struct DbConf {
pub url: String,
pub db_name: String,
}
pub type DbContainer = TestContainer<DbConf>;
#[cfg(feature = "postgres")]
pub mod postgres {
use crate::{Init, TestContainer};
use super::*;
use sqlx::PgPool;
use testcontainers::core::ContainerPort;
use testcontainers_modules::postgres::Postgres;
pub async fn run(db_name: DbName) -> DbContainer {
let container = crate::container(Postgres::default()).await;
let port = container
.get_host_port_ipv4(ContainerPort::Tcp(5432))
.await
.expect("failed to get host port");
let conf = super::setup_database::<sqlx::postgres::Postgres, _>(
db_url, port, &db_name, "postgres",
)
.await;
DbContainer { container, conf }
}
fn db_url(port: u16, db_name: &str) -> String {
format!("postgres://postgres:postgres@127.0.0.1:{port}/{db_name}")
}
impl Init<PgPool> for TestContainer<DbConf> {
async fn init(self) -> PgPool {
PgPool::connect_lazy(&self.conf.url).unwrap()
}
}
}
#[cfg(feature = "mysql")]
pub mod mysql {
use super::*;
use crate::{Init, TestContainer};
use sqlx::MySqlPool;
use testcontainers::core::ContainerPort;
use testcontainers_modules::mysql::Mysql;
pub async fn run(db_name: DbName) -> DbContainer {
let container = crate::container(Mysql::default()).await;
let port = container
.get_host_port_ipv4(ContainerPort::Tcp(3306))
.await
.expect("failed to get host port");
let conf =
super::setup_database::<sqlx::mysql::MySql, _>(db_url, port, &db_name, "test").await;
DbContainer { container, conf }
}
fn db_url(port: u16, db_name: &str) -> String {
format!("mysql://root@127.0.0.1:{port}/{db_name}")
}
impl Init<MySqlPool> for TestContainer<DbConf> {
async fn init(self) -> MySqlPool {
MySqlPool::connect_lazy(&self.conf.url).unwrap()
}
}
}
#[cfg(any(feature = "mysql", feature = "postgres"))]
async fn setup_database<DB, F>(db_url: F, port: u16, db_name: &DbName, default: &str) -> DbConf
where
DB: sqlx::Database,
F: Fn(u16, &str) -> String,
for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
{
let url = db_url(port, default);
unsafe {
::std::env::set_var("DATABASE_URL", &url);
}
let name = match db_name {
DbName::Random => {
let db_name = format!("_{}", uuid::Uuid::new_v4().simple());
init_database::<DB>(&url, &db_name).await;
db_name
}
DbName::Static(name) => {
init_database::<DB>(&url, name).await;
name.to_string()
}
DbName::Default => default.to_string(),
};
DbConf {
url: db_url(port, &name),
db_name: name,
}
}
#[cfg(any(feature = "mysql", feature = "postgres"))]
async fn init_database<DB>(db_url: &str, db_name: &str)
where
DB: sqlx::Database,
for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
{
use sqlx::{ConnectOptions, Executor};
let mut conn =
<DB::Connection as sqlx::Connection>::Options::from_url(&db_url.parse().unwrap())
.unwrap()
.connect()
.await
.expect("failed to connect to test database");
(&mut conn)
.execute(format!(r#"CREATE DATABASE {db_name}"#).as_str())
.await
.expect("Failed to create database");
}