1use super::client_exec::{execute_rows_on_client, query_rows_on_client};
2use crate::adapters::result_set::{column_count, init_result_set};
3use crate::middleware::{ResultSet, RowValues, SqlMiddlewareDbError};
4use crate::query_utils::extract_column_names;
5use chrono::NaiveDateTime;
6use serde_json::Value;
7use tokio_postgres::{Client, Statement, Transaction, types::ToSql};
8
9pub async fn build_result_set(
14 stmt: &Statement,
15 params: &[&(dyn ToSql + Sync)],
16 transaction: &Transaction<'_>,
17) -> Result<ResultSet, SqlMiddlewareDbError> {
18 let rows = transaction.query(stmt, params).await?;
20
21 let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
22
23 let capacity = rows.len();
25 let mut result_set = init_result_set(column_names, capacity);
26
27 for row in rows {
28 let mut row_values = Vec::new();
29
30 let col_count = column_count(&result_set)?;
31
32 for i in 0..col_count {
33 let value = postgres_extract_value(&row, i)?;
34 row_values.push(value);
35 }
36
37 result_set.add_row_values(row_values);
38 }
39
40 Ok(result_set)
41}
42
43pub fn postgres_extract_value(
48 row: &tokio_postgres::Row,
49 idx: usize,
50) -> Result<RowValues, SqlMiddlewareDbError> {
51 let type_info = row.columns()[idx].type_();
53
54 if type_info.name() == "int2" {
57 let val: Option<i16> = row.try_get(idx)?;
58 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
59 } else if type_info.name() == "int4" {
60 let val: Option<i32> = row.try_get(idx)?;
61 Ok(val.map_or(RowValues::Null, |v| RowValues::Int(i64::from(v))))
62 } else if type_info.name() == "int8" {
63 let val: Option<i64> = row.try_get(idx)?;
64 Ok(val.map_or(RowValues::Null, RowValues::Int))
65 } else if type_info.name() == "float4" || type_info.name() == "float8" {
66 let val: Option<f64> = row.try_get(idx)?;
67 Ok(val.map_or(RowValues::Null, RowValues::Float))
68 } else if type_info.name() == "bool" {
69 let val: Option<bool> = row.try_get(idx)?;
70 Ok(val.map_or(RowValues::Null, RowValues::Bool))
71 } else if type_info.name() == "timestamp" || type_info.name() == "timestamptz" {
72 let val: Option<NaiveDateTime> = row.try_get(idx)?;
73 Ok(val.map_or(RowValues::Null, RowValues::Timestamp))
74 } else if type_info.name() == "json" || type_info.name() == "jsonb" {
75 let val: Option<Value> = row.try_get(idx)?;
76 Ok(val.map_or(RowValues::Null, RowValues::JSON))
77 } else if type_info.name() == "bytea" {
78 let val: Option<Vec<u8>> = row.try_get(idx)?;
79 Ok(val.map_or(RowValues::Null, RowValues::Blob))
80 } else if type_info.name() == "text"
81 || type_info.name() == "varchar"
82 || type_info.name() == "char"
83 {
84 let val: Option<String> = row.try_get(idx)?;
85 Ok(val.map_or(RowValues::Null, RowValues::Text))
86 } else {
87 let val: Option<String> = row.try_get(idx)?;
89 Ok(val.map_or(RowValues::Null, RowValues::Text))
90 }
91}
92
93pub(crate) fn build_result_set_from_rows(
98 rows: &[tokio_postgres::Row],
99) -> Result<ResultSet, SqlMiddlewareDbError> {
100 let mut result_set = ResultSet::with_capacity(rows.len());
101 if let Some(row) = rows.first() {
102 let cols = extract_column_names(row.columns().iter(), |col| col.name());
103 result_set.set_column_names(std::sync::Arc::new(cols));
104 }
105
106 for row in rows {
107 let col_count = row.columns().len();
108 let mut row_values = Vec::with_capacity(col_count);
109 for idx in 0..col_count {
110 row_values.push(postgres_extract_value(row, idx)?);
111 }
112 result_set.add_row_values(row_values);
113 }
114
115 Ok(result_set)
116}
117
118pub(crate) fn build_result_set_from_statement(
123 stmt: &Statement,
124 rows: &[tokio_postgres::Row],
125) -> Result<ResultSet, SqlMiddlewareDbError> {
126 let column_names = extract_column_names(stmt.columns().iter(), |col| col.name());
127 let column_count = column_names.len();
128
129 let mut result_set = init_result_set(column_names, rows.len());
130
131 for row in rows {
132 let mut row_values = Vec::with_capacity(column_count);
133 for idx in 0..column_count {
134 row_values.push(postgres_extract_value(row, idx)?);
135 }
136 result_set.add_row_values(row_values);
137 }
138
139 Ok(result_set)
140}
141
142pub async fn execute_query_on_client(
147 client: &Client,
148 query: &str,
149 params: &[RowValues],
150) -> Result<ResultSet, SqlMiddlewareDbError> {
151 let rows = query_rows_on_client(client, query, None, params, "postgres select error").await?;
152 build_result_set_from_rows(&rows)
153}
154
155pub(crate) async fn execute_query_prepared_on_client(
160 client: &Client,
161 query: &str,
162 params: &[RowValues],
163) -> Result<ResultSet, SqlMiddlewareDbError> {
164 let stmt = client.prepare(query).await.map_err(|e| {
165 SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
166 })?;
167 let rows =
168 query_rows_on_client(client, query, Some(&stmt), params, "postgres select error").await?;
169 build_result_set_from_statement(&stmt, &rows)
170}
171
172pub(crate) async fn execute_query_prepared_statement_on_client(
177 client: &Client,
178 stmt: &Statement,
179 params: &[RowValues],
180) -> Result<ResultSet, SqlMiddlewareDbError> {
181 let rows =
182 query_rows_on_client(client, "", Some(stmt), params, "postgres select error").await?;
183 build_result_set_from_statement(stmt, &rows)
184}
185
186pub async fn execute_dml_on_client(
191 client: &Client,
192 query: &str,
193 params: &[RowValues],
194 err_label: &str,
195) -> Result<usize, SqlMiddlewareDbError> {
196 let rows = execute_rows_on_client(client, query, None, params, err_label).await?;
197 convert_affected_rows(rows, "postgres affected rows conversion error")
198}
199
200pub(crate) async fn execute_dml_prepared_on_client(
205 client: &Client,
206 query: &str,
207 params: &[RowValues],
208) -> Result<usize, SqlMiddlewareDbError> {
209 let stmt = client.prepare(query).await.map_err(|e| {
210 SqlMiddlewareDbError::ExecutionError(format!("postgres prepare error: {e}"))
211 })?;
212 let rows = execute_rows_on_client(client, query, Some(&stmt), params, "postgres execute error")
213 .await?;
214 convert_affected_rows(rows, "postgres affected rows conversion error")
215}
216
217pub(crate) async fn execute_dml_prepared_statement_on_client(
222 client: &Client,
223 stmt: &Statement,
224 params: &[RowValues],
225) -> Result<usize, SqlMiddlewareDbError> {
226 let rows =
227 execute_rows_on_client(client, "", Some(stmt), params, "postgres execute error").await?;
228 convert_affected_rows(rows, "postgres affected rows conversion error")
229}
230
231pub(crate) fn convert_affected_rows(rows: u64, label: &str) -> Result<usize, SqlMiddlewareDbError> {
232 usize::try_from(rows).map_err(|e| SqlMiddlewareDbError::ExecutionError(format!("{label}: {e}")))
233}