use bevy::prelude::App;
use bevy_persistence_database::PersistencePluginCore;
use bevy_persistence_database::persistence_plugin::PersistencePluginConfig;
use bevy_persistence_database::{ ArangoDbConnection, DatabaseConnection };
#[cfg(feature = "postgres")]
use bevy_persistence_database::PostgresDbConnection;
use std::sync::Arc;
use std::sync::{OnceLock, atomic::{AtomicUsize, Ordering}, Mutex};
use testcontainers::{core::WaitFor, runners::AsyncRunner, ContainerAsync, GenericImage, ImageExt};
use tokio::runtime::Runtime;
static TEST_RT: OnceLock<Arc<Runtime>> = OnceLock::new();
struct GlobalContainerState {
rt: Arc<Runtime>,
container: Mutex<Option<ContainerAsync<GenericImage>>>,
container_id: String,
base_url: String,
}
impl Drop for GlobalContainerState {
fn drop(&mut self) {
let keep = std::env::var("bevy_persistence_database_KEEP_CONTAINER")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if keep {
if self.container.lock().unwrap().is_some() {
eprintln!("[bevy_persistence_database tests] bevy_persistence_database_KEEP_CONTAINER=1 set; leaving ArangoDB container running at {}", self.base_url);
}
return;
}
if self.container.lock().unwrap().is_some() {
let _ = std::process::Command::new("docker")
.args(["rm", "-f", "-v", &self.container_id])
.status();
if let Some(c) = self.container.lock().unwrap().take() {
std::mem::forget(c);
}
}
}
}
static GLOBAL: OnceLock<GlobalContainerState> = OnceLock::new();
static DB_COUNTER: AtomicUsize = AtomicUsize::new(0);
async fn start_container() -> (ContainerAsync<GenericImage>, String, String) {
let container = GenericImage::new("arangodb", "3.12.5")
.with_wait_for(WaitFor::message_on_stdout("is ready for business"))
.with_env_var("ARANGO_ROOT_PASSWORD", "password")
.start()
.await
.expect("Failed to start ArangoDB container");
let host_port = container.get_host_port_ipv4(8529).await.unwrap();
let url = format!("http://127.0.0.1:{}", host_port);
let id = container.id().to_string();
(container, url, id)
}
fn ensure_global() -> &'static GlobalContainerState {
GLOBAL.get_or_init(|| {
let rt = TEST_RT.get_or_init(|| {
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to build test tokio runtime"),
)
}).clone();
let (container, base_url, container_id) = rt.block_on(start_container());
GlobalContainerState { rt, container: Mutex::new(Some(container)), container_id, base_url }
})
}
#[ctor::ctor]
fn initialize_logging() {
if std::env::var("RUST_LOG").is_err() {
unsafe {
std::env::set_var("RUST_LOG", "warn");
}
}
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_test_writer()
.init();
let _ = ensure_global();
}
#[ctor::dtor]
fn teardown_container() {
if let Some(state) = GLOBAL.get() {
let keep = std::env::var("bevy_persistence_database_KEEP_CONTAINER")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if keep {
if state.container.lock().unwrap().is_some() {
eprintln!("[bevy_persistence_database tests] bevy_persistence_database_KEEP_CONTAINER=1 set; leaving ArangoDB container running at {}", state.base_url);
}
return;
}
if state.container.lock().unwrap().is_some() {
let _ = std::process::Command::new("docker")
.args(["rm", "-f", "-v", &state.container_id])
.status();
if let Some(c) = state.container.lock().unwrap().take() {
std::mem::forget(c);
}
}
}
}
pub struct ContainerGuard {
pub(crate) inner: Option<ContainerAsync<GenericImage>>,
}
impl Drop for ContainerGuard {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
std::mem::forget(inner);
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TestBackend { Arango, Postgres }
pub fn configured_backends() -> Vec<TestBackend> {
let raw = std::env::var("bevy_persistence_database_TEST_BACKENDS").unwrap_or_default();
let mut out = Vec::new();
for token in raw.split(',').map(|s| s.trim().to_ascii_lowercase()).filter(|s| !s.is_empty()) {
match token.as_str() {
"arango" => { #[cfg(feature = "arango")] out.push(TestBackend::Arango); }
"postgres" => { #[cfg(feature = "postgres")] out.push(TestBackend::Postgres); }
_ => {}
}
}
if out.is_empty() {
#[cfg(feature = "arango")] { out.push(TestBackend::Arango); }
#[cfg(feature = "postgres")] { out.push(TestBackend::Postgres); }
}
out
}
#[cfg(feature = "postgres")]
struct PgGlobalContainerState {
rt: Arc<Runtime>,
container: Mutex<Option<ContainerAsync<GenericImage>>>,
container_id: String,
host: String,
port: u16,
}
#[cfg(feature = "postgres")]
impl Drop for PgGlobalContainerState {
fn drop(&mut self) {
let keep = std::env::var("bevy_persistence_database_KEEP_CONTAINER")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if keep {
if self.container.lock().unwrap().is_some() {
eprintln!("[bevy_persistence_database tests] bevy_persistence_database_KEEP_CONTAINER=1 set; leaving Postgres container running at {}:{}", self.host, self.port);
}
return;
}
if self.container.lock().unwrap().is_some() {
let _ = std::process::Command::new("docker")
.args(["rm", "-f", "-v", &self.container_id])
.status();
if let Some(c) = self.container.lock().unwrap().take() {
std::mem::forget(c);
}
}
}
}
#[cfg(feature = "postgres")]
static PG_GLOBAL: OnceLock<PgGlobalContainerState> = OnceLock::new();
#[cfg(feature = "postgres")]
async fn start_postgres_container() -> (ContainerAsync<GenericImage>, String, u16, String) {
let image = GenericImage::new("postgres", "16")
.with_env_var("POSTGRES_PASSWORD", "password")
.with_env_var("POSTGRES_USER", "postgres");
let container = image.start().await.expect("Failed to start Postgres container");
let host_port = container.get_host_port_ipv4(5432).await.unwrap();
let id = container.id().to_string();
(container, "127.0.0.1".to_string(), host_port, id)
}
#[cfg(feature = "postgres")]
fn ensure_pg_global() -> &'static PgGlobalContainerState {
PG_GLOBAL.get_or_init(|| {
let rt = TEST_RT.get_or_init(|| {
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to build test tokio runtime"),
)
}).clone();
let (container, host, port, container_id) = rt.block_on(async {
let (c, h, p, id) = start_postgres_container().await;
(c, h, p, id)
});
PgGlobalContainerState { rt, container: Mutex::new(Some(container)), container_id, host, port }
})
}
pub fn setup_backend(backend: TestBackend) -> (Arc<dyn DatabaseConnection>, ContainerGuard) {
match backend {
TestBackend::Arango => {
let state = ensure_global();
let db_name = format!("test_db_{}", DB_COUNTER.fetch_add(1, Ordering::Relaxed));
state.rt.block_on(async {
ArangoDbConnection::ensure_database(&state.base_url, "root", "password", &db_name)
.await
.expect("Failed to create database");
});
let db = state.rt.block_on(ArangoDbConnection::connect(&state.base_url, "root", "password", &db_name))
.expect("Failed to connect to per-test database");
(Arc::new(db), ContainerGuard { inner: None })
}
#[cfg(feature = "postgres")]
TestBackend::Postgres => {
let pg = ensure_pg_global();
let db_name = format!("test_db_{}", DB_COUNTER.fetch_add(1, Ordering::Relaxed));
let mut last_err = String::new();
let mut created = false;
for _ in 0..30 {
match pg.rt.block_on(async {
PostgresDbConnection::ensure_database(&pg.host, "postgres", "password", &db_name, Some(pg.port)).await
}) {
Ok(_) => { created = true; break; }
Err(e) => {
last_err = format!("{e:?}");
std::thread::sleep(std::time::Duration::from_millis(250));
}
}
}
if !created {
panic!("Failed to create Postgres database after retries: {}", last_err);
}
let mut last = String::new();
let mut db_opt = None;
for _ in 0..30 {
match pg.rt.block_on(PostgresDbConnection::connect(&pg.host, "postgres", "password", &db_name, Some(pg.port))) {
Ok(conn) => { db_opt = Some(conn); break; }
Err(e) => {
last = format!("{e:?}");
std::thread::sleep(std::time::Duration::from_millis(250));
}
}
}
let db = db_opt.unwrap_or_else(|| panic!("Failed to connect to Postgres per-test database after retries: {}", last));
(Arc::new(db), ContainerGuard { inner: None })
}
#[cfg(not(feature = "postgres"))]
TestBackend::Postgres => panic!("postgres feature not enabled"),
}
}
pub fn setup_sync() -> (Arc<dyn DatabaseConnection>, ContainerGuard) {
let backends = configured_backends();
let backend = backends.first().copied().expect("No test backend configured or available");
setup_backend(backend)
}
pub fn run_async<F: std::future::Future>(fut: F) -> F::Output {
let rt = TEST_RT.get_or_init(|| {
Arc::new(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("failed to build test tokio runtime"),
)
}).clone();
rt.block_on(fut)
}
pub fn make_app(db: Arc<dyn DatabaseConnection>, batch_size: usize) -> App {
let config = PersistencePluginConfig { batching_enabled: true, commit_batch_size: batch_size, thread_count: 4 };
let plugin = PersistencePluginCore::new(db.clone()).with_config(config);
let mut app = App::new();
app.add_plugins(plugin);
app
}
#[cfg(feature = "postgres")]
#[ctor::dtor]
fn teardown_pg_container() {
if let Some(state) = PG_GLOBAL.get() {
let keep = std::env::var("BEVY_POSTGRESDB_KEEP_CONTAINER")
.map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
.unwrap_or(false);
if keep {
if state.container.lock().unwrap().is_some() {
eprintln!(
"[bevy_persistence_database tests] BEVY_POSTGRESDB_KEEP_CONTAINER=1 set; leaving Postgres container running at {}:{}",
state.host, state.port
);
}
return;
}
if state.container.lock().unwrap().is_some() {
let _ = std::process::Command::new("docker")
.args(["rm", "-f", "-v", &state.container_id])
.status();
if let Some(c) = state.container.lock().unwrap().take() {
std::mem::forget(c);
}
}
}
}