1use std::cell::RefCell;
26
27use crate::data::db::ArclyDbPool;
28#[cfg(any(feature = "db-sqlx", feature = "db-seaorm", feature = "db-diesel"))]
29use crate::data::db::DbDriver;
30use crate::data::DataError;
31use crate::web::context::RequestContext;
32use crate::web::error::{HttpException, Internal};
33
34pub enum ArclyTransaction {
38 #[cfg(feature = "db-sqlx")]
39 Sqlx(sqlx::Transaction<'static, sqlx::Any>),
40 #[cfg(feature = "db-seaorm")]
41 SeaOrm(sea_orm::DatabaseTransaction),
42}
43
44impl ArclyTransaction {
45 pub async fn commit(self) -> Result<(), DataError> {
46 match self {
47 #[cfg(feature = "db-sqlx")]
48 ArclyTransaction::Sqlx(tx) => tx
49 .commit()
50 .await
51 .map_err(|e| DataError::query(e.to_string())),
52 #[cfg(feature = "db-seaorm")]
53 ArclyTransaction::SeaOrm(tx) => tx
54 .commit()
55 .await
56 .map_err(|e| DataError::query(e.to_string())),
57 #[allow(unreachable_patterns)]
58 _ => Ok(()),
59 }
60 }
61
62 pub async fn rollback(self) -> Result<(), DataError> {
63 match self {
64 #[cfg(feature = "db-sqlx")]
65 ArclyTransaction::Sqlx(tx) => tx
66 .rollback()
67 .await
68 .map_err(|e| DataError::query(e.to_string())),
69 #[cfg(feature = "db-seaorm")]
70 ArclyTransaction::SeaOrm(tx) => tx
71 .rollback()
72 .await
73 .map_err(|e| DataError::query(e.to_string())),
74 #[allow(unreachable_patterns)]
75 _ => Ok(()),
76 }
77 }
78}
79
80impl ArclyDbPool {
81 #[allow(unreachable_code)]
83 pub async fn begin(&self) -> Result<ArclyTransaction, DataError> {
84 match self.primary() {
85 #[cfg(feature = "db-sqlx")]
86 DbDriver::Sqlx(pool) => Ok(ArclyTransaction::Sqlx(
87 pool.begin()
88 .await
89 .map_err(|e| DataError::connection(e.to_string()))?,
90 )),
91 #[cfg(feature = "db-seaorm")]
92 DbDriver::SeaOrm(conn) => {
93 use sea_orm::TransactionTrait;
94 Ok(ArclyTransaction::SeaOrm(
95 conn.begin()
96 .await
97 .map_err(|e| DataError::connection(e.to_string()))?,
98 ))
99 }
100 #[cfg(feature = "db-diesel")]
101 DbDriver::Diesel(_) => Err(DataError::config(
102 "#[Transactional] is not supported on sync Diesel pools — \
103 run the whole transaction inside DieselBlockingPool::transaction(…)",
104 )),
105 #[allow(unreachable_patterns)]
106 _ => Err(DataError::config("no database driver feature enabled")),
107 }
108 }
109}
110
111tokio::task_local! {
114 static CURRENT_TX: RefCell<Option<ArclyTransaction>>;
116}
117
118pub async fn with_current_tx<R, F, Fut>(work: F) -> Result<Option<R>, DataError>
125where
126 F: FnOnce(ArclyTransaction) -> Fut,
127 Fut: std::future::Future<Output = (ArclyTransaction, Result<R, DataError>)>,
128{
129 let taken = CURRENT_TX
131 .try_with(|slot| slot.borrow_mut().take())
132 .ok()
133 .flatten();
134
135 let Some(tx) = taken else { return Ok(None) };
136
137 let (tx, result) = work(tx).await;
138
139 let _ = CURRENT_TX.try_with(|slot| *slot.borrow_mut() = Some(tx));
141 result.map(Some)
142}
143
144pub fn in_transaction() -> bool {
146 CURRENT_TX
147 .try_with(|slot| slot.borrow().is_some())
148 .unwrap_or(false)
149}
150
151#[doc(hidden)]
160pub async fn run_transactional<T, Fut>(ctx: &RequestContext, body: Fut) -> Result<T, HttpException>
161where
162 Fut: std::future::Future<Output = Result<T, HttpException>>,
163{
164 let registry = ctx
165 .try_inject::<crate::data::DataSourceRegistry<ArclyDbPool>>()
166 .ok_or_else(|| {
167 Internal::new(
168 "#[Transactional] requires DataSourceRegistry<ArclyDbPool> in the DI container",
169 )
170 })?;
171
172 let pool = registry.for_tenant(ctx.tenant());
173 let tx = pool
174 .begin()
175 .await
176 .map_err(|e| Internal::new(format!("failed to begin transaction: {e}")))?;
177
178 CURRENT_TX
179 .scope(RefCell::new(Some(tx)), async move {
180 let outcome = body.await;
181
182 let tx = CURRENT_TX
183 .try_with(|slot| slot.borrow_mut().take())
184 .ok()
185 .flatten();
186
187 match (outcome, tx) {
188 (Ok(v), Some(tx)) => {
189 tx.commit()
190 .await
191 .map_err(|e| Internal::new(format!("commit failed: {e}")))?;
192 Ok(v)
193 }
194 (Ok(v), None) => Ok(v),
196 (Err(e), Some(tx)) => {
197 if let Err(rb) = tx.rollback().await {
198 tracing::error!(error = %rb, "rollback failed after handler error");
199 }
200 Err(e)
201 }
202 (Err(e), None) => Err(e),
203 }
204 })
205 .await
206}