Skip to main content

prax_postgres/
engine.rs

1//! PostgreSQL query engine implementation.
2
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use prax_query::QueryResult;
7use prax_query::filter::FilterValue;
8use prax_query::traits::{BoxFuture, Model, QueryEngine};
9use tracing::trace;
10
11use crate::pool::PgPool;
12use crate::types::filter_value_to_sql;
13
14/// PostgreSQL query engine that implements the Prax `QueryEngine`
15/// trait.
16///
17/// Two modes, controlled by the `tx_conn` field:
18///
19/// - **Pool mode** (`tx_conn == None`, the default): each query
20///   acquires a fresh connection from [`PgPool`] and drops it after
21///   the call.
22/// - **Transaction mode** (`tx_conn == Some(conn)`): each query routes
23///   through the single pinned [`deadpool_postgres::Object`]. The
24///   tx-bound engine is built by [`PgEngine::transaction`], which
25///   issues a raw `BEGIN`; the outer future then runs `COMMIT` or
26///   `ROLLBACK` on the same connection based on the closure's
27///   `Ok` / `Err` result.
28///
29/// We lean on raw `BEGIN` / `COMMIT` / `ROLLBACK` strings instead of
30/// `tokio_postgres::Transaction<'_>` because `Transaction<'_>` borrows
31/// from its owning `Client`, and bundling both into a heap cell
32/// requires `mem::transmute` gymnastics to launder the lifetime to
33/// `'static`. Since `Object` implements `Deref<Target = Client>` and
34/// `Client::query` / `execute` take `&self`, an `Arc<Object>` is all
35/// we need — every engine clone can share it freely, and the last
36/// clone drops the `Arc`, which drops the `Object` back to the pool.
37/// This path is explicitly sanctioned by the task plan's "fall back"
38/// guardrail.
39#[derive(Clone)]
40pub struct PgEngine {
41    pool: PgPool,
42    /// Present when this engine is bound to an in-flight transaction.
43    /// `None` in the normal pool-backed case.
44    tx_conn: Option<Arc<deadpool_postgres::Object>>,
45}
46
47impl PgEngine {
48    /// Create a new PostgreSQL engine with the given connection pool.
49    pub fn new(pool: PgPool) -> Self {
50        Self {
51            pool,
52            tx_conn: None,
53        }
54    }
55
56    /// Get a reference to the connection pool.
57    pub fn pool(&self) -> &PgPool {
58        &self.pool
59    }
60
61    /// Convert filter values to PostgreSQL parameters.
62    #[allow(clippy::result_large_err)]
63    fn to_params(
64        values: &[FilterValue],
65    ) -> Result<Vec<Box<dyn tokio_postgres::types::ToSql + Sync + Send>>, prax_query::QueryError>
66    {
67        values
68            .iter()
69            .map(|v| {
70                filter_value_to_sql(v).map_err(|e| {
71                    let msg = e.to_string();
72                    prax_query::QueryError::database(msg).with_source(e)
73                })
74            })
75            .collect()
76    }
77}
78
79impl QueryEngine for PgEngine {
80    fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
81        &prax_query::dialect::Postgres
82    }
83
84    fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
85        &self,
86        sql: &str,
87        params: Vec<FilterValue>,
88    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
89        let sql = sql.to_string();
90        Box::pin(async move {
91            trace!(sql = %sql, "Executing query_many");
92
93            let pg_params = Self::to_params(&params)?;
94            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
95                pg_params.iter().map(|p| p.as_ref() as _).collect();
96
97            let rows = if let Some(tx) = &self.tx_conn {
98                // Tx mode: drive the pinned connection directly so the
99                // query lands inside the same BEGIN…COMMIT block as
100                // every sibling call.
101                tx.query(&sql, &param_refs)
102                    .await
103                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
104            } else {
105                let conn = self.pool.get().await.map_err(|e| {
106                    prax_query::QueryError::connection(e.to_string()).with_source(e)
107                })?;
108                conn.query(&sql, &param_refs)
109                    .await
110                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
111            };
112
113            crate::deserialize::rows_into::<T>(rows)
114        })
115    }
116
117    fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
118        &self,
119        sql: &str,
120        params: Vec<FilterValue>,
121    ) -> BoxFuture<'_, QueryResult<T>> {
122        let sql = sql.to_string();
123        Box::pin(async move {
124            trace!(sql = %sql, "Executing query_one");
125
126            let pg_params = Self::to_params(&params)?;
127            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
128                pg_params.iter().map(|p| p.as_ref() as _).collect();
129
130            // Shared `no rows` → `NotFound` translation, factored out
131            // so both dispatch arms convert the same error text the
132            // same way.
133            let map_err = |e: String| -> prax_query::QueryError {
134                if e.contains("no rows") {
135                    prax_query::QueryError::not_found(T::MODEL_NAME)
136                } else {
137                    prax_query::QueryError::database(e)
138                }
139            };
140
141            let row = if let Some(tx) = &self.tx_conn {
142                tx.query_one(&sql, &param_refs)
143                    .await
144                    .map_err(|e| map_err(e.to_string()).with_source(e))?
145            } else {
146                let conn = self.pool.get().await.map_err(|e| {
147                    prax_query::QueryError::connection(e.to_string()).with_source(e)
148                })?;
149                conn.query_one(&sql, &param_refs)
150                    .await
151                    .map_err(|e| map_err(e.to_string()).with_source(e))?
152            };
153
154            crate::deserialize::row_into::<T>(row)
155        })
156    }
157
158    fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
159        &self,
160        sql: &str,
161        params: Vec<FilterValue>,
162    ) -> BoxFuture<'_, QueryResult<Option<T>>> {
163        let sql = sql.to_string();
164        Box::pin(async move {
165            trace!(sql = %sql, "Executing query_optional");
166
167            let pg_params = Self::to_params(&params)?;
168            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
169                pg_params.iter().map(|p| p.as_ref() as _).collect();
170
171            let row = if let Some(tx) = &self.tx_conn {
172                tx.query_opt(&sql, &param_refs)
173                    .await
174                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
175            } else {
176                let conn = self.pool.get().await.map_err(|e| {
177                    prax_query::QueryError::connection(e.to_string()).with_source(e)
178                })?;
179                conn.query_opt(&sql, &param_refs)
180                    .await
181                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
182            };
183
184            row.map(crate::deserialize::row_into::<T>).transpose()
185        })
186    }
187
188    fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
189        &self,
190        sql: &str,
191        params: Vec<FilterValue>,
192    ) -> BoxFuture<'_, QueryResult<T>> {
193        let sql = sql.to_string();
194        Box::pin(async move {
195            trace!(sql = %sql, "Executing insert");
196
197            let pg_params = Self::to_params(&params)?;
198            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
199                pg_params.iter().map(|p| p.as_ref() as _).collect();
200
201            let row = if let Some(tx) = &self.tx_conn {
202                tx.query_one(&sql, &param_refs)
203                    .await
204                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
205            } else {
206                let conn = self.pool.get().await.map_err(|e| {
207                    prax_query::QueryError::connection(e.to_string()).with_source(e)
208                })?;
209                conn.query_one(&sql, &param_refs)
210                    .await
211                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
212            };
213
214            crate::deserialize::row_into::<T>(row)
215        })
216    }
217
218    fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
219        &self,
220        sql: &str,
221        params: Vec<FilterValue>,
222    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
223        let sql = sql.to_string();
224        Box::pin(async move {
225            trace!(sql = %sql, "Executing update");
226
227            let pg_params = Self::to_params(&params)?;
228            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
229                pg_params.iter().map(|p| p.as_ref() as _).collect();
230
231            let rows = if let Some(tx) = &self.tx_conn {
232                tx.query(&sql, &param_refs)
233                    .await
234                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
235            } else {
236                let conn = self.pool.get().await.map_err(|e| {
237                    prax_query::QueryError::connection(e.to_string()).with_source(e)
238                })?;
239                conn.query(&sql, &param_refs)
240                    .await
241                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
242            };
243
244            crate::deserialize::rows_into::<T>(rows)
245        })
246    }
247
248    fn execute_delete(
249        &self,
250        sql: &str,
251        params: Vec<FilterValue>,
252    ) -> BoxFuture<'_, QueryResult<u64>> {
253        let sql = sql.to_string();
254        Box::pin(async move {
255            trace!(sql = %sql, "Executing delete");
256
257            let pg_params = Self::to_params(&params)?;
258            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
259                pg_params.iter().map(|p| p.as_ref() as _).collect();
260
261            if let Some(tx) = &self.tx_conn {
262                tx.execute(&sql, &param_refs)
263                    .await
264                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
265            } else {
266                let conn = self.pool.get().await.map_err(|e| {
267                    prax_query::QueryError::connection(e.to_string()).with_source(e)
268                })?;
269                conn.execute(&sql, &param_refs)
270                    .await
271                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
272            }
273        })
274    }
275
276    fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
277        let sql = sql.to_string();
278        Box::pin(async move {
279            trace!(sql = %sql, "Executing raw SQL");
280
281            let pg_params = Self::to_params(&params)?;
282            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
283                pg_params.iter().map(|p| p.as_ref() as _).collect();
284
285            if let Some(tx) = &self.tx_conn {
286                tx.execute(&sql, &param_refs)
287                    .await
288                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
289            } else {
290                let conn = self.pool.get().await.map_err(|e| {
291                    prax_query::QueryError::connection(e.to_string()).with_source(e)
292                })?;
293                conn.execute(&sql, &param_refs)
294                    .await
295                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))
296            }
297        })
298    }
299
300    fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
301        let sql = sql.to_string();
302        Box::pin(async move {
303            trace!(sql = %sql, "Executing count");
304
305            let pg_params = Self::to_params(&params)?;
306            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
307                pg_params.iter().map(|p| p.as_ref() as _).collect();
308
309            let row = if let Some(tx) = &self.tx_conn {
310                tx.query_one(&sql, &param_refs)
311                    .await
312                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
313            } else {
314                let conn = self.pool.get().await.map_err(|e| {
315                    prax_query::QueryError::connection(e.to_string()).with_source(e)
316                })?;
317                conn.query_one(&sql, &param_refs)
318                    .await
319                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
320            };
321
322            let count: i64 = row.get(0);
323            Ok(count as u64)
324        })
325    }
326
327    fn aggregate_query(
328        &self,
329        sql: &str,
330        params: Vec<FilterValue>,
331    ) -> BoxFuture<'_, QueryResult<Vec<std::collections::HashMap<String, FilterValue>>>> {
332        let sql = sql.to_string();
333        Box::pin(async move {
334            trace!(sql = %sql, "Executing aggregate_query");
335
336            let pg_params = Self::to_params(&params)?;
337            let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
338                pg_params.iter().map(|p| p.as_ref() as _).collect();
339
340            let rows = if let Some(tx) = &self.tx_conn {
341                tx.query(&sql, &param_refs)
342                    .await
343                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
344            } else {
345                let conn = self.pool.get().await.map_err(|e| {
346                    prax_query::QueryError::connection(e.to_string()).with_source(e)
347                })?;
348                conn.query(&sql, &param_refs)
349                    .await
350                    .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?
351            };
352
353            Ok(rows
354                .into_iter()
355                .map(|row| {
356                    let mut map = std::collections::HashMap::new();
357                    for (i, col) in row.columns().iter().enumerate() {
358                        let name = col.name().to_string();
359                        let value = decode_aggregate_cell(&row, i, col.type_());
360                        map.insert(name, value);
361                    }
362                    map
363                })
364                .collect())
365        })
366    }
367
368    fn transaction<'a, R, Fut, F>(&'a self, f: F) -> BoxFuture<'a, QueryResult<R>>
369    where
370        F: FnOnce(Self) -> Fut + Send + 'a,
371        Fut: std::future::Future<Output = QueryResult<R>> + Send + 'a,
372        R: Send + 'a,
373        Self: Clone,
374    {
375        Box::pin(async move {
376            // Refuse nested transactions until dialect-aware SAVEPOINT
377            // support lands. Users can still run SAVEPOINT / RELEASE
378            // manually via `execute_raw` if they need it.
379            if self.tx_conn.is_some() {
380                return Err(prax_query::QueryError::internal(
381                    "nested transactions not yet implemented \
382                     (call .transaction() on the outer engine only, or \
383                     issue SAVEPOINT via execute_raw)",
384                ));
385            }
386
387            // Acquire a dedicated raw `deadpool_postgres::Object`.
388            // Going through `PgPool::inner()` keeps the connection
389            // pinned to this future — every query the closure emits
390            // will run on the same physical connection.
391            let conn =
392                self.pool.inner().get().await.map_err(|e| {
393                    prax_query::QueryError::connection(e.to_string()).with_source(e)
394                })?;
395
396            // Issue `BEGIN` directly as a batch_execute on the raw
397            // connection. Using `tokio_postgres::Transaction<'_>`
398            // would bundle a borrow back into `conn`; instead we rely
399            // on the connection's session state (postgres tracks the
400            // BEGIN/COMMIT/ROLLBACK on the connection itself, so every
401            // subsequent query on the same `Object` sees the same
402            // transaction). This is the approach sanctioned by the
403            // task plan's fallback guardrail.
404            conn.batch_execute("BEGIN")
405                .await
406                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
407
408            let tx_conn = Arc::new(conn);
409            let tx_engine = PgEngine {
410                pool: self.pool.clone(),
411                tx_conn: Some(tx_conn.clone()),
412            };
413
414            // Run the caller's closure on the tx-bound engine clone.
415            // When the future resolves the closure's engine clone has
416            // dropped, so `tx_conn` is the only remaining `Arc` (plus
417            // the clone we handed to the engine itself).
418            let result = f(tx_engine).await;
419
420            // Finalise: COMMIT on success, best-effort ROLLBACK on
421            // failure. Preserve the caller's error if rollback fails —
422            // the connection drops in a moment either way and the
423            // server aborts the transaction on session close.
424            match result {
425                Ok(v) => {
426                    tx_conn.batch_execute("COMMIT").await.map_err(|e| {
427                        prax_query::QueryError::database(e.to_string()).with_source(e)
428                    })?;
429                    Ok(v)
430                }
431                Err(e) => {
432                    let _ = tx_conn.batch_execute("ROLLBACK").await;
433                    Err(e)
434                }
435            }
436        })
437    }
438}
439
440/// A typed query builder that uses the PostgreSQL engine.
441pub struct PgQueryBuilder<T: Model> {
442    engine: PgEngine,
443    _marker: PhantomData<T>,
444}
445
446impl<T: Model> PgQueryBuilder<T> {
447    /// Create a new query builder.
448    pub fn new(engine: PgEngine) -> Self {
449        Self {
450            engine,
451            _marker: PhantomData,
452        }
453    }
454
455    /// Get the underlying engine.
456    pub fn engine(&self) -> &PgEngine {
457        &self.engine
458    }
459}
460
461/// Decode a single aggregate result cell by its Postgres column type.
462///
463/// Aggregate result sets don't have a fixed schema — SUM over an
464/// INT4 column comes back as BIGINT, AVG returns NUMERIC, MIN/MAX
465/// preserves the source column's type, and COUNT is always BIGINT.
466/// Rather than route these through the `FromRow` machinery (which
467/// needs a model whose columns are known at compile time), we
468/// type-dispatch at runtime on `Column::type_()` and project into a
469/// [`FilterValue`].
470///
471/// NULL maps to `FilterValue::Null`. NUMERIC is returned as
472/// `FilterValue::String` because the workspace's tokio-postgres
473/// feature set doesn't enable `with-rust_decimal-*`; the aggregate
474/// result folder's numeric parser reads the text form back into a
475/// float for sum/avg accessors.
476///
477/// Unknown types fall through to `try_get::<String>` so a novel
478/// column type doesn't silently drop. Decoding failures record
479/// `FilterValue::Null` rather than aborting the whole query.
480fn decode_aggregate_cell(
481    row: &tokio_postgres::Row,
482    idx: usize,
483    ty: &tokio_postgres::types::Type,
484) -> FilterValue {
485    use tokio_postgres::types::Type;
486    match *ty {
487        Type::BOOL => row
488            .try_get::<_, Option<bool>>(idx)
489            .ok()
490            .flatten()
491            .map(FilterValue::Bool)
492            .unwrap_or(FilterValue::Null),
493        Type::INT2 => row
494            .try_get::<_, Option<i16>>(idx)
495            .ok()
496            .flatten()
497            .map(|n| FilterValue::Int(n as i64))
498            .unwrap_or(FilterValue::Null),
499        Type::INT4 => row
500            .try_get::<_, Option<i32>>(idx)
501            .ok()
502            .flatten()
503            .map(|n| FilterValue::Int(n as i64))
504            .unwrap_or(FilterValue::Null),
505        Type::INT8 => row
506            .try_get::<_, Option<i64>>(idx)
507            .ok()
508            .flatten()
509            .map(FilterValue::Int)
510            .unwrap_or(FilterValue::Null),
511        Type::FLOAT4 => row
512            .try_get::<_, Option<f32>>(idx)
513            .ok()
514            .flatten()
515            .map(|f| FilterValue::Float(f as f64))
516            .unwrap_or(FilterValue::Null),
517        Type::FLOAT8 => row
518            .try_get::<_, Option<f64>>(idx)
519            .ok()
520            .flatten()
521            .map(FilterValue::Float)
522            .unwrap_or(FilterValue::Null),
523        Type::TEXT | Type::VARCHAR | Type::CHAR | Type::NAME | Type::BPCHAR | Type::NUMERIC => row
524            .try_get::<_, Option<String>>(idx)
525            .ok()
526            .flatten()
527            .map(FilterValue::String)
528            .unwrap_or(FilterValue::Null),
529        Type::JSON | Type::JSONB => row
530            .try_get::<_, Option<serde_json::Value>>(idx)
531            .ok()
532            .flatten()
533            .map(FilterValue::Json)
534            .unwrap_or(FilterValue::Null),
535        _ => row
536            .try_get::<_, Option<String>>(idx)
537            .ok()
538            .flatten()
539            .map(FilterValue::String)
540            .unwrap_or(FilterValue::Null),
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    // Integration tests would require a real PostgreSQL database
547}