axum_sqlx_tx/tx.rs
1//! A request extension that enables the [`Tx`](crate::Tx) extractor.
2
3use std::{fmt, marker::PhantomData};
4
5use axum_core::{
6 extract::{FromRef, FromRequestParts},
7 response::IntoResponse,
8};
9use futures_core::{future::BoxFuture, stream::BoxStream};
10use http::request::Parts;
11use parking_lot::{lock_api::ArcMutexGuard, RawMutex};
12
13use crate::{
14 extension::{Extension, LazyTransaction},
15 Config, Error, Marker, State,
16};
17
18/// An `axum` extractor for a database transaction.
19///
20/// `&mut Tx` implements [`sqlx::Executor`] so it can be used directly with [`sqlx::query()`]
21/// (and [`sqlx::query_as()`], the corresponding macros, etc.):
22///
23/// ```
24/// use axum_sqlx_tx::Tx;
25/// use sqlx::Sqlite;
26///
27/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
28/// sqlx::query("...").execute(&mut tx).await?;
29/// /* ... */
30/// # Ok(())
31/// }
32/// ```
33///
34/// It also implements `Deref<Target = `[`sqlx::Transaction`]`>` and `DerefMut`, so you can call
35/// methods from `Transaction` and its traits:
36///
37/// ```
38/// use axum_sqlx_tx::Tx;
39/// use sqlx::{Acquire as _, Sqlite};
40///
41/// async fn handler(mut tx: Tx<Sqlite>) -> Result<(), sqlx::Error> {
42/// let inner = tx.begin().await?;
43/// /* ... */
44/// # Ok(())
45/// }
46/// ```
47///
48/// The `E` generic parameter controls the error type returned when the extractor fails. This can be
49/// used to configure the error response returned when the extractor fails:
50///
51/// ```
52/// use axum::response::IntoResponse;
53/// use axum_sqlx_tx::Tx;
54/// use sqlx::Sqlite;
55///
56/// struct MyError(axum_sqlx_tx::Error);
57///
58/// // The error type must implement From<axum_sqlx_tx::Error>
59/// impl From<axum_sqlx_tx::Error> for MyError {
60/// fn from(error: axum_sqlx_tx::Error) -> Self {
61/// Self(error)
62/// }
63/// }
64///
65/// // The error type must implement IntoResponse
66/// impl IntoResponse for MyError {
67/// fn into_response(self) -> axum::response::Response {
68/// (http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response()
69/// }
70/// }
71///
72/// async fn handler(tx: Tx<Sqlite, MyError>) {
73/// /* ... */
74/// }
75/// ```
76pub struct Tx<DB: Marker, E = Error> {
77 tx: ArcMutexGuard<RawMutex, LazyTransaction<DB>>,
78 _error: PhantomData<E>,
79}
80
81impl<DB: Marker, E> Tx<DB, E> {
82 /// Crate a [`State`] and [`Layer`](crate::Layer) to enable the extractor.
83 ///
84 /// This is convenient to use from a type alias, e.g.
85 ///
86 /// ```
87 /// # async fn foo() {
88 /// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
89 ///
90 /// let pool: sqlx::SqlitePool = todo!();
91 /// let (state, layer) = Tx::setup(pool);
92 /// # }
93 /// ```
94 pub fn setup(pool: sqlx::Pool<DB::Driver>) -> (State<DB>, crate::Layer<DB, Error>) {
95 Config::new(pool).setup()
96 }
97
98 /// Configure extractor behaviour.
99 ///
100 /// See the [`Config`] API for available options.
101 ///
102 /// This is convenient to use from a type alias, e.g.
103 ///
104 /// ```
105 /// # async fn foo() {
106 /// type Tx = axum_sqlx_tx::Tx<sqlx::Sqlite>;
107 ///
108 /// # let pool: sqlx::SqlitePool = todo!();
109 /// let config = Tx::config(pool);
110 /// # }
111 /// ```
112 pub fn config(pool: sqlx::Pool<DB::Driver>) -> Config<DB, Error> {
113 Config::new(pool)
114 }
115
116 /// Explicitly commit the transaction.
117 ///
118 /// By default, the transaction will be committed when a successful response is returned
119 /// (specifically, when the [`Service`](crate::Service) middleware intercepts an HTTP `2XX` or
120 /// `3XX` response). This method allows the transaction to be committed explicitly.
121 ///
122 /// **Note:** trying to use the `Tx` extractor again after calling `commit` will currently
123 /// generate [`Error::OverlappingExtractors`] errors. This may change in future.
124 pub async fn commit(mut self) -> Result<(), sqlx::Error> {
125 self.tx.commit().await
126 }
127}
128
129impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 f.debug_struct("Tx").finish_non_exhaustive()
132 }
133}
134
135impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
136 fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
137 self.tx.as_ref()
138 }
139}
140
141impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
142 fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
143 self.tx.as_mut()
144 }
145}
146
147impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
148 type Target = sqlx::Transaction<'static, DB::Driver>;
149
150 fn deref(&self) -> &Self::Target {
151 self.tx.as_ref()
152 }
153}
154
155impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
156 fn deref_mut(&mut self) -> &mut Self::Target {
157 self.tx.as_mut()
158 }
159}
160
161impl<DB: Marker, S, E> FromRequestParts<S> for Tx<DB, E>
162where
163 S: Sync,
164 E: From<Error> + IntoResponse + Send,
165 State<DB>: FromRef<S>,
166{
167 type Rejection = E;
168
169 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
170 let ext: &Extension<DB> = parts.extensions.get().ok_or(Error::MissingExtension)?;
171
172 let tx = ext.acquire().await?;
173
174 Ok(Self {
175 tx,
176 _error: PhantomData,
177 })
178 }
179}
180
181impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx<DB, E>
182where
183 DB: Marker,
184 for<'t> &'t mut <DB::Driver as sqlx::Database>::Connection:
185 sqlx::Executor<'t, Database = DB::Driver>,
186 E: std::fmt::Debug + Send,
187{
188 type Database = DB::Driver;
189
190 #[allow(clippy::type_complexity)]
191 fn fetch_many<'e, 'q: 'e, Q>(
192 self,
193 query: Q,
194 ) -> BoxStream<
195 'e,
196 Result<
197 sqlx::Either<
198 <Self::Database as sqlx::Database>::QueryResult,
199 <Self::Database as sqlx::Database>::Row,
200 >,
201 sqlx::Error,
202 >,
203 >
204 where
205 'c: 'e,
206 Q: sqlx::Execute<'q, Self::Database> + 'q,
207 {
208 (&mut ***self).fetch_many(query)
209 }
210
211 fn fetch_optional<'e, 'q: 'e, Q>(
212 self,
213 query: Q,
214 ) -> BoxFuture<'e, Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>>
215 where
216 'c: 'e,
217 Q: sqlx::Execute<'q, Self::Database> + 'q,
218 {
219 (&mut ***self).fetch_optional(query)
220 }
221
222 fn prepare_with<'e, 'q: 'e>(
223 self,
224 sql: &'q str,
225 parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
226 ) -> BoxFuture<'e, Result<<Self::Database as sqlx::Database>::Statement<'q>, sqlx::Error>>
227 where
228 'c: 'e,
229 {
230 (&mut ***self).prepare_with(sql, parameters)
231 }
232
233 fn describe<'e, 'q: 'e>(
234 self,
235 sql: &'q str,
236 ) -> BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
237 where
238 'c: 'e,
239 {
240 (&mut ***self).describe(sql)
241 }
242}