axum_sqlx_tx/
layer.rs

1//! A [`tower_layer::Layer`] that enables the [`Tx`](crate::Tx) extractor.
2
3use std::marker::PhantomData;
4
5use axum_core::response::IntoResponse;
6use bytes::Bytes;
7use futures_core::future::BoxFuture;
8use http_body::Body;
9
10use crate::{extension::Extension, Marker, State};
11
12/// A [`tower_layer::Layer`] that enables the [`Tx`] extractor.
13///
14/// This layer adds a lazily-initialised transaction to the [request extensions]. The first time the
15/// [`Tx`] extractor is used on a request, a connection is acquired from the configured
16/// [`sqlx::Pool`] and a transaction is started on it. The same transaction will be returned for
17/// subsequent uses of [`Tx`] on the same request. The inner service is then called as normal. Once
18/// the inner service responds, the transaction is committed or rolled back depending on the status
19/// code of the response.
20///
21/// [`Tx`]: crate::Tx
22/// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html
23pub struct Layer<DB: Marker, E> {
24    state: State<DB>,
25    _error: PhantomData<E>,
26}
27
28impl<DB: Marker, E> Layer<DB, E>
29where
30    E: IntoResponse,
31    sqlx::Error: Into<E>,
32{
33    pub(crate) fn new(state: State<DB>) -> Self {
34        Self {
35            state,
36            _error: PhantomData,
37        }
38    }
39}
40
41impl<DB: Marker, E> Clone for Layer<DB, E> {
42    fn clone(&self) -> Self {
43        Self {
44            state: self.state.clone(),
45            _error: self._error,
46        }
47    }
48}
49
50impl<DB: Marker, S, E> tower_layer::Layer<S> for Layer<DB, E>
51where
52    E: IntoResponse,
53    sqlx::Error: Into<E>,
54{
55    type Service = Service<DB, S, E>;
56
57    fn layer(&self, inner: S) -> Self::Service {
58        Service {
59            state: self.state.clone(),
60            inner,
61            _error: self._error,
62        }
63    }
64}
65
66/// A [`tower_service::Service`] that enables the [`Tx`](crate::Tx) extractor.
67///
68/// See [`Layer`] for more information.
69pub struct Service<DB: Marker, S, E> {
70    state: State<DB>,
71    inner: S,
72    _error: PhantomData<E>,
73}
74
75// can't simply derive because `DB` isn't `Clone`
76impl<DB: Marker, S: Clone, E> Clone for Service<DB, S, E> {
77    fn clone(&self) -> Self {
78        Self {
79            state: self.state.clone(),
80            inner: self.inner.clone(),
81            _error: self._error,
82        }
83    }
84}
85
86impl<DB: Marker, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
87    for Service<DB, S, E>
88where
89    S: tower_service::Service<
90        http::Request<ReqBody>,
91        Response = http::Response<ResBody>,
92        Error = std::convert::Infallible,
93    >,
94    S::Future: Send + 'static,
95    E: IntoResponse,
96    sqlx::Error: Into<E>,
97    ResBody: Body<Data = Bytes> + Send + 'static,
98    ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
99{
100    type Response = http::Response<axum_core::body::Body>;
101    type Error = S::Error;
102    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
103
104    fn poll_ready(
105        &mut self,
106        cx: &mut std::task::Context<'_>,
107    ) -> std::task::Poll<Result<(), Self::Error>> {
108        self.inner.poll_ready(cx).map_err(|err| match err {})
109    }
110
111    fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
112        let ext = Extension::new(self.state.clone());
113        req.extensions_mut().insert(ext.clone());
114
115        let res = self.inner.call(req);
116
117        Box::pin(async move {
118            let res = res.await.unwrap(); // inner service is infallible
119
120            if !res.status().is_server_error() && !res.status().is_client_error() {
121                if let Err(error) = ext.resolve().await {
122                    return Ok(error.into().into_response());
123                }
124            }
125
126            Ok(res.map(axum_core::body::Body::new))
127        })
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use tokio::net::TcpListener;
134
135    use crate::{Error, State};
136
137    use super::Layer;
138
139    // The trait shenanigans required by axum for layers are significant, so this "test" ensures
140    // we've got it right.
141    #[allow(unused, unreachable_code, clippy::diverging_sub_expression)]
142    fn layer_compiles() {
143        let state: State<sqlx::Sqlite> = todo!();
144
145        let layer = Layer::<_, Error>::new(state);
146
147        let app = axum::Router::new()
148            .route("/", axum::routing::get(|| async { "hello" }))
149            .layer(layer);
150
151        let listener: TcpListener = todo!();
152        axum::serve(listener, app);
153    }
154}