1use std::marker::PhantomData;
4
5use actix_web::{dev::Extensions, FromRequest, HttpMessage, ResponseError};
6use futures_core::future::LocalBoxFuture;
7use sqlx::Transaction;
8
9use crate::{
10 error::Error,
11 slot::{Lease, Slot},
12};
13
14#[derive(Debug)]
82pub struct Tx<DB: sqlx::Database, E = Error>(Lease<sqlx::Transaction<'static, DB>>, PhantomData<E>);
83
84impl<DB: sqlx::Database, E> Tx<DB, E> {
85 pub async fn commit(self) -> Result<(), sqlx::Error> {
94 self.0.steal().commit().await
95 }
96}
97
98impl<DB: sqlx::Database, E> AsRef<sqlx::Transaction<'static, DB>> for Tx<DB, E> {
99 fn as_ref(&self) -> &sqlx::Transaction<'static, DB> {
100 &self.0
101 }
102}
103
104impl<DB: sqlx::Database, E> AsMut<sqlx::Transaction<'static, DB>> for Tx<DB, E> {
105 fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB> {
106 &mut self.0
107 }
108}
109
110impl<DB: sqlx::Database, E> std::ops::Deref for Tx<DB, E> {
111 type Target = sqlx::Transaction<'static, DB>;
112
113 fn deref(&self) -> &Self::Target {
114 &self.0
115 }
116}
117
118impl<DB: sqlx::Database, E> std::ops::DerefMut for Tx<DB, E> {
119 fn deref_mut(&mut self) -> &mut Self::Target {
120 &mut self.0
121 }
122}
123
124impl<DB: sqlx::Database, E> FromRequest for Tx<DB, E>
125where
126 E: From<Error> + ResponseError + 'static,
127{
128 type Error = E;
129
130 type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
131
132 #[inline]
133 fn from_request(req: &actix_web::HttpRequest, _: &mut actix_web::dev::Payload) -> Self::Future {
134 let req = req.clone();
135 Box::pin(async move {
136 let mut ext = req
138 .extensions_mut()
139 .remove::<Lazy<DB>>()
140 .ok_or(Error::MissingExtension)?;
141
142 let tx = ext.get_or_begin().await?;
143
144 Ok(Self(tx, PhantomData))
145 })
146 }
147}
148
149pub(crate) struct TxSlot<DB: sqlx::Database>(Slot<Option<Slot<Transaction<'static, DB>>>>);
151
152impl<DB: sqlx::Database> TxSlot<DB> {
153 pub(crate) fn bind(extensions: &mut Extensions, pool: &sqlx::Pool<DB>) -> Self {
158 let (slot, tx) = Slot::new_leased(None);
159 extensions.insert(Lazy {
160 pool: pool.clone(),
161 tx,
162 });
163 Self(slot)
164 }
165
166 pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
167 if let Some(tx) = self.0.into_inner().flatten().and_then(Slot::into_inner) {
168 tx.commit().await?;
169 }
170 Ok(())
171 }
172}
173
174struct Lazy<DB: sqlx::Database> {
179 pool: sqlx::Pool<DB>,
180 tx: Lease<Option<Slot<Transaction<'static, DB>>>>,
181}
182
183impl<DB: sqlx::Database> Lazy<DB> {
184 async fn get_or_begin(&mut self) -> Result<Lease<Transaction<'static, DB>>, Error> {
185 let tx = if let Some(tx) = self.tx.as_mut() {
186 tx
187 } else {
188 let tx = self.pool.begin().await?;
189 self.tx.insert(Slot::new(tx))
190 };
191
192 tx.lease().ok_or(Error::OverlappingExtractors)
193 }
194}
195
196#[cfg(any(
197 feature = "any",
198 feature = "mssql",
199 feature = "mysql",
200 feature = "postgres",
201 feature = "sqlite"
202))]
203mod sqlx_impls {
204 use std::fmt::Debug;
205
206 use futures_core::{future::BoxFuture, stream::BoxStream};
207
208 macro_rules! impl_executor {
209 ($db:path) => {
210 impl<'c, E: Debug + Send> sqlx::Executor<'c> for &'c mut super::Tx<$db, E> {
211 type Database = $db;
212
213 #[allow(clippy::type_complexity)]
214 fn fetch_many<'e, 'q: 'e, Q: 'q>(
215 self,
216 query: Q,
217 ) -> BoxStream<
218 'e,
219 Result<
220 sqlx::Either<
221 <Self::Database as sqlx::Database>::QueryResult,
222 <Self::Database as sqlx::Database>::Row,
223 >,
224 sqlx::Error,
225 >,
226 >
227 where
228 'c: 'e,
229 Q: sqlx::Execute<'q, Self::Database>,
230 {
231 (&mut **self).fetch_many(query)
232 }
233
234 fn fetch_optional<'e, 'q: 'e, Q: 'q>(
235 self,
236 query: Q,
237 ) -> BoxFuture<
238 'e,
239 Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>,
240 >
241 where
242 'c: 'e,
243 Q: sqlx::Execute<'q, Self::Database>,
244 {
245 (&mut **self).fetch_optional(query)
246 }
247
248 fn prepare_with<'e, 'q: 'e>(
249 self,
250 sql: &'q str,
251 parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
252 ) -> BoxFuture<
253 'e,
254 Result<
255 <Self::Database as sqlx::database::HasStatement<'q>>::Statement,
256 sqlx::Error,
257 >,
258 >
259 where
260 'c: 'e,
261 {
262 (&mut **self).prepare_with(sql, parameters)
263 }
264
265 fn describe<'e, 'q: 'e>(
266 self,
267 sql: &'q str,
268 ) -> BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
269 where
270 'c: 'e,
271 {
272 (&mut **self).describe(sql)
273 }
274 }
275 };
276 }
277
278 #[cfg(feature = "any")]
279 impl_executor!(sqlx::Any);
280
281 #[cfg(feature = "mssql")]
282 impl_executor!(sqlx::Mssql);
283
284 #[cfg(feature = "mysql")]
285 impl_executor!(sqlx::MySql);
286
287 #[cfg(feature = "postgres")]
288 impl_executor!(sqlx::Postgres);
289
290 #[cfg(feature = "sqlite")]
291 impl_executor!(sqlx::Sqlite);
292}