Skip to main content

sql_middleware/postgres/
query.rs

1use super::client_exec::{execute_rows_on_client, query_rows_on_client};
2use crate::adapters::result_set::{column_count, init_result_set};
3use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
4use crate::query_utils::extract_column_names;
5use chrono::NaiveDateTime;
6use serde_json::Value;
7use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
8
9/// Build a result set from a Postgres query execution
10///
11/// # Errors
12/// Returns errors from query execution or result processing.
13pub async fn build_result_set(
14    stmt: &Statement,
15    params: &[&(dyn ToSql + Sync)],
16    transaction: &Transaction<'_>,
17) -> Result<ResultSet, SqlMiddlewareDbError> {
18    // Execute the query
19    let rows = transaction.query(stmt, params).await?;
20
21    let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
22
23    // Preallocate capacity if we can estimate the number of rows
24    let capacity = rows.len();
25    let mut result_set = init_result_set(column_names, capacity);
26
27    for row in rows {
28        let mut row_values = Vec::new();
29
30        let col_count = column_count(&result_set)?;
31
32        for i in 0..col_count {
33            let value = postgres_extract_value(&row, i)?;
34            row_values.push(value);
35        }
36
37        result_set.add_row_values(row_values);
38    }
39
40    Ok(result_set)
41}
42
43/// Extracts a `RowValues` from a `tokio_postgres` Row at the given index.
44///
45/// # Errors
46/// Returns `SqlMiddlewareDbError` if the column cannot be retrieved.
47pub fn postgres_extract_value(
48    row: &tokio_postgres::Row,
49    idx: usize,
50) -> Result<RowValues, SqlMiddlewareDbError> {
51    // Determine the type of the column and extract accordingly
52    let type_info = row.columns()[idx].type_();
53
54    // Match on the type based on PostgreSQL type OIDs or names
55    // For simplicity, we'll handle common types. You may need to expand this.
56    if type_info.name() == "int2" {
57        let val: Option<i16> = row.try_get(idx)?;
58        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
59    } else if type_info.name() == "int4" {
60        let val: Option<i32> = row.try_get(idx)?;
61        Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
62    } else if type_info.name() == "int8" {
63        let val: Option<i64> = row.try_get(idx)?;
64        Ok(val.map_or(RowValues::Null, RowValues::Int))
65    } else if type_info.name() == "float4" || type_info.name() == "float8" {
66        let val: Option<f64> = row.try_get(idx)?;
67        Ok(val.map_or(RowValues::Null, RowValues::Float))
68    } else if type_info.name() == "bool" {
69        let val: Option<bool> = row.try_get(idx)?;
70        Ok(val.map_or(RowValues::Null, RowValues::Bool))
71    } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
72        let val: Option<NaiveDateTime> = row.try_get(idx)?;
73        Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
74    } else if type_info.name() == "json" || type_info.name() == "jsonb" {
75        let val: Option<Value> = row.try_get(idx)?;
76        Ok(val.map_or(RowValues::Null, RowValues::JSON))
77    } else if type_info.name() == "bytea" {
78        let val: Option<Vec<u8>> = row.try_get(idx)?;
79        Ok(val.map_or(RowValues::Null, RowValues::Blob))
80    } else if type_info.name() == "text"
81        || type_info.name() == "varchar"
82        || type_info.name() == "char"
83    {
84        let val: Option<String> = row.try_get(idx)?;
85        Ok(val.map_or(RowValues::Null, RowValues::Text))
86    } else {
87        // For other types, attempt to get as string
88        let val: Option<String> = row.try_get(idx)?;
89        Ok(val.map_or(RowValues::Null, RowValues::Text))
90    }
91}
92
93/// Build a result set from raw Postgres rows (without a Transaction)
94///
95/// # Errors
96/// Returns errors from result processing.
97pub(crate) fn build_result_set_from_rows(
98    rows: &[tokio_postgres::Row],
99) -> Result<ResultSet, SqlMiddlewareDbError> {
100    let mut result_set = ResultSet::with_capacity(rows.len());
101    if let Some(row) = rows.first() {
102        let cols = extract_column_names(row.columns().iter(), |col| col.name());
103        result_set.set_column_names(std::sync::Arc::new(cols));
104    }
105
106    for row in rows {
107        let col_count = row.columns().len();
108        let mut row_values = Vec::with_capacity(col_count);
109        for idx in 0..col_count {
110            row_values.push(postgres_extract_value(row, idx)?);
111        }
112        result_set.add_row_values(row_values);
113    }
114
115    Ok(result_set)
116}
117
118/// Build a result set using statement metadata for column names.
119///
120/// # Errors
121/// Returns errors from row value extraction.
122pub(crate) fn build_result_set_from_statement(
123    stmt: &Statement,
124    rows: &[tokio_postgres::Row],
125) -> Result<ResultSet, SqlMiddlewareDbError> {
126    let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
127    let column_count = column_names.len();
128
129    let mut result_set = init_result_set(column_names, rows.len());
130
131    for row in rows {
132        let mut row_values = Vec::with_capacity(column_count);
133        for idx in 0..column_count {
134            row_values.push(postgres_extract_value(row, idx)?);
135        }
136        result_set.add_row_values(row_values);
137    }
138
139    Ok(result_set)
140}
141
142/// Execute a SELECT query on a client without managing transactions
143///
144/// # Errors
145/// Returns errors from parameter conversion or query execution.
146pub async fn execute_query_on_client(
147    client: &Client,
148    query: &str,
149    params: &[RowValues],
150) -> Result<ResultSet, SqlMiddlewareDbError> {
151    let rows = query_rows_on_client(client, query, None, params, "postgres select error").await?;
152    build_result_set_from_rows(&rows)
153}
154
155/// Execute a prepared SELECT query on a client without managing transactions.
156///
157/// # Errors
158/// Returns errors from parameter conversion, preparation, or query execution.
159pub(crate) async fn execute_query_prepared_on_client(
160    client: &Client,
161    query: &str,
162    params: &[RowValues],
163) -> Result<ResultSet, SqlMiddlewareDbError> {
164    let stmt = client.prepare(query).await.map_err(|e| {
165        SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
166    })?;
167    let rows =
168        query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
169    build_result_set_from_statement(&stmt, &rows)
170}
171
172/// Execute a prepared SELECT query with an already-prepared statement.
173///
174/// # Errors
175/// Returns errors from parameter conversion, query execution, or result building.
176pub(crate) async fn execute_query_prepared_statement_on_client(
177    client: &Client,
178    stmt: &Statement,
179    params: &[RowValues],
180) -> Result<ResultSet, SqlMiddlewareDbError> {
181    let rows =
182        query_rows_on_client(client, "", Some(stmt), params, "postgres select error").await?;
183    build_result_set_from_statement(stmt, &rows)
184}
185
186/// Execute a DML query on a client without managing transactions
187///
188/// # Errors
189/// Returns errors from parameter conversion or query execution.
190pub async fn execute_dml_on_client(
191    client: &Client,
192    query: &str,
193    params: &[RowValues],
194    err_label: &str,
195) -> Result<usize, SqlMiddlewareDbError> {
196    let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
197    convert_affected_rows(rows, "postgres affected rows conversion error")
198}
199
200/// Execute a prepared DML query on a client without managing transactions.
201///
202/// # Errors
203/// Returns errors from parameter conversion, preparation, or query execution.
204pub(crate) async fn execute_dml_prepared_on_client(
205    client: &Client,
206    query: &str,
207    params: &[RowValues],
208) -> Result<usize, SqlMiddlewareDbError> {
209    let stmt = client.prepare(query).await.map_err(|e| {
210        SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
211    })?;
212    let rows = execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
213        .await?;
214    convert_affected_rows(rows, "postgres affected rows conversion error")
215}
216
217/// Execute a prepared DML query with an already-prepared statement.
218///
219/// # Errors
220/// Returns errors from parameter conversion, query execution, or row-count conversion.
221pub(crate) async fn execute_dml_prepared_statement_on_client(
222    client: &Client,
223    stmt: &Statement,
224    params: &[RowValues],
225) -> Result<usize, SqlMiddlewareDbError> {
226    let rows =
227        execute_rows_on_client(client, "", Some(stmt), params, "postgres execute error").await?;
228    convert_affected_rows(rows, "postgres affected rows conversion error")
229}
230
231pub(crate) fn convert_affected_rows(rows: u64, label: &str) -> Result<usize, SqlMiddlewareDbError> {
232    usize::try_from(rows).map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}")))
233}