axum_sqlx_tx/
tx.rs

1//! A request extension that enables the [`Tx`](crate::Tx) extractor.
2
3use std::{fmt, marker::PhantomData};
4
5use axum_core::{
6    extract::{FromRef, FromRequestParts},
7    response::IntoResponse,
8};
9use futures_core::{future::BoxFuture, stream::BoxStream};
10use http::request::Parts;
11use parking_lot::{lock_api::ArcMutexGuard, RawMutex};
12
13use crate::{
14    extension::{Extension, LazyTransaction},
15    Config, Error, Marker, State,
16};
17
18/// An `axum` extractor for a database transaction.
19///
20/// `&mut Tx` implements [`sqlx::Executor`] so it can be used directly with [`sqlx::query()`]
21/// (and [`sqlx::query_as()`], the corresponding macros, etc.):
22///
23/// ```
24/// use axum_sqlx_tx::Tx;
25/// use sqlx::Sqlite;
26///
27/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
28///     sqlx::query("...").execute(&mut tx).await?;
29///     /* ... */
30/// #   Ok(())
31/// }
32/// ```
33///
34/// It also implements `Deref<Target = `[`sqlx::Transaction`]`>` and `DerefMut`, so you can call
35/// methods from `Transaction` and its traits:
36///
37/// ```
38/// use axum_sqlx_tx::Tx;
39/// use sqlx::{Acquire as _, Sqlite};
40///
41/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
42///     let inner = tx.begin().await?;
43///     /* ... */
44/// #   Ok(())
45/// }
46/// ```
47///
48/// The `E` generic parameter controls the error type returned when the extractor fails. This can be
49/// used to configure the error response returned when the extractor fails:
50///
51/// ```
52/// use axum::response::IntoResponse;
53/// use axum_sqlx_tx::Tx;
54/// use sqlx::Sqlite;
55///
56/// struct MyError(axum_sqlx_tx::Error);
57///
58/// // The error type must implement From<axum_sqlx_tx::Error>
59/// impl From<axum_sqlx_tx::Error> for MyError {
60///     fn from(error: axum_sqlx_tx::Error) -> Self {
61///         Self(error)
62///     }
63/// }
64///
65/// // The error type must implement IntoResponse
66/// impl IntoResponse for MyError {
67///     fn into_response(self) -> axum::response::Response {
68///         (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
69///     }
70/// }
71///
72/// async fn handler(tx: Tx<Sqlite, MyError>) {
73///     /* ... */
74/// }
75/// ```
76pub struct Tx<DB: Marker, E = Error> {
77    tx: ArcMutexGuard<RawMutex, LazyTransaction<DB>>,
78    _error: PhantomData<E>,
79}
80
81impl<DB: Marker, E> Tx<DB, E> {
82    /// Crate a [`State`] and [`Layer`](crate::Layer) to enable the extractor.
83    ///
84    /// This is convenient to use from a type alias, e.g.
85    ///
86    /// ```
87    /// # async fn foo() {
88    /// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
89    ///
90    /// let pool: sqlx::SqlitePool = todo!();
91    /// let (state, layer) = Tx::setup(pool);
92    /// # }
93    /// ```
94    pub fn setup(pool: sqlx::Pool<DB::Driver>) -> (State<DB>, crate::Layer<DB, Error>) {
95        Config::new(pool).setup()
96    }
97
98    /// Configure extractor behaviour.
99    ///
100    /// See the [`Config`] API for available options.
101    ///
102    /// This is convenient to use from a type alias, e.g.
103    ///
104    /// ```
105    /// # async fn foo() {
106    /// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
107    ///
108    /// # let pool: sqlx::SqlitePool = todo!();
109    /// let config = Tx::config(pool);
110    /// # }
111    /// ```
112    pub fn config(pool: sqlx::Pool<DB::Driver>) -> Config<DB, Error> {
113        Config::new(pool)
114    }
115
116    /// Explicitly commit the transaction.
117    ///
118    /// By default, the transaction will be committed when a successful response is returned
119    /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX` or
120    /// `3XX` response). This method allows the transaction to be committed explicitly.
121    ///
122    /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently
123    /// generate [`Error::OverlappingExtractors`] errors. This may change in future.
124    pub async fn commit(mut self) -> Result<(), sqlx::Error> {
125        self.tx.commit().await
126    }
127}
128
129impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        f.debug_struct("Tx").finish_non_exhaustive()
132    }
133}
134
135impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
136    fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
137        self.tx.as_ref()
138    }
139}
140
141impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
142    fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
143        self.tx.as_mut()
144    }
145}
146
147impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
148    type Target = sqlx::Transaction<'static, DB::Driver>;
149
150    fn deref(&self) -> &Self::Target {
151        self.tx.as_ref()
152    }
153}
154
155impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
156    fn deref_mut(&mut self) -> &mut Self::Target {
157        self.tx.as_mut()
158    }
159}
160
161impl<DB: Marker, S, E> FromRequestParts<S> for Tx<DB, E>
162where
163    S: Sync,
164    E: From<Error> + IntoResponse + Send,
165    State<DB>: FromRef<S>,
166{
167    type Rejection = E;
168
169    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
170        let ext: &Extension<DB> = parts.extensions.get().ok_or(Error::MissingExtension)?;
171
172        let tx = ext.acquire().await?;
173
174        Ok(Self {
175            tx,
176            _error: PhantomData,
177        })
178    }
179}
180
181impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx<DB, E>
182where
183    DB: Marker,
184    for<'t> &'t mut <DB::Driver as sqlx::Database>::Connection:
185        sqlx::Executor<'t, Database = DB::Driver>,
186    E: std::fmt::Debug + Send,
187{
188    type Database = DB::Driver;
189
190    #[allow(clippy::type_complexity)]
191    fn fetch_many<'e, 'q: 'e, Q>(
192        self,
193        query: Q,
194    ) -> BoxStream<
195        'e,
196        Result<
197            sqlx::Either<
198                <Self::Database as sqlx::Database>::QueryResult,
199                <Self::Database as sqlx::Database>::Row,
200            >,
201            sqlx::Error,
202        >,
203    >
204    where
205        'c: 'e,
206        Q: sqlx::Execute<'q, Self::Database> + 'q,
207    {
208        (&mut ***self).fetch_many(query)
209    }
210
211    fn fetch_optional<'e, 'q: 'e, Q>(
212        self,
213        query: Q,
214    ) -> BoxFuture<'e, Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>>
215    where
216        'c: 'e,
217        Q: sqlx::Execute<'q, Self::Database> + 'q,
218    {
219        (&mut ***self).fetch_optional(query)
220    }
221
222    fn prepare_with<'e, 'q: 'e>(
223        self,
224        sql: &'q str,
225        parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
226    ) -> BoxFuture<'e, Result<<Self::Database as sqlx::Database>::Statement<'q>, sqlx::Error>>
227    where
228        'c: 'e,
229    {
230        (&mut ***self).prepare_with(sql, parameters)
231    }
232
233    fn describe<'e, 'q: 'e>(
234        self,
235        sql: &'q str,
236    ) -> BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
237    where
238        'c: 'e,
239    {
240        (&mut ***self).describe(sql)
241    }
242}