sql_middleware/postgres/
transaction.rs1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::middleware::{
6 ConversionMode, ParamConverter, ResultSet, RowValues, SqlMiddlewareDbError,
7};
8use crate::tx_outcome::TxOutcome;
9
10use super::{Params, build_result_set};
11
12pub struct Tx<'a> {
14 tx: PgTransaction<'a>,
15}
16
17pub struct Prepared {
19 stmt: Statement,
20}
21
22pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28 C: DerefMut<Target = Client>,
29{
30 let tx = conn.deref_mut().transaction().await?;
31 Ok(Tx { tx })
32}
33
34impl Tx<'_> {
35 pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40 let stmt = self.tx.prepare(sql).await?;
41 Ok(Prepared { stmt })
42 }
43
44 pub async fn execute_prepared(
49 &self,
50 prepared: &Prepared,
51 params: &[RowValues],
52 ) -> Result<usize, SqlMiddlewareDbError> {
53 let converted =
54 <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Execute)?;
55
56 let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
57
58 usize::try_from(rows).map_err(|e| {
59 SqlMiddlewareDbError::ExecutionError(format!("Invalid rows affected count: {e}"))
60 })
61 }
62
63 pub async fn query_prepared(
68 &self,
69 prepared: &Prepared,
70 params: &[RowValues],
71 ) -> Result<ResultSet, SqlMiddlewareDbError> {
72 let converted =
73 <Params as ParamConverter>::convert_sql_params(params, ConversionMode::Query)?;
74 build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
75 }
76
77 pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
82 self.tx.batch_execute(sql).await?;
83 Ok(())
84 }
85
86 pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
91 self.tx.commit().await?;
92 Ok(TxOutcome::without_restored_connection())
93 }
94
95 pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
100 self.tx.rollback().await?;
101 Ok(TxOutcome::without_restored_connection())
102 }
103}