actix_sqlx_tx/
middleware.rs1use std::{marker::PhantomData, rc::Rc};
2
3use actix_utils::future::{ready, Ready};
4use actix_web::{
5 body::MessageBody,
6 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
7 HttpMessage, ResponseError,
8};
9use futures_core::future::LocalBoxFuture;
10
11use crate::{tx::TxSlot, Error};
12
13#[derive(Clone)]
24pub struct TransactionMiddleware<DB: sqlx::Database, E = Error> {
25 pool: Rc<sqlx::Pool<DB>>,
26 _error: PhantomData<E>,
27}
28
29impl<DB: sqlx::Database> TransactionMiddleware<DB> {
30 pub fn new(pool: sqlx::Pool<DB>) -> Self {
31 Self::new_with_error(pool)
32 }
33
34 pub fn new_with_error<E>(pool: sqlx::Pool<DB>) -> TransactionMiddleware<DB, E> {
36 TransactionMiddleware {
37 pool: Rc::new(pool),
38 _error: PhantomData,
39 }
40 }
41}
42
43impl<S, B, DB, E> Transform<S, ServiceRequest> for TransactionMiddleware<DB, E>
44where
45 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
46 S::Future: 'static,
47 B: MessageBody + 'static,
48 DB: sqlx::Database + 'static,
49 E: From<Error> + ResponseError + 'static,
50{
51 type Response = ServiceResponse<B>;
52
53 type Error = actix_web::Error;
54
55 type Transform = InnerTransactionMiddleware<S, DB, E>;
56
57 type InitError = ();
58
59 type Future = Ready<Result<Self::Transform, Self::InitError>>;
60
61 fn new_transform(&self, service: S) -> Self::Future {
62 ready(Ok(InnerTransactionMiddleware {
63 service: Rc::new(service),
64 pool: Rc::clone(&self.pool),
65 _error: self._error,
66 }))
67 }
68}
69
70#[doc(hidden)]
71pub struct InnerTransactionMiddleware<S, DB: sqlx::Database, E = Error> {
72 service: Rc<S>,
73 pool: Rc<sqlx::Pool<DB>>,
74 _error: PhantomData<E>,
75}
76
77impl<DB: sqlx::Database, S: Clone, E> Clone for InnerTransactionMiddleware<S, DB, E> {
78 fn clone(&self) -> Self {
79 Self {
80 pool: self.pool.clone(),
81 service: self.service.clone(),
82 _error: self._error,
83 }
84 }
85}
86
87impl<S, B, DB, E> Service<ServiceRequest> for InnerTransactionMiddleware<S, DB, E>
88where
89 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
90 S::Future: 'static,
91 DB: sqlx::Database + 'static,
92 E: From<Error> + ResponseError + 'static,
93{
94 type Response = ServiceResponse<B>;
95
96 type Error = actix_web::Error;
97
98 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
99
100 forward_ready!(service);
101
102 fn call(&self, req: ServiceRequest) -> Self::Future {
103 let transaction = TxSlot::bind(&mut req.extensions_mut(), &self.pool);
104 let srv = Rc::clone(&self.service);
105 let res = srv.call(req);
106
107 Box::pin(async move {
108 let res = res.await?;
109
110 if res.status().is_success() {
111 if let Err(error) = transaction.commit().await {
112 return Err(E::from(Error::Database { error }).into());
113 }
114 }
115 Ok(res)
116 })
117 }
118}