axum_postgres_tx/
layer.rs1use std::{
2 marker::PhantomData,
3 task::{Context, Poll},
4};
5
6use axum_core::{
7 extract::Request,
8 response::{IntoResponse, Response},
9};
10use bb8_postgres::tokio_postgres;
11use futures_core::future::BoxFuture;
12
13use super::{Pool, extension::Extension};
14
15pub struct Layer<E> {
16 pool: Pool,
17 _error: PhantomData<E>,
18}
19
20impl<E> Clone for Layer<E> {
21 fn clone(&self) -> Self {
22 Self {
23 pool: self.pool.clone(),
24 _error: self._error,
25 }
26 }
27}
28
29impl From<Pool> for Layer<super::Error> {
30 fn from(value: Pool) -> Self {
31 Self {
32 pool: value,
33 _error: PhantomData,
34 }
35 }
36}
37
38impl<S, E> tower_layer::Layer<S> for Layer<E> {
39 type Service = Service<S, E>;
40 fn layer(&self, inner: S) -> Self::Service {
41 Service {
42 pool: self.pool.clone(),
43 inner,
44 _error: self._error,
45 }
46 }
47}
48
49pub struct Service<S, E> {
50 pool: Pool,
51 inner: S,
52 _error: PhantomData<E>,
53}
54
55impl<S: Clone, E> Clone for Service<S, E> {
56 fn clone(&self) -> Self {
57 Self {
58 pool: self.pool.clone(),
59 inner: self.inner.clone(),
60 _error: self._error,
61 }
62 }
63}
64
65impl<S, E> tower_service::Service<Request> for Service<S, E>
66where
67 S: tower_service::Service<Request, Response = Response> + Send + 'static,
68 S::Future: Send + 'static,
69 E: From<tokio_postgres::Error> + IntoResponse,
70{
71 type Response = S::Response;
72 type Error = S::Error;
73 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74
75 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76 self.inner.poll_ready(cx)
77 }
78
79 fn call(&mut self, mut req: Request) -> Self::Future {
80 let ext = Extension::from(self.pool.clone());
81 req.extensions_mut().insert(ext.clone());
82
83 let res = self.inner.call(req);
84
85 Box::pin(async move {
86 let res = res.await?;
87
88 if !res.status().is_server_error()
89 && !res.status().is_client_error()
90 && let Err(err) = ext.commit().await
91 {
92 return Ok(E::from(err).into_response());
93 }
94
95 Ok(res)
96 })
97 }
98}