Skip to main content

fraiseql_db/postgres/adapter/
database.rs

1//! `DatabaseAdapter` and `SupportsMutations` implementations for `PostgresAdapter`.
2
3use async_trait::async_trait;
4use bytes::BufMut as _;
5use fraiseql_error::{FraiseQLError, Result};
6use tokio_postgres::Row;
7
8use super::{PostgresAdapter, build_where_select_sql, build_where_select_sql_ordered};
9use crate::{
10    identifier::quote_postgres_identifier,
11    traits::{DatabaseAdapter, SupportsMutations},
12    types::{
13        DatabaseType, JsonbValue, PoolMetrics, QueryParam,
14        sql_hints::{OrderByClause, SqlProjectionHint},
15    },
16    where_clause::WhereClause,
17};
18
19/// PostgreSQL SQLSTATE 42703: undefined column.
20const PG_UNDEFINED_COLUMN: &str = "42703";
21
22/// A flexible SQL parameter that binds to any PostgreSQL type.
23///
24/// Solves the impedance mismatch between `serde_json::Value` (only accepts JSON/JSONB)
25/// and `Option<String>` (only accepts text-family types) when binding function-call
26/// arguments whose types are resolved at runtime from the function signature.
27///
28/// Serialisation strategy (binary wire format):
29/// - `JSONB`: 1-byte version header (1) + UTF-8 JSON bytes
30/// - `JSON`: UTF-8 JSON bytes
31/// - `UUID`: 16-byte big-endian UUID
32/// - `INT4`: 4-byte big-endian i32
33/// - `INT8`: 8-byte big-endian i64
34/// - `BOOL`: 1-byte (0 or 1)
35/// - All other types: UTF-8 bytes (PostgreSQL text binary = raw UTF-8)
36#[derive(Debug)]
37enum FlexParam {
38    /// SQL NULL — accepted by any PostgreSQL type.
39    Null,
40    /// A text-encoded value; binary-serialised according to the server-resolved type.
41    Text(String),
42}
43
44impl tokio_postgres::types::ToSql for FlexParam {
45    fn to_sql(
46        &self,
47        ty: &tokio_postgres::types::Type,
48        out: &mut bytes::BytesMut,
49    ) -> std::result::Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
50    {
51        use tokio_postgres::types::{IsNull, Type};
52        match self {
53            Self::Null => Ok(IsNull::Yes),
54            Self::Text(s) => {
55                if *ty == Type::JSONB {
56                    // JSONB binary wire format: 1-byte version (1) + JSON bytes
57                    out.put_u8(1);
58                    out.extend_from_slice(s.as_bytes());
59                } else if *ty == Type::JSON {
60                    out.extend_from_slice(s.as_bytes());
61                } else if *ty == Type::UUID {
62                    let uuid = uuid::Uuid::parse_str(s)?;
63                    out.extend_from_slice(uuid.as_bytes());
64                } else if *ty == Type::INT4 {
65                    let n: i32 = s.parse()?;
66                    out.put_i32(n);
67                } else if *ty == Type::INT8 {
68                    let n: i64 = s.parse()?;
69                    out.put_i64(n);
70                } else if *ty == Type::BOOL {
71                    let b: bool = s.parse()?;
72                    out.put_u8(u8::from(b));
73                } else {
74                    // TEXT, VARCHAR, BPCHAR, NAME, UNKNOWN, and any user-defined type:
75                    // UTF-8 bytes are the binary wire representation for text-family types.
76                    out.extend_from_slice(s.as_bytes());
77                }
78                Ok(IsNull::No)
79            },
80        }
81    }
82
83    fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
84        // Accepts all types; per-type serialisation is handled in `to_sql`.
85        true
86    }
87
88    fn to_sql_checked(
89        &self,
90        ty: &tokio_postgres::types::Type,
91        out: &mut bytes::BytesMut,
92    ) -> std::result::Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
93    {
94        // `accepts()` returns true for all types, so the standard WrongType check is
95        // unnecessary.  Delegate directly to `to_sql`.
96        self.to_sql(ty, out)
97    }
98}
99
100/// Enrich a `FraiseQLError::Database` error for PostgreSQL SQLSTATE 42703 (undefined column)
101/// when the WHERE clause contains `NativeField` conditions.
102///
103/// Native columns may be inferred automatically at compile time from `ID`/`UUID`-typed
104/// arguments.  If the column does not exist on the target table at runtime, the raw
105/// PostgreSQL error is replaced with a diagnostic message that names the native columns
106/// involved and explains how to fix the schema.
107fn enrich_undefined_column_error(
108    err: FraiseQLError,
109    view: &str,
110    where_clause: Option<&WhereClause>,
111) -> FraiseQLError {
112    let FraiseQLError::Database { ref sql_state, .. } = err else {
113        return err;
114    };
115    if sql_state.as_deref() != Some(PG_UNDEFINED_COLUMN) {
116        return err;
117    }
118    let native_cols: Vec<&str> =
119        where_clause.map(|wc| wc.native_column_names()).unwrap_or_default();
120    if native_cols.is_empty() {
121        return err;
122    }
123    FraiseQLError::Database {
124        message:   format!(
125            "Column(s) {:?} referenced as native column(s) on `{view}` do not exist. \
126             These columns were auto-inferred from ID/UUID-typed query arguments. \
127             Either add the column(s) to the table/view, or set \
128             `native_columns = {{}}` explicitly in your schema to disable inference.",
129            native_cols,
130        ),
131        sql_state: Some(PG_UNDEFINED_COLUMN.to_string()),
132    }
133}
134
135/// Convert a single `tokio_postgres::Row` into a `HashMap<String, serde_json::Value>`.
136///
137/// Tries each PostgreSQL type in priority order; falls back to `Null` for
138/// types that cannot be represented as JSON.
139fn row_to_map(row: &Row) -> std::collections::HashMap<String, serde_json::Value> {
140    let mut map = std::collections::HashMap::new();
141    for (idx, column) in row.columns().iter().enumerate() {
142        let column_name = column.name().to_string();
143        let value: serde_json::Value = if let Ok(v) = row.try_get::<_, i32>(idx) {
144            serde_json::json!(v)
145        } else if let Ok(v) = row.try_get::<_, i64>(idx) {
146            serde_json::json!(v)
147        } else if let Ok(v) = row.try_get::<_, f64>(idx) {
148            serde_json::json!(v)
149        } else if let Ok(v) = row.try_get::<_, String>(idx) {
150            serde_json::json!(v)
151        } else if let Ok(v) = row.try_get::<_, bool>(idx) {
152            serde_json::json!(v)
153        } else if let Ok(v) = row.try_get::<_, serde_json::Value>(idx) {
154            v
155        } else {
156            serde_json::Value::Null
157        };
158        map.insert(column_name, value);
159    }
160    map
161}
162
163// Reason: DatabaseAdapter is defined with #[async_trait]; all implementations must match
164// its transformed method signatures to satisfy the trait contract
165// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
166#[async_trait]
167impl DatabaseAdapter for PostgresAdapter {
168    async fn execute_with_projection(
169        &self,
170        view: &str,
171        projection: Option<&SqlProjectionHint>,
172        where_clause: Option<&WhereClause>,
173        limit: Option<u32>,
174        offset: Option<u32>,
175        order_by: Option<&[OrderByClause]>,
176    ) -> Result<Vec<JsonbValue>> {
177        self.execute_with_projection_impl(view, projection, where_clause, limit, offset, order_by)
178            .await
179    }
180
181    async fn execute_where_query(
182        &self,
183        view: &str,
184        where_clause: Option<&WhereClause>,
185        limit: Option<u32>,
186        offset: Option<u32>,
187        order_by: Option<&[OrderByClause]>,
188    ) -> Result<Vec<JsonbValue>> {
189        let (sql, typed_params) =
190            build_where_select_sql_ordered(view, where_clause, limit, offset, order_by)?;
191
192        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
193            .iter()
194            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
195            .collect();
196
197        self.execute_raw(&sql, &param_refs)
198            .await
199            .map_err(|e| enrich_undefined_column_error(e, view, where_clause))
200    }
201
202    async fn explain_where_query(
203        &self,
204        view: &str,
205        where_clause: Option<&WhereClause>,
206        limit: Option<u32>,
207        offset: Option<u32>,
208    ) -> Result<serde_json::Value> {
209        let (select_sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
210        // Defense-in-depth: compiler-generated SQL should never contain a
211        // semicolon, but guard against it to prevent statement injection.
212        if select_sql.contains(';') {
213            return Err(FraiseQLError::Validation {
214                message: "EXPLAIN SQL must be a single statement".into(),
215                path:    None,
216            });
217        }
218        let explain_sql = format!("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {select_sql}");
219
220        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
221            .iter()
222            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
223            .collect();
224
225        let client = self.acquire_connection_with_retry().await?;
226        let rows = client.query(explain_sql.as_str(), &param_refs).await.map_err(|e| {
227            FraiseQLError::Database {
228                message:   format!("EXPLAIN ANALYZE failed: {e}"),
229                sql_state: e.code().map(|c| c.code().to_string()),
230            }
231        })?;
232
233        if let Some(row) = rows.first() {
234            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
235                message:   format!("Failed to parse EXPLAIN output: {e}"),
236                sql_state: None,
237            })?;
238            Ok(plan)
239        } else {
240            Ok(serde_json::Value::Null)
241        }
242    }
243
244    fn database_type(&self) -> DatabaseType {
245        DatabaseType::PostgreSQL
246    }
247
248    async fn health_check(&self) -> Result<()> {
249        // Use retry logic for health check to avoid false negatives during pool exhaustion
250        let client = self.acquire_connection_with_retry().await?;
251
252        client.query("SELECT 1", &[]).await.map_err(|e| FraiseQLError::Database {
253            message:   format!("Health check failed: {e}"),
254            sql_state: e.code().map(|c| c.code().to_string()),
255        })?;
256
257        Ok(())
258    }
259
260    #[allow(clippy::cast_possible_truncation)] // Reason: value is bounded; truncation cannot occur in practice
261    fn pool_metrics(&self) -> PoolMetrics {
262        let status = self.pool.status();
263
264        PoolMetrics {
265            total_connections:  status.size as u32,
266            idle_connections:   status.available as u32,
267            active_connections: (status.size - status.available) as u32,
268            waiting_requests:   status.waiting as u32,
269        }
270    }
271
272    /// # Security
273    ///
274    /// `sql` **must** be compiler-generated. Never pass user-supplied strings
275    /// directly — doing so would open SQL-injection vulnerabilities.
276    async fn execute_raw_query(
277        &self,
278        sql: &str,
279    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
280        // Use retry logic for connection acquisition
281        let client = self.acquire_connection_with_retry().await?;
282
283        let rows: Vec<Row> = client.query(sql, &[]).await.map_err(|e| FraiseQLError::Database {
284            message:   format!("Query execution failed: {e}"),
285            sql_state: e.code().map(|c| c.code().to_string()),
286        })?;
287
288        // Convert each row to HashMap<String, Value>
289        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
290            rows.iter().map(row_to_map).collect();
291
292        Ok(results)
293    }
294
295    async fn execute_parameterized_aggregate(
296        &self,
297        sql: &str,
298        params: &[serde_json::Value],
299    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
300        // Convert serde_json::Value params to QueryParam so that strings are bound
301        // as TEXT (not JSONB), which is required for correct WHERE comparisons against
302        // data->>'field' expressions that return TEXT.
303        let typed: Vec<QueryParam> = params.iter().cloned().map(QueryParam::from).collect();
304        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
305            typed.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
306
307        let client = self.acquire_connection_with_retry().await?;
308        let rows: Vec<Row> =
309            client.query(sql, &param_refs).await.map_err(|e| FraiseQLError::Database {
310                message:   format!("Parameterized aggregate query failed: {e}"),
311                sql_state: e.code().map(|c| c.code().to_string()),
312            })?;
313
314        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
315            rows.iter().map(row_to_map).collect();
316
317        Ok(results)
318    }
319
320    async fn execute_function_call(
321        &self,
322        function_name: &str,
323        args: &[serde_json::Value],
324    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
325        // Build: SELECT * FROM "fn_name"($1, $2, ...)
326        // Use the standard identifier quoting utility so that schema-qualified
327        // names like "benchmark.fn_update_user" are correctly split into
328        // "benchmark"."fn_update_user" instead of being wrapped as a single
329        // identifier.
330        let quoted_fn = quote_postgres_identifier(function_name);
331        let placeholders: Vec<String> = (1..=args.len()).map(|i| format!("${i}")).collect();
332        let sql = format!("SELECT * FROM {quoted_fn}({})", placeholders.join(", "));
333
334        let mut client = self.acquire_connection_with_retry().await?;
335
336        // Convert serde_json::Value arguments to FlexParam for binding.
337        //
338        // serde_json::Value only accepts JSON/JSONB types; Option<String> only accepts
339        // text-family types.  Neither works universally when the function signature
340        // contains a mix of JSONB, UUID, INT4, and TEXT parameters.  FlexParam accepts
341        // all PostgreSQL types and serialises each value in the correct binary wire
342        // format for the server-resolved parameter type.
343        let flex_args: Vec<FlexParam> = args
344            .iter()
345            .map(|v| match v {
346                serde_json::Value::Null => FlexParam::Null,
347                serde_json::Value::String(s) => FlexParam::Text(s.clone()),
348                _ => FlexParam::Text(v.to_string()),
349            })
350            .collect();
351        let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = flex_args
352            .iter()
353            .map(|v| v as &(dyn tokio_postgres::types::ToSql + Sync))
354            .collect();
355
356        if self.mutation_timing_enabled {
357            // Wrap in a transaction so SET LOCAL scopes the variable to this call only.
358            // `set_config(name, value, is_local)` with is_local=true is equivalent to
359            // SET LOCAL and is parameterized to avoid SQL injection.
360            let txn =
361                client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
362                    message:   format!("Failed to start mutation timing transaction: {e}"),
363                    sql_state: e.code().map(|c| c.code().to_string()),
364                })?;
365
366            txn.execute(
367                "SELECT set_config($1, clock_timestamp()::text, true)",
368                &[&self.timing_variable_name],
369            )
370            .await
371            .map_err(|e| FraiseQLError::Database {
372                message:   format!("Failed to set mutation timing variable: {e}"),
373                sql_state: e.code().map(|c| c.code().to_string()),
374            })?;
375
376            let rows: Vec<Row> = txn.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
377                let detail = e.as_db_error().map_or("", |d| d.message());
378                FraiseQLError::Database {
379                    message:   format!("Function call {function_name} failed: {e}: {detail}"),
380                    sql_state: e.code().map(|c| c.code().to_string()),
381                }
382            })?;
383
384            txn.commit().await.map_err(|e| FraiseQLError::Database {
385                message:   format!("Failed to commit mutation timing transaction: {e}"),
386                sql_state: e.code().map(|c| c.code().to_string()),
387            })?;
388
389            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
390                rows.iter().map(row_to_map).collect();
391
392            Ok(results)
393        } else {
394            let rows: Vec<Row> =
395                client.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
396                    let detail = e.as_db_error().map_or("", |d| d.message());
397                    FraiseQLError::Database {
398                        message:   format!("Function call {function_name} failed: {e}: {detail}"),
399                        sql_state: e.code().map(|c| c.code().to_string()),
400                    }
401                })?;
402
403            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
404                rows.iter().map(row_to_map).collect();
405
406            Ok(results)
407        }
408    }
409
410    async fn set_session_variables(&self, variables: &[(&str, &str)]) -> Result<()> {
411        if variables.is_empty() {
412            return Ok(());
413        }
414        let client = self.acquire_connection_with_retry().await?;
415        for (name, value) in variables {
416            client
417                .execute("SELECT set_config($1, $2, true)", &[name, value])
418                .await
419                .map_err(|e| FraiseQLError::Database {
420                    message:   format!("set_config({name:?}) failed: {e}"),
421                    sql_state: e.code().map(|c| c.code().to_string()),
422                })?;
423        }
424        Ok(())
425    }
426
427    async fn explain_query(
428        &self,
429        sql: &str,
430        _params: &[serde_json::Value],
431    ) -> Result<serde_json::Value> {
432        // Defense-in-depth: reject multi-statement input even though this SQL is
433        // compiler-generated. A semicolon would allow a second statement to be
434        // appended to the EXPLAIN prefix.
435        if sql.contains(';') {
436            return Err(FraiseQLError::Validation {
437                message: "EXPLAIN SQL must be a single statement".into(),
438                path:    None,
439            });
440        }
441        let explain_sql = format!("EXPLAIN (ANALYZE false, FORMAT JSON) {sql}");
442        let client = self.acquire_connection_with_retry().await?;
443        let rows: Vec<Row> =
444            client
445                .query(explain_sql.as_str(), &[])
446                .await
447                .map_err(|e| FraiseQLError::Database {
448                    message:   format!("EXPLAIN failed: {e}"),
449                    sql_state: e.code().map(|c| c.code().to_string()),
450                })?;
451
452        if let Some(row) = rows.first() {
453            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
454                message:   format!("Failed to parse EXPLAIN output: {e}"),
455                sql_state: None,
456            })?;
457            Ok(plan)
458        } else {
459            Ok(serde_json::Value::Null)
460        }
461    }
462}
463
464impl SupportsMutations for PostgresAdapter {}