1use std::{marker::PhantomData, ops::Deref};
2
3use axum_core::{RequestPartsExt, extract::FromRequestParts, response::IntoResponse};
4use bb8_postgres::tokio_postgres::Transaction;
5use http::{Method, request::Parts};
6use parking_lot::{ArcMutexGuard, RawMutex};
7
8use super::extension::{Extension, LazyTx};
9
10pub struct Tx<E = super::Error> {
11 tx: ArcMutexGuard<RawMutex, LazyTx>,
12 _error: PhantomData<E>,
13}
14
15impl<E> AsRef<Transaction<'static>> for Tx<E> {
16 fn as_ref(&self) -> &Transaction<'static> {
17 self.tx.as_ref()
18 }
19}
20
21impl<E> Deref for Tx<E> {
22 type Target = Transaction<'static>;
23 fn deref(&self) -> &Self::Target {
24 self.tx.deref()
25 }
26}
27
28impl<S, E> FromRequestParts<S> for Tx<E>
29where
30 S: Sync,
31 E: From<super::Error> + IntoResponse,
32{
33 type Rejection = E;
34 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35 let Ok(method) = parts.extract::<Method>().await;
36 let ext: &Extension = parts
37 .extensions
38 .get()
39 .ok_or(super::Error::MissingExtension)?;
40 let tx = ext
41 .get(method == Method::GET || method == Method::HEAD || method == Method::OPTIONS)
42 .await?;
43 Ok(Self {
44 tx,
45 _error: PhantomData,
46 })
47 }
48}