use std::cell::RefCell;
use sea_orm::{Database, DatabaseConnection};
use crate::runtime;
const DATABASE_NOT_ESTABLISHED: &str = "rustrails_support::database::establish() must be called on this thread before accessing the database connection";
thread_local! {
static DB_CONNECTION: RefCell<Option<DatabaseConnection>> = const { RefCell::new(None) };
}
#[derive(Debug, thiserror::Error)]
pub enum DatabaseError {
#[error("database connection failed: {0}")]
ConnectionFailed(#[from] sea_orm::DbErr),
}
pub fn establish(url: &str) -> Result<(), DatabaseError> {
let connection = runtime::block_on(Database::connect(url))?;
DB_CONNECTION.with(|cell| {
*cell.borrow_mut() = Some(connection);
});
Ok(())
}
pub fn db() -> DatabaseConnection {
DB_CONNECTION.with(|cell| {
cell.borrow()
.as_ref()
.cloned()
.unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"))
})
}
pub fn with_db<F, R>(f: F) -> R
where
F: FnOnce(&DatabaseConnection) -> R,
{
DB_CONNECTION.with(|cell| {
let borrow = cell.borrow();
let connection = borrow
.as_ref()
.unwrap_or_else(|| panic!("{DATABASE_NOT_ESTABLISHED}"));
f(connection)
})
}
pub fn is_established() -> bool {
DB_CONNECTION.with(|cell| cell.borrow().is_some())
}
#[cfg(test)]
mod tests {
use std::{any::Any, thread};
use sea_orm::{
ConnectionTrait, DatabaseBackend,
sea_query::{Alias, ColumnDef, Expr, Query, Table},
};
use super::{db, establish, is_established, with_db};
use crate::runtime;
fn run_isolated<R>(test: impl FnOnce() -> R + Send + 'static) -> R
where
R: Send + 'static,
{
match thread::spawn(test).join() {
Ok(result) => result,
Err(payload) => std::panic::resume_unwind(payload),
}
}
fn panic_message(payload: Box<dyn Any + Send>) -> String {
if let Some(message) = payload.downcast_ref::<String>() {
message.clone()
} else if let Some(message) = payload.downcast_ref::<&str>() {
(*message).to_owned()
} else {
"non-string panic payload".to_owned()
}
}
#[test]
fn establish_connects_to_in_memory_sqlite() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
assert!(is_established());
});
}
#[test]
fn db_panics_before_establish() {
let message = run_isolated(|| {
let panic = std::panic::catch_unwind(db)
.expect_err("db should panic before establish is called");
panic_message(panic)
});
assert!(message.contains("database::establish() must be called on this thread"));
}
#[test]
fn db_returns_a_usable_connection_after_establish() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
runtime::block_on(async {
db().ping()
.await
.expect("stored connection should respond to ping");
});
});
}
#[test]
fn with_db_passes_the_connection_into_the_closure() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
let backend = with_db(|connection| connection.get_database_backend());
assert_eq!(backend, DatabaseBackend::Sqlite);
});
}
#[test]
fn establish_twice_replaces_the_connection() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
establish("sqlite::memory:").expect("first sqlite in-memory connection should succeed");
runtime::block_on(async {
db().execute(
&Table::create()
.table(Alias::new("replacement_check"))
.col(
ColumnDef::new(Alias::new("id"))
.integer()
.not_null()
.primary_key(),
)
.to_owned(),
)
.await
.expect("table creation should succeed");
});
establish("sqlite::memory:")
.expect("second sqlite in-memory connection should succeed");
let query_result = runtime::block_on(async {
db().query_one(
&Query::select()
.expr(Expr::col(Alias::new("id")))
.from(Alias::new("replacement_check"))
.limit(1)
.to_owned(),
)
.await
});
assert!(query_result.is_err());
});
}
#[test]
fn is_established_reflects_connection_state() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
assert!(!is_established());
establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
assert!(is_established());
});
}
#[test]
fn database_error_displays_the_underlying_failure() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
let error = establish("not-a-valid-database-url")
.expect_err("invalid database URLs should fail");
assert!(error.to_string().starts_with("database connection failed:"));
});
}
#[test]
fn with_db_can_return_a_computed_value() {
run_isolated(|| {
let _runtime = runtime::init_runtime();
establish("sqlite::memory:").expect("sqlite in-memory connection should succeed");
let is_sqlite =
with_db(|connection| connection.get_database_backend() == DatabaseBackend::Sqlite);
assert!(is_sqlite);
});
}
}