parsql_deadpool_postgres/
transaction_extensions.rs

1use crate::traits::{FromRow, SqlCommand, SqlParams, SqlQuery, TransactionOps, UpdateParams};
2use deadpool_postgres::{GenericClient, Transaction};
3use std::fmt::Debug;
4use std::sync::OnceLock;
5use tokio_postgres::Row;
6use tokio_postgres::{types::FromSql, Error};
7
8/// Transaction extension trait for additional query operations
9#[async_trait::async_trait]
10pub trait TransactionExtensions {
11    /// Inserts a new record into the database within a transaction
12    async fn insert<T, P: for<'a> FromSql<'a> + Send + Sync>(&self, entity: T) -> Result<P, Error>
13    where
14        T: SqlCommand + SqlParams + Send + Sync + 'static;
15
16    /// Updates an existing record in the database within a transaction
17    async fn update<T>(&self, entity: T) -> Result<bool, Error>
18    where
19        T: SqlCommand + UpdateParams + Send + Sync + 'static;
20
21    /// Deletes a record from the database within a transaction
22    async fn delete<T>(&self, entity: T) -> Result<u64, Error>
23    where
24        T: SqlCommand + SqlParams + Send + Sync + 'static;
25
26    /// Retrieves a single record from the database within a transaction
27    async fn fetch<P, R>(&self, params: P) -> Result<R, Error>
28    where
29        P: SqlQuery<R> + SqlParams + Send + Sync + 'static,
30        R: FromRow + Send + Sync + 'static;
31
32    /// Retrieves multiple records from the database within a transaction
33    async fn fetch_all<P, R>(&self, params: P) -> Result<Vec<R>, Error>
34    where
35        P: SqlQuery<R> + SqlParams + Send + Sync + 'static,
36        R: FromRow + Send + Sync + 'static;
37}
38
39#[async_trait::async_trait]
40impl<'a> TransactionOps for Transaction<'a> {
41    async fn tx_insert<T, P>(&self, entity: T) -> Result<P, Error>
42    where
43        T: SqlCommand + SqlParams + Debug + Send + 'static,
44        P: for<'b> tokio_postgres::types::FromSql<'b> + Send + Sync,
45    {
46        let sql = T::query();
47
48        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
49            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
50        }
51
52        let query_params = entity.params();
53        let row = self.query_one(&sql, &query_params).await?;
54        row.try_get::<_, P>(0)
55    }
56
57    async fn tx_update<T>(&self, entity: T) -> Result<bool, Error>
58    where
59        T: SqlCommand + UpdateParams + SqlParams + Debug + Send + 'static,
60    {
61        let sql = T::query();
62
63        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
64            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
65        }
66
67        let query_params = <T as UpdateParams>::params(&entity);
68        let result = self.execute(&sql, &query_params).await?;
69        Ok(result > 0)
70    }
71
72    async fn tx_delete<T>(&self, entity: T) -> Result<u64, Error>
73    where
74        T: SqlCommand + SqlParams + Debug + Send + 'static,
75    {
76        let sql = T::query();
77
78        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
79            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
80        }
81
82        let query_params = entity.params();
83        self.execute(&sql, &query_params).await
84    }
85
86    async fn tx_fetch<P, R>(&self, params: &P) -> Result<R, Error>
87    where
88        P: SqlQuery<R> + SqlParams + Debug + Send + Sync + Clone + 'static,
89        R: FromRow + Debug + Send + Sync + Clone + 'static,
90    {
91        let sql = P::query();
92
93        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
94            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
95        }
96
97        let query_params = params.params();
98        let row = self.query_one(&sql, &query_params).await?;
99        R::from_row(&row)
100    }
101
102    async fn tx_fetch_all<P, R>(&self, params: &P) -> Result<Vec<R>, Error>
103    where
104        P: SqlQuery<R> + SqlParams + Debug + Send + Sync + Clone + 'static,
105        R: FromRow + Debug + Send + Sync + Clone + 'static,
106    {
107        let sql = P::query();
108
109        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
110            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
111        }
112
113        let query_params = params.params();
114        let rows = self.query(&sql, &query_params).await?;
115
116        let mut results = Vec::with_capacity(rows.len());
117        for row in rows {
118            results.push(R::from_row(&row)?);
119        }
120
121        Ok(results)
122    }
123
124    async fn tx_select<T, F, R>(&self, entity: T, to_model: F) -> Result<R, Error>
125    where
126        T: SqlQuery<T> + SqlParams + Debug + Send + 'static,
127        F: Fn(&Row) -> Result<R, Error> + Send + Sync + 'static,
128        R: Send + 'static,
129    {
130        let sql = T::query();
131
132        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
133            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
134        }
135
136        let query_params = entity.params();
137        let row = self.query_one(&sql, &query_params).await?;
138        to_model(&row)
139    }
140
141    async fn tx_select_all<T, F, R>(&self, entity: T, to_model: F) -> Result<Vec<R>, Error>
142    where
143        T: SqlQuery<T> + SqlParams + Debug + Send + 'static,
144        F: Fn(&Row) -> R + Send + Sync + 'static,
145        R: Send + 'static,
146    {
147        let sql = T::query();
148
149        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
150            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
151        }
152
153        let query_params = entity.params();
154        let rows = self.query(&sql, &query_params).await?;
155
156        let mut results = Vec::with_capacity(rows.len());
157        for row in rows {
158            results.push(to_model(&row));
159        }
160
161        Ok(results)
162    }
163
164    // Deprecated methods for backward compatibility
165    async fn insert<T>(&self, entity: T) -> Result<u64, Error>
166    where
167        T: SqlCommand + SqlParams + Debug + Send + 'static,
168    {
169        let sql = T::query();
170
171        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
172            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
173        }
174
175        let query_params = entity.params();
176        self.execute(&sql, &query_params).await
177    }
178
179    async fn update<T>(&self, entity: T) -> Result<u64, Error>
180    where
181        T: SqlCommand + UpdateParams + SqlParams + Debug + Send + 'static,
182    {
183        let sql = T::query();
184
185        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
186            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
187        }
188
189        let query_params = <T as UpdateParams>::params(&entity);
190        self.execute(&sql, &query_params).await
191    }
192
193    async fn delete<T>(&self, entity: T) -> Result<u64, Error>
194    where
195        T: SqlCommand + SqlParams + Debug + Send + 'static,
196    {
197        let sql = T::query();
198
199        if std::env::var("PARSQL_TRACE").unwrap_or_default() == "1" {
200            println!("[PARSQL-TOKIO-POSTGRES-POOL] Execute SQL: {}", sql);
201        }
202
203        let query_params = entity.params();
204        self.execute(&sql, &query_params).await
205    }
206
207    async fn get<T>(&self, params: &T) -> Result<T, Error>
208    where
209        T: SqlQuery<T> + FromRow + SqlParams + Debug + Send + Sync + Clone + 'static,
210    {
211        self.tx_fetch(params).await
212    }
213
214    async fn get_all<T>(&self, params: &T) -> Result<Vec<T>, Error>
215    where
216        T: SqlQuery<T> + FromRow + SqlParams + Debug + Send + Sync + Clone + 'static,
217    {
218        self.tx_fetch_all(params).await
219    }
220
221    async fn select<T, R, F>(&self, entity: T, to_model: F) -> Result<R, Error>
222    where
223        T: SqlQuery<T> + SqlParams + Debug + Send + 'static,
224        F: Fn(&Row) -> Result<R, Error> + Send + Sync + 'static,
225        R: Send + 'static,
226    {
227        self.tx_select(entity, to_model).await
228    }
229
230    async fn select_all<T, R, F>(&self, entity: T, to_model: F) -> Result<Vec<R>, Error>
231    where
232        T: SqlQuery<T> + SqlParams + Debug + Send + 'static,
233        F: Fn(&Row) -> R + Send + Sync + 'static,
234        R: Send + 'static,
235    {
236        self.tx_select_all(entity, to_model).await
237    }
238}
239
240#[async_trait::async_trait]
241impl TransactionExtensions for Transaction<'_> {
242    async fn insert<T, P: for<'a> FromSql<'a> + Send + Sync>(&self, entity: T) -> Result<P, Error>
243    where
244        T: SqlCommand + SqlParams + Send + Sync + 'static,
245    {
246        let sql = T::query();
247
248        static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
249        let is_trace_enabled =
250            *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
251
252        if is_trace_enabled {
253            println!("[PARSQL-DEADPOOL-POSTGRES-TX] Execute SQL: {}", sql);
254        }
255
256        let params = entity.params();
257        let row = self.query_one(&sql, &params).await?;
258        row.try_get::<_, P>(0)
259    }
260
261    async fn update<T>(&self, entity: T) -> Result<bool, Error>
262    where
263        T: SqlCommand + UpdateParams + Send + Sync + 'static,
264    {
265        let sql = T::query();
266
267        static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
268        let is_trace_enabled =
269            *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
270
271        if is_trace_enabled {
272            println!("[PARSQL-DEADPOOL-POSTGRES-TX] Execute SQL: {}", sql);
273        }
274
275        let params = entity.params();
276        let result = self.execute(&sql, &params).await?;
277        Ok(result > 0)
278    }
279
280    async fn delete<T>(&self, entity: T) -> Result<u64, Error>
281    where
282        T: SqlCommand + SqlParams + Send + Sync + 'static,
283    {
284        let sql = T::query();
285
286        static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
287        let is_trace_enabled =
288            *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
289
290        if is_trace_enabled {
291            println!("[PARSQL-DEADPOOL-POSTGRES-TX] Execute SQL: {}", sql);
292        }
293
294        let params = entity.params();
295        self.execute(&sql, &params).await
296    }
297
298    async fn fetch<P, R>(&self, params: P) -> Result<R, Error>
299    where
300        P: SqlQuery<R> + SqlParams + Send + Sync + 'static,
301        R: FromRow + Send + Sync + 'static,
302    {
303        let sql = P::query();
304
305        static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
306        let is_trace_enabled =
307            *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
308
309        if is_trace_enabled {
310            println!("[PARSQL-DEADPOOL-POSTGRES-TX] Execute SQL: {}", sql);
311        }
312
313        let query_params = params.params();
314        let row = self.query_one(&sql, &query_params).await?;
315        R::from_row(&row)
316    }
317
318    async fn fetch_all<P, R>(&self, params: P) -> Result<Vec<R>, Error>
319    where
320        P: SqlQuery<R> + SqlParams + Send + Sync + 'static,
321        R: FromRow + Send + Sync + 'static,
322    {
323        let sql = P::query();
324
325        static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
326        let is_trace_enabled =
327            *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
328
329        if is_trace_enabled {
330            println!("[PARSQL-DEADPOOL-POSTGRES-TX] Execute SQL: {}", sql);
331        }
332
333        let query_params = params.params();
334        let rows = self.query(&sql, &query_params).await?;
335
336        let mut results = Vec::with_capacity(rows.len());
337        for row in rows {
338            results.push(R::from_row(&row)?);
339        }
340
341        Ok(results)
342    }
343}