actix_sqlx_tx/
tx.rs

1//! A request extension that enables the [`Tx`](crate::Tx) extractor.
2
3use std::marker::PhantomData;
4
5use actix_web::{dev::Extensions, FromRequest, HttpMessage, ResponseError};
6use futures_core::future::LocalBoxFuture;
7use sqlx::Transaction;
8
9use crate::{
10    error::Error,
11    slot::{Lease, Slot},
12};
13
14/// An `actix` extractor for a database transaction.
15///
16/// `&mut Tx` implements [`sqlx::Executor`] so it can be used directly with [`sqlx::query()`]
17/// (and [`sqlx::query_as()`], the corresponding macros, etc.):
18///
19/// ```
20/// use actix_sqlx_tx::Tx;
21/// use sqlx::Sqlite;
22///
23/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
24///     sqlx::query("...").execute(&mut tx).await?;
25///     /* ... */
26/// #   Ok(())
27/// }
28/// ```
29///
30/// It also implements `Deref<Target = `[`sqlx::Transaction`]`>` and `DerefMut`, so you can call
31/// methods from `Transaction` and its traits:
32///
33/// ```
34/// use actix_sqlx_tx::Tx;
35/// use sqlx::{Acquire as _, Sqlite};
36///
37/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
38///     let inner = tx.begin().await?;
39///     /* ... */
40/// #   Ok(())
41/// }
42/// ```
43///
44/// The `E` generic parameter controls the error type returned when the extractor fails. This can be
45/// used to configure the error response returned when the extractor fails:
46///
47/// ```
48/// use actix_web::{
49///     http::{header::ContentType, StatusCode},
50///     HttpResponse, ResponseError,
51/// };
52/// use actix_sqlx_tx::Tx;
53/// use sqlx::Sqlite;
54/// [derive(Debug)]
55/// struct MyError(actix_sqlx_tx::Error);
56///
57/// // The error type must implement From<actix_sqlx_tx::Error>
58/// impl From<actix_sqlx_tx::Error> for MyError {
59///     fn from(error: actix_sqlx_tx::Error) -> Self {
60///         Self(error)
61///     }
62/// }
63///
64/// // The error type must implement ResponseError
65/// impl ResponseError for MyError {
66///     fn error_response(&self) -> HttpResponse {
67///         HttpResponse::build(self.status_code())
68///             .insert_header(ContentType::html())
69///             .body(self.to_string())
70///     }
71///
72///     fn status_code(&self) -> StatusCode {
73///         StatusCode::INTERNAL_SERVER_ERROR
74///     }
75/// }
76///
77/// async fn handler(tx: Tx<Sqlite, MyError>) {
78///     /* ... */
79/// }
80/// ```
81#[derive(Debug)]
82pub struct Tx<DB: sqlx::Database, E = Error>(Lease<sqlx::Transaction<'static, DB>>, PhantomData<E>);
83
84impl<DB: sqlx::Database, E> Tx<DB, E> {
85    /// Explicitly commit the transaction.
86    ///
87    /// By default, the transaction will be committed when a successful response is returned
88    /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX`
89    /// response). This method allows the transaction to be committed explicitly.
90    ///
91    /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently
92    /// generate [`Error::OverlappingExtractors`] errors. This may change in future.
93    pub async fn commit(self) -> Result<(), sqlx::Error> {
94        self.0.steal().commit().await
95    }
96}
97
98impl<DB: sqlx::Database, E> AsRef<sqlx::Transaction<'static, DB>> for Tx<DB, E> {
99    fn as_ref(&self) -> &sqlx::Transaction<'static, DB> {
100        &self.0
101    }
102}
103
104impl<DB: sqlx::Database, E> AsMut<sqlx::Transaction<'static, DB>> for Tx<DB, E> {
105    fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB> {
106        &mut self.0
107    }
108}
109
110impl<DB: sqlx::Database, E> std::ops::Deref for Tx<DB, E> {
111    type Target = sqlx::Transaction<'static, DB>;
112
113    fn deref(&self) -> &Self::Target {
114        &self.0
115    }
116}
117
118impl<DB: sqlx::Database, E> std::ops::DerefMut for Tx<DB, E> {
119    fn deref_mut(&mut self) -> &mut Self::Target {
120        &mut self.0
121    }
122}
123
124impl<DB: sqlx::Database, E> FromRequest for Tx<DB, E>
125where
126    E: From<Error> + ResponseError + 'static,
127{
128    type Error = E;
129
130    type Future = LocalBoxFuture<'static, Result<Self, Self::Error>>;
131
132    #[inline]
133    fn from_request(req: &actix_web::HttpRequest, _: &mut actix_web::dev::Payload) -> Self::Future {
134        let req = req.clone();
135        Box::pin(async move {
136            // drop ext, or it will drop after request finish
137            let mut ext = req
138                .extensions_mut()
139                .remove::<Lazy<DB>>()
140                .ok_or(Error::MissingExtension)?;
141
142            let tx = ext.get_or_begin().await?;
143
144            Ok(Self(tx, PhantomData))
145        })
146    }
147}
148
149/// The OG `Slot` – the transaction (if any) returns here when the `Extension` is dropped.
150pub(crate) struct TxSlot<DB: sqlx::Database>(Slot<Option<Slot<Transaction<'static, DB>>>>);
151
152impl<DB: sqlx::Database> TxSlot<DB> {
153    /// Create a `TxSlot` bound to the given request extensions.
154    ///
155    /// When the request extensions are dropped, `commit` can be called to commit the transaction
156    /// (if any).
157    pub(crate) fn bind(extensions: &mut Extensions, pool: &sqlx::Pool<DB>) -> Self {
158        let (slot, tx) = Slot::new_leased(None);
159        extensions.insert(Lazy {
160            pool: pool.clone(),
161            tx,
162        });
163        Self(slot)
164    }
165
166    pub(crate) async fn commit(self) -> Result<(), sqlx::Error> {
167        if let Some(tx) = self.0.into_inner().flatten().and_then(Slot::into_inner) {
168            tx.commit().await?;
169        }
170        Ok(())
171    }
172}
173
174/// A lazily acquired transaction.
175///
176/// When the transaction is started, it's inserted into the `Option` leased from the `TxSlot`, so
177/// that when `Lazy` is dropped the transaction is moved to the `TxSlot`.
178struct Lazy<DB: sqlx::Database> {
179    pool: sqlx::Pool<DB>,
180    tx: Lease<Option<Slot<Transaction<'static, DB>>>>,
181}
182
183impl<DB: sqlx::Database> Lazy<DB> {
184    async fn get_or_begin(&mut self) -> Result<Lease<Transaction<'static, DB>>, Error> {
185        let tx = if let Some(tx) = self.tx.as_mut() {
186            tx
187        } else {
188            let tx = self.pool.begin().await?;
189            self.tx.insert(Slot::new(tx))
190        };
191
192        tx.lease().ok_or(Error::OverlappingExtractors)
193    }
194}
195
196#[cfg(any(
197    feature = "any",
198    feature = "mssql",
199    feature = "mysql",
200    feature = "postgres",
201    feature = "sqlite"
202))]
203mod sqlx_impls {
204    use std::fmt::Debug;
205
206    use futures_core::{future::BoxFuture, stream::BoxStream};
207
208    macro_rules! impl_executor {
209        ($db:path) => {
210            impl<'c, E: Debug + Send> sqlx::Executor<'c> for &'c mut super::Tx<$db, E> {
211                type Database = $db;
212
213                #[allow(clippy::type_complexity)]
214                fn fetch_many<'e, 'q: 'e, Q: 'q>(
215                    self,
216                    query: Q,
217                ) -> BoxStream<
218                    'e,
219                    Result<
220                        sqlx::Either<
221                            <Self::Database as sqlx::Database>::QueryResult,
222                            <Self::Database as sqlx::Database>::Row,
223                        >,
224                        sqlx::Error,
225                    >,
226                >
227                where
228                    'c: 'e,
229                    Q: sqlx::Execute<'q, Self::Database>,
230                {
231                    (&mut **self).fetch_many(query)
232                }
233
234                fn fetch_optional<'e, 'q: 'e, Q: 'q>(
235                    self,
236                    query: Q,
237                ) -> BoxFuture<
238                    'e,
239                    Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>,
240                >
241                where
242                    'c: 'e,
243                    Q: sqlx::Execute<'q, Self::Database>,
244                {
245                    (&mut **self).fetch_optional(query)
246                }
247
248                fn prepare_with<'e, 'q: 'e>(
249                    self,
250                    sql: &'q str,
251                    parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
252                ) -> BoxFuture<
253                    'e,
254                    Result<
255                        <Self::Database as sqlx::database::HasStatement<'q>>::Statement,
256                        sqlx::Error,
257                    >,
258                >
259                where
260                    'c: 'e,
261                {
262                    (&mut **self).prepare_with(sql, parameters)
263                }
264
265                fn describe<'e, 'q: 'e>(
266                    self,
267                    sql: &'q str,
268                ) -> BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
269                where
270                    'c: 'e,
271                {
272                    (&mut **self).describe(sql)
273                }
274            }
275        };
276    }
277
278    #[cfg(feature = "any")]
279    impl_executor!(sqlx::Any);
280
281    #[cfg(feature = "mssql")]
282    impl_executor!(sqlx::Mssql);
283
284    #[cfg(feature = "mysql")]
285    impl_executor!(sqlx::MySql);
286
287    #[cfg(feature = "postgres")]
288    impl_executor!(sqlx::Postgres);
289
290    #[cfg(feature = "sqlite")]
291    impl_executor!(sqlx::Sqlite);
292}