Skip to main content

fraiseql_db/postgres/adapter/
database.rs

1//! `DatabaseAdapter` and `SupportsMutations` implementations for `PostgresAdapter`.
2
3use async_trait::async_trait;
4use fraiseql_error::{FraiseQLError, Result};
5use tokio_postgres::Row;
6
7use super::{PostgresAdapter, build_where_select_sql};
8use crate::{
9    traits::{DatabaseAdapter, SupportsMutations},
10    types::{
11        DatabaseType, JsonbValue, PoolMetrics, QueryParam,
12        sql_hints::{OrderByClause, SqlProjectionHint},
13    },
14    where_clause::WhereClause,
15};
16
17/// Convert a single `tokio_postgres::Row` into a `HashMap<String, serde_json::Value>`.
18///
19/// Tries each PostgreSQL type in priority order; falls back to `Null` for
20/// types that cannot be represented as JSON.
21fn row_to_map(row: &Row) -> std::collections::HashMap<String, serde_json::Value> {
22    let mut map = std::collections::HashMap::new();
23    for (idx, column) in row.columns().iter().enumerate() {
24        let column_name = column.name().to_string();
25        let value: serde_json::Value = if let Ok(v) = row.try_get::<_, i32>(idx) {
26            serde_json::json!(v)
27        } else if let Ok(v) = row.try_get::<_, i64>(idx) {
28            serde_json::json!(v)
29        } else if let Ok(v) = row.try_get::<_, f64>(idx) {
30            serde_json::json!(v)
31        } else if let Ok(v) = row.try_get::<_, String>(idx) {
32            serde_json::json!(v)
33        } else if let Ok(v) = row.try_get::<_, bool>(idx) {
34            serde_json::json!(v)
35        } else if let Ok(v) = row.try_get::<_, serde_json::Value>(idx) {
36            v
37        } else {
38            serde_json::Value::Null
39        };
40        map.insert(column_name, value);
41    }
42    map
43}
44
45// Reason: DatabaseAdapter is defined with #[async_trait]; all implementations must match
46// its transformed method signatures to satisfy the trait contract
47// async_trait: dyn-dispatch required; remove when RTN + Send is stable (RFC 3425)
48#[async_trait]
49impl DatabaseAdapter for PostgresAdapter {
50    async fn execute_with_projection(
51        &self,
52        view: &str,
53        projection: Option<&SqlProjectionHint>,
54        where_clause: Option<&WhereClause>,
55        limit: Option<u32>,
56        offset: Option<u32>,
57        _order_by: Option<&[OrderByClause]>,
58    ) -> Result<Vec<JsonbValue>> {
59        self.execute_with_projection(view, projection, where_clause, limit, offset)
60            .await
61    }
62
63    async fn execute_where_query(
64        &self,
65        view: &str,
66        where_clause: Option<&WhereClause>,
67        limit: Option<u32>,
68        offset: Option<u32>,
69        _order_by: Option<&[OrderByClause]>,
70    ) -> Result<Vec<JsonbValue>> {
71        let (sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
72
73        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
74            .iter()
75            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
76            .collect();
77
78        self.execute_raw(&sql, &param_refs).await
79    }
80
81    async fn explain_where_query(
82        &self,
83        view: &str,
84        where_clause: Option<&WhereClause>,
85        limit: Option<u32>,
86        offset: Option<u32>,
87    ) -> Result<serde_json::Value> {
88        let (select_sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
89        // Defense-in-depth: compiler-generated SQL should never contain a
90        // semicolon, but guard against it to prevent statement injection.
91        if select_sql.contains(';') {
92            return Err(FraiseQLError::Validation {
93                message: "EXPLAIN SQL must be a single statement".into(),
94                path:    None,
95            });
96        }
97        let explain_sql = format!("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {select_sql}");
98
99        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
100            .iter()
101            .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
102            .collect();
103
104        let client = self.acquire_connection_with_retry().await?;
105        let rows = client.query(explain_sql.as_str(), &param_refs).await.map_err(|e| {
106            FraiseQLError::Database {
107                message:   format!("EXPLAIN ANALYZE failed: {e}"),
108                sql_state: e.code().map(|c| c.code().to_string()),
109            }
110        })?;
111
112        if let Some(row) = rows.first() {
113            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
114                message:   format!("Failed to parse EXPLAIN output: {e}"),
115                sql_state: None,
116            })?;
117            Ok(plan)
118        } else {
119            Ok(serde_json::Value::Null)
120        }
121    }
122
123    fn database_type(&self) -> DatabaseType {
124        DatabaseType::PostgreSQL
125    }
126
127    async fn health_check(&self) -> Result<()> {
128        // Use retry logic for health check to avoid false negatives during pool exhaustion
129        let client = self.acquire_connection_with_retry().await?;
130
131        client.query("SELECT 1", &[]).await.map_err(|e| FraiseQLError::Database {
132            message:   format!("Health check failed: {e}"),
133            sql_state: e.code().map(|c| c.code().to_string()),
134        })?;
135
136        Ok(())
137    }
138
139    #[allow(clippy::cast_possible_truncation)] // Reason: value is bounded; truncation cannot occur in practice
140    fn pool_metrics(&self) -> PoolMetrics {
141        let status = self.pool.status();
142
143        PoolMetrics {
144            total_connections:  status.size as u32,
145            idle_connections:   status.available as u32,
146            active_connections: (status.size - status.available) as u32,
147            waiting_requests:   status.waiting as u32,
148        }
149    }
150
151    /// # Security
152    ///
153    /// `sql` **must** be compiler-generated. Never pass user-supplied strings
154    /// directly — doing so would open SQL-injection vulnerabilities.
155    async fn execute_raw_query(
156        &self,
157        sql: &str,
158    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
159        // Use retry logic for connection acquisition
160        let client = self.acquire_connection_with_retry().await?;
161
162        let rows: Vec<Row> = client.query(sql, &[]).await.map_err(|e| FraiseQLError::Database {
163            message:   format!("Query execution failed: {e}"),
164            sql_state: e.code().map(|c| c.code().to_string()),
165        })?;
166
167        // Convert each row to HashMap<String, Value>
168        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
169            rows.iter().map(row_to_map).collect();
170
171        Ok(results)
172    }
173
174    async fn execute_parameterized_aggregate(
175        &self,
176        sql: &str,
177        params: &[serde_json::Value],
178    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
179        // Convert serde_json::Value params to QueryParam so that strings are bound
180        // as TEXT (not JSONB), which is required for correct WHERE comparisons against
181        // data->>'field' expressions that return TEXT.
182        let typed: Vec<QueryParam> = params.iter().cloned().map(QueryParam::from).collect();
183        let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
184            typed.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
185
186        let client = self.acquire_connection_with_retry().await?;
187        let rows: Vec<Row> =
188            client.query(sql, &param_refs).await.map_err(|e| FraiseQLError::Database {
189                message:   format!("Parameterized aggregate query failed: {e}"),
190                sql_state: e.code().map(|c| c.code().to_string()),
191            })?;
192
193        let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
194            rows.iter().map(row_to_map).collect();
195
196        Ok(results)
197    }
198
199    async fn execute_function_call(
200        &self,
201        function_name: &str,
202        args: &[serde_json::Value],
203    ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
204        // Build: SELECT * FROM "fn_name"($1, $2, ...)
205        // The function name is double-quoted so that reserved words, mixed-case
206        // names, and names with special characters are handled correctly.
207        // Any embedded double quotes are escaped by doubling them ("").
208        let quoted_fn = format!("\"{}\"", function_name.replace('"', "\"\""));
209        let placeholders: Vec<String> = (1..=args.len()).map(|i| format!("${i}")).collect();
210        let sql = format!("SELECT * FROM {quoted_fn}({})", placeholders.join(", "));
211
212        let mut client = self.acquire_connection_with_retry().await?;
213
214        // Bind each JSON argument as a text parameter (PostgreSQL can cast text→jsonb)
215        let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
216            args.iter().map(|v| v as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
217
218        if self.mutation_timing_enabled {
219            // Wrap in a transaction so SET LOCAL scopes the variable to this call only.
220            // `set_config(name, value, is_local)` with is_local=true is equivalent to
221            // SET LOCAL and is parameterized to avoid SQL injection.
222            let txn =
223                client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
224                    message:   format!("Failed to start mutation timing transaction: {e}"),
225                    sql_state: e.code().map(|c| c.code().to_string()),
226                })?;
227
228            txn.execute(
229                "SELECT set_config($1, clock_timestamp()::text, true)",
230                &[&self.timing_variable_name],
231            )
232            .await
233            .map_err(|e| FraiseQLError::Database {
234                message:   format!("Failed to set mutation timing variable: {e}"),
235                sql_state: e.code().map(|c| c.code().to_string()),
236            })?;
237
238            let rows: Vec<Row> = txn.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
239                FraiseQLError::Database {
240                    message:   format!("Function call {function_name} failed: {e}"),
241                    sql_state: e.code().map(|c| c.code().to_string()),
242                }
243            })?;
244
245            txn.commit().await.map_err(|e| FraiseQLError::Database {
246                message:   format!("Failed to commit mutation timing transaction: {e}"),
247                sql_state: e.code().map(|c| c.code().to_string()),
248            })?;
249
250            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
251                rows.iter().map(row_to_map).collect();
252
253            Ok(results)
254        } else {
255            let rows: Vec<Row> =
256                client.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
257                    FraiseQLError::Database {
258                        message:   format!("Function call {function_name} failed: {e}"),
259                        sql_state: e.code().map(|c| c.code().to_string()),
260                    }
261                })?;
262
263            let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
264                rows.iter().map(row_to_map).collect();
265
266            Ok(results)
267        }
268    }
269
270    async fn explain_query(
271        &self,
272        sql: &str,
273        _params: &[serde_json::Value],
274    ) -> Result<serde_json::Value> {
275        // Defense-in-depth: reject multi-statement input even though this SQL is
276        // compiler-generated. A semicolon would allow a second statement to be
277        // appended to the EXPLAIN prefix.
278        if sql.contains(';') {
279            return Err(FraiseQLError::Validation {
280                message: "EXPLAIN SQL must be a single statement".into(),
281                path:    None,
282            });
283        }
284        let explain_sql = format!("EXPLAIN (ANALYZE false, FORMAT JSON) {sql}");
285        let client = self.acquire_connection_with_retry().await?;
286        let rows: Vec<Row> =
287            client
288                .query(explain_sql.as_str(), &[])
289                .await
290                .map_err(|e| FraiseQLError::Database {
291                    message:   format!("EXPLAIN failed: {e}"),
292                    sql_state: e.code().map(|c| c.code().to_string()),
293                })?;
294
295        if let Some(row) = rows.first() {
296            let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
297                message:   format!("Failed to parse EXPLAIN output: {e}"),
298                sql_state: None,
299            })?;
300            Ok(plan)
301        } else {
302            Ok(serde_json::Value::Null)
303        }
304    }
305}
306
307impl SupportsMutations for PostgresAdapter {}