Skip to main content

fraiseql_db/postgres/adapter/
database.rs

1//! `DatabaseAdapter` and `SupportsMutations` implementations for `PostgresAdapter`.
2
3use async_trait::async_trait;
4use fraiseql_error::{FraiseQLError, Result};
5use tokio_postgres::Row;
6
7use super::{PostgresAdapter, build_where_select_sql};
8use crate::{
9    identifier::quote_postgres_identifier,
10    traits::{DatabaseAdapter, SupportsMutations},
11    types::{
12        DatabaseType, JsonbValue, PoolMetrics, QueryParam,
13        sql_hints::{OrderByClause, SqlProjectionHint},
14    },
15    where_clause::WhereClause,
16};
17
18/// PostgreSQL SQLSTATE 42703: undefined column.
19const PG_UNDEFINED_COLUMN: &str = "42703";
20
21/// Enrich a `FraiseQLError::Database` error for PostgreSQL SQLSTATE 42703 (undefined column)
22/// when the WHERE clause contains `NativeField` conditions.
23///
24/// Native columns may be inferred automatically at compile time from `ID`/`UUID`-typed
25/// arguments.  If the column does not exist on the target table at runtime, the raw
26/// PostgreSQL error is replaced with a diagnostic message that names the native columns
27/// involved and explains how to fix the schema.
28fn enrich_undefined_column_error(
29    err: FraiseQLError,
30    view: &str,
31    where_clause: Option<&WhereClause>,
32) -> FraiseQLError {
33    let FraiseQLError::Database { ref sql_state, .. } = err else {
34        return err;
35    };
36    if sql_state.as_deref() != Some(PG_UNDEFINED_COLUMN) {
37        return err;
38    }
39    let native_cols: Vec<&str> = where_clause
40        .map(|wc| wc.native_column_names())
41        .unwrap_or_default();
42    if native_cols.is_empty() {
43        return err;
44    }
45    FraiseQLError::Database {
46        message: format!(
47            "Column(s) {:?} referenced as native column(s) on `{view}` do not exist. \
48             These columns were auto-inferred from ID/UUID-typed query arguments. \
49             Either add the column(s) to the table/view, or set \
50             `native_columns = {{}}` explicitly in your schema to disable inference.",
51            native_cols,
52        ),
53        sql_state: Some(PG_UNDEFINED_COLUMN.to_string()),
54    }
55}
56
57/// Convert a single `tokio_postgres::Row` into a `HashMap<String, serde_json::Value>`.
58///
59/// Tries each PostgreSQL type in priority order; falls back to `Null` for
60/// types that cannot be represented as JSON.
61fn row_to_map(row: &Row) -> std::collections::HashMap<String, serde_json::Value> {
62    let mut map = std::collections::HashMap::new();
63    for (idx, column) in row.columns().iter().enumerate() {
64        let column_name = column.name().to_string();
65        let value: serde_json::Value = if let Ok(v) = row.try_get::<_, i32>(idx) {
66            serde_json::json!(v)
67        } else if let Ok(v) = row.try_get::<_, i64>(idx) {
68            serde_json::json!(v)
69        } else if let Ok(v) = row.try_get::<_, f64>(idx) {
70            serde_json::json!(v)
71        } else if let Ok(v) = row.try_get::<_, String>(idx) {
72            serde_json::json!(v)
73        } else if let Ok(v) = row.try_get::<_, bool>(idx) {
74            serde_json::json!(v)
75        } else if let Ok(v) = row.try_get::<_, serde_json::Value>(idx) {
76            v
77        } else {
78            serde_json::Value::Null
79        };
80        map.insert(column_name, value);
81    }
82    map
83}
84
85// Reason: DatabaseAdapter is defined with #[async_trait]; all implementations must match
86// its transformed method signatures to satisfy the trait contract
87// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
88#[async_trait]
89impl DatabaseAdapter for PostgresAdapter {
90    async fn execute_with_projection(
91        &self,
92        view: &str,
93        projection: Option<&SqlProjectionHint>,
94        where_clause: Option<&WhereClause>,
95        limit: Option<u32>,
96        offset: Option<u32>,
97        _order_by: Option<&[OrderByClause]>,
98    ) -> Result<Vec<JsonbValue>> {
99        self.execute_with_projection(view, projection, where_clause, limit, offset)
100            .await
101    }
102
103    async fn execute_where_query(
104        &self,
105        view: &str,
106        where_clause: Option<&WhereClause>,
107        limit: Option<u32>,
108        offset: Option<u32>,
109        _order_by: Option<&[OrderByClause]>,
110    ) -> Result<Vec<JsonbValue>> {
111        let (sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
112
113        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
114            .iter()
115            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
116            .collect();
117
118        self.execute_raw(&sql, &param_refs).await.map_err(|e| {
119            enrich_undefined_column_error(e, view, where_clause)
120        })
121    }
122
123    async fn explain_where_query(
124        &self,
125        view: &str,
126        where_clause: Option<&WhereClause>,
127        limit: Option<u32>,
128        offset: Option<u32>,
129    ) -> Result<serde_json::Value> {
130        let (select_sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
131        // Defense-in-depth: compiler-generated SQL should never contain a
132        // semicolon, but guard against it to prevent statement injection.
133        if select_sql.contains(';') {
134            return Err(FraiseQLError::Validation {
135                message: "EXPLAIN SQL must be a single statement".into(),
136                path:    None,
137            });
138        }
139        let explain_sql = format!("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {select_sql}");
140
141        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
142            .iter()
143            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
144            .collect();
145
146        let client = self.acquire_connection_with_retry().await?;
147        let rows = client.query(explain_sql.as_str(), &param_refs).await.map_err(|e| {
148            FraiseQLError::Database {
149                message:   format!("EXPLAIN ANALYZE failed: {e}"),
150                sql_state: e.code().map(|c| c.code().to_string()),
151            }
152        })?;
153
154        if let Some(row) = rows.first() {
155            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
156                message:   format!("Failed to parse EXPLAIN output: {e}"),
157                sql_state: None,
158            })?;
159            Ok(plan)
160        } else {
161            Ok(serde_json::Value::Null)
162        }
163    }
164
165    fn database_type(&self) -> DatabaseType {
166        DatabaseType::PostgreSQL
167    }
168
169    async fn health_check(&self) -> Result<()> {
170        // Use retry logic for health check to avoid false negatives during pool exhaustion
171        let client = self.acquire_connection_with_retry().await?;
172
173        client.query("SELECT 1", &[]).await.map_err(|e| FraiseQLError::Database {
174            message:   format!("Health check failed: {e}"),
175            sql_state: e.code().map(|c| c.code().to_string()),
176        })?;
177
178        Ok(())
179    }
180
181    #[allow(clippy::cast_possible_truncation)] // Reason: value is bounded; truncation cannot occur in practice
182    fn pool_metrics(&self) -> PoolMetrics {
183        let status = self.pool.status();
184
185        PoolMetrics {
186            total_connections:  status.size as u32,
187            idle_connections:   status.available as u32,
188            active_connections: (status.size - status.available) as u32,
189            waiting_requests:   status.waiting as u32,
190        }
191    }
192
193    /// # Security
194    ///
195    /// `sql` **must** be compiler-generated. Never pass user-supplied strings
196    /// directly — doing so would open SQL-injection vulnerabilities.
197    async fn execute_raw_query(
198        &self,
199        sql: &str,
200    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
201        // Use retry logic for connection acquisition
202        let client = self.acquire_connection_with_retry().await?;
203
204        let rows: Vec<Row> = client.query(sql, &[]).await.map_err(|e| FraiseQLError::Database {
205            message:   format!("Query execution failed: {e}"),
206            sql_state: e.code().map(|c| c.code().to_string()),
207        })?;
208
209        // Convert each row to HashMap<String, Value>
210        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
211            rows.iter().map(row_to_map).collect();
212
213        Ok(results)
214    }
215
216    async fn execute_parameterized_aggregate(
217        &self,
218        sql: &str,
219        params: &[serde_json::Value],
220    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
221        // Convert serde_json::Value params to QueryParam so that strings are bound
222        // as TEXT (not JSONB), which is required for correct WHERE comparisons against
223        // data->>'field' expressions that return TEXT.
224        let typed: Vec<QueryParam> = params.iter().cloned().map(QueryParam::from).collect();
225        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
226            typed.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
227
228        let client = self.acquire_connection_with_retry().await?;
229        let rows: Vec<Row> =
230            client.query(sql, &param_refs).await.map_err(|e| FraiseQLError::Database {
231                message:   format!("Parameterized aggregate query failed: {e}"),
232                sql_state: e.code().map(|c| c.code().to_string()),
233            })?;
234
235        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
236            rows.iter().map(row_to_map).collect();
237
238        Ok(results)
239    }
240
241    async fn execute_function_call(
242        &self,
243        function_name: &str,
244        args: &[serde_json::Value],
245    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
246        // Build: SELECT * FROM "fn_name"($1, $2, ...)
247        // Use the standard identifier quoting utility so that schema-qualified
248        // names like "benchmark.fn_update_user" are correctly split into
249        // "benchmark"."fn_update_user" instead of being wrapped as a single
250        // identifier.
251        let quoted_fn = quote_postgres_identifier(function_name);
252        let placeholders: Vec<String> = (1..=args.len()).map(|i| format!("${i}")).collect();
253        let sql = format!("SELECT * FROM {quoted_fn}({})", placeholders.join(", "));
254
255        let mut client = self.acquire_connection_with_retry().await?;
256
257        // Bind each JSON argument as a text parameter (PostgreSQL can cast text→jsonb)
258        let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
259            args.iter().map(|v| v as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
260
261        if self.mutation_timing_enabled {
262            // Wrap in a transaction so SET LOCAL scopes the variable to this call only.
263            // `set_config(name, value, is_local)` with is_local=true is equivalent to
264            // SET LOCAL and is parameterized to avoid SQL injection.
265            let txn =
266                client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
267                    message:   format!("Failed to start mutation timing transaction: {e}"),
268                    sql_state: e.code().map(|c| c.code().to_string()),
269                })?;
270
271            txn.execute(
272                "SELECT set_config($1, clock_timestamp()::text, true)",
273                &[&self.timing_variable_name],
274            )
275            .await
276            .map_err(|e| FraiseQLError::Database {
277                message:   format!("Failed to set mutation timing variable: {e}"),
278                sql_state: e.code().map(|c| c.code().to_string()),
279            })?;
280
281            let rows: Vec<Row> = txn.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
282                FraiseQLError::Database {
283                    message:   format!("Function call {function_name} failed: {e}"),
284                    sql_state: e.code().map(|c| c.code().to_string()),
285                }
286            })?;
287
288            txn.commit().await.map_err(|e| FraiseQLError::Database {
289                message:   format!("Failed to commit mutation timing transaction: {e}"),
290                sql_state: e.code().map(|c| c.code().to_string()),
291            })?;
292
293            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
294                rows.iter().map(row_to_map).collect();
295
296            Ok(results)
297        } else {
298            let rows: Vec<Row> =
299                client.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
300                    FraiseQLError::Database {
301                        message:   format!("Function call {function_name} failed: {e}"),
302                        sql_state: e.code().map(|c| c.code().to_string()),
303                    }
304                })?;
305
306            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
307                rows.iter().map(row_to_map).collect();
308
309            Ok(results)
310        }
311    }
312
313    async fn explain_query(
314        &self,
315        sql: &str,
316        _params: &[serde_json::Value],
317    ) -> Result<serde_json::Value> {
318        // Defense-in-depth: reject multi-statement input even though this SQL is
319        // compiler-generated. A semicolon would allow a second statement to be
320        // appended to the EXPLAIN prefix.
321        if sql.contains(';') {
322            return Err(FraiseQLError::Validation {
323                message: "EXPLAIN SQL must be a single statement".into(),
324                path:    None,
325            });
326        }
327        let explain_sql = format!("EXPLAIN (ANALYZE false, FORMAT JSON) {sql}");
328        let client = self.acquire_connection_with_retry().await?;
329        let rows: Vec<Row> =
330            client
331                .query(explain_sql.as_str(), &[])
332                .await
333                .map_err(|e| FraiseQLError::Database {
334                    message:   format!("EXPLAIN failed: {e}"),
335                    sql_state: e.code().map(|c| c.code().to_string()),
336                })?;
337
338        if let Some(row) = rows.first() {
339            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
340                message:   format!("Failed to parse EXPLAIN output: {e}"),
341                sql_state: None,
342            })?;
343            Ok(plan)
344        } else {
345            Ok(serde_json::Value::Null)
346        }
347    }
348}
349
350impl SupportsMutations for PostgresAdapter {}