use diesel::{
QueryResult,
query_builder::{AsQuery, QueryFragment, QueryId},
};
use futures_util::{StreamExt, TryStreamExt, stream::BoxStream};
use tokio_postgres::{
Error, Row, RowStream, Statement, ToStatement, Transaction,
types::{BorrowToSql, ToSql, Type},
};
use super::{
cache::PgCache, error_helper::ErrorHelper, prepared_client::PreparedClient, row::PgRow,
};
use crate::{AsyncExecute, AsyncTransaction, AsyncTransactional};
pub struct AsyncPgTransaction<'a> {
transaction: tokio_postgres::Transaction<'a>,
cache: &'a mut PgCache,
}
impl<'a> AsyncPgTransaction<'a> {
pub(super) fn new(
transaction: tokio_postgres::Transaction<'a>,
cache: &'a mut PgCache,
) -> Self {
Self { transaction, cache }
}
}
impl AsyncTransaction for AsyncPgTransaction<'_> {
async fn commit(self) -> QueryResult<()> {
Ok(self.transaction.commit().await.map_err(ErrorHelper)?)
}
async fn rollback(self) -> QueryResult<()> {
Ok(self.transaction.rollback().await.map_err(ErrorHelper)?)
}
}
impl AsyncTransactional for AsyncPgTransaction<'_> {
type Transaction<'a>
= AsyncPgTransaction<'a>
where
Self: 'a;
async fn begin_transaction(&mut self) -> QueryResult<Self::Transaction<'_>> {
let transaction = self.transaction.transaction().await.map_err(ErrorHelper)?;
let transaction = AsyncPgTransaction::new(transaction, self.cache);
Ok(transaction)
}
}
impl AsyncExecute for AsyncPgTransaction<'_> {
type Stream<'conn>
= BoxStream<'conn, QueryResult<PgRow>>
where
Self: 'conn;
type Row<'conn>
= PgRow
where
Self: 'conn;
type Backend = diesel::pg::Pg;
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
Ok(self
.transaction
.batch_execute(query)
.await
.map_err(ErrorHelper)?)
}
async fn load<T>(&mut self, source: T) -> QueryResult<Self::Stream<'_>>
where
T: AsQuery,
T::Query: QueryFragment<Self::Backend> + QueryId,
{
let res = self
.cache
.load_cached(&mut self.transaction, source)
.await?;
let res = res
.map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
.map_ok(PgRow::new);
Ok(res.boxed())
}
async fn execute_returning_count<T>(&mut self, source: T) -> QueryResult<usize>
where
T: QueryFragment<Self::Backend> + QueryId + Send,
{
self.cache
.execute_returning_count_cached(&mut self.transaction, source)
.await
}
}
impl PreparedClient for Transaction<'_> {
async fn query_one<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Row, Error>
where
T: ?Sized + Send + Sync + ToStatement,
{
(self as &Transaction<'_>)
.query_one(statement, params)
.await
}
async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
(self as &Transaction)
.prepare_typed(query, parameter_types)
.await
}
async fn query_raw<T, P, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + Send + Sync + ToStatement,
P: BorrowToSql,
I: IntoIterator<Item = P> + Send + Sync,
I::IntoIter: ExactSizeIterator,
{
(self as &Transaction<'_>)
.query_raw(statement, params)
.await
}
async fn execute<T>(&self, statement: &T, params: &[&(dyn ToSql + Sync)]) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
(self as &Transaction<'_>).execute(statement, params).await
}
}