use std::ops::DerefMut;
use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
use crate::adapters::params::convert_params;
use crate::middleware::{ConversionMode, ResultSet, RowValues, SqlMiddlewareDbError};
use crate::tx_outcome::TxOutcome;
use super::{Params, build_result_set};
use crate::postgres::query::{build_result_set_from_rows, convert_affected_rows};
pub struct Tx<'a> {
tx: PgTransaction<'a>,
}
pub struct Prepared {
stmt: Statement,
}
pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
where
C: DerefMut<Target = Client>,
{
let tx = conn.deref_mut().transaction().await?;
Ok(Tx { tx })
}
impl Tx<'_> {
pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
let stmt = self.tx.prepare(sql).await?;
Ok(Prepared { stmt })
}
pub async fn execute_prepared(
&self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
convert_affected_rows(rows, "Invalid rows affected count")
}
pub async fn execute_dml(
&self,
query: &str,
params: &[RowValues],
) -> Result<usize, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
let rows = self.tx.execute(query, converted.as_refs()).await?;
convert_affected_rows(rows, "Invalid rows affected count")
}
pub async fn query_prepared(
&self,
prepared: &Prepared,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Query)?;
build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
}
pub async fn query(
&self,
query: &str,
params: &[RowValues],
) -> Result<ResultSet, SqlMiddlewareDbError> {
let converted = convert_params::<Params>(params, ConversionMode::Query)?;
let rows = self.tx.query(query, converted.as_refs()).await?;
build_result_set_from_rows(&rows)
}
pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
self.tx.batch_execute(sql).await?;
Ok(())
}
pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
self.tx.commit().await?;
Ok(TxOutcome::without_restored_connection())
}
pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
self.tx.rollback().await?;
Ok(TxOutcome::without_restored_connection())
}
}