Skip to main content

sql_middleware/postgres/
query.rs

1use crate::adapters::params::convert_params;
2use crate::adapters::result_set::{column_count, init_result_set};
3use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
4use crate::query_utils::extract_column_names;
5use crate::types::ConversionMode;
6use chrono::NaiveDateTime;
7use serde_json::Value;
8use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
9
10use super::params::Params as PgParams;
11
12/// Build a result set from a Postgres query execution
13///
14/// # Errors
15/// Returns errors from query execution or result processing.
16pub async fn build_result_set(
17    stmt: &Statement,
18    params: &[&(dyn ToSql + Sync)],
19    transaction: &Transaction<'_>,
20) -> Result<ResultSet, SqlMiddlewareDbError> {
21    // Execute the query
22    let rows = transaction.query(stmt, params).await?;
23
24    let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
25
26    // Preallocate capacity if we can estimate the number of rows
27    let capacity = rows.len();
28    let mut result_set = init_result_set(column_names, capacity);
29
30    for row in rows {
31        let mut row_values = Vec::new();
32
33        let col_count = column_count(&result_set)?;
34
35        for i in 0..col_count {
36            let value = postgres_extract_value(&row, i)?;
37            row_values.push(value);
38        }
39
40        result_set.add_row_values(row_values);
41    }
42
43    Ok(result_set)
44}
45
46/// Extracts a `RowValues` from a `tokio_postgres` Row at the given index.
47///
48/// # Errors
49/// Returns `SqlMiddlewareDbError` if the column cannot be retrieved.
50pub fn postgres_extract_value(
51    row: &tokio_postgres::Row,
52    idx: usize,
53) -> Result<RowValues, SqlMiddlewareDbError> {
54    // Determine the type of the column and extract accordingly
55    let type_info = row.columns()[idx].type_();
56
57    // Match on the type based on PostgreSQL type OIDs or names
58    // For simplicity, we'll handle common types. You may need to expand this.
59    if type_info.name() == "int2" {
60        let val: Option<i16> = row.try_get(idx)?;
61        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
62    } else if type_info.name() == "int4" {
63        let val: Option<i32> = row.try_get(idx)?;
64        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
65    } else if type_info.name() == "int8" {
66        let val: Option<i64> = row.try_get(idx)?;
67        Ok(val.map_or(RowValues::Null, RowValues::Int))
68    } else if type_info.name() == "float4" || type_info.name() == "float8" {
69        let val: Option<f64> = row.try_get(idx)?;
70        Ok(val.map_or(RowValues::Null, RowValues::Float))
71    } else if type_info.name() == "bool" {
72        let val: Option<bool> = row.try_get(idx)?;
73        Ok(val.map_or(RowValues::Null, RowValues::Bool))
74    } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
75        let val: Option<NaiveDateTime> = row.try_get(idx)?;
76        Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
77    } else if type_info.name() == "json" || type_info.name() == "jsonb" {
78        let val: Option<Value> = row.try_get(idx)?;
79        Ok(val.map_or(RowValues::Null, RowValues::JSON))
80    } else if type_info.name() == "bytea" {
81        let val: Option<Vec<u8>> = row.try_get(idx)?;
82        Ok(val.map_or(RowValues::Null, RowValues::Blob))
83    } else if type_info.name() == "text"
84        || type_info.name() == "varchar"
85        || type_info.name() == "char"
86    {
87        let val: Option<String> = row.try_get(idx)?;
88        Ok(val.map_or(RowValues::Null, RowValues::Text))
89    } else {
90        // For other types, attempt to get as string
91        let val: Option<String> = row.try_get(idx)?;
92        Ok(val.map_or(RowValues::Null, RowValues::Text))
93    }
94}
95
96/// Build a result set from raw Postgres rows (without a Transaction)
97///
98/// # Errors
99/// Returns errors from result processing.
100pub(crate) fn build_result_set_from_rows(
101    rows: &[tokio_postgres::Row],
102) -> Result<ResultSet, SqlMiddlewareDbError> {
103    let mut result_set = ResultSet::with_capacity(rows.len());
104    if let Some(row) = rows.first() {
105        let cols = extract_column_names(row.columns().iter(), |col| col.name());
106        result_set.set_column_names(std::sync::Arc::new(cols));
107    }
108
109    for row in rows {
110        let col_count = row.columns().len();
111        let mut row_values = Vec::with_capacity(col_count);
112        for idx in 0..col_count {
113            row_values.push(postgres_extract_value(row, idx)?);
114        }
115        result_set.add_row_values(row_values);
116    }
117
118    Ok(result_set)
119}
120
121/// Build a result set using statement metadata for column names.
122///
123/// # Errors
124/// Returns errors from row value extraction.
125pub(crate) fn build_result_set_from_statement(
126    stmt: &Statement,
127    rows: &[tokio_postgres::Row],
128) -> Result<ResultSet, SqlMiddlewareDbError> {
129    let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
130    let column_count = column_names.len();
131
132    let mut result_set = init_result_set(column_names, rows.len());
133
134    for row in rows {
135        let mut row_values = Vec::with_capacity(column_count);
136        for idx in 0..column_count {
137            row_values.push(postgres_extract_value(row, idx)?);
138        }
139        result_set.add_row_values(row_values);
140    }
141
142    Ok(result_set)
143}
144
145/// Execute a SELECT query on a client without managing transactions
146///
147/// # Errors
148/// Returns errors from parameter conversion or query execution.
149pub async fn execute_query_on_client(
150    client: &Client,
151    query: &str,
152    params: &[RowValues],
153) -> Result<ResultSet, SqlMiddlewareDbError> {
154    let rows =
155        query_rows_on_client(client, query, None, params, "postgres select error").await?;
156    build_result_set_from_rows(&rows)
157}
158
159/// Execute a prepared SELECT query on a client without managing transactions.
160///
161/// # Errors
162/// Returns errors from parameter conversion, preparation, or query execution.
163pub(crate) async fn execute_query_prepared_on_client(
164    client: &Client,
165    query: &str,
166    params: &[RowValues],
167) -> Result<ResultSet, SqlMiddlewareDbError> {
168    let stmt = client.prepare(query).await.map_err(|e| {
169        SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
170    })?;
171    let rows =
172        query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
173    build_result_set_from_statement(&stmt, &rows)
174}
175
176/// Execute a DML query on a client without managing transactions
177///
178/// # Errors
179/// Returns errors from parameter conversion or query execution.
180pub async fn execute_dml_on_client(
181    client: &Client,
182    query: &str,
183    params: &[RowValues],
184    err_label: &str,
185) -> Result<usize, SqlMiddlewareDbError> {
186    let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
187    convert_affected_rows(rows, "postgres affected rows conversion error")
188}
189
190/// Execute a prepared DML query on a client without managing transactions.
191///
192/// # Errors
193/// Returns errors from parameter conversion, preparation, or query execution.
194pub(crate) async fn execute_dml_prepared_on_client(
195    client: &Client,
196    query: &str,
197    params: &[RowValues],
198) -> Result<usize, SqlMiddlewareDbError> {
199    let stmt = client.prepare(query).await.map_err(|e| {
200        SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
201    })?;
202    let rows =
203        execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
204            .await?;
205    convert_affected_rows(rows, "postgres affected rows conversion error")
206}
207
208pub(crate) fn convert_affected_rows(
209    rows: u64,
210    label: &str,
211) -> Result<usize, SqlMiddlewareDbError> {
212    usize::try_from(rows).map_err(|e| {
213        SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}"))
214    })
215}
216
217async fn query_rows_on_client(
218    client: &Client,
219    query: &str,
220    stmt: Option<&Statement>,
221    params: &[RowValues],
222    err_label: &str,
223) -> Result<Vec<tokio_postgres::Row>, SqlMiddlewareDbError>
224{
225    let converted = convert_params::<PgParams>(params, ConversionMode::Query)?;
226    let refs = converted.as_refs();
227    let rows = match stmt {
228        Some(stmt) => client.query(stmt, refs).await,
229        None => client.query(query, refs).await,
230    };
231    rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
232}
233
234async fn execute_rows_on_client(
235    client: &Client,
236    query: &str,
237    stmt: Option<&Statement>,
238    params: &[RowValues],
239    err_label: &str,
240) -> Result<u64, SqlMiddlewareDbError>
241{
242    let converted = convert_params::<PgParams>(params, ConversionMode::Execute)?;
243    let refs = converted.as_refs();
244    let rows = match stmt {
245        Some(stmt) => client.execute(stmt, refs).await,
246        None => client.execute(query, refs).await,
247    };
248    rows.map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{err_label}: {e}")))
249}