actix_sqlx_tx/
middleware.rs

1use std::{marker::PhantomData, rc::Rc};
2
3use actix_utils::future::{ready, Ready};
4use actix_web::{
5    body::MessageBody,
6    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
7    HttpMessage, ResponseError,
8};
9use futures_core::future::LocalBoxFuture;
10
11use crate::{tx::TxSlot, Error};
12
13/// This middleware adds a lazily-initialised transaction to the [request extensions]. The first time the
14/// [`Tx`] extractor is used on a request, a connection is acquired from the configured
15/// [`sqlx::Pool`] and a transaction is started on it. The same transaction will be returned for
16/// subsequent uses of [`Tx`] on the same request. The inner service is then called as normal. Once
17/// the inner service responds, the transaction is committed or rolled back depending on the status
18/// code of the response.
19///
20/// [`Tx`]: crate::Tx
21/// [request extensions]: https://docs.rs/actix-web/latest/actix_web/dev/struct.Extensions.html
22/// [refer to axum-sqlx-tx]: https://github.com/wasdacraic/axum-sqlx-tx
23#[derive(Clone)]
24pub struct TransactionMiddleware<DB: sqlx::Database, E = Error> {
25    pool: Rc<sqlx::Pool<DB>>,
26    _error: PhantomData<E>,
27}
28
29impl<DB: sqlx::Database> TransactionMiddleware<DB> {
30    pub fn new(pool: sqlx::Pool<DB>) -> Self {
31        Self::new_with_error(pool)
32    }
33
34    /// Construct a new layer with a specific error type.
35    pub fn new_with_error<E>(pool: sqlx::Pool<DB>) -> TransactionMiddleware<DB, E> {
36        TransactionMiddleware {
37            pool: Rc::new(pool),
38            _error: PhantomData,
39        }
40    }
41}
42
43impl<S, B, DB, E> Transform<S, ServiceRequest> for TransactionMiddleware<DB, E>
44where
45    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
46    S::Future: 'static,
47    B: MessageBody + 'static,
48    DB: sqlx::Database + 'static,
49    E: From<Error> + ResponseError + 'static,
50{
51    type Response = ServiceResponse<B>;
52
53    type Error = actix_web::Error;
54
55    type Transform = InnerTransactionMiddleware<S, DB, E>;
56
57    type InitError = ();
58
59    type Future = Ready<Result<Self::Transform, Self::InitError>>;
60
61    fn new_transform(&self, service: S) -> Self::Future {
62        ready(Ok(InnerTransactionMiddleware {
63            service: Rc::new(service),
64            pool: Rc::clone(&self.pool),
65            _error: self._error,
66        }))
67    }
68}
69
70#[doc(hidden)]
71pub struct InnerTransactionMiddleware<S, DB: sqlx::Database, E = Error> {
72    service: Rc<S>,
73    pool: Rc<sqlx::Pool<DB>>,
74    _error: PhantomData<E>,
75}
76
77impl<DB: sqlx::Database, S: Clone, E> Clone for InnerTransactionMiddleware<S, DB, E> {
78    fn clone(&self) -> Self {
79        Self {
80            pool: self.pool.clone(),
81            service: self.service.clone(),
82            _error: self._error,
83        }
84    }
85}
86
87impl<S, B, DB, E> Service<ServiceRequest> for InnerTransactionMiddleware<S, DB, E>
88where
89    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
90    S::Future: 'static,
91    DB: sqlx::Database + 'static,
92    E: From<Error> + ResponseError + 'static,
93{
94    type Response = ServiceResponse<B>;
95
96    type Error = actix_web::Error;
97
98    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
99
100    forward_ready!(service);
101
102    fn call(&self, req: ServiceRequest) -> Self::Future {
103        let transaction = TxSlot::bind(&mut req.extensions_mut(), &self.pool);
104        let srv = Rc::clone(&self.service);
105        let res = srv.call(req);
106
107        Box::pin(async move {
108            let res = res.await?;
109
110            if res.status().is_success() {
111                if let Err(error) = transaction.commit().await {
112                    return Err(E::from(Error::Database { error }).into());
113                }
114            }
115            Ok(res)
116        })
117    }
118}