use std::{fmt, marker::PhantomData};
use axum_core::{
extract::{FromRef, FromRequestParts},
response::IntoResponse,
};
use futures_core::{future::BoxFuture, stream::BoxStream};
use http::request::Parts;
use parking_lot::{lock_api::ArcMutexGuard, RawMutex};
use crate::{
extension::{Extension, LazyTransaction},
Config, Error, Marker, State,
};
pub struct Tx<DB: Marker, E = Error> {
tx: ArcMutexGuard<RawMutex, LazyTransaction<DB>>,
_error: PhantomData<E>,
}
impl<DB: Marker, E> Tx<DB, E> {
pub fn setup(pool: sqlx::Pool<DB::Driver>) -> (State<DB>, crate::Layer<DB, Error>) {
Config::new(pool).setup()
}
pub fn config(pool: sqlx::Pool<DB::Driver>) -> Config<DB, Error> {
Config::new(pool)
}
pub async fn commit(mut self) -> Result<(), sqlx::Error> {
self.tx.commit().await
}
}
impl<DB: Marker, E> fmt::Debug for Tx<DB, E> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tx").finish_non_exhaustive()
}
}
impl<DB: Marker, E> AsRef<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_ref(&self) -> &sqlx::Transaction<'static, DB::Driver> {
self.tx.as_ref()
}
}
impl<DB: Marker, E> AsMut<sqlx::Transaction<'static, DB::Driver>> for Tx<DB, E> {
fn as_mut(&mut self) -> &mut sqlx::Transaction<'static, DB::Driver> {
self.tx.as_mut()
}
}
impl<DB: Marker, E> std::ops::Deref for Tx<DB, E> {
type Target = sqlx::Transaction<'static, DB::Driver>;
fn deref(&self) -> &Self::Target {
self.tx.as_ref()
}
}
impl<DB: Marker, E> std::ops::DerefMut for Tx<DB, E> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tx.as_mut()
}
}
impl<DB: Marker, S, E> FromRequestParts<S> for Tx<DB, E>
where
S: Sync,
E: From<Error> + IntoResponse + Send,
State<DB>: FromRef<S>,
{
type Rejection = E;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let ext: &Extension<DB> = parts.extensions.get().ok_or(Error::MissingExtension)?;
let tx = ext.acquire().await?;
Ok(Self {
tx,
_error: PhantomData,
})
}
}
impl<'c, DB, E> sqlx::Executor<'c> for &'c mut Tx<DB, E>
where
DB: Marker,
for<'t> &'t mut <DB::Driver as sqlx::Database>::Connection:
sqlx::Executor<'t, Database = DB::Driver>,
E: std::fmt::Debug + Send,
{
type Database = DB::Driver;
#[allow(clippy::type_complexity)]
fn fetch_many<'e, 'q: 'e, Q>(
self,
query: Q,
) -> BoxStream<
'e,
Result<
sqlx::Either<
<Self::Database as sqlx::Database>::QueryResult,
<Self::Database as sqlx::Database>::Row,
>,
sqlx::Error,
>,
>
where
'c: 'e,
Q: sqlx::Execute<'q, Self::Database> + 'q,
{
(&mut ***self).fetch_many(query)
}
fn fetch_optional<'e, 'q: 'e, Q>(
self,
query: Q,
) -> BoxFuture<'e, Result<Option<<Self::Database as sqlx::Database>::Row>, sqlx::Error>>
where
'c: 'e,
Q: sqlx::Execute<'q, Self::Database> + 'q,
{
(&mut ***self).fetch_optional(query)
}
fn prepare_with<'e, 'q: 'e>(
self,
sql: &'q str,
parameters: &'e [<Self::Database as sqlx::Database>::TypeInfo],
) -> BoxFuture<'e, Result<<Self::Database as sqlx::Database>::Statement<'q>, sqlx::Error>>
where
'c: 'e,
{
(&mut ***self).prepare_with(sql, parameters)
}
fn describe<'e, 'q: 'e>(
self,
sql: &'q str,
) -> BoxFuture<'e, Result<sqlx::Describe<Self::Database>, sqlx::Error>>
where
'c: 'e,
{
(&mut ***self).describe(sql)
}
}