use std::cell::RefCell;
use crate::data::db::ArclyDbPool;
#[cfg(any(feature = "db-sqlx", feature = "db-seaorm", feature = "db-diesel"))]
use crate::data::db::DbDriver;
use crate::data::DataError;
use crate::web::context::RequestContext;
use crate::web::error::{HttpException, Internal};
pub enum ArclyTransaction {
#[cfg(feature = "db-sqlx")]
Sqlx(sqlx::Transaction<'static, sqlx::Any>),
#[cfg(feature = "db-seaorm")]
SeaOrm(sea_orm::DatabaseTransaction),
}
impl ArclyTransaction {
pub async fn commit(self) -> Result<(), DataError> {
match self {
#[cfg(feature = "db-sqlx")]
ArclyTransaction::Sqlx(tx) => tx
.commit()
.await
.map_err(|e| DataError::query(e.to_string())),
#[cfg(feature = "db-seaorm")]
ArclyTransaction::SeaOrm(tx) => tx
.commit()
.await
.map_err(|e| DataError::query(e.to_string())),
#[allow(unreachable_patterns)]
_ => Ok(()),
}
}
pub async fn rollback(self) -> Result<(), DataError> {
match self {
#[cfg(feature = "db-sqlx")]
ArclyTransaction::Sqlx(tx) => tx
.rollback()
.await
.map_err(|e| DataError::query(e.to_string())),
#[cfg(feature = "db-seaorm")]
ArclyTransaction::SeaOrm(tx) => tx
.rollback()
.await
.map_err(|e| DataError::query(e.to_string())),
#[allow(unreachable_patterns)]
_ => Ok(()),
}
}
}
impl ArclyDbPool {
#[allow(unreachable_code)]
pub async fn begin(&self) -> Result<ArclyTransaction, DataError> {
match self.primary() {
#[cfg(feature = "db-sqlx")]
DbDriver::Sqlx(pool) => Ok(ArclyTransaction::Sqlx(
pool.begin()
.await
.map_err(|e| DataError::connection(e.to_string()))?,
)),
#[cfg(feature = "db-seaorm")]
DbDriver::SeaOrm(conn) => {
use sea_orm::TransactionTrait;
Ok(ArclyTransaction::SeaOrm(
conn.begin()
.await
.map_err(|e| DataError::connection(e.to_string()))?,
))
}
#[cfg(feature = "db-diesel")]
DbDriver::Diesel(_) => Err(DataError::config(
"#[Transactional] is not supported on sync Diesel pools — \
run the whole transaction inside DieselBlockingPool::transaction(…)",
)),
#[allow(unreachable_patterns)]
_ => Err(DataError::config("no database driver feature enabled")),
}
}
}
tokio::task_local! {
static CURRENT_TX: RefCell<Option<ArclyTransaction>>;
}
pub async fn with_current_tx<R, F, Fut>(work: F) -> Result<Option<R>, DataError>
where
F: FnOnce(ArclyTransaction) -> Fut,
Fut: std::future::Future<Output = (ArclyTransaction, Result<R, DataError>)>,
{
let taken = CURRENT_TX
.try_with(|slot| slot.borrow_mut().take())
.ok()
.flatten();
let Some(tx) = taken else { return Ok(None) };
let (tx, result) = work(tx).await;
let _ = CURRENT_TX.try_with(|slot| *slot.borrow_mut() = Some(tx));
result.map(Some)
}
pub fn in_transaction() -> bool {
CURRENT_TX
.try_with(|slot| slot.borrow().is_some())
.unwrap_or(false)
}
#[doc(hidden)]
pub async fn run_transactional<T, Fut>(ctx: &RequestContext, body: Fut) -> Result<T, HttpException>
where
Fut: std::future::Future<Output = Result<T, HttpException>>,
{
let registry = ctx
.try_inject::<crate::data::DataSourceRegistry<ArclyDbPool>>()
.ok_or_else(|| {
Internal::new(
"#[Transactional] requires DataSourceRegistry<ArclyDbPool> in the DI container",
)
})?;
let pool = registry.for_tenant(ctx.tenant());
let tx = pool
.begin()
.await
.map_err(|e| Internal::new(format!("failed to begin transaction: {e}")))?;
CURRENT_TX
.scope(RefCell::new(Some(tx)), async move {
let outcome = body.await;
let tx = CURRENT_TX
.try_with(|slot| slot.borrow_mut().take())
.ok()
.flatten();
match (outcome, tx) {
(Ok(v), Some(tx)) => {
tx.commit()
.await
.map_err(|e| Internal::new(format!("commit failed: {e}")))?;
Ok(v)
}
(Ok(v), None) => Ok(v),
(Err(e), Some(tx)) => {
if let Err(rb) = tx.rollback().await {
tracing::error!(error = %rb, "rollback failed after handler error");
}
Err(e)
}
(Err(e), None) => Err(e),
}
})
.await
}