1use std::cell::RefCell;
26
27use crate::data::db::ArclyDbPool;
28#[cfg(any(feature = "db-sqlx", feature = "db-seaorm", feature = "db-diesel"))]
29use crate::data::db::DbDriver;
30use crate::data::DataError;
31use crate::web::context::RequestContext;
32use crate::web::error::{HttpException, Internal};
33
34pub enum ArclyTransaction {
38 #[cfg(feature = "db-sqlx")]
39 Sqlx(sqlx::Transaction<'static, sqlx::Any>),
40 #[cfg(feature = "db-seaorm")]
41 SeaOrm(sea_orm::DatabaseTransaction),
42}
43
44impl ArclyTransaction {
45 pub async fn commit(self) -> Result<(), DataError> {
46 match self {
47 #[cfg(feature = "db-sqlx")]
48 ArclyTransaction::Sqlx(tx) => tx.commit().await.map_err(|e| DataError(e.to_string())),
49 #[cfg(feature = "db-seaorm")]
50 ArclyTransaction::SeaOrm(tx) => tx.commit().await.map_err(|e| DataError(e.to_string())),
51 #[allow(unreachable_patterns)]
52 _ => Ok(()),
53 }
54 }
55
56 pub async fn rollback(self) -> Result<(), DataError> {
57 match self {
58 #[cfg(feature = "db-sqlx")]
59 ArclyTransaction::Sqlx(tx) => tx.rollback().await.map_err(|e| DataError(e.to_string())),
60 #[cfg(feature = "db-seaorm")]
61 ArclyTransaction::SeaOrm(tx) => {
62 tx.rollback().await.map_err(|e| DataError(e.to_string()))
63 }
64 #[allow(unreachable_patterns)]
65 _ => Ok(()),
66 }
67 }
68}
69
70impl ArclyDbPool {
71 #[allow(unreachable_code)]
73 pub async fn begin(&self) -> Result<ArclyTransaction, DataError> {
74 match self.primary() {
75 #[cfg(feature = "db-sqlx")]
76 DbDriver::Sqlx(pool) => Ok(ArclyTransaction::Sqlx(
77 pool.begin().await.map_err(|e| DataError(e.to_string()))?,
78 )),
79 #[cfg(feature = "db-seaorm")]
80 DbDriver::SeaOrm(conn) => {
81 use sea_orm::TransactionTrait;
82 Ok(ArclyTransaction::SeaOrm(
83 conn.begin().await.map_err(|e| DataError(e.to_string()))?,
84 ))
85 }
86 #[cfg(feature = "db-diesel")]
87 DbDriver::Diesel(_) => Err(DataError(
88 "#[Transactional] is not supported on sync Diesel pools — \
89 run the whole transaction inside DieselBlockingPool::transaction(…)"
90 .into(),
91 )),
92 #[allow(unreachable_patterns)]
93 _ => Err(DataError("no database driver feature enabled".into())),
94 }
95 }
96}
97
98tokio::task_local! {
101 static CURRENT_TX: RefCell<Option<ArclyTransaction>>;
103}
104
105pub async fn with_current_tx<R, F, Fut>(work: F) -> Result<Option<R>, DataError>
112where
113 F: FnOnce(ArclyTransaction) -> Fut,
114 Fut: std::future::Future<Output = (ArclyTransaction, Result<R, DataError>)>,
115{
116 let taken = CURRENT_TX
118 .try_with(|slot| slot.borrow_mut().take())
119 .ok()
120 .flatten();
121
122 let Some(tx) = taken else { return Ok(None) };
123
124 let (tx, result) = work(tx).await;
125
126 let _ = CURRENT_TX.try_with(|slot| *slot.borrow_mut() = Some(tx));
128 result.map(Some)
129}
130
131pub fn in_transaction() -> bool {
133 CURRENT_TX
134 .try_with(|slot| slot.borrow().is_some())
135 .unwrap_or(false)
136}
137
138#[doc(hidden)]
147pub async fn run_transactional<T, Fut>(ctx: &RequestContext, body: Fut) -> Result<T, HttpException>
148where
149 Fut: std::future::Future<Output = Result<T, HttpException>>,
150{
151 let registry = ctx
152 .try_inject::<crate::data::DataSourceRegistry<ArclyDbPool>>()
153 .ok_or_else(|| {
154 Internal::new(
155 "#[Transactional] requires DataSourceRegistry<ArclyDbPool> in the DI container",
156 )
157 })?;
158
159 let pool = registry.for_tenant(ctx.tenant());
160 let tx = pool
161 .begin()
162 .await
163 .map_err(|e| Internal::new(format!("failed to begin transaction: {e}")))?;
164
165 CURRENT_TX
166 .scope(RefCell::new(Some(tx)), async move {
167 let outcome = body.await;
168
169 let tx = CURRENT_TX
170 .try_with(|slot| slot.borrow_mut().take())
171 .ok()
172 .flatten();
173
174 match (outcome, tx) {
175 (Ok(v), Some(tx)) => {
176 tx.commit()
177 .await
178 .map_err(|e| Internal::new(format!("commit failed: {e}")))?;
179 Ok(v)
180 }
181 (Ok(v), None) => Ok(v),
183 (Err(e), Some(tx)) => {
184 if let Err(rb) = tx.rollback().await {
185 tracing::error!(error = %rb, "rollback failed after handler error");
186 }
187 Err(e)
188 }
189 (Err(e), None) => Err(e),
190 }
191 })
192 .await
193}