1use std::ops::DerefMut;
2
3use tokio_postgres::{Client, Statement, Transaction as PgTransaction};
4
5use crate::adapters::params::convert_params;
6use crate::middleware::{ConversionMode, CustomDbRow, ResultSet, RowValues, SqlMiddlewareDbError};
7use crate::tx_outcome::TxOutcome;
8
9use super::{Params, build_result_set};
10use crate::postgres::query::{build_result_set_from_rows, convert_affected_rows};
11
12pub struct Tx<'a> {
14 tx: PgTransaction<'a>,
15}
16
17pub struct Prepared {
19 stmt: Statement,
20}
21
22pub async fn begin_transaction<C>(conn: &mut C) -> Result<Tx<'_>, SqlMiddlewareDbError>
27where
28 C: DerefMut<Target = Client>,
29{
30 let tx = conn.deref_mut().transaction().await?;
31 Ok(Tx { tx })
32}
33
34impl<'conn> Tx<'conn> {
35 pub async fn prepare(&self, sql: &str) -> Result<Prepared, SqlMiddlewareDbError> {
40 let stmt = self.tx.prepare(sql).await?;
41 Ok(Prepared { stmt })
42 }
43
44 #[must_use]
46 pub fn select<'tx, 'prepared>(
47 &'tx self,
48 prepared: &'prepared Prepared,
49 ) -> PreparedSelect<'tx, 'prepared, 'static, 'conn> {
50 PreparedSelect {
51 tx: self,
52 prepared,
53 params: &[],
54 }
55 }
56
57 #[must_use]
59 pub fn execute<'tx, 'prepared>(
60 &'tx self,
61 prepared: &'prepared Prepared,
62 ) -> PreparedExecute<'tx, 'prepared, 'static, 'conn> {
63 PreparedExecute {
64 tx: self,
65 prepared,
66 params: &[],
67 }
68 }
69
70 pub(crate) async fn execute_prepared(
75 &self,
76 prepared: &Prepared,
77 params: &[RowValues],
78 ) -> Result<usize, SqlMiddlewareDbError> {
79 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
80
81 let rows = self.tx.execute(&prepared.stmt, converted.as_refs()).await?;
82
83 convert_affected_rows(rows, "Invalid rows affected count")
84 }
85
86 pub async fn execute_dml(
91 &self,
92 query: &str,
93 params: &[RowValues],
94 ) -> Result<usize, SqlMiddlewareDbError> {
95 let converted = convert_params::<Params>(params, ConversionMode::Execute)?;
96 let rows = self.tx.execute(query, converted.as_refs()).await?;
97 convert_affected_rows(rows, "Invalid rows affected count")
98 }
99
100 pub(crate) async fn query_prepared(
105 &self,
106 prepared: &Prepared,
107 params: &[RowValues],
108 ) -> Result<ResultSet, SqlMiddlewareDbError> {
109 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
110 build_result_set(&prepared.stmt, converted.as_refs(), &self.tx).await
111 }
112
113 pub(crate) async fn query_prepared_optional(
118 &self,
119 prepared: &Prepared,
120 params: &[RowValues],
121 ) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
122 self.query_prepared(prepared, params)
123 .await
124 .map(ResultSet::into_optional)
125 }
126
127 pub(crate) async fn query_prepared_one(
132 &self,
133 prepared: &Prepared,
134 params: &[RowValues],
135 ) -> Result<CustomDbRow, SqlMiddlewareDbError> {
136 self.query_prepared(prepared, params).await?.into_one()
137 }
138
139 pub(crate) async fn query_prepared_map_one<T, F>(
147 &self,
148 prepared: &Prepared,
149 params: &[RowValues],
150 mapper: F,
151 ) -> Result<T, SqlMiddlewareDbError>
152 where
153 F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
154 {
155 self.query_prepared_map_optional(prepared, params, mapper)
156 .await?
157 .ok_or_else(|| SqlMiddlewareDbError::ExecutionError("query returned no rows".into()))
158 }
159
160 pub(crate) async fn query_prepared_map_optional<T, F>(
166 &self,
167 prepared: &Prepared,
168 params: &[RowValues],
169 mapper: F,
170 ) -> Result<Option<T>, SqlMiddlewareDbError>
171 where
172 F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
173 {
174 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
175 let row = self
176 .tx
177 .query_opt(&prepared.stmt, converted.as_refs())
178 .await?;
179 row.as_ref().map(mapper).transpose()
180 }
181
182 pub async fn query(
187 &self,
188 query: &str,
189 params: &[RowValues],
190 ) -> Result<ResultSet, SqlMiddlewareDbError> {
191 let converted = convert_params::<Params>(params, ConversionMode::Query)?;
192 let rows = self.tx.query(query, converted.as_refs()).await?;
193 build_result_set_from_rows(&rows)
194 }
195
196 pub async fn execute_batch(&self, sql: &str) -> Result<(), SqlMiddlewareDbError> {
201 self.tx.batch_execute(sql).await?;
202 Ok(())
203 }
204
205 pub async fn commit(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
210 self.tx.commit().await?;
211 Ok(TxOutcome::without_restored_connection())
212 }
213
214 pub async fn rollback(self) -> Result<TxOutcome, SqlMiddlewareDbError> {
219 self.tx.rollback().await?;
220 Ok(TxOutcome::without_restored_connection())
221 }
222}
223
224pub struct PreparedExecute<'tx, 'prepared, 'params, 'conn> {
226 tx: &'tx Tx<'conn>,
227 prepared: &'prepared Prepared,
228 params: &'params [RowValues],
229}
230
231impl<'tx, 'prepared, 'params, 'conn> PreparedExecute<'tx, 'prepared, 'params, 'conn> {
232 #[must_use]
234 pub fn params<'next>(
235 self,
236 params: &'next [RowValues],
237 ) -> PreparedExecute<'tx, 'prepared, 'next, 'conn> {
238 PreparedExecute {
239 tx: self.tx,
240 prepared: self.prepared,
241 params,
242 }
243 }
244
245 pub async fn run(self) -> Result<usize, SqlMiddlewareDbError> {
250 self.tx.execute_prepared(self.prepared, self.params).await
251 }
252}
253
254pub struct PreparedSelect<'tx, 'prepared, 'params, 'conn> {
256 tx: &'tx Tx<'conn>,
257 prepared: &'prepared Prepared,
258 params: &'params [RowValues],
259}
260
261impl<'tx, 'prepared, 'params, 'conn> PreparedSelect<'tx, 'prepared, 'params, 'conn> {
262 #[must_use]
264 pub fn params<'next>(
265 self,
266 params: &'next [RowValues],
267 ) -> PreparedSelect<'tx, 'prepared, 'next, 'conn> {
268 PreparedSelect {
269 tx: self.tx,
270 prepared: self.prepared,
271 params,
272 }
273 }
274
275 pub async fn all(self) -> Result<ResultSet, SqlMiddlewareDbError> {
280 self.tx.query_prepared(self.prepared, self.params).await
281 }
282
283 pub async fn optional(self) -> Result<Option<CustomDbRow>, SqlMiddlewareDbError> {
288 self.tx
289 .query_prepared_optional(self.prepared, self.params)
290 .await
291 }
292
293 pub async fn one(self) -> Result<CustomDbRow, SqlMiddlewareDbError> {
298 self.tx.query_prepared_one(self.prepared, self.params).await
299 }
300
301 pub async fn map_one<T, F>(self, mapper: F) -> Result<T, SqlMiddlewareDbError>
306 where
307 F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
308 {
309 self.tx
310 .query_prepared_map_one(self.prepared, self.params, mapper)
311 .await
312 }
313
314 pub async fn map_optional<T, F>(self, mapper: F) -> Result<Option<T>, SqlMiddlewareDbError>
319 where
320 F: FnOnce(&tokio_postgres::Row) -> Result<T, SqlMiddlewareDbError>,
321 {
322 self.tx
323 .query_prepared_map_optional(self.prepared, self.params, mapper)
324 .await
325 }
326}