use arc_swap::ArcSwapOption;
use std::future::Future;
use std::sync::{Arc, OnceLock};
#[cfg(not(feature = "runtime-tokio"))]
use std::cell::RefCell;
use crate::error::{Error, Result};
use crate::internal::{DbBackend, InternalConnection};
use super::{ConnectionRef, Database};
static GLOBAL_DB: OnceLock<Database> = OnceLock::new();
static GLOBAL_CONNECTION: OnceLock<ArcSwapOption<InternalConnection>> = OnceLock::new();
#[cfg(feature = "runtime-tokio")]
tokio::task_local! {
pub(super) static TASK_DB_OVERRIDE: DatabaseHandle;
}
#[cfg(not(feature = "runtime-tokio"))]
thread_local! {
pub(super) static THREAD_DB_OVERRIDE: RefCell<Option<DatabaseHandle>> = const { RefCell::new(None) };
}
#[derive(Clone)]
pub(crate) enum DatabaseHandle {
Connection(Arc<InternalConnection>),
Transaction(Arc<crate::internal::DatabaseTransaction>),
}
pub(super) fn global_connection_slot() -> &'static ArcSwapOption<InternalConnection> {
GLOBAL_CONNECTION.get_or_init(|| ArcSwapOption::new(None))
}
pub(super) fn global_db_handle() -> &'static Database {
GLOBAL_DB.get_or_init(Database::global_handle)
}
pub(super) fn panic_missing_global_db(message: &str) -> ! {
panic!("{}", message)
}
pub fn db() -> &'static Database {
let db = global_db_handle();
if db.is_connected() {
db
} else {
panic_missing_global_db(
"Global database connection not initialized. \
Call Database::init() or Database::set_global() before using models. \
Use try_db() for a non-panicking alternative.",
)
}
}
pub fn require_db() -> Result<Database> {
global_connection_slot()
.load_full()
.map(|inner| Database::from_handle(DatabaseHandle::Connection(inner)))
.ok_or_else(|| {
Error::connection(
"Global database connection not initialized. \
Call Database::init() or Database::set_global() before using models."
.to_string(),
)
})
}
pub fn try_db() -> Option<Database> {
global_connection_slot()
.load_full()
.map(|inner| Database::from_handle(DatabaseHandle::Connection(inner)))
}
pub fn has_global_db() -> bool {
global_connection_slot().load_full().is_some()
}
#[cfg(feature = "runtime-tokio")]
fn current_override_handle() -> Option<DatabaseHandle> {
TASK_DB_OVERRIDE.try_with(Clone::clone).ok()
}
#[cfg(not(feature = "runtime-tokio"))]
fn current_override_handle() -> Option<DatabaseHandle> {
THREAD_DB_OVERRIDE.with(|slot| slot.borrow().clone())
}
#[doc(hidden)]
pub fn __current_db() -> Result<Database> {
if let Some(handle) = current_override_handle() {
return Ok(Database::from_handle(handle));
}
require_db()
}
pub(super) fn current_scope_handle() -> Result<DatabaseHandle> {
if let Some(handle) = current_override_handle() {
return Ok(handle);
}
global_connection_slot()
.load_full()
.map(DatabaseHandle::Connection)
.ok_or_else(|| {
Error::connection(
"Global database connection not initialized. \
Call Database::init() or Database::set_global() before using models."
.to_string(),
)
})
}
#[cfg(feature = "runtime-tokio")]
pub(super) async fn with_connection_override<F>(handle: DatabaseHandle, future: F) -> F::Output
where
F: Future,
{
TASK_DB_OVERRIDE.scope(handle, future).await
}
#[cfg(not(feature = "runtime-tokio"))]
pub(super) async fn with_connection_override<F>(handle: DatabaseHandle, future: F) -> F::Output
where
F: Future,
{
let previous = THREAD_DB_OVERRIDE.with(|slot| slot.replace(Some(handle)));
struct ResetThreadOverride(Option<DatabaseHandle>);
impl Drop for ResetThreadOverride {
fn drop(&mut self) {
THREAD_DB_OVERRIDE.with(|slot| {
slot.replace(self.0.take());
});
}
}
let _reset = ResetThreadOverride(previous);
future.await
}
#[doc(hidden)]
pub fn __current_connection() -> Result<ConnectionRef> {
Ok(match current_scope_handle()? {
DatabaseHandle::Connection(inner) => ConnectionRef::Database(inner),
DatabaseHandle::Transaction(tx) => ConnectionRef::Transaction(tx),
})
}
#[doc(hidden)]
pub fn __current_backend() -> Result<DbBackend> {
use crate::internal::ConnectionTrait;
Ok(match current_scope_handle()? {
DatabaseHandle::Connection(inner) => inner.connection().get_database_backend(),
DatabaseHandle::Transaction(tx) => tx.as_ref().get_database_backend(),
})
}