1use std::marker::PhantomData;
4use std::sync::Arc;
5
6use prax_query::QueryResult;
7use prax_query::filter::FilterValue;
8use prax_query::traits::{BoxFuture, Model, QueryEngine};
9use tracing::trace;
10
11use crate::pool::PgPool;
12use crate::types::filter_value_to_sql;
13
14#[derive(Clone)]
40pub struct PgEngine {
41 pool: PgPool,
42 tx_conn: Option<Arc<deadpool_postgres::Object>>,
45}
46
47impl PgEngine {
48 pub fn new(pool: PgPool) -> Self {
50 Self {
51 pool,
52 tx_conn: None,
53 }
54 }
55
56 pub fn pool(&self) -> &PgPool {
58 &self.pool
59 }
60
61 #[allow(clippy::result_large_err)]
63 fn to_params(
64 values: &[FilterValue],
65 ) -> Result<Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>, prax_query::QueryError>
66 {
67 values
68 .iter()
69 .map(|v| {
70 filter_value_to_sql(v).map_err(|e| {
71 let msg = e.to_string();
72 prax_query::QueryError::database(msg).with_source(e)
73 })
74 })
75 .collect()
76 }
77}
78
79impl QueryEngine for PgEngine {
80 fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
81 &prax_query::dialect::Postgres
82 }
83
84 fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
85 &self,
86 sql: &str,
87 params: Vec<FilterValue>,
88 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
89 let sql = sql.to_string();
90 Box::pin(async move {
91 trace!(sql = %sql, "Executing query_many");
92
93 let pg_params = Self::to_params(¶ms)?;
94 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
95 pg_params.iter().map(|p| p.as_ref() as _).collect();
96
97 let rows = if let Some(tx) = &self.tx_conn {
98 tx.query(&sql, ¶m_refs)
102 .await
103 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
104 } else {
105 let conn = self.pool.get().await.map_err(|e| {
106 prax_query::QueryError::connection(e.to_string()).with_source(e)
107 })?;
108 conn.query(&sql, ¶m_refs)
109 .await
110 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
111 };
112
113 crate::deserialize::rows_into::<T>(rows)
114 })
115 }
116
117 fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
118 &self,
119 sql: &str,
120 params: Vec<FilterValue>,
121 ) -> BoxFuture<'_, QueryResult<T>> {
122 let sql = sql.to_string();
123 Box::pin(async move {
124 trace!(sql = %sql, "Executing query_one");
125
126 let pg_params = Self::to_params(¶ms)?;
127 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
128 pg_params.iter().map(|p| p.as_ref() as _).collect();
129
130 let map_err = |e: String| -> prax_query::QueryError {
134 if e.contains("no rows") {
135 prax_query::QueryError::not_found(T::MODEL_NAME)
136 } else {
137 prax_query::QueryError::database(e)
138 }
139 };
140
141 let row = if let Some(tx) = &self.tx_conn {
142 tx.query_one(&sql, ¶m_refs)
143 .await
144 .map_err(|e| map_err(e.to_string()).with_source(e))?
145 } else {
146 let conn = self.pool.get().await.map_err(|e| {
147 prax_query::QueryError::connection(e.to_string()).with_source(e)
148 })?;
149 conn.query_one(&sql, ¶m_refs)
150 .await
151 .map_err(|e| map_err(e.to_string()).with_source(e))?
152 };
153
154 crate::deserialize::row_into::<T>(row)
155 })
156 }
157
158 fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
159 &self,
160 sql: &str,
161 params: Vec<FilterValue>,
162 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
163 let sql = sql.to_string();
164 Box::pin(async move {
165 trace!(sql = %sql, "Executing query_optional");
166
167 let pg_params = Self::to_params(¶ms)?;
168 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
169 pg_params.iter().map(|p| p.as_ref() as _).collect();
170
171 let row = if let Some(tx) = &self.tx_conn {
172 tx.query_opt(&sql, ¶m_refs)
173 .await
174 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
175 } else {
176 let conn = self.pool.get().await.map_err(|e| {
177 prax_query::QueryError::connection(e.to_string()).with_source(e)
178 })?;
179 conn.query_opt(&sql, ¶m_refs)
180 .await
181 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
182 };
183
184 row.map(crate::deserialize::row_into::<T>).transpose()
185 })
186 }
187
188 fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
189 &self,
190 sql: &str,
191 params: Vec<FilterValue>,
192 ) -> BoxFuture<'_, QueryResult<T>> {
193 let sql = sql.to_string();
194 Box::pin(async move {
195 trace!(sql = %sql, "Executing insert");
196
197 let pg_params = Self::to_params(¶ms)?;
198 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
199 pg_params.iter().map(|p| p.as_ref() as _).collect();
200
201 let row = if let Some(tx) = &self.tx_conn {
202 tx.query_one(&sql, ¶m_refs)
203 .await
204 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
205 } else {
206 let conn = self.pool.get().await.map_err(|e| {
207 prax_query::QueryError::connection(e.to_string()).with_source(e)
208 })?;
209 conn.query_one(&sql, ¶m_refs)
210 .await
211 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
212 };
213
214 crate::deserialize::row_into::<T>(row)
215 })
216 }
217
218 fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
219 &self,
220 sql: &str,
221 params: Vec<FilterValue>,
222 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
223 let sql = sql.to_string();
224 Box::pin(async move {
225 trace!(sql = %sql, "Executing update");
226
227 let pg_params = Self::to_params(¶ms)?;
228 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
229 pg_params.iter().map(|p| p.as_ref() as _).collect();
230
231 let rows = if let Some(tx) = &self.tx_conn {
232 tx.query(&sql, ¶m_refs)
233 .await
234 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
235 } else {
236 let conn = self.pool.get().await.map_err(|e| {
237 prax_query::QueryError::connection(e.to_string()).with_source(e)
238 })?;
239 conn.query(&sql, ¶m_refs)
240 .await
241 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
242 };
243
244 crate::deserialize::rows_into::<T>(rows)
245 })
246 }
247
248 fn execute_delete(
249 &self,
250 sql: &str,
251 params: Vec<FilterValue>,
252 ) -> BoxFuture<'_, QueryResult<u64>> {
253 let sql = sql.to_string();
254 Box::pin(async move {
255 trace!(sql = %sql, "Executing delete");
256
257 let pg_params = Self::to_params(¶ms)?;
258 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
259 pg_params.iter().map(|p| p.as_ref() as _).collect();
260
261 if let Some(tx) = &self.tx_conn {
262 tx.execute(&sql, ¶m_refs)
263 .await
264 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
265 } else {
266 let conn = self.pool.get().await.map_err(|e| {
267 prax_query::QueryError::connection(e.to_string()).with_source(e)
268 })?;
269 conn.execute(&sql, ¶m_refs)
270 .await
271 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
272 }
273 })
274 }
275
276 fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
277 let sql = sql.to_string();
278 Box::pin(async move {
279 trace!(sql = %sql, "Executing raw SQL");
280
281 let pg_params = Self::to_params(¶ms)?;
282 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
283 pg_params.iter().map(|p| p.as_ref() as _).collect();
284
285 if let Some(tx) = &self.tx_conn {
286 tx.execute(&sql, ¶m_refs)
287 .await
288 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
289 } else {
290 let conn = self.pool.get().await.map_err(|e| {
291 prax_query::QueryError::connection(e.to_string()).with_source(e)
292 })?;
293 conn.execute(&sql, ¶m_refs)
294 .await
295 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
296 }
297 })
298 }
299
300 fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
301 let sql = sql.to_string();
302 Box::pin(async move {
303 trace!(sql = %sql, "Executing count");
304
305 let pg_params = Self::to_params(¶ms)?;
306 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
307 pg_params.iter().map(|p| p.as_ref() as _).collect();
308
309 let row = if let Some(tx) = &self.tx_conn {
310 tx.query_one(&sql, ¶m_refs)
311 .await
312 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
313 } else {
314 let conn = self.pool.get().await.map_err(|e| {
315 prax_query::QueryError::connection(e.to_string()).with_source(e)
316 })?;
317 conn.query_one(&sql, ¶m_refs)
318 .await
319 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
320 };
321
322 let count: i64 = row.get(0);
323 Ok(count as u64)
324 })
325 }
326
327 fn aggregate_query(
328 &self,
329 sql: &str,
330 params: Vec<FilterValue>,
331 ) -> BoxFuture<'_, QueryResult<Vec<std::collections::HashMap<String, FilterValue>>>> {
332 let sql = sql.to_string();
333 Box::pin(async move {
334 trace!(sql = %sql, "Executing aggregate_query");
335
336 let pg_params = Self::to_params(¶ms)?;
337 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
338 pg_params.iter().map(|p| p.as_ref() as _).collect();
339
340 let rows = if let Some(tx) = &self.tx_conn {
341 tx.query(&sql, ¶m_refs)
342 .await
343 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
344 } else {
345 let conn = self.pool.get().await.map_err(|e| {
346 prax_query::QueryError::connection(e.to_string()).with_source(e)
347 })?;
348 conn.query(&sql, ¶m_refs)
349 .await
350 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
351 };
352
353 Ok(rows
354 .into_iter()
355 .map(|row| {
356 let mut map = std::collections::HashMap::new();
357 for (i, col) in row.columns().iter().enumerate() {
358 let name = col.name().to_string();
359 let value = decode_aggregate_cell(&row, i, col.type_());
360 map.insert(name, value);
361 }
362 map
363 })
364 .collect())
365 })
366 }
367
368 fn transaction<'a, R, Fut, F>(&'a self, f: F) -> BoxFuture<'a, QueryResult<R>>
369 where
370 F: FnOnce(Self) -> Fut + Send + 'a,
371 Fut: std::future::Future<Output = QueryResult<R>> + Send + 'a,
372 R: Send + 'a,
373 Self: Clone,
374 {
375 Box::pin(async move {
376 if self.tx_conn.is_some() {
380 return Err(prax_query::QueryError::internal(
381 "nested transactions not yet implemented \
382 (call .transaction() on the outer engine only, or \
383 issue SAVEPOINT via execute_raw)",
384 ));
385 }
386
387 let conn =
392 self.pool.inner().get().await.map_err(|e| {
393 prax_query::QueryError::connection(e.to_string()).with_source(e)
394 })?;
395
396 conn.batch_execute("BEGIN")
405 .await
406 .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
407
408 let tx_conn = Arc::new(conn);
409 let tx_engine = PgEngine {
410 pool: self.pool.clone(),
411 tx_conn: Some(tx_conn.clone()),
412 };
413
414 let result = f(tx_engine).await;
419
420 match result {
425 Ok(v) => {
426 tx_conn.batch_execute("COMMIT").await.map_err(|e| {
427 prax_query::QueryError::database(e.to_string()).with_source(e)
428 })?;
429 Ok(v)
430 }
431 Err(e) => {
432 let _ = tx_conn.batch_execute("ROLLBACK").await;
433 Err(e)
434 }
435 }
436 })
437 }
438}
439
440pub struct PgQueryBuilder<T: Model> {
442 engine: PgEngine,
443 _marker: PhantomData<T>,
444}
445
446impl<T: Model> PgQueryBuilder<T> {
447 pub fn new(engine: PgEngine) -> Self {
449 Self {
450 engine,
451 _marker: PhantomData,
452 }
453 }
454
455 pub fn engine(&self) -> &PgEngine {
457 &self.engine
458 }
459}
460
461fn decode_aggregate_cell(
481 row: &tokio_postgres::Row,
482 idx: usize,
483 ty: &tokio_postgres::types::Type,
484) -> FilterValue {
485 use tokio_postgres::types::Type;
486 match *ty {
487 Type::BOOL => row
488 .try_get::<_, Option<bool>>(idx)
489 .ok()
490 .flatten()
491 .map(FilterValue::Bool)
492 .unwrap_or(FilterValue::Null),
493 Type::INT2 => row
494 .try_get::<_, Option<i16>>(idx)
495 .ok()
496 .flatten()
497 .map(|n| FilterValue::Int(n as i64))
498 .unwrap_or(FilterValue::Null),
499 Type::INT4 => row
500 .try_get::<_, Option<i32>>(idx)
501 .ok()
502 .flatten()
503 .map(|n| FilterValue::Int(n as i64))
504 .unwrap_or(FilterValue::Null),
505 Type::INT8 => row
506 .try_get::<_, Option<i64>>(idx)
507 .ok()
508 .flatten()
509 .map(FilterValue::Int)
510 .unwrap_or(FilterValue::Null),
511 Type::FLOAT4 => row
512 .try_get::<_, Option<f32>>(idx)
513 .ok()
514 .flatten()
515 .map(|f| FilterValue::Float(f as f64))
516 .unwrap_or(FilterValue::Null),
517 Type::FLOAT8 => row
518 .try_get::<_, Option<f64>>(idx)
519 .ok()
520 .flatten()
521 .map(FilterValue::Float)
522 .unwrap_or(FilterValue::Null),
523 Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME | Type::BPCHAR | Type::NUMERIC => row
524 .try_get::<_, Option<String>>(idx)
525 .ok()
526 .flatten()
527 .map(FilterValue::String)
528 .unwrap_or(FilterValue::Null),
529 Type::JSON | Type::JSONB => row
530 .try_get::<_, Option<serde_json::Value>>(idx)
531 .ok()
532 .flatten()
533 .map(FilterValue::Json)
534 .unwrap_or(FilterValue::Null),
535 _ => row
536 .try_get::<_, Option<String>>(idx)
537 .ok()
538 .flatten()
539 .map(FilterValue::String)
540 .unwrap_or(FilterValue::Null),
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 }