axum_postgres_tx/
layer.rs

1use std::{
2    marker::PhantomData,
3    task::{Context, Poll},
4};
5
6use axum_core::{
7    extract::Request,
8    response::{IntoResponse, Response},
9};
10use bb8_postgres::tokio_postgres;
11use futures_core::future::BoxFuture;
12
13use super::{Pool, extension::Extension};
14
15pub struct Layer<E> {
16    pool: Pool,
17    _error: PhantomData<E>,
18}
19
20impl<E> Clone for Layer<E> {
21    fn clone(&self) -> Self {
22        Self {
23            pool: self.pool.clone(),
24            _error: self._error,
25        }
26    }
27}
28
29impl From<Pool> for Layer<super::Error> {
30    fn from(value: Pool) -> Self {
31        Self {
32            pool: value,
33            _error: PhantomData,
34        }
35    }
36}
37
38impl<S, E> tower_layer::Layer<S> for Layer<E> {
39    type Service = Service<S, E>;
40    fn layer(&self, inner: S) -> Self::Service {
41        Service {
42            pool: self.pool.clone(),
43            inner,
44            _error: self._error,
45        }
46    }
47}
48
49pub struct Service<S, E> {
50    pool: Pool,
51    inner: S,
52    _error: PhantomData<E>,
53}
54
55impl<S: Clone, E> Clone for Service<S, E> {
56    fn clone(&self) -> Self {
57        Self {
58            pool: self.pool.clone(),
59            inner: self.inner.clone(),
60            _error: self._error,
61        }
62    }
63}
64
65impl<S, E> tower_service::Service<Request> for Service<S, E>
66where
67    S: tower_service::Service<Request, Response = Response> + Send + 'static,
68    S::Future: Send + 'static,
69    E: From<tokio_postgres::Error> + IntoResponse,
70{
71    type Response = S::Response;
72    type Error = S::Error;
73    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
74
75    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
76        self.inner.poll_ready(cx)
77    }
78
79    fn call(&mut self, mut req: Request) -> Self::Future {
80        let ext = Extension::from(self.pool.clone());
81        req.extensions_mut().insert(ext.clone());
82
83        let res = self.inner.call(req);
84
85        Box::pin(async move {
86            let res = res.await?;
87
88            if !res.status().is_server_error()
89                && !res.status().is_client_error()
90                && let Err(err) = ext.commit().await
91            {
92                return Ok(E::from(err).into_response());
93            }
94
95            Ok(res)
96        })
97    }
98}