Skip to main content

fraiseql_db/postgres/adapter/
relay.rs

1//! `RelayDatabaseAdapter` implementation for `PostgresAdapter`.
2
3use fraiseql_error::{FraiseQLError, Result};
4
5use super::{PostgresAdapter, escape_jsonb_key};
6use crate::{
7    dialect::PostgresDialect,
8    identifier::quote_postgres_identifier,
9    postgres::where_generator::PostgresWhereGenerator,
10    traits::{CursorValue, RelayDatabaseAdapter, RelayPageResult},
11    types::{
12        QueryParam,
13        sql_hints::{OrderByClause, OrderDirection},
14    },
15    where_clause::WhereClause,
16};
17
18impl RelayDatabaseAdapter for PostgresAdapter {
19    /// Execute keyset (cursor-based) pagination against a JSONB view.
20    ///
21    /// # `totalCount` semantics
22    ///
23    /// When `include_total_count` is `true`, **two queries** are issued on the same
24    /// connection:
25    ///
26    /// 1. A count query — `SELECT COUNT(*) FROM {view} WHERE {user_filter}` — that reflects the
27    ///    **full connection** size, ignoring cursor position. This is required by the Relay Cursor
28    ///    Connections spec, which defines `totalCount` as the count of all objects in the
29    ///    connection, regardless of `after`/`before`.
30    ///
31    /// 2. A page query — the cursor-filtered, limited result set.
32    ///
33    /// The two-query approach fixes a previous bug where `COUNT(*) OVER()` ran
34    /// inside the cursor-filtered subquery, causing `totalCount` to shrink as the
35    /// cursor advanced.  It also handles the edge case where the current page is
36    /// empty but the total count is non-zero (e.g., cursor past the last row).
37    ///
38    /// When `include_total_count` is `false`, only the page query is issued.
39    ///
40    /// # Performance note
41    ///
42    /// The count query scans all rows matching the user filter without LIMIT. On
43    /// large unfiltered tables this may be slow. Mitigations:
44    /// - Only enable `totalCount` when the client explicitly requests it (enforced by the executor
45    ///   via `include_total_count`).
46    /// - Add a `statement_timeout` on the connection for relay queries on very large datasets.
47    /// - Maintain a denormalised count table or materialised view for hot paths.
48    async fn execute_relay_page(
49        &self,
50        view: &str,
51        cursor_column: &str,
52        after: Option<CursorValue>,
53        before: Option<CursorValue>,
54        limit: u32,
55        forward: bool,
56        where_clause: Option<&WhereClause>,
57        order_by: Option<&[OrderByClause]>,
58        include_total_count: bool,
59        session_vars: &[(&str, &str)],
60    ) -> Result<RelayPageResult> {
61        let quoted_view = quote_postgres_identifier(view);
62        let quoted_col = quote_postgres_identifier(cursor_column);
63
64        // ── Cursor condition (page query only, NOT the count query) ────────────
65        //
66        // Per the Relay spec, totalCount ignores cursor position. The cursor
67        // condition is therefore excluded from the count query.
68        //
69        // The cursor occupies at most one parameter slot ($1) at the front of the
70        // page query's parameter list.
71        //
72        // UUID cursors use `$1::uuid` cast; BIGINT cursors use plain `$1`.
73        let cursor_param: Option<QueryParam>;
74        let cursor_where_part: Option<String>;
75        let active_cursor = if forward { after } else { before };
76        match active_cursor {
77            None => {
78                cursor_param = None;
79                cursor_where_part = None;
80            },
81            Some(CursorValue::Int64(pk)) => {
82                let op = if forward { ">" } else { "<" };
83                cursor_param = Some(QueryParam::BigInt(pk));
84                cursor_where_part = Some(format!("{quoted_col} {op} $1"));
85            },
86            Some(CursorValue::Uuid(uuid)) => {
87                let op = if forward { ">" } else { "<" };
88                cursor_param = Some(QueryParam::Text(uuid));
89                cursor_where_part = Some(format!("{quoted_col} {op} $1::uuid"));
90            },
91        }
92        let cursor_param_count: usize = usize::from(cursor_param.is_some());
93
94        // ── User WHERE clause ──────────────────────────────────────────────────
95        //
96        // Used in BOTH the count query (offset 0) and the page query (offset by
97        // cursor_param_count so parameter indices don't collide).
98        let mut user_where_json_params: Vec<serde_json::Value> = Vec::new();
99        let page_user_where_sql: Option<String> = if let Some(clause) = where_clause {
100            let generator = PostgresWhereGenerator::new(PostgresDialect);
101            let (sql, params) = generator.generate_with_param_offset(clause, cursor_param_count)?;
102            user_where_json_params = params;
103            Some(sql)
104        } else {
105            None
106        };
107        let user_param_count = user_where_json_params.len();
108
109        // ── ORDER BY clause ────────────────────────────────────────────────────
110        //
111        // Custom sort columns first, then cursor column as tiebreaker for stable
112        // keyset pagination.
113        let order_sql = if let Some(clauses) = order_by {
114            let mut parts: Vec<String> = clauses
115                .iter()
116                .map(|c| {
117                    let dir = match c.direction {
118                        OrderDirection::Asc => "ASC",
119                        OrderDirection::Desc => "DESC",
120                    };
121                    // escape_jsonb_key is defense-in-depth: field names are already
122                    // validated as GraphQL identifiers (which cannot contain `'`).
123                    format!("data->>'{field}' {dir}", field = escape_jsonb_key(&c.field))
124                })
125                .collect();
126            let primary_dir = if forward { "ASC" } else { "DESC" };
127            parts.push(format!("{quoted_col} {primary_dir}"));
128            format!(" ORDER BY {}", parts.join(", "))
129        } else {
130            let dir = if forward { "ASC" } else { "DESC" };
131            format!(" ORDER BY {quoted_col} {dir}")
132        };
133
134        // ── Page WHERE SQL ─────────────────────────────────────────────────────
135        //
136        // Combines cursor condition AND user filter with offset parameter indices.
137        let cursor_part = cursor_where_part.as_deref().unwrap_or("");
138        let user_part =
139            page_user_where_sql.as_deref().map(|s| format!("({s})")).unwrap_or_default();
140        let page_where_sql = if cursor_part.is_empty() && user_part.is_empty() {
141            String::new()
142        } else if cursor_part.is_empty() {
143            format!(" WHERE {user_part}")
144        } else if user_part.is_empty() {
145            format!(" WHERE {cursor_part}")
146        } else {
147            format!(" WHERE {cursor_part} AND {user_part}")
148        };
149
150        // ── LIMIT parameter index ──────────────────────────────────────────────
151        let limit_idx = cursor_param_count + user_param_count + 1;
152
153        // ── Page SQL ───────────────────────────────────────────────────────────
154        //
155        // Backward pagination wraps the inner query in a subquery to re-sort
156        // the descending page back to ascending order.
157        let page_sql = if forward {
158            format!("SELECT data FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ${limit_idx}")
159        } else {
160            let inner = format!(
161                "SELECT data, {quoted_col} AS _relay_cursor \
162                 FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ${limit_idx}"
163            );
164            format!("SELECT data FROM ({inner}) _relay_page ORDER BY _relay_cursor ASC")
165        };
166
167        // ── Page params: [cursor?, user_where_params..., limit] ────────────────
168        let mut page_typed_params: Vec<QueryParam> = Vec::new();
169        if let Some(cp) = cursor_param {
170            page_typed_params.push(cp);
171        }
172        for v in &user_where_json_params {
173            page_typed_params.push(QueryParam::from(v.clone()));
174        }
175        page_typed_params.push(QueryParam::BigInt(i64::from(limit)));
176
177        let mut client = self.acquire_connection_with_retry().await?;
178
179        let (rows, total_count) = if !session_vars.is_empty() {
180            let txn = client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
181                message: format!("Failed to start transaction: {e}"),
182                sql_state: e.code().map(|c| c.code().to_string()),
183            })?;
184
185            // Set all session variables
186            for (name, value) in session_vars {
187                txn.execute("SELECT set_config($1, $2, true)", &[name, value]).await.map_err(|e| FraiseQLError::Database {
188                    message: format!("Failed to set session variable {name}: {e}"),
189                    sql_state: e.code().map(|c| c.code().to_string()),
190                })?;
191            }
192
193            // ── Execute page query ─────────────────────────────────────────────────
194            let page_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = page_typed_params
195                .iter()
196                .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
197                .collect();
198
199            let page_rows = txn.query(&page_sql, &page_param_refs).await.map_err(|e| FraiseQLError::Database {
200                message: format!("Relay page query failed: {e}"),
201                sql_state: e.code().map(|c| c.code().to_string()),
202            })?;
203
204            let rows: Vec<crate::types::JsonbValue> = page_rows
205                .iter()
206                .map(|row| {
207                    let data: serde_json::Value = row.get("data");
208                    crate::types::JsonbValue::new(data)
209                })
210                .collect();
211
212            // ── Count query (Relay spec: totalCount ignores cursor position) ────────
213            //
214            // The WHERE clause is regenerated with offset 0 (no cursor parameter prefix)
215            // because this is a standalone query. Using the same connection avoids an
216            // extra pool acquisition.
217            let total_count = if include_total_count {
218                let (count_sql, count_typed_params) = if let Some(clause) = where_clause {
219                    let generator = PostgresWhereGenerator::new(PostgresDialect);
220                    let (where_sql, params) = generator.generate_with_param_offset(clause, 0)?;
221                    let sql = format!("SELECT COUNT(*) FROM {quoted_view} WHERE ({where_sql})");
222                    let typed: Vec<QueryParam> = params.into_iter().map(QueryParam::from).collect();
223                    (sql, typed)
224                } else {
225                    (format!("SELECT COUNT(*) FROM {quoted_view}"), Vec::<QueryParam>::new())
226                };
227
228                let count_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
229                    count_typed_params
230                        .iter()
231                        .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
232                        .collect();
233
234                let count_row = txn.query_one(&count_sql, &count_param_refs).await.map_err(|e| FraiseQLError::Database {
235                    message: format!("Relay count query failed: {e}"),
236                    sql_state: e.code().map(|c| c.code().to_string()),
237                })?;
238
239                let total: i64 = count_row.get(0);
240                // cast_unsigned() is the clippy-recommended alternative to `as u64` for i64;
241                // it has the same bit-pattern semantics but makes the sign-loss intent explicit.
242                // Row counts from COUNT(*) are always non-negative so sign loss is impossible.
243                Some(total.cast_unsigned())
244            } else {
245                None
246            };
247
248            txn.commit().await.map_err(|e| FraiseQLError::Database {
249                message: format!("Failed to commit relay transaction: {e}"),
250                sql_state: e.code().map(|c| c.code().to_string()),
251            })?;
252
253            (rows, total_count)
254        } else {
255            // ── Execute page query ─────────────────────────────────────────────────
256            let page_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = page_typed_params
257                .iter()
258                .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
259                .collect();
260
261            let page_rows = client.query(&page_sql, &page_param_refs).await.map_err(|e| FraiseQLError::Database {
262                message: format!("Relay page query failed: {e}"),
263                sql_state: e.code().map(|c| c.code().to_string()),
264            })?;
265
266            let rows: Vec<crate::types::JsonbValue> = page_rows
267                .iter()
268                .map(|row| {
269                    let data: serde_json::Value = row.get("data");
270                    crate::types::JsonbValue::new(data)
271                })
272                .collect();
273
274            // ── Count query (Relay spec: totalCount ignores cursor position) ────────
275            //
276            // The WHERE clause is regenerated with offset 0 (no cursor parameter prefix)
277            // because this is a standalone query. Using the same connection avoids an
278            // extra pool acquisition.
279            let total_count = if include_total_count {
280                let (count_sql, count_typed_params) = if let Some(clause) = where_clause {
281                    let generator = PostgresWhereGenerator::new(PostgresDialect);
282                    let (where_sql, params) = generator.generate_with_param_offset(clause, 0)?;
283                    let sql = format!("SELECT COUNT(*) FROM {quoted_view} WHERE ({where_sql})");
284                    let typed: Vec<QueryParam> = params.into_iter().map(QueryParam::from).collect();
285                    (sql, typed)
286                } else {
287                    (format!("SELECT COUNT(*) FROM {quoted_view}"), Vec::<QueryParam>::new())
288                };
289
290                let count_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
291                    count_typed_params
292                        .iter()
293                        .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
294                        .collect();
295
296                let count_row = client.query_one(&count_sql, &count_param_refs).await.map_err(|e| FraiseQLError::Database {
297                    message: format!("Relay count query failed: {e}"),
298                    sql_state: e.code().map(|c| c.code().to_string()),
299                })?;
300
301                let total: i64 = count_row.get(0);
302                // cast_unsigned() is the clippy-recommended alternative to `as u64` for i64;
303                // it has the same bit-pattern semantics but makes the sign-loss intent explicit.
304                // Row counts from COUNT(*) are always non-negative so sign loss is impossible.
305                Some(total.cast_unsigned())
306            } else {
307                None
308            };
309
310            (rows, total_count)
311        };
312
313        Ok(RelayPageResult { rows, total_count })
314    }
315}