use std::{
error::Error,
sync::{Arc, Mutex, RwLock},
};
use toasty::{Db, schema::ModelSet};
use tokio::runtime::Runtime;
use crate::{ExecLog, Isolate, LoggingDriver, Setup};
static TEST_LOCK: RwLock<()> = RwLock::new(());
pub struct Test {
setup: Arc<dyn Setup>,
isolate: Isolate,
runtime: Option<Runtime>,
exec_log: ExecLog,
tables: Vec<String>,
serial: bool,
}
impl Test {
pub fn new(setup: Arc<dyn Setup>) -> Self {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create Tokio runtime");
Test {
setup,
isolate: Isolate::new(),
runtime: Some(runtime),
exec_log: ExecLog::new(Arc::new(Mutex::new(Vec::new()))),
tables: vec![],
serial: false,
}
}
pub async fn try_setup_db(&mut self, models: ModelSet) -> toasty::Result<Db> {
self.try_setup_db_with(models, |_| {}).await
}
pub async fn try_setup_db_with(
&mut self,
models: ModelSet,
customize: impl FnOnce(&mut toasty::db::Builder),
) -> toasty::Result<Db> {
let mut builder = toasty::Db::builder();
builder.models(models);
builder.table_name_prefix(&self.isolate.table_prefix());
customize(&mut builder);
let logging_driver = LoggingDriver::new(self.setup.driver());
let ops_log = logging_driver.ops_log_handle();
self.exec_log = ExecLog::new(ops_log);
let db = builder.build(logging_driver).await?;
db.push_schema().await?;
for table in &db.schema().db.tables {
self.tables.push(table.name.clone());
}
Ok(db)
}
pub async fn setup_db(&mut self, models: ModelSet) -> Db {
self.try_setup_db(models).await.unwrap()
}
pub async fn setup_db_with(
&mut self,
models: ModelSet,
customize: impl FnOnce(&mut toasty::db::Builder),
) -> Db {
self.try_setup_db_with(models, customize).await.unwrap()
}
pub fn capability(&self) -> &'static toasty_core::driver::Capability {
self.setup.driver().capability()
}
pub fn log(&mut self) -> &mut ExecLog {
&mut self.exec_log
}
pub fn set_serial(&mut self, serial: bool) {
self.serial = serial;
}
pub fn run<R>(&mut self, f: impl AsyncFn(&mut Test) -> R)
where
R: Into<TestResult>,
{
let _guard: Box<dyn std::any::Any> = if self.serial {
Box::new(TEST_LOCK.write().unwrap_or_else(|e| e.into_inner()))
} else {
Box::new(TEST_LOCK.read().unwrap_or_else(|e| e.into_inner()))
};
let runtime = self.runtime.take().expect("runtime already consumed");
let f: std::pin::Pin<Box<dyn std::future::Future<Output = R>>> = Box::pin(f(self));
let result = runtime.block_on(f).into();
for table in &self.tables {
runtime.block_on(self.setup.delete_table(table));
}
if let Some(error) = result.error {
panic!("Driver test returned an error: {error}");
}
self.runtime = Some(runtime);
}
}
pub struct TestResult {
error: Option<Box<dyn Error>>,
}
impl From<()> for TestResult {
fn from(_: ()) -> Self {
TestResult { error: None }
}
}
impl<O, E> From<Result<O, E>> for TestResult
where
E: Into<Box<dyn Error>>,
{
fn from(value: Result<O, E>) -> Self {
TestResult {
error: value.err().map(Into::into),
}
}
}