fraiseql_db/postgres/adapter/
relay.rs1use fraiseql_error::{FraiseQLError, Result};
4
5use super::{PostgresAdapter, escape_jsonb_key};
6use crate::{
7 dialect::PostgresDialect,
8 identifier::quote_postgres_identifier,
9 postgres::where_generator::PostgresWhereGenerator,
10 traits::{CursorValue, RelayDatabaseAdapter, RelayPageResult},
11 types::{
12 QueryParam,
13 sql_hints::{OrderByClause, OrderDirection},
14 },
15 where_clause::WhereClause,
16};
17
18impl RelayDatabaseAdapter for PostgresAdapter {
19 async fn execute_relay_page(
49 &self,
50 view: &str,
51 cursor_column: &str,
52 after: Option<CursorValue>,
53 before: Option<CursorValue>,
54 limit: u32,
55 forward: bool,
56 where_clause: Option<&WhereClause>,
57 order_by: Option<&[OrderByClause]>,
58 include_total_count: bool,
59 session_vars: &[(&str, &str)],
60 ) -> Result<RelayPageResult> {
61 let quoted_view = quote_postgres_identifier(view);
62 let quoted_col = quote_postgres_identifier(cursor_column);
63
64 let cursor_param: Option<QueryParam>;
74 let cursor_where_part: Option<String>;
75 let active_cursor = if forward { after } else { before };
76 match active_cursor {
77 None => {
78 cursor_param = None;
79 cursor_where_part = None;
80 },
81 Some(CursorValue::Int64(pk)) => {
82 let op = if forward { ">" } else { "<" };
83 cursor_param = Some(QueryParam::BigInt(pk));
84 cursor_where_part = Some(format!("{quoted_col} {op} $1"));
85 },
86 Some(CursorValue::Uuid(uuid)) => {
87 let op = if forward { ">" } else { "<" };
88 cursor_param = Some(QueryParam::Text(uuid));
89 cursor_where_part = Some(format!("{quoted_col} {op} $1::uuid"));
90 },
91 }
92 let cursor_param_count: usize = usize::from(cursor_param.is_some());
93
94 let mut user_where_json_params: Vec<serde_json::Value> = Vec::new();
99 let page_user_where_sql: Option<String> = if let Some(clause) = where_clause {
100 let generator = PostgresWhereGenerator::new(PostgresDialect);
101 let (sql, params) = generator.generate_with_param_offset(clause, cursor_param_count)?;
102 user_where_json_params = params;
103 Some(sql)
104 } else {
105 None
106 };
107 let user_param_count = user_where_json_params.len();
108
109 let order_sql = if let Some(clauses) = order_by {
114 let mut parts: Vec<String> = clauses
115 .iter()
116 .map(|c| {
117 let dir = match c.direction {
118 OrderDirection::Asc => "ASC",
119 OrderDirection::Desc => "DESC",
120 };
121 format!("data->>'{field}' {dir}", field = escape_jsonb_key(&c.field))
124 })
125 .collect();
126 let primary_dir = if forward { "ASC" } else { "DESC" };
127 parts.push(format!("{quoted_col} {primary_dir}"));
128 format!(" ORDER BY {}", parts.join(", "))
129 } else {
130 let dir = if forward { "ASC" } else { "DESC" };
131 format!(" ORDER BY {quoted_col} {dir}")
132 };
133
134 let cursor_part = cursor_where_part.as_deref().unwrap_or("");
138 let user_part =
139 page_user_where_sql.as_deref().map(|s| format!("({s})")).unwrap_or_default();
140 let page_where_sql = if cursor_part.is_empty() && user_part.is_empty() {
141 String::new()
142 } else if cursor_part.is_empty() {
143 format!(" WHERE {user_part}")
144 } else if user_part.is_empty() {
145 format!(" WHERE {cursor_part}")
146 } else {
147 format!(" WHERE {cursor_part} AND {user_part}")
148 };
149
150 let limit_idx = cursor_param_count + user_param_count + 1;
152
153 let page_sql = if forward {
158 format!("SELECT data FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ${limit_idx}")
159 } else {
160 let inner = format!(
161 "SELECT data, {quoted_col} AS _relay_cursor \
162 FROM {quoted_view}{page_where_sql}{order_sql} LIMIT ${limit_idx}"
163 );
164 format!("SELECT data FROM ({inner}) _relay_page ORDER BY _relay_cursor ASC")
165 };
166
167 let mut page_typed_params: Vec<QueryParam> = Vec::new();
169 if let Some(cp) = cursor_param {
170 page_typed_params.push(cp);
171 }
172 for v in &user_where_json_params {
173 page_typed_params.push(QueryParam::from(v.clone()));
174 }
175 page_typed_params.push(QueryParam::BigInt(i64::from(limit)));
176
177 let mut client = self.acquire_connection_with_retry().await?;
178
179 let (rows, total_count) = if !session_vars.is_empty() {
180 let txn = client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
181 message: format!("Failed to start transaction: {e}"),
182 sql_state: e.code().map(|c| c.code().to_string()),
183 })?;
184
185 for (name, value) in session_vars {
187 txn.execute("SELECT set_config($1, $2, true)", &[name, value]).await.map_err(|e| FraiseQLError::Database {
188 message: format!("Failed to set session variable {name}: {e}"),
189 sql_state: e.code().map(|c| c.code().to_string()),
190 })?;
191 }
192
193 let page_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = page_typed_params
195 .iter()
196 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
197 .collect();
198
199 let page_rows = txn.query(&page_sql, &page_param_refs).await.map_err(|e| FraiseQLError::Database {
200 message: format!("Relay page query failed: {e}"),
201 sql_state: e.code().map(|c| c.code().to_string()),
202 })?;
203
204 let rows: Vec<crate::types::JsonbValue> = page_rows
205 .iter()
206 .map(|row| {
207 let data: serde_json::Value = row.get("data");
208 crate::types::JsonbValue::new(data)
209 })
210 .collect();
211
212 let total_count = if include_total_count {
218 let (count_sql, count_typed_params) = if let Some(clause) = where_clause {
219 let generator = PostgresWhereGenerator::new(PostgresDialect);
220 let (where_sql, params) = generator.generate_with_param_offset(clause, 0)?;
221 let sql = format!("SELECT COUNT(*) FROM {quoted_view} WHERE ({where_sql})");
222 let typed: Vec<QueryParam> = params.into_iter().map(QueryParam::from).collect();
223 (sql, typed)
224 } else {
225 (format!("SELECT COUNT(*) FROM {quoted_view}"), Vec::<QueryParam>::new())
226 };
227
228 let count_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
229 count_typed_params
230 .iter()
231 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
232 .collect();
233
234 let count_row = txn.query_one(&count_sql, &count_param_refs).await.map_err(|e| FraiseQLError::Database {
235 message: format!("Relay count query failed: {e}"),
236 sql_state: e.code().map(|c| c.code().to_string()),
237 })?;
238
239 let total: i64 = count_row.get(0);
240 Some(total.cast_unsigned())
244 } else {
245 None
246 };
247
248 txn.commit().await.map_err(|e| FraiseQLError::Database {
249 message: format!("Failed to commit relay transaction: {e}"),
250 sql_state: e.code().map(|c| c.code().to_string()),
251 })?;
252
253 (rows, total_count)
254 } else {
255 let page_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = page_typed_params
257 .iter()
258 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
259 .collect();
260
261 let page_rows = client.query(&page_sql, &page_param_refs).await.map_err(|e| FraiseQLError::Database {
262 message: format!("Relay page query failed: {e}"),
263 sql_state: e.code().map(|c| c.code().to_string()),
264 })?;
265
266 let rows: Vec<crate::types::JsonbValue> = page_rows
267 .iter()
268 .map(|row| {
269 let data: serde_json::Value = row.get("data");
270 crate::types::JsonbValue::new(data)
271 })
272 .collect();
273
274 let total_count = if include_total_count {
280 let (count_sql, count_typed_params) = if let Some(clause) = where_clause {
281 let generator = PostgresWhereGenerator::new(PostgresDialect);
282 let (where_sql, params) = generator.generate_with_param_offset(clause, 0)?;
283 let sql = format!("SELECT COUNT(*) FROM {quoted_view} WHERE ({where_sql})");
284 let typed: Vec<QueryParam> = params.into_iter().map(QueryParam::from).collect();
285 (sql, typed)
286 } else {
287 (format!("SELECT COUNT(*) FROM {quoted_view}"), Vec::<QueryParam>::new())
288 };
289
290 let count_param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
291 count_typed_params
292 .iter()
293 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
294 .collect();
295
296 let count_row = client.query_one(&count_sql, &count_param_refs).await.map_err(|e| FraiseQLError::Database {
297 message: format!("Relay count query failed: {e}"),
298 sql_state: e.code().map(|c| c.code().to_string()),
299 })?;
300
301 let total: i64 = count_row.get(0);
302 Some(total.cast_unsigned())
306 } else {
307 None
308 };
309
310 (rows, total_count)
311 };
312
313 Ok(RelayPageResult { rows, total_count })
314 }
315}