1use async_trait::async_trait;
4use bytes::BufMut as _;
5use fraiseql_error::{FraiseQLError, Result};
6use tokio_postgres::Row;
7
8use super::{PostgresAdapter, build_where_select_sql, build_where_select_sql_ordered};
9use crate::{
10 identifier::quote_postgres_identifier,
11 traits::{DatabaseAdapter, SupportsMutations},
12 types::{
13 DatabaseType, JsonbValue, PoolMetrics, QueryParam,
14 sql_hints::{OrderByClause, SqlProjectionHint},
15 },
16 where_clause::WhereClause,
17};
18
19#[allow(dead_code)]
21const PG_UNDEFINED_COLUMN: &str = "42703";
22
23#[derive(Debug)]
38enum FlexParam {
39 Null,
41 Text(String),
43}
44
45impl tokio_postgres::types::ToSql for FlexParam {
46 fn to_sql(
47 &self,
48 ty: &tokio_postgres::types::Type,
49 out: &mut bytes::BytesMut,
50 ) -> std::result::Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
51 {
52 use tokio_postgres::types::{IsNull, Type};
53 match self {
54 Self::Null => Ok(IsNull::Yes),
55 Self::Text(s) => {
56 if *ty == Type::JSONB {
57 out.put_u8(1);
59 out.extend_from_slice(s.as_bytes());
60 } else if *ty == Type::JSON {
61 out.extend_from_slice(s.as_bytes());
62 } else if *ty == Type::UUID {
63 let uuid = uuid::Uuid::parse_str(s)?;
64 out.extend_from_slice(uuid.as_bytes());
65 } else if *ty == Type::INT4 {
66 let n: i32 = s.parse()?;
67 out.put_i32(n);
68 } else if *ty == Type::INT8 {
69 let n: i64 = s.parse()?;
70 out.put_i64(n);
71 } else if *ty == Type::BOOL {
72 let b: bool = s.parse()?;
73 out.put_u8(u8::from(b));
74 } else {
75 out.extend_from_slice(s.as_bytes());
78 }
79 Ok(IsNull::No)
80 },
81 }
82 }
83
84 fn accepts(_ty: &tokio_postgres::types::Type) -> bool {
85 true
87 }
88
89 fn to_sql_checked(
90 &self,
91 ty: &tokio_postgres::types::Type,
92 out: &mut bytes::BytesMut,
93 ) -> std::result::Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
94 {
95 self.to_sql(ty, out)
98 }
99}
100
101#[allow(dead_code)]
109fn enrich_undefined_column_error(
110 err: FraiseQLError,
111 view: &str,
112 where_clause: Option<&WhereClause>,
113) -> FraiseQLError {
114 let FraiseQLError::Database { ref sql_state, .. } = err else {
115 return err;
116 };
117 if sql_state.as_deref() != Some(PG_UNDEFINED_COLUMN) {
118 return err;
119 }
120 let native_cols: Vec<&str> =
121 where_clause.map(|wc| wc.native_column_names()).unwrap_or_default();
122 if native_cols.is_empty() {
123 return err;
124 }
125 FraiseQLError::Database {
126 message: format!(
127 "Column(s) {:?} referenced as native column(s) on `{view}` do not exist. \
128 These columns were auto-inferred from ID/UUID-typed query arguments. \
129 Either add the column(s) to the table/view, or set \
130 `native_columns = {{}}` explicitly in your schema to disable inference.",
131 native_cols,
132 ),
133 sql_state: Some(PG_UNDEFINED_COLUMN.to_string()),
134 }
135}
136
137struct EnumText(String);
144
145impl<'a> tokio_postgres::types::FromSql<'a> for EnumText {
146 fn from_sql(
147 _ty: &tokio_postgres::types::Type,
148 raw: &'a [u8],
149 ) -> std::result::Result<EnumText, Box<dyn std::error::Error + Sync + Send>> {
150 std::str::from_utf8(raw)
151 .map(|s| EnumText(s.to_owned()))
152 .map_err(|e| Box::new(e) as Box<dyn std::error::Error + Sync + Send>)
153 }
154
155 fn accepts(ty: &tokio_postgres::types::Type) -> bool {
156 matches!(ty.kind(), tokio_postgres::types::Kind::Enum(_))
157 }
158}
159
160fn row_to_map(row: &Row) -> std::collections::HashMap<String, serde_json::Value> {
192 let mut map = std::collections::HashMap::new();
193 for (idx, column) in row.columns().iter().enumerate() {
194 let column_name = column.name().to_string();
195 let value: serde_json::Value = if let Ok(v) = row.try_get::<_, i32>(idx) {
196 serde_json::json!(v)
197 } else if let Ok(v) = row.try_get::<_, i64>(idx) {
198 serde_json::json!(v)
199 } else if let Ok(v) = row.try_get::<_, f64>(idx) {
200 serde_json::json!(v)
201 } else if let Ok(v) = row.try_get::<_, String>(idx) {
202 serde_json::json!(v)
203 } else if let Ok(v) = row.try_get::<_, bool>(idx) {
204 serde_json::json!(v)
205 } else if let Ok(v) = row.try_get::<_, serde_json::Value>(idx) {
206 v
207 } else if let Ok(v) = row.try_get::<_, Option<Vec<String>>>(idx) {
208 match v {
213 Some(arr) => serde_json::Value::Array(
214 arr.into_iter().map(serde_json::Value::String).collect(),
215 ),
216 None => serde_json::Value::Null,
217 }
218 } else if let Ok(EnumText(v)) = row.try_get::<_, EnumText>(idx) {
219 serde_json::json!(v)
224 } else {
225 serde_json::Value::Null
226 };
227 map.insert(column_name, value);
228 }
229 map
230}
231
232#[async_trait]
236impl DatabaseAdapter for PostgresAdapter {
237 async fn execute_with_projection(
238 &self,
239 view: &str,
240 projection: Option<&SqlProjectionHint>,
241 where_clause: Option<&WhereClause>,
242 limit: Option<u32>,
243 offset: Option<u32>,
244 order_by: Option<&[OrderByClause]>,
245 session_vars: &[(&str, &str)],
246 ) -> Result<Vec<JsonbValue>> {
247 self.execute_with_projection_impl(view, projection, where_clause, limit, offset, order_by, session_vars)
248 .await
249 }
250
251 async fn execute_where_query(
252 &self,
253 view: &str,
254 where_clause: Option<&WhereClause>,
255 limit: Option<u32>,
256 offset: Option<u32>,
257 order_by: Option<&[OrderByClause]>,
258 session_vars: &[(&str, &str)],
259 ) -> Result<Vec<JsonbValue>> {
260 let (sql, typed_params) =
261 build_where_select_sql_ordered(view, where_clause, limit, offset, order_by)?;
262
263 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
264 .iter()
265 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
266 .collect();
267
268 let mut client = self.acquire_connection_with_retry().await?;
269
270 if !session_vars.is_empty() {
271 let txn = client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
272 message: format!("Failed to start transaction: {e}"),
273 sql_state: e.code().map(|c| c.code().to_string()),
274 })?;
275
276 for (name, value) in session_vars {
278 txn.execute("SELECT set_config($1, $2, true)", &[name, value]).await.map_err(|e| FraiseQLError::Database {
279 message: format!("Failed to set session variable {name}: {e}"),
280 sql_state: e.code().map(|c| c.code().to_string()),
281 })?;
282 }
283
284 let rows = txn.query(&sql, ¶m_refs).await.map_err(|e| FraiseQLError::Database {
286 message: format!("Query execution failed: {e}"),
287 sql_state: e.code().map(|c| c.code().to_string()),
288 })?;
289 txn.commit().await.map_err(|e| FraiseQLError::Database {
290 message: format!("Failed to commit transaction: {e}"),
291 sql_state: e.code().map(|c| c.code().to_string()),
292 })?;
293
294 Ok(rows.iter().map(|row| {
295 let data: serde_json::Value = row.get(0);
296 JsonbValue::new(data)
297 }).collect())
298 } else {
299 let rows = client.query(&sql, ¶m_refs).await.map_err(|e| FraiseQLError::Database {
301 message: format!("Query execution failed: {e}"),
302 sql_state: e.code().map(|c| c.code().to_string()),
303 })?;
304 Ok(rows.iter().map(|row| {
305 let data: serde_json::Value = row.get(0);
306 JsonbValue::new(data)
307 }).collect())
308 }
309 }
310
311 async fn explain_where_query(
312 &self,
313 view: &str,
314 where_clause: Option<&WhereClause>,
315 limit: Option<u32>,
316 offset: Option<u32>,
317 ) -> Result<serde_json::Value> {
318 let (select_sql, typed_params) = build_where_select_sql(view, where_clause, limit, offset)?;
319 if select_sql.contains(';') {
322 return Err(FraiseQLError::Validation {
323 message: "EXPLAIN SQL must be a single statement".into(),
324 path: None,
325 });
326 }
327 let explain_sql = format!("EXPLAIN (ANALYZE, BUFFERS, FORMAT JSON) {select_sql}");
328
329 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = typed_params
330 .iter()
331 .map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync))
332 .collect();
333
334 let client = self.acquire_connection_with_retry().await?;
335 let rows = client.query(explain_sql.as_str(), ¶m_refs).await.map_err(|e| {
336 FraiseQLError::Database {
337 message: format!("EXPLAIN ANALYZE failed: {e}"),
338 sql_state: e.code().map(|c| c.code().to_string()),
339 }
340 })?;
341
342 if let Some(row) = rows.first() {
343 let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
344 message: format!("Failed to parse EXPLAIN output: {e}"),
345 sql_state: None,
346 })?;
347 Ok(plan)
348 } else {
349 Ok(serde_json::Value::Null)
350 }
351 }
352
353 fn database_type(&self) -> DatabaseType {
354 DatabaseType::PostgreSQL
355 }
356
357 async fn health_check(&self) -> Result<()> {
358 let client = self.acquire_connection_with_retry().await?;
360
361 client.query("SELECT 1", &[]).await.map_err(|e| FraiseQLError::Database {
362 message: format!("Health check failed: {e}"),
363 sql_state: e.code().map(|c| c.code().to_string()),
364 })?;
365
366 Ok(())
367 }
368
369 #[allow(clippy::cast_possible_truncation)] fn pool_metrics(&self) -> PoolMetrics {
371 let status = self.pool.status();
372
373 PoolMetrics {
374 total_connections: status.size as u32,
375 idle_connections: status.available as u32,
376 active_connections: (status.size - status.available) as u32,
377 waiting_requests: status.waiting as u32,
378 }
379 }
380
381 async fn execute_raw_query(
386 &self,
387 sql: &str,
388 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
389 let client = self.acquire_connection_with_retry().await?;
391
392 let rows: Vec<Row> = client.query(sql, &[]).await.map_err(|e| FraiseQLError::Database {
393 message: format!("Query execution failed: {e}"),
394 sql_state: e.code().map(|c| c.code().to_string()),
395 })?;
396
397 let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
399 rows.iter().map(row_to_map).collect();
400
401 Ok(results)
402 }
403
404 async fn execute_parameterized_aggregate(
405 &self,
406 sql: &str,
407 params: &[serde_json::Value],
408 session_vars: &[(&str, &str)],
409 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
410 let typed: Vec<QueryParam> = params.iter().cloned().map(QueryParam::from).collect();
414 let param_refs: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
415 typed.iter().map(|p| p as &(dyn tokio_postgres::types::ToSql + Sync)).collect();
416
417 let mut client = self.acquire_connection_with_retry().await?;
418
419 if !session_vars.is_empty() {
420 let txn = client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
421 message: format!("Failed to start transaction: {e}"),
422 sql_state: e.code().map(|c| c.code().to_string()),
423 })?;
424
425 for (name, value) in session_vars {
427 txn.execute("SELECT set_config($1, $2, true)", &[name, value]).await.map_err(|e| FraiseQLError::Database {
428 message: format!("Failed to set session variable {name}: {e}"),
429 sql_state: e.code().map(|c| c.code().to_string()),
430 })?;
431 }
432
433 let rows: Vec<Row> = txn.query(sql, ¶m_refs).await.map_err(|e| FraiseQLError::Database {
435 message: format!("Parameterized aggregate query failed: {e}"),
436 sql_state: e.code().map(|c| c.code().to_string()),
437 })?;
438 txn.commit().await.map_err(|e| FraiseQLError::Database {
439 message: format!("Failed to commit transaction: {e}"),
440 sql_state: e.code().map(|c| c.code().to_string()),
441 })?;
442
443 let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
444 rows.iter().map(row_to_map).collect();
445
446 Ok(results)
447 } else {
448 let rows: Vec<Row> = client.query(sql, ¶m_refs).await.map_err(|e| FraiseQLError::Database {
450 message: format!("Parameterized aggregate query failed: {e}"),
451 sql_state: e.code().map(|c| c.code().to_string()),
452 })?;
453 let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
454 rows.iter().map(row_to_map).collect();
455
456 Ok(results)
457 }
458 }
459
460 async fn execute_function_call(
461 &self,
462 function_name: &str,
463 args: &[serde_json::Value],
464 session_vars: &[(&str, &str)],
465 ) -> Result<Vec<std::collections::HashMap<String, serde_json::Value>>> {
466 let quoted_fn = quote_postgres_identifier(function_name);
472 let placeholders: Vec<String> = (1..=args.len()).map(|i| format!("${i}")).collect();
473 let sql = format!("SELECT * FROM {quoted_fn}({})", placeholders.join(", "));
474
475 let mut client = self.acquire_connection_with_retry().await?;
476
477 let flex_args: Vec<FlexParam> = args
485 .iter()
486 .map(|v| match v {
487 serde_json::Value::Null => FlexParam::Null,
488 serde_json::Value::String(s) => FlexParam::Text(s.clone()),
489 _ => FlexParam::Text(v.to_string()),
490 })
491 .collect();
492 let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = flex_args
493 .iter()
494 .map(|v| v as &(dyn tokio_postgres::types::ToSql + Sync))
495 .collect();
496
497 if !session_vars.is_empty() || self.mutation_timing_enabled {
503 let txn =
504 client.build_transaction().start().await.map_err(|e| FraiseQLError::Database {
505 message: format!("Failed to start mutation transaction: {e}"),
506 sql_state: e.code().map(|c| c.code().to_string()),
507 })?;
508
509 for (name, value) in session_vars {
511 txn.execute("SELECT set_config($1, $2, true)", &[name, value])
512 .await
513 .map_err(|e| FraiseQLError::Database {
514 message: format!("Failed to set session variable {name}: {e}"),
515 sql_state: e.code().map(|c| c.code().to_string()),
516 })?;
517 }
518
519 if self.mutation_timing_enabled {
521 txn.execute(
522 "SELECT set_config($1, clock_timestamp()::text, true)",
523 &[&self.timing_variable_name],
524 )
525 .await
526 .map_err(|e| FraiseQLError::Database {
527 message: format!("Failed to set mutation timing variable: {e}"),
528 sql_state: e.code().map(|c| c.code().to_string()),
529 })?;
530 }
531
532 let rows: Vec<Row> = txn.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
534 let detail = e.as_db_error().map_or("", |d| d.message());
535 FraiseQLError::Database {
536 message: format!("Function call {function_name} failed: {e}: {detail}"),
537 sql_state: e.code().map(|c| c.code().to_string()),
538 }
539 })?;
540
541 txn.commit().await.map_err(|e| FraiseQLError::Database {
542 message: format!("Failed to commit mutation transaction: {e}"),
543 sql_state: e.code().map(|c| c.code().to_string()),
544 })?;
545
546 let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
547 rows.iter().map(row_to_map).collect();
548
549 Ok(results)
550 } else {
551 let rows: Vec<Row> =
553 client.query(sql.as_str(), params.as_slice()).await.map_err(|e| {
554 let detail = e.as_db_error().map_or("", |d| d.message());
555 FraiseQLError::Database {
556 message: format!("Function call {function_name} failed: {e}: {detail}"),
557 sql_state: e.code().map(|c| c.code().to_string()),
558 }
559 })?;
560
561 let results: Vec<std::collections::HashMap<String, serde_json::Value>> =
562 rows.iter().map(row_to_map).collect();
563
564 Ok(results)
565 }
566 }
567
568 async fn set_session_variables(&self, variables: &[(&str, &str)]) -> Result<()> {
569 if variables.is_empty() {
570 return Ok(());
571 }
572 let client = self.acquire_connection_with_retry().await?;
573 for (name, value) in variables {
574 client
575 .execute("SELECT set_config($1, $2, true)", &[name, value])
576 .await
577 .map_err(|e| FraiseQLError::Database {
578 message: format!("set_config({name:?}) failed: {e}"),
579 sql_state: e.code().map(|c| c.code().to_string()),
580 })?;
581 }
582 Ok(())
583 }
584
585 async fn explain_query(
586 &self,
587 sql: &str,
588 _params: &[serde_json::Value],
589 ) -> Result<serde_json::Value> {
590 if sql.contains(';') {
594 return Err(FraiseQLError::Validation {
595 message: "EXPLAIN SQL must be a single statement".into(),
596 path: None,
597 });
598 }
599 let explain_sql = format!("EXPLAIN (ANALYZE false, FORMAT JSON) {sql}");
600 let client = self.acquire_connection_with_retry().await?;
601 let rows: Vec<Row> =
602 client
603 .query(explain_sql.as_str(), &[])
604 .await
605 .map_err(|e| FraiseQLError::Database {
606 message: format!("EXPLAIN failed: {e}"),
607 sql_state: e.code().map(|c| c.code().to_string()),
608 })?;
609
610 if let Some(row) = rows.first() {
611 let plan: serde_json::Value = row.try_get(0).map_err(|e| FraiseQLError::Database {
612 message: format!("Failed to parse EXPLAIN output: {e}"),
613 sql_state: None,
614 })?;
615 Ok(plan)
616 } else {
617 Ok(serde_json::Value::Null)
618 }
619 }
620}
621
622impl SupportsMutations for PostgresAdapter {}