1use crate::traits::{CrudOps, FromRow, SqlParams, SqlQuery, UpdateParams};
2use postgres::types::FromSql;
3use std::sync::OnceLock;
4use tokio_postgres::{Client, Error, Row, Transaction};
5
6#[async_trait::async_trait]
7impl CrudOps for Client {
8 async fn insert<T, P: for<'a> FromSql<'a> + Send + Sync>(&self, entity: T) -> Result<P, Error>
9 where
10 T: SqlQuery + SqlParams + Send + Sync + 'static,
11 {
12 let sql = T::query();
13
14 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
15 let is_trace_enabled =
16 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
17
18 if is_trace_enabled {
19 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
20 }
21
22 let params = entity.params();
23 let row = self.query_one(&sql, ¶ms).await?;
24 row.try_get::<_, P>(0)
25 }
26
27 async fn update<T>(&self, entity: T) -> Result<bool, Error>
28 where
29 T: SqlQuery + UpdateParams + Send + Sync + 'static,
30 {
31 let sql = T::query();
32
33 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
34 let is_trace_enabled =
35 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
36
37 if is_trace_enabled {
38 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
39 }
40
41 let params = entity.params();
42 let result = self.execute(&sql, ¶ms).await?;
43 Ok(result > 0)
44 }
45
46 async fn delete<T>(&self, entity: T) -> Result<u64, Error>
47 where
48 T: SqlQuery + SqlParams + Send + Sync + 'static,
49 {
50 let sql = T::query();
51
52 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
53 let is_trace_enabled =
54 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
55
56 if is_trace_enabled {
57 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
58 }
59
60 let params = entity.params();
61 self.execute(&sql, ¶ms).await
62 }
63
64 async fn fetch<T>(&self, params: T) -> Result<T, Error>
65 where
66 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
67 {
68 let sql = T::query();
69
70 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
71 let is_trace_enabled =
72 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
73
74 if is_trace_enabled {
75 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
76 }
77
78 let query_params = params.params();
79 let row = self.query_one(&sql, &query_params).await?;
80 T::from_row(&row)
81 }
82
83 async fn fetch_all<T>(&self, params: T) -> Result<Vec<T>, Error>
84 where
85 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
86 {
87 let sql = T::query();
88
89 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
90 let is_trace_enabled =
91 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
92
93 if is_trace_enabled {
94 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
95 }
96
97 let query_params = params.params();
98 let rows = self.query(&sql, &query_params).await?;
99
100 let mut results = Vec::with_capacity(rows.len());
101 for row in rows {
102 results.push(T::from_row(&row)?);
103 }
104
105 Ok(results)
106 }
107
108 async fn select<T, F, R>(&self, entity: T, to_model: F) -> Result<R, Error>
109 where
110 T: SqlQuery + SqlParams + Send + Sync + 'static,
111 F: Fn(&Row) -> Result<R, Error> + Send + Sync + 'static,
112 R: Send + 'static,
113 {
114 let sql = T::query();
115
116 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
117 let is_trace_enabled =
118 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
119
120 if is_trace_enabled {
121 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
122 }
123
124 let params = entity.params();
125 let row = self.query_one(&sql, ¶ms).await?;
126 to_model(&row)
127 }
128
129 async fn select_all<T, F, R>(&self, entity: T, to_model: F) -> Result<Vec<R>, Error>
130 where
131 T: SqlQuery + SqlParams + Send + Sync + 'static,
132 F: Fn(&Row) -> R + Send + Sync + 'static,
133 R: Send + 'static,
134 {
135 let sql = T::query();
136
137 static TRACE_ENABLED: OnceLock<bool> = OnceLock::new();
138 let is_trace_enabled =
139 *TRACE_ENABLED.get_or_init(|| std::env::var("PARSQL_TRACE").unwrap_or_default() == "1");
140
141 if is_trace_enabled {
142 println!("[PARSQL-TOKIO-POSTGRES] Execute SQL: {}", sql);
143 }
144
145 let params = entity.params();
146 let rows = self.query(&sql, ¶ms).await?;
147
148 let mut results = Vec::with_capacity(rows.len());
149 for row in rows {
150 results.push(to_model(&row));
151 }
152
153 Ok(results)
154 }
155}
156
157pub async fn insert<T, P: for<'a> FromSql<'a> + Send + Sync>(
168 client: &Client,
169 entity: T,
170) -> Result<P, Error>
171where
172 T: SqlQuery + SqlParams + Send + Sync + 'static,
173{
174 client.insert::<T, P>(entity).await
175}
176
177pub async fn update<T>(client: &Client, entity: T) -> Result<bool, Error>
188where
189 T: SqlQuery + UpdateParams + Send + Sync + 'static,
190{
191 client.update(entity).await
192}
193
194pub async fn delete<T>(client: &Client, entity: T) -> Result<u64, Error>
205where
206 T: SqlQuery + SqlParams + Send + Sync + 'static,
207{
208 client.delete(entity).await
209}
210
211pub async fn fetch<T>(client: &Client, params: T) -> Result<T, Error>
222where
223 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
224{
225 client.fetch(params).await
226}
227
228pub async fn fetch_all<T>(client: &Client, params: T) -> Result<Vec<T>, Error>
239where
240 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
241{
242 client.fetch_all(params).await
243}
244
245pub async fn select<T, F, R>(client: &Client, entity: T, to_model: F) -> Result<R, Error>
258where
259 T: SqlQuery + SqlParams + Send + Sync + 'static,
260 F: Fn(&Row) -> Result<R, Error> + Send + Sync + 'static,
261 R: Send + 'static,
262{
263 client.select(entity, to_model).await
264}
265
266pub async fn select_all<T, F, R>(client: &Client, entity: T, to_model: F) -> Result<Vec<R>, Error>
279where
280 T: SqlQuery + SqlParams + Send + Sync + 'static,
281 F: Fn(&Row) -> R + Send + Sync + 'static,
282 R: Send + 'static,
283{
284 client.select_all(entity, to_model).await
285}
286
287#[deprecated(
303 since = "0.2.0",
304 note = "Renamed to `fetch`. Please use `fetch` function instead."
305)]
306pub async fn get<T>(client: &Client, params: T) -> Result<T, Error>
307where
308 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
309{
310 fetch(client, params).await
311}
312
313#[deprecated(
327 since = "0.2.0",
328 note = "Renamed to `fetch_all`. Please use `fetch_all` function instead."
329)]
330pub async fn get_all<T>(client: &Client, params: T) -> Result<Vec<T>, Error>
331where
332 T: SqlQuery + FromRow + SqlParams + Send + Sync + 'static,
333{
334 fetch_all(client, params).await
335}