use std::{cell::Cell, future::Future, pin::Pin, sync::Arc};
use sea_orm::{DatabaseConnection, DatabaseTransaction, TransactionTrait};
use super::tx_config::TxConfig;
use super::tx_error::TxError;
use crate::{DbError, DbHandle};
tokio::task_local! {
static IN_TX: Cell<bool>;
}
fn is_in_transaction() -> bool {
IN_TX.try_with(Cell::get).unwrap_or(false)
}
async fn with_tx_guard<F, T>(f: F) -> T
where
F: Future<Output = T>,
{
IN_TX.scope(Cell::new(true), f).await
}
#[derive(Clone)]
pub struct Db {
handle: Arc<DbHandle>,
}
impl std::fmt::Debug for Db {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Db")
.field("engine", &self.handle.engine())
.finish_non_exhaustive()
}
}
impl Db {
#[must_use]
pub(crate) fn new(handle: DbHandle) -> Self {
Self {
handle: Arc::new(handle),
}
}
pub(crate) fn sea_internal(&self) -> DatabaseConnection {
self.handle.sea_internal()
}
pub fn conn(&self) -> Result<DbConn<'_>, DbError> {
if is_in_transaction() {
return Err(DbError::ConnRequestedInsideTx);
}
Ok(DbConn {
conn: self.handle.sea_internal_ref(),
})
}
pub async fn lock(&self, module: &str, key: &str) -> crate::Result<crate::DbLockGuard> {
self.handle.lock(module, key).await
}
pub async fn try_lock(
&self,
module: &str,
key: &str,
config: crate::LockConfig,
) -> crate::Result<Option<crate::DbLockGuard>> {
self.handle.try_lock(module, key, config).await
}
pub async fn transaction_ref<F, T>(&self, f: F) -> Result<T, DbError>
where
F: for<'a> FnOnce(
&'a DbTx<'a>,
)
-> Pin<Box<dyn Future<Output = Result<T, DbError>> + Send + 'a>>
+ Send,
T: Send + 'static,
{
let txn = self.handle.sea_internal_ref().begin().await?;
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => {
txn.commit().await?;
Ok(v)
}
Err(e) => {
_ = txn.rollback().await;
Err(e)
}
}
}
pub async fn transaction_ref_mapped<F, T, E>(&self, f: F) -> Result<T, E>
where
E: From<DbError> + Send + 'static,
F: for<'a> FnOnce(&'a DbTx<'a>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>
+ Send,
T: Send + 'static,
{
let txn = self
.handle
.sea_internal_ref()
.begin()
.await
.map_err(DbError::from)
.map_err(E::from)?;
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => {
txn.commit().await.map_err(DbError::from).map_err(E::from)?;
Ok(v)
}
Err(e) => {
_ = txn.rollback().await;
Err(e)
}
}
}
pub async fn transaction_ref_mapped_with_config<F, T, E>(
&self,
config: TxConfig,
f: F,
) -> Result<T, E>
where
E: From<DbError> + Send + 'static,
F: for<'a> FnOnce(&'a DbTx<'a>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>
+ Send,
T: Send + 'static,
{
use sea_orm::{AccessMode, IsolationLevel};
let isolation: Option<IsolationLevel> = config.isolation.map(Into::into);
let access_mode: Option<AccessMode> = config.access_mode.map(Into::into);
let txn = self
.handle
.sea_internal_ref()
.begin_with_config(isolation, access_mode)
.await
.map_err(DbError::from)
.map_err(E::from)?;
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => {
txn.commit().await.map_err(DbError::from).map_err(E::from)?;
Ok(v)
}
Err(e) => {
_ = txn.rollback().await;
Err(e)
}
}
}
pub async fn transaction<F, T>(self, f: F) -> (Self, anyhow::Result<T>)
where
F: for<'a> FnOnce(
&'a DbTx<'a>,
)
-> Pin<Box<dyn Future<Output = anyhow::Result<T>> + Send + 'a>>
+ Send,
T: Send + 'static,
{
let txn = match self.handle.sea_internal_ref().begin().await {
Ok(t) => t,
Err(e) => return (self, Err(e.into())),
};
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(e.into())),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(e))
}
}
}
pub async fn in_transaction<T, E, F>(self, f: F) -> (Self, Result<T, TxError<E>>)
where
T: Send + 'static,
E: std::fmt::Debug + std::fmt::Display + Send + 'static,
F: for<'a> FnOnce(&'a DbTx<'a>) -> Pin<Box<dyn Future<Output = Result<T, E>> + Send + 'a>>
+ Send,
{
use super::tx_error::InfraError;
let txn = match self.handle.sea_internal_ref().begin().await {
Ok(txn) => txn,
Err(e) => return (self, Err(TxError::Infra(InfraError::new(e.to_string())))),
};
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(TxError::Infra(InfraError::new(e.to_string())))),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(TxError::Domain(e)))
}
}
}
pub async fn transaction_with_config<T, F>(
self,
config: TxConfig,
f: F,
) -> (Self, anyhow::Result<T>)
where
T: Send + 'static,
F: for<'a> FnOnce(
&'a DbTx<'a>,
)
-> Pin<Box<dyn Future<Output = anyhow::Result<T>> + Send + 'a>>
+ Send,
{
use sea_orm::{AccessMode, IsolationLevel};
let isolation: Option<IsolationLevel> = config.isolation.map(Into::into);
let access_mode: Option<AccessMode> = config.access_mode.map(Into::into);
let txn = match self
.handle
.sea_internal_ref()
.begin_with_config(isolation, access_mode)
.await
{
Ok(t) => t,
Err(e) => return (self, Err(e.into())),
};
let tx = DbTx { tx: &txn };
let res = with_tx_guard(f(&tx)).await;
match res {
Ok(v) => match txn.commit().await {
Ok(()) => (self, Ok(v)),
Err(e) => (self, Err(e.into())),
},
Err(e) => {
_ = txn.rollback().await;
(self, Err(e))
}
}
}
#[must_use]
pub fn db_engine(&self) -> &'static str {
use sea_orm::{ConnectionTrait, DbBackend};
match self.handle.sea_internal_ref().get_database_backend() {
DbBackend::Postgres => "postgres",
DbBackend::MySql => "mysql",
DbBackend::Sqlite => "sqlite",
}
}
}
pub struct DbConn<'a> {
pub(crate) conn: &'a DatabaseConnection,
}
impl std::fmt::Debug for DbConn<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DbConn").finish_non_exhaustive()
}
}
pub struct DbTx<'a> {
pub(crate) tx: &'a DatabaseTransaction,
}
impl std::fmt::Debug for DbTx<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DbTx").finish_non_exhaustive()
}
}